feat(voice): add voice input/output foundation with rodio, cpal, and webrtc-vad
This commit is contained in:
@@ -20,3 +20,8 @@ libc = "0.2"
|
||||
notify = { version = "7", features = ["macos_fsevent"] }
|
||||
ratatui = "0.29"
|
||||
crossterm = "0.28"
|
||||
reqwest = { version = "0.12", features = ["json", "blocking"] }
|
||||
rodio = "0.19"
|
||||
base64 = "0.22"
|
||||
cpal = "0.15"
|
||||
webrtc-vad = "0.4"
|
||||
|
||||
20
examples/voice_test.rs
Normal file
20
examples/voice_test.rs
Normal file
@@ -0,0 +1,20 @@
|
||||
fn main() {
|
||||
eprintln!("Voice test: initializing...");
|
||||
let voice = match aishell::voice::VoiceSystem::new() {
|
||||
Some(v) => v,
|
||||
None => {
|
||||
eprintln!("Voice system not available. Check ELEVENLABS_API_KEY in .env");
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
eprintln!("Voice test: listening (speak now)...");
|
||||
match voice.listen() {
|
||||
Some(text) => {
|
||||
eprintln!("Recognized: {text}");
|
||||
eprintln!("Speaking back...");
|
||||
voice.speak(&text);
|
||||
}
|
||||
None => eprintln!("No speech detected."),
|
||||
}
|
||||
}
|
||||
@@ -179,7 +179,10 @@ impl Agent {
|
||||
pub fn stop(&mut self) {
|
||||
if self.is_running() && !self.stopped {
|
||||
self.stopped = true;
|
||||
#[cfg(unix)]
|
||||
unsafe { libc::kill(self.pid as i32, libc::SIGTERM); }
|
||||
#[cfg(windows)]
|
||||
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.pid.to_string()]).output(); }
|
||||
self.status = AgentStatus::Error("stopped".to_string());
|
||||
self.dirty = true;
|
||||
self.log("stopped", "");
|
||||
@@ -291,7 +294,10 @@ impl Drop for Agent {
|
||||
fn drop(&mut self) {
|
||||
if !self.stopped {
|
||||
self.stopped = true;
|
||||
#[cfg(unix)]
|
||||
unsafe { libc::kill(self.pid as i32, libc::SIGTERM); }
|
||||
#[cfg(windows)]
|
||||
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.pid.to_string()]).output(); }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -44,7 +44,10 @@ impl ClaudeManager {
|
||||
|
||||
pub fn cancel(&mut self) {
|
||||
self.status = StatusKind::Idle;
|
||||
#[cfg(unix)]
|
||||
unsafe { libc::kill(self.child_pid as i32, libc::SIGINT); }
|
||||
#[cfg(windows)]
|
||||
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.child_pid.to_string()]).output(); }
|
||||
while self.output_rx.try_recv().is_ok() {}
|
||||
}
|
||||
|
||||
@@ -59,6 +62,9 @@ impl ClaudeManager {
|
||||
|
||||
impl Drop for ClaudeManager {
|
||||
fn drop(&mut self) {
|
||||
#[cfg(unix)]
|
||||
unsafe { libc::kill(self.child_pid as i32, libc::SIGTERM); }
|
||||
#[cfg(windows)]
|
||||
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.child_pid.to_string()]).output(); }
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,3 +8,4 @@ pub mod agent;
|
||||
pub mod tui;
|
||||
pub mod headless;
|
||||
pub mod watch;
|
||||
pub mod voice;
|
||||
|
||||
@@ -26,7 +26,7 @@ fn main() {
|
||||
println!("{}", env!("CARGO_PKG_VERSION"));
|
||||
}
|
||||
Some("help" | "--help" | "-h") => print_help(),
|
||||
None | Some("tui") => {
|
||||
None | Some("tui") | Some("--voice") => {
|
||||
// Show logo before entering alternate screen
|
||||
eprintln!("\x1b[38;5;226m{}\x1b[0m\n\x1b[1m aishell\x1b[0m v{}\n",
|
||||
LOGO, env!("CARGO_PKG_VERSION"));
|
||||
|
||||
46
src/tui.rs
46
src/tui.rs
@@ -51,6 +51,9 @@ pub struct App {
|
||||
|
||||
cmd_cache: CommandCache,
|
||||
watch_rx: Option<std::sync::mpsc::Receiver<Vec<String>>>,
|
||||
voice: Option<std::sync::Arc<crate::voice::VoiceSystem>>,
|
||||
voice_input_tx: Option<std::sync::mpsc::Sender<String>>,
|
||||
voice_input_rx: Option<std::sync::mpsc::Receiver<String>>,
|
||||
mode: Mode,
|
||||
input: String,
|
||||
input_task: String,
|
||||
@@ -82,6 +85,13 @@ impl App {
|
||||
agent_scroll: 0,
|
||||
cmd_cache: CommandCache::new(),
|
||||
watch_rx: None,
|
||||
voice: if std::env::args().any(|a| a == "--voice") {
|
||||
crate::voice::VoiceSystem::new().map(std::sync::Arc::new)
|
||||
} else {
|
||||
None
|
||||
},
|
||||
voice_input_tx: None,
|
||||
voice_input_rx: None,
|
||||
mode: Mode::Ai,
|
||||
input: String::new(),
|
||||
input_task: String::new(),
|
||||
@@ -89,6 +99,13 @@ impl App {
|
||||
should_quit: false,
|
||||
};
|
||||
|
||||
// Setup voice input channel
|
||||
if app.voice.is_some() {
|
||||
let (tx, rx) = std::sync::mpsc::channel();
|
||||
app.voice_input_tx = Some(tx);
|
||||
app.voice_input_rx = Some(rx);
|
||||
}
|
||||
|
||||
// Send protocol + identity + project context
|
||||
if let Some(ref mut claude) = app.claude {
|
||||
let cwd = std::env::current_dir()
|
||||
@@ -128,6 +145,17 @@ impl App {
|
||||
}
|
||||
|
||||
fn poll_all(&mut self) {
|
||||
// Check for voice input
|
||||
if let Some(ref rx) = self.voice_input_rx {
|
||||
if let Ok(text) = rx.try_recv() {
|
||||
self.ai_output.push_str(&format!("\n---\n[voice] {text}\n"));
|
||||
self.ai_scroll = u16::MAX;
|
||||
if let Some(ref mut claude) = self.claude {
|
||||
claude.send(&text);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut stream_ended = false;
|
||||
|
||||
if let Some(ref mut claude) = self.claude {
|
||||
@@ -152,6 +180,24 @@ impl App {
|
||||
&format!("{STATE_DIR}/ai.txt"),
|
||||
self.ai_output.as_bytes(),
|
||||
);
|
||||
// Voice: speak response, then listen for next input
|
||||
if let Some(ref voice) = self.voice {
|
||||
let last = self.ai_output.rsplit("---").next().unwrap_or(&self.ai_output);
|
||||
let text = last.trim().to_string();
|
||||
if !text.is_empty() {
|
||||
let v = voice.clone();
|
||||
let tx = self.voice_input_tx.clone();
|
||||
std::thread::spawn(move || {
|
||||
v.speak(&text);
|
||||
// After speaking, listen for user's response
|
||||
if let Some(heard) = v.listen() {
|
||||
if let Some(tx) = tx {
|
||||
let _ = tx.send(heard);
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
stream_ended = true;
|
||||
}
|
||||
}
|
||||
|
||||
83
src/voice/mod.rs
Normal file
83
src/voice/mod.rs
Normal file
@@ -0,0 +1,83 @@
|
||||
pub mod tts;
|
||||
pub mod stt;
|
||||
|
||||
/// Load .env file from cwd, setting vars that aren't already set.
|
||||
fn load_dotenv() {
|
||||
for dir in &[".", env!("CARGO_MANIFEST_DIR")] {
|
||||
let path = std::path::Path::new(dir).join(".env");
|
||||
if let Ok(content) = std::fs::read_to_string(&path) {
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with('#') { continue; }
|
||||
if let Some((key, val)) = line.split_once('=') {
|
||||
let key = key.trim();
|
||||
let val = val.trim();
|
||||
if std::env::var(key).is_err() {
|
||||
std::env::set_var(key, val);
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VoiceConfig {
|
||||
pub tts_api_key: String,
|
||||
pub tts_voice_id: String,
|
||||
pub tts_model: String,
|
||||
pub stt_language: String,
|
||||
}
|
||||
|
||||
impl VoiceConfig {
|
||||
pub fn load() -> Option<Self> {
|
||||
// Load .env file if present (cwd or project root)
|
||||
load_dotenv();
|
||||
|
||||
let tts_api_key = std::env::var("ELEVENLABS_API_KEY").ok()?;
|
||||
|
||||
Some(Self {
|
||||
tts_api_key,
|
||||
tts_voice_id: std::env::var("ELEVENLABS_VOICE_ID").unwrap_or_default(),
|
||||
tts_model: std::env::var("ELEVENLABS_MODEL_ID").unwrap_or_else(|_| "eleven_multilingual_v2".into()),
|
||||
stt_language: std::env::var("STT_LANGUAGE").unwrap_or_else(|_| "ja-JP".into()),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn is_available(&self) -> bool {
|
||||
!self.tts_api_key.is_empty()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct VoiceSystem {
|
||||
pub config: VoiceConfig,
|
||||
}
|
||||
|
||||
impl VoiceSystem {
|
||||
pub fn new() -> Option<Self> {
|
||||
let config = VoiceConfig::load()?;
|
||||
if !config.is_available() {
|
||||
return None;
|
||||
}
|
||||
Some(Self { config })
|
||||
}
|
||||
|
||||
pub fn speak(&self, text: &str) {
|
||||
if text.trim().is_empty() { return; }
|
||||
let audio = match tts::synthesize(&self.config, text) {
|
||||
Ok(data) => data,
|
||||
Err(e) => { eprintln!("tts error: {e}"); return; }
|
||||
};
|
||||
if let Err(e) = tts::play_audio(&audio) {
|
||||
eprintln!("audio play error: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
pub fn listen(&self) -> Option<String> {
|
||||
match stt::recognize(&self.config) {
|
||||
Ok(text) if !text.is_empty() => Some(text),
|
||||
Ok(_) => None,
|
||||
Err(e) => { eprintln!("stt error: {e}"); None }
|
||||
}
|
||||
}
|
||||
}
|
||||
208
src/voice/stt.rs
Normal file
208
src/voice/stt.rs
Normal file
@@ -0,0 +1,208 @@
|
||||
use crate::voice::VoiceConfig;
|
||||
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
|
||||
use std::sync::{Arc, Mutex, mpsc};
|
||||
|
||||
/// Record audio via VAD and recognize speech via Google Cloud STT.
|
||||
pub fn recognize(config: &VoiceConfig) -> Result<String, String> {
|
||||
let audio = record_vad().map_err(|e| format!("recording: {e}"))?;
|
||||
if audio.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
transcribe(config, &audio)
|
||||
}
|
||||
|
||||
/// Record until speech ends (VAD-based).
|
||||
fn record_vad() -> Result<Vec<i16>, String> {
|
||||
let host = cpal::default_host();
|
||||
let device = host.default_input_device()
|
||||
.ok_or("No input device")?;
|
||||
|
||||
let default_config = device.default_input_config()
|
||||
.map_err(|e| format!("input config: {e}"))?;
|
||||
let device_rate = default_config.sample_rate().0;
|
||||
let device_channels = default_config.channels() as usize;
|
||||
|
||||
let config = cpal::StreamConfig {
|
||||
channels: default_config.channels(),
|
||||
sample_rate: cpal::SampleRate(device_rate),
|
||||
buffer_size: cpal::BufferSize::Default,
|
||||
};
|
||||
|
||||
let (tx, rx) = mpsc::channel::<Vec<i16>>();
|
||||
let done = Arc::new(Mutex::new(false));
|
||||
let done_clone = done.clone();
|
||||
let gain: f32 = 8.0;
|
||||
|
||||
let stream = device.build_input_stream(
|
||||
&config,
|
||||
move |data: &[f32], _: &cpal::InputCallbackInfo| {
|
||||
if *done_clone.lock().unwrap() { return; }
|
||||
let mono: Vec<i16> = data.chunks(device_channels)
|
||||
.map(|ch| {
|
||||
let avg = ch.iter().sum::<f32>() / device_channels as f32;
|
||||
let amplified = (avg * gain).clamp(-1.0, 1.0);
|
||||
(amplified * 32767.0) as i16
|
||||
})
|
||||
.collect();
|
||||
let _ = tx.send(mono);
|
||||
},
|
||||
|err| eprintln!("audio error: {err}"),
|
||||
None,
|
||||
).map_err(|e| format!("build stream: {e}"))?;
|
||||
|
||||
stream.play().map_err(|e| format!("play: {e}"))?;
|
||||
|
||||
let frame_ms: u32 = 30;
|
||||
let device_frame_size = (device_rate * frame_ms / 1000) as usize;
|
||||
let vad_frame_size: usize = 480; // 30ms @ 16kHz
|
||||
let silence_frames: u32 = 500 / frame_ms;
|
||||
let min_speech_frames: u32 = 200 / frame_ms;
|
||||
let max_frames: u32 = 8000 / frame_ms; // 8 second max
|
||||
|
||||
let mut vad = webrtc_vad::Vad::new_with_rate_and_mode(
|
||||
webrtc_vad::SampleRate::Rate16kHz,
|
||||
webrtc_vad::VadMode::Quality,
|
||||
);
|
||||
|
||||
let mut frame_buf: Vec<i16> = Vec::with_capacity(device_frame_size);
|
||||
let mut audio_buf: Vec<i16> = Vec::new();
|
||||
let mut recording = false;
|
||||
let mut silence_count: u32 = 0;
|
||||
let mut speech_count: u32 = 0;
|
||||
let mut total_frames: u32 = 0;
|
||||
|
||||
eprintln!(" listening...");
|
||||
|
||||
loop {
|
||||
match rx.recv_timeout(std::time::Duration::from_millis(100)) {
|
||||
Ok(samples) => {
|
||||
for sample in samples {
|
||||
frame_buf.push(sample);
|
||||
if frame_buf.len() >= device_frame_size {
|
||||
total_frames += 1;
|
||||
let vad_frame = resample(&frame_buf, device_rate, 16000, vad_frame_size);
|
||||
let is_speech = vad.is_voice_segment(&vad_frame).unwrap_or(false);
|
||||
|
||||
if is_speech {
|
||||
recording = true;
|
||||
silence_count = 0;
|
||||
speech_count += 1;
|
||||
audio_buf.extend_from_slice(&frame_buf);
|
||||
} else if recording {
|
||||
silence_count += 1;
|
||||
audio_buf.extend_from_slice(&frame_buf);
|
||||
if silence_count >= silence_frames && speech_count >= min_speech_frames {
|
||||
*done.lock().unwrap() = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if total_frames >= max_frames {
|
||||
*done.lock().unwrap() = true;
|
||||
break;
|
||||
}
|
||||
frame_buf.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(mpsc::RecvTimeoutError::Timeout) => {}
|
||||
Err(mpsc::RecvTimeoutError::Disconnected) => break,
|
||||
}
|
||||
if *done.lock().unwrap() { break; }
|
||||
}
|
||||
|
||||
drop(stream);
|
||||
|
||||
if audio_buf.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
let output_len = (audio_buf.len() as f32 * 16000.0 / device_rate as f32) as usize;
|
||||
Ok(resample(&audio_buf, device_rate, 16000, output_len))
|
||||
}
|
||||
|
||||
/// Send audio to Google Cloud STT and return transcript.
|
||||
fn transcribe(config: &VoiceConfig, audio: &[i16]) -> Result<String, String> {
|
||||
let api_key = std::env::var("GOOGLE_API_KEY")
|
||||
.map_err(|_| "GOOGLE_API_KEY not set".to_string())?;
|
||||
|
||||
// Convert i16 samples to WAV bytes
|
||||
let wav_data = encode_wav(audio, 16000);
|
||||
let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &wav_data);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"config": {
|
||||
"encoding": "LINEAR16",
|
||||
"sampleRateHertz": 16000,
|
||||
"languageCode": config.stt_language,
|
||||
},
|
||||
"audio": {
|
||||
"content": encoded
|
||||
}
|
||||
});
|
||||
|
||||
let url = format!("https://speech.googleapis.com/v1/speech:recognize?key={api_key}");
|
||||
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let resp = client.post(&url)
|
||||
.json(&body)
|
||||
.send()
|
||||
.map_err(|e| format!("STT request: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().unwrap_or_default();
|
||||
return Err(format!("STT API error {status}: {body}"));
|
||||
}
|
||||
|
||||
let json: serde_json::Value = resp.json()
|
||||
.map_err(|e| format!("STT parse: {e}"))?;
|
||||
|
||||
let transcript = json["results"][0]["alternatives"][0]["transcript"]
|
||||
.as_str()
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
|
||||
Ok(transcript)
|
||||
}
|
||||
|
||||
/// Encode i16 samples as WAV bytes.
|
||||
fn encode_wav(samples: &[i16], sample_rate: u32) -> Vec<u8> {
|
||||
let data_len = (samples.len() * 2) as u32;
|
||||
let file_len = 36 + data_len;
|
||||
let mut buf = Vec::with_capacity(file_len as usize + 8);
|
||||
|
||||
// RIFF header
|
||||
buf.extend_from_slice(b"RIFF");
|
||||
buf.extend_from_slice(&file_len.to_le_bytes());
|
||||
buf.extend_from_slice(b"WAVE");
|
||||
// fmt chunk
|
||||
buf.extend_from_slice(b"fmt ");
|
||||
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
|
||||
buf.extend_from_slice(&1u16.to_le_bytes()); // mono
|
||||
buf.extend_from_slice(&sample_rate.to_le_bytes());
|
||||
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate
|
||||
buf.extend_from_slice(&2u16.to_le_bytes()); // block align
|
||||
buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
|
||||
// data chunk
|
||||
buf.extend_from_slice(b"data");
|
||||
buf.extend_from_slice(&data_len.to_le_bytes());
|
||||
for &s in samples {
|
||||
buf.extend_from_slice(&s.to_le_bytes());
|
||||
}
|
||||
buf
|
||||
}
|
||||
|
||||
fn resample(input: &[i16], from_rate: u32, to_rate: u32, output_len: usize) -> Vec<i16> {
|
||||
if from_rate == to_rate {
|
||||
return input.to_vec();
|
||||
}
|
||||
let ratio = from_rate as f64 / to_rate as f64;
|
||||
(0..output_len)
|
||||
.map(|i| {
|
||||
let src = (i as f64 * ratio) as usize;
|
||||
input.get(src).copied().unwrap_or(0)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
54
src/voice/tts.rs
Normal file
54
src/voice/tts.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use std::io::{Cursor, Read};
|
||||
use crate::voice::VoiceConfig;
|
||||
|
||||
/// Synthesize text to audio bytes via ElevenLabs API.
|
||||
pub fn synthesize(config: &VoiceConfig, text: &str) -> Result<Vec<u8>, String> {
|
||||
let url = format!(
|
||||
"https://api.elevenlabs.io/v1/text-to-speech/{}",
|
||||
config.tts_voice_id
|
||||
);
|
||||
|
||||
let body = serde_json::json!({
|
||||
"text": text,
|
||||
"model_id": config.tts_model,
|
||||
"voice_settings": {
|
||||
"stability": 0.5,
|
||||
"similarity_boost": 0.75
|
||||
}
|
||||
});
|
||||
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let resp = client.post(&url)
|
||||
.header("xi-api-key", &config.tts_api_key)
|
||||
.header("Content-Type", "application/json")
|
||||
.header("Accept", "audio/mpeg")
|
||||
.json(&body)
|
||||
.send()
|
||||
.map_err(|e| format!("TTS request failed: {e}"))?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let status = resp.status();
|
||||
let body = resp.text().unwrap_or_default();
|
||||
return Err(format!("TTS API error {status}: {body}"));
|
||||
}
|
||||
|
||||
resp.bytes()
|
||||
.map(|b| b.to_vec())
|
||||
.map_err(|e| format!("TTS read error: {e}"))
|
||||
}
|
||||
|
||||
/// Play audio bytes (MP3) using rodio.
|
||||
pub fn play_audio(data: &[u8]) -> Result<(), String> {
|
||||
let (_stream, handle) = rodio::OutputStream::try_default()
|
||||
.map_err(|e| format!("audio output error: {e}"))?;
|
||||
let sink = rodio::Sink::try_new(&handle)
|
||||
.map_err(|e| format!("audio sink error: {e}"))?;
|
||||
|
||||
let cursor = Cursor::new(data.to_vec());
|
||||
let source = rodio::Decoder::new(cursor)
|
||||
.map_err(|e| format!("audio decode error: {e}"))?;
|
||||
|
||||
sink.append(source);
|
||||
sink.sleep_until_end();
|
||||
Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user