From 2df28d13b28e01ebc27156c91c78d4f82848c7f9 Mon Sep 17 00:00:00 2001 From: syui Date: Thu, 26 Mar 2026 07:06:37 +0900 Subject: [PATCH] feat(voice): add voice input/output foundation with rodio, cpal, and webrtc-vad --- Cargo.toml | 5 + examples/voice_test.rs | 20 ++++ src/agent.rs | 6 ++ src/ai.rs | 6 ++ src/lib.rs | 1 + src/main.rs | 2 +- src/tui.rs | 46 +++++++++ src/voice/mod.rs | 83 ++++++++++++++++ src/voice/stt.rs | 208 +++++++++++++++++++++++++++++++++++++++++ src/voice/tts.rs | 54 +++++++++++ 10 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 examples/voice_test.rs create mode 100644 src/voice/mod.rs create mode 100644 src/voice/stt.rs create mode 100644 src/voice/tts.rs diff --git a/Cargo.toml b/Cargo.toml index df7e4d6..620f597 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,3 +20,8 @@ libc = "0.2" notify = { version = "7", features = ["macos_fsevent"] } ratatui = "0.29" crossterm = "0.28" +reqwest = { version = "0.12", features = ["json", "blocking"] } +rodio = "0.19" +base64 = "0.22" +cpal = "0.15" +webrtc-vad = "0.4" diff --git a/examples/voice_test.rs b/examples/voice_test.rs new file mode 100644 index 0000000..df78c1d --- /dev/null +++ b/examples/voice_test.rs @@ -0,0 +1,20 @@ +fn main() { + eprintln!("Voice test: initializing..."); + let voice = match aishell::voice::VoiceSystem::new() { + Some(v) => v, + None => { + eprintln!("Voice system not available. Check ELEVENLABS_API_KEY in .env"); + return; + } + }; + + eprintln!("Voice test: listening (speak now)..."); + match voice.listen() { + Some(text) => { + eprintln!("Recognized: {text}"); + eprintln!("Speaking back..."); + voice.speak(&text); + } + None => eprintln!("No speech detected."), + } +} diff --git a/src/agent.rs b/src/agent.rs index f3611e2..02d34fd 100644 --- a/src/agent.rs +++ b/src/agent.rs @@ -179,7 +179,10 @@ impl Agent { pub fn stop(&mut self) { if self.is_running() && !self.stopped { self.stopped = true; + #[cfg(unix)] unsafe { libc::kill(self.pid as i32, libc::SIGTERM); } + #[cfg(windows)] + { let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.pid.to_string()]).output(); } self.status = AgentStatus::Error("stopped".to_string()); self.dirty = true; self.log("stopped", ""); @@ -291,7 +294,10 @@ impl Drop for Agent { fn drop(&mut self) { if !self.stopped { self.stopped = true; + #[cfg(unix)] unsafe { libc::kill(self.pid as i32, libc::SIGTERM); } + #[cfg(windows)] + { let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.pid.to_string()]).output(); } } } } diff --git a/src/ai.rs b/src/ai.rs index aed2455..cc5fa6e 100644 --- a/src/ai.rs +++ b/src/ai.rs @@ -44,7 +44,10 @@ impl ClaudeManager { pub fn cancel(&mut self) { self.status = StatusKind::Idle; + #[cfg(unix)] unsafe { libc::kill(self.child_pid as i32, libc::SIGINT); } + #[cfg(windows)] + { let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.child_pid.to_string()]).output(); } while self.output_rx.try_recv().is_ok() {} } @@ -59,6 +62,9 @@ impl ClaudeManager { impl Drop for ClaudeManager { fn drop(&mut self) { + #[cfg(unix)] unsafe { libc::kill(self.child_pid as i32, libc::SIGTERM); } + #[cfg(windows)] + { let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.child_pid.to_string()]).output(); } } } diff --git a/src/lib.rs b/src/lib.rs index 40da075..fdca225 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,3 +8,4 @@ pub mod agent; pub mod tui; pub mod headless; pub mod watch; +pub mod voice; diff --git a/src/main.rs b/src/main.rs index b1be52e..f3abbf7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -26,7 +26,7 @@ fn main() { println!("{}", env!("CARGO_PKG_VERSION")); } Some("help" | "--help" | "-h") => print_help(), - None | Some("tui") => { + None | Some("tui") | Some("--voice") => { // Show logo before entering alternate screen eprintln!("\x1b[38;5;226m{}\x1b[0m\n\x1b[1m aishell\x1b[0m v{}\n", LOGO, env!("CARGO_PKG_VERSION")); diff --git a/src/tui.rs b/src/tui.rs index 1acc74f..99210e8 100644 --- a/src/tui.rs +++ b/src/tui.rs @@ -51,6 +51,9 @@ pub struct App { cmd_cache: CommandCache, watch_rx: Option>>, + voice: Option>, + voice_input_tx: Option>, + voice_input_rx: Option>, mode: Mode, input: String, input_task: String, @@ -82,6 +85,13 @@ impl App { agent_scroll: 0, cmd_cache: CommandCache::new(), watch_rx: None, + voice: if std::env::args().any(|a| a == "--voice") { + crate::voice::VoiceSystem::new().map(std::sync::Arc::new) + } else { + None + }, + voice_input_tx: None, + voice_input_rx: None, mode: Mode::Ai, input: String::new(), input_task: String::new(), @@ -89,6 +99,13 @@ impl App { should_quit: false, }; + // Setup voice input channel + if app.voice.is_some() { + let (tx, rx) = std::sync::mpsc::channel(); + app.voice_input_tx = Some(tx); + app.voice_input_rx = Some(rx); + } + // Send protocol + identity + project context if let Some(ref mut claude) = app.claude { let cwd = std::env::current_dir() @@ -128,6 +145,17 @@ impl App { } fn poll_all(&mut self) { + // Check for voice input + if let Some(ref rx) = self.voice_input_rx { + if let Ok(text) = rx.try_recv() { + self.ai_output.push_str(&format!("\n---\n[voice] {text}\n")); + self.ai_scroll = u16::MAX; + if let Some(ref mut claude) = self.claude { + claude.send(&text); + } + } + } + let mut stream_ended = false; if let Some(ref mut claude) = self.claude { @@ -152,6 +180,24 @@ impl App { &format!("{STATE_DIR}/ai.txt"), self.ai_output.as_bytes(), ); + // Voice: speak response, then listen for next input + if let Some(ref voice) = self.voice { + let last = self.ai_output.rsplit("---").next().unwrap_or(&self.ai_output); + let text = last.trim().to_string(); + if !text.is_empty() { + let v = voice.clone(); + let tx = self.voice_input_tx.clone(); + std::thread::spawn(move || { + v.speak(&text); + // After speaking, listen for user's response + if let Some(heard) = v.listen() { + if let Some(tx) = tx { + let _ = tx.send(heard); + } + } + }); + } + } stream_ended = true; } } diff --git a/src/voice/mod.rs b/src/voice/mod.rs new file mode 100644 index 0000000..c3e0150 --- /dev/null +++ b/src/voice/mod.rs @@ -0,0 +1,83 @@ +pub mod tts; +pub mod stt; + +/// Load .env file from cwd, setting vars that aren't already set. +fn load_dotenv() { + for dir in &[".", env!("CARGO_MANIFEST_DIR")] { + let path = std::path::Path::new(dir).join(".env"); + if let Ok(content) = std::fs::read_to_string(&path) { + for line in content.lines() { + let line = line.trim(); + if line.is_empty() || line.starts_with('#') { continue; } + if let Some((key, val)) = line.split_once('=') { + let key = key.trim(); + let val = val.trim(); + if std::env::var(key).is_err() { + std::env::set_var(key, val); + } + } + } + break; + } + } +} + +pub struct VoiceConfig { + pub tts_api_key: String, + pub tts_voice_id: String, + pub tts_model: String, + pub stt_language: String, +} + +impl VoiceConfig { + pub fn load() -> Option { + // Load .env file if present (cwd or project root) + load_dotenv(); + + let tts_api_key = std::env::var("ELEVENLABS_API_KEY").ok()?; + + Some(Self { + tts_api_key, + tts_voice_id: std::env::var("ELEVENLABS_VOICE_ID").unwrap_or_default(), + tts_model: std::env::var("ELEVENLABS_MODEL_ID").unwrap_or_else(|_| "eleven_multilingual_v2".into()), + stt_language: std::env::var("STT_LANGUAGE").unwrap_or_else(|_| "ja-JP".into()), + }) + } + + pub fn is_available(&self) -> bool { + !self.tts_api_key.is_empty() + } +} + +pub struct VoiceSystem { + pub config: VoiceConfig, +} + +impl VoiceSystem { + pub fn new() -> Option { + let config = VoiceConfig::load()?; + if !config.is_available() { + return None; + } + Some(Self { config }) + } + + pub fn speak(&self, text: &str) { + if text.trim().is_empty() { return; } + let audio = match tts::synthesize(&self.config, text) { + Ok(data) => data, + Err(e) => { eprintln!("tts error: {e}"); return; } + }; + if let Err(e) = tts::play_audio(&audio) { + eprintln!("audio play error: {e}"); + } + } + + pub fn listen(&self) -> Option { + match stt::recognize(&self.config) { + Ok(text) if !text.is_empty() => Some(text), + Ok(_) => None, + Err(e) => { eprintln!("stt error: {e}"); None } + } + } +} diff --git a/src/voice/stt.rs b/src/voice/stt.rs new file mode 100644 index 0000000..262f6e9 --- /dev/null +++ b/src/voice/stt.rs @@ -0,0 +1,208 @@ +use crate::voice::VoiceConfig; +use cpal::traits::{DeviceTrait, HostTrait, StreamTrait}; +use std::sync::{Arc, Mutex, mpsc}; + +/// Record audio via VAD and recognize speech via Google Cloud STT. +pub fn recognize(config: &VoiceConfig) -> Result { + let audio = record_vad().map_err(|e| format!("recording: {e}"))?; + if audio.is_empty() { + return Ok(String::new()); + } + transcribe(config, &audio) +} + +/// Record until speech ends (VAD-based). +fn record_vad() -> Result, String> { + let host = cpal::default_host(); + let device = host.default_input_device() + .ok_or("No input device")?; + + let default_config = device.default_input_config() + .map_err(|e| format!("input config: {e}"))?; + let device_rate = default_config.sample_rate().0; + let device_channels = default_config.channels() as usize; + + let config = cpal::StreamConfig { + channels: default_config.channels(), + sample_rate: cpal::SampleRate(device_rate), + buffer_size: cpal::BufferSize::Default, + }; + + let (tx, rx) = mpsc::channel::>(); + let done = Arc::new(Mutex::new(false)); + let done_clone = done.clone(); + let gain: f32 = 8.0; + + let stream = device.build_input_stream( + &config, + move |data: &[f32], _: &cpal::InputCallbackInfo| { + if *done_clone.lock().unwrap() { return; } + let mono: Vec = data.chunks(device_channels) + .map(|ch| { + let avg = ch.iter().sum::() / device_channels as f32; + let amplified = (avg * gain).clamp(-1.0, 1.0); + (amplified * 32767.0) as i16 + }) + .collect(); + let _ = tx.send(mono); + }, + |err| eprintln!("audio error: {err}"), + None, + ).map_err(|e| format!("build stream: {e}"))?; + + stream.play().map_err(|e| format!("play: {e}"))?; + + let frame_ms: u32 = 30; + let device_frame_size = (device_rate * frame_ms / 1000) as usize; + let vad_frame_size: usize = 480; // 30ms @ 16kHz + let silence_frames: u32 = 500 / frame_ms; + let min_speech_frames: u32 = 200 / frame_ms; + let max_frames: u32 = 8000 / frame_ms; // 8 second max + + let mut vad = webrtc_vad::Vad::new_with_rate_and_mode( + webrtc_vad::SampleRate::Rate16kHz, + webrtc_vad::VadMode::Quality, + ); + + let mut frame_buf: Vec = Vec::with_capacity(device_frame_size); + let mut audio_buf: Vec = Vec::new(); + let mut recording = false; + let mut silence_count: u32 = 0; + let mut speech_count: u32 = 0; + let mut total_frames: u32 = 0; + + eprintln!(" listening..."); + + loop { + match rx.recv_timeout(std::time::Duration::from_millis(100)) { + Ok(samples) => { + for sample in samples { + frame_buf.push(sample); + if frame_buf.len() >= device_frame_size { + total_frames += 1; + let vad_frame = resample(&frame_buf, device_rate, 16000, vad_frame_size); + let is_speech = vad.is_voice_segment(&vad_frame).unwrap_or(false); + + if is_speech { + recording = true; + silence_count = 0; + speech_count += 1; + audio_buf.extend_from_slice(&frame_buf); + } else if recording { + silence_count += 1; + audio_buf.extend_from_slice(&frame_buf); + if silence_count >= silence_frames && speech_count >= min_speech_frames { + *done.lock().unwrap() = true; + break; + } + } + + if total_frames >= max_frames { + *done.lock().unwrap() = true; + break; + } + frame_buf.clear(); + } + } + } + Err(mpsc::RecvTimeoutError::Timeout) => {} + Err(mpsc::RecvTimeoutError::Disconnected) => break, + } + if *done.lock().unwrap() { break; } + } + + drop(stream); + + if audio_buf.is_empty() { + return Ok(Vec::new()); + } + + let output_len = (audio_buf.len() as f32 * 16000.0 / device_rate as f32) as usize; + Ok(resample(&audio_buf, device_rate, 16000, output_len)) +} + +/// Send audio to Google Cloud STT and return transcript. +fn transcribe(config: &VoiceConfig, audio: &[i16]) -> Result { + let api_key = std::env::var("GOOGLE_API_KEY") + .map_err(|_| "GOOGLE_API_KEY not set".to_string())?; + + // Convert i16 samples to WAV bytes + let wav_data = encode_wav(audio, 16000); + let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &wav_data); + + let body = serde_json::json!({ + "config": { + "encoding": "LINEAR16", + "sampleRateHertz": 16000, + "languageCode": config.stt_language, + }, + "audio": { + "content": encoded + } + }); + + let url = format!("https://speech.googleapis.com/v1/speech:recognize?key={api_key}"); + + let client = reqwest::blocking::Client::new(); + let resp = client.post(&url) + .json(&body) + .send() + .map_err(|e| format!("STT request: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().unwrap_or_default(); + return Err(format!("STT API error {status}: {body}")); + } + + let json: serde_json::Value = resp.json() + .map_err(|e| format!("STT parse: {e}"))?; + + let transcript = json["results"][0]["alternatives"][0]["transcript"] + .as_str() + .unwrap_or("") + .to_string(); + + Ok(transcript) +} + +/// Encode i16 samples as WAV bytes. +fn encode_wav(samples: &[i16], sample_rate: u32) -> Vec { + let data_len = (samples.len() * 2) as u32; + let file_len = 36 + data_len; + let mut buf = Vec::with_capacity(file_len as usize + 8); + + // RIFF header + buf.extend_from_slice(b"RIFF"); + buf.extend_from_slice(&file_len.to_le_bytes()); + buf.extend_from_slice(b"WAVE"); + // fmt chunk + buf.extend_from_slice(b"fmt "); + buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size + buf.extend_from_slice(&1u16.to_le_bytes()); // PCM + buf.extend_from_slice(&1u16.to_le_bytes()); // mono + buf.extend_from_slice(&sample_rate.to_le_bytes()); + buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate + buf.extend_from_slice(&2u16.to_le_bytes()); // block align + buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample + // data chunk + buf.extend_from_slice(b"data"); + buf.extend_from_slice(&data_len.to_le_bytes()); + for &s in samples { + buf.extend_from_slice(&s.to_le_bytes()); + } + buf +} + +fn resample(input: &[i16], from_rate: u32, to_rate: u32, output_len: usize) -> Vec { + if from_rate == to_rate { + return input.to_vec(); + } + let ratio = from_rate as f64 / to_rate as f64; + (0..output_len) + .map(|i| { + let src = (i as f64 * ratio) as usize; + input.get(src).copied().unwrap_or(0) + }) + .collect() +} diff --git a/src/voice/tts.rs b/src/voice/tts.rs new file mode 100644 index 0000000..1befe41 --- /dev/null +++ b/src/voice/tts.rs @@ -0,0 +1,54 @@ +use std::io::{Cursor, Read}; +use crate::voice::VoiceConfig; + +/// Synthesize text to audio bytes via ElevenLabs API. +pub fn synthesize(config: &VoiceConfig, text: &str) -> Result, String> { + let url = format!( + "https://api.elevenlabs.io/v1/text-to-speech/{}", + config.tts_voice_id + ); + + let body = serde_json::json!({ + "text": text, + "model_id": config.tts_model, + "voice_settings": { + "stability": 0.5, + "similarity_boost": 0.75 + } + }); + + let client = reqwest::blocking::Client::new(); + let resp = client.post(&url) + .header("xi-api-key", &config.tts_api_key) + .header("Content-Type", "application/json") + .header("Accept", "audio/mpeg") + .json(&body) + .send() + .map_err(|e| format!("TTS request failed: {e}"))?; + + if !resp.status().is_success() { + let status = resp.status(); + let body = resp.text().unwrap_or_default(); + return Err(format!("TTS API error {status}: {body}")); + } + + resp.bytes() + .map(|b| b.to_vec()) + .map_err(|e| format!("TTS read error: {e}")) +} + +/// Play audio bytes (MP3) using rodio. +pub fn play_audio(data: &[u8]) -> Result<(), String> { + let (_stream, handle) = rodio::OutputStream::try_default() + .map_err(|e| format!("audio output error: {e}"))?; + let sink = rodio::Sink::try_new(&handle) + .map_err(|e| format!("audio sink error: {e}"))?; + + let cursor = Cursor::new(data.to_vec()); + let source = rodio::Decoder::new(cursor) + .map_err(|e| format!("audio decode error: {e}"))?; + + sink.append(source); + sink.sleep_until_end(); + Ok(()) +}