2
0

feat(voice): add voice input/output foundation with rodio, cpal, and webrtc-vad

This commit is contained in:
2026-03-26 07:06:37 +09:00
parent c4dcac1d95
commit 2df28d13b2
10 changed files with 430 additions and 1 deletions

View File

@@ -20,3 +20,8 @@ libc = "0.2"
notify = { version = "7", features = ["macos_fsevent"] }
ratatui = "0.29"
crossterm = "0.28"
reqwest = { version = "0.12", features = ["json", "blocking"] }
rodio = "0.19"
base64 = "0.22"
cpal = "0.15"
webrtc-vad = "0.4"

20
examples/voice_test.rs Normal file
View File

@@ -0,0 +1,20 @@
fn main() {
eprintln!("Voice test: initializing...");
let voice = match aishell::voice::VoiceSystem::new() {
Some(v) => v,
None => {
eprintln!("Voice system not available. Check ELEVENLABS_API_KEY in .env");
return;
}
};
eprintln!("Voice test: listening (speak now)...");
match voice.listen() {
Some(text) => {
eprintln!("Recognized: {text}");
eprintln!("Speaking back...");
voice.speak(&text);
}
None => eprintln!("No speech detected."),
}
}

View File

@@ -179,7 +179,10 @@ impl Agent {
pub fn stop(&mut self) {
if self.is_running() && !self.stopped {
self.stopped = true;
#[cfg(unix)]
unsafe { libc::kill(self.pid as i32, libc::SIGTERM); }
#[cfg(windows)]
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.pid.to_string()]).output(); }
self.status = AgentStatus::Error("stopped".to_string());
self.dirty = true;
self.log("stopped", "");
@@ -291,7 +294,10 @@ impl Drop for Agent {
fn drop(&mut self) {
if !self.stopped {
self.stopped = true;
#[cfg(unix)]
unsafe { libc::kill(self.pid as i32, libc::SIGTERM); }
#[cfg(windows)]
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.pid.to_string()]).output(); }
}
}
}

View File

@@ -44,7 +44,10 @@ impl ClaudeManager {
pub fn cancel(&mut self) {
self.status = StatusKind::Idle;
#[cfg(unix)]
unsafe { libc::kill(self.child_pid as i32, libc::SIGINT); }
#[cfg(windows)]
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.child_pid.to_string()]).output(); }
while self.output_rx.try_recv().is_ok() {}
}
@@ -59,6 +62,9 @@ impl ClaudeManager {
impl Drop for ClaudeManager {
fn drop(&mut self) {
#[cfg(unix)]
unsafe { libc::kill(self.child_pid as i32, libc::SIGTERM); }
#[cfg(windows)]
{ let _ = std::process::Command::new("taskkill").args(["/F", "/PID", &self.child_pid.to_string()]).output(); }
}
}

View File

@@ -8,3 +8,4 @@ pub mod agent;
pub mod tui;
pub mod headless;
pub mod watch;
pub mod voice;

View File

@@ -26,7 +26,7 @@ fn main() {
println!("{}", env!("CARGO_PKG_VERSION"));
}
Some("help" | "--help" | "-h") => print_help(),
None | Some("tui") => {
None | Some("tui") | Some("--voice") => {
// Show logo before entering alternate screen
eprintln!("\x1b[38;5;226m{}\x1b[0m\n\x1b[1m aishell\x1b[0m v{}\n",
LOGO, env!("CARGO_PKG_VERSION"));

View File

@@ -51,6 +51,9 @@ pub struct App {
cmd_cache: CommandCache,
watch_rx: Option<std::sync::mpsc::Receiver<Vec<String>>>,
voice: Option<std::sync::Arc<crate::voice::VoiceSystem>>,
voice_input_tx: Option<std::sync::mpsc::Sender<String>>,
voice_input_rx: Option<std::sync::mpsc::Receiver<String>>,
mode: Mode,
input: String,
input_task: String,
@@ -82,6 +85,13 @@ impl App {
agent_scroll: 0,
cmd_cache: CommandCache::new(),
watch_rx: None,
voice: if std::env::args().any(|a| a == "--voice") {
crate::voice::VoiceSystem::new().map(std::sync::Arc::new)
} else {
None
},
voice_input_tx: None,
voice_input_rx: None,
mode: Mode::Ai,
input: String::new(),
input_task: String::new(),
@@ -89,6 +99,13 @@ impl App {
should_quit: false,
};
// Setup voice input channel
if app.voice.is_some() {
let (tx, rx) = std::sync::mpsc::channel();
app.voice_input_tx = Some(tx);
app.voice_input_rx = Some(rx);
}
// Send protocol + identity + project context
if let Some(ref mut claude) = app.claude {
let cwd = std::env::current_dir()
@@ -128,6 +145,17 @@ impl App {
}
fn poll_all(&mut self) {
// Check for voice input
if let Some(ref rx) = self.voice_input_rx {
if let Ok(text) = rx.try_recv() {
self.ai_output.push_str(&format!("\n---\n[voice] {text}\n"));
self.ai_scroll = u16::MAX;
if let Some(ref mut claude) = self.claude {
claude.send(&text);
}
}
}
let mut stream_ended = false;
if let Some(ref mut claude) = self.claude {
@@ -152,6 +180,24 @@ impl App {
&format!("{STATE_DIR}/ai.txt"),
self.ai_output.as_bytes(),
);
// Voice: speak response, then listen for next input
if let Some(ref voice) = self.voice {
let last = self.ai_output.rsplit("---").next().unwrap_or(&self.ai_output);
let text = last.trim().to_string();
if !text.is_empty() {
let v = voice.clone();
let tx = self.voice_input_tx.clone();
std::thread::spawn(move || {
v.speak(&text);
// After speaking, listen for user's response
if let Some(heard) = v.listen() {
if let Some(tx) = tx {
let _ = tx.send(heard);
}
}
});
}
}
stream_ended = true;
}
}

83
src/voice/mod.rs Normal file
View File

@@ -0,0 +1,83 @@
pub mod tts;
pub mod stt;
/// Load .env file from cwd, setting vars that aren't already set.
fn load_dotenv() {
for dir in &[".", env!("CARGO_MANIFEST_DIR")] {
let path = std::path::Path::new(dir).join(".env");
if let Ok(content) = std::fs::read_to_string(&path) {
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') { continue; }
if let Some((key, val)) = line.split_once('=') {
let key = key.trim();
let val = val.trim();
if std::env::var(key).is_err() {
std::env::set_var(key, val);
}
}
}
break;
}
}
}
pub struct VoiceConfig {
pub tts_api_key: String,
pub tts_voice_id: String,
pub tts_model: String,
pub stt_language: String,
}
impl VoiceConfig {
pub fn load() -> Option<Self> {
// Load .env file if present (cwd or project root)
load_dotenv();
let tts_api_key = std::env::var("ELEVENLABS_API_KEY").ok()?;
Some(Self {
tts_api_key,
tts_voice_id: std::env::var("ELEVENLABS_VOICE_ID").unwrap_or_default(),
tts_model: std::env::var("ELEVENLABS_MODEL_ID").unwrap_or_else(|_| "eleven_multilingual_v2".into()),
stt_language: std::env::var("STT_LANGUAGE").unwrap_or_else(|_| "ja-JP".into()),
})
}
pub fn is_available(&self) -> bool {
!self.tts_api_key.is_empty()
}
}
pub struct VoiceSystem {
pub config: VoiceConfig,
}
impl VoiceSystem {
pub fn new() -> Option<Self> {
let config = VoiceConfig::load()?;
if !config.is_available() {
return None;
}
Some(Self { config })
}
pub fn speak(&self, text: &str) {
if text.trim().is_empty() { return; }
let audio = match tts::synthesize(&self.config, text) {
Ok(data) => data,
Err(e) => { eprintln!("tts error: {e}"); return; }
};
if let Err(e) = tts::play_audio(&audio) {
eprintln!("audio play error: {e}");
}
}
pub fn listen(&self) -> Option<String> {
match stt::recognize(&self.config) {
Ok(text) if !text.is_empty() => Some(text),
Ok(_) => None,
Err(e) => { eprintln!("stt error: {e}"); None }
}
}
}

208
src/voice/stt.rs Normal file
View File

@@ -0,0 +1,208 @@
use crate::voice::VoiceConfig;
use cpal::traits::{DeviceTrait, HostTrait, StreamTrait};
use std::sync::{Arc, Mutex, mpsc};
/// Record audio via VAD and recognize speech via Google Cloud STT.
pub fn recognize(config: &VoiceConfig) -> Result<String, String> {
let audio = record_vad().map_err(|e| format!("recording: {e}"))?;
if audio.is_empty() {
return Ok(String::new());
}
transcribe(config, &audio)
}
/// Record until speech ends (VAD-based).
fn record_vad() -> Result<Vec<i16>, String> {
let host = cpal::default_host();
let device = host.default_input_device()
.ok_or("No input device")?;
let default_config = device.default_input_config()
.map_err(|e| format!("input config: {e}"))?;
let device_rate = default_config.sample_rate().0;
let device_channels = default_config.channels() as usize;
let config = cpal::StreamConfig {
channels: default_config.channels(),
sample_rate: cpal::SampleRate(device_rate),
buffer_size: cpal::BufferSize::Default,
};
let (tx, rx) = mpsc::channel::<Vec<i16>>();
let done = Arc::new(Mutex::new(false));
let done_clone = done.clone();
let gain: f32 = 8.0;
let stream = device.build_input_stream(
&config,
move |data: &[f32], _: &cpal::InputCallbackInfo| {
if *done_clone.lock().unwrap() { return; }
let mono: Vec<i16> = data.chunks(device_channels)
.map(|ch| {
let avg = ch.iter().sum::<f32>() / device_channels as f32;
let amplified = (avg * gain).clamp(-1.0, 1.0);
(amplified * 32767.0) as i16
})
.collect();
let _ = tx.send(mono);
},
|err| eprintln!("audio error: {err}"),
None,
).map_err(|e| format!("build stream: {e}"))?;
stream.play().map_err(|e| format!("play: {e}"))?;
let frame_ms: u32 = 30;
let device_frame_size = (device_rate * frame_ms / 1000) as usize;
let vad_frame_size: usize = 480; // 30ms @ 16kHz
let silence_frames: u32 = 500 / frame_ms;
let min_speech_frames: u32 = 200 / frame_ms;
let max_frames: u32 = 8000 / frame_ms; // 8 second max
let mut vad = webrtc_vad::Vad::new_with_rate_and_mode(
webrtc_vad::SampleRate::Rate16kHz,
webrtc_vad::VadMode::Quality,
);
let mut frame_buf: Vec<i16> = Vec::with_capacity(device_frame_size);
let mut audio_buf: Vec<i16> = Vec::new();
let mut recording = false;
let mut silence_count: u32 = 0;
let mut speech_count: u32 = 0;
let mut total_frames: u32 = 0;
eprintln!(" listening...");
loop {
match rx.recv_timeout(std::time::Duration::from_millis(100)) {
Ok(samples) => {
for sample in samples {
frame_buf.push(sample);
if frame_buf.len() >= device_frame_size {
total_frames += 1;
let vad_frame = resample(&frame_buf, device_rate, 16000, vad_frame_size);
let is_speech = vad.is_voice_segment(&vad_frame).unwrap_or(false);
if is_speech {
recording = true;
silence_count = 0;
speech_count += 1;
audio_buf.extend_from_slice(&frame_buf);
} else if recording {
silence_count += 1;
audio_buf.extend_from_slice(&frame_buf);
if silence_count >= silence_frames && speech_count >= min_speech_frames {
*done.lock().unwrap() = true;
break;
}
}
if total_frames >= max_frames {
*done.lock().unwrap() = true;
break;
}
frame_buf.clear();
}
}
}
Err(mpsc::RecvTimeoutError::Timeout) => {}
Err(mpsc::RecvTimeoutError::Disconnected) => break,
}
if *done.lock().unwrap() { break; }
}
drop(stream);
if audio_buf.is_empty() {
return Ok(Vec::new());
}
let output_len = (audio_buf.len() as f32 * 16000.0 / device_rate as f32) as usize;
Ok(resample(&audio_buf, device_rate, 16000, output_len))
}
/// Send audio to Google Cloud STT and return transcript.
fn transcribe(config: &VoiceConfig, audio: &[i16]) -> Result<String, String> {
let api_key = std::env::var("GOOGLE_API_KEY")
.map_err(|_| "GOOGLE_API_KEY not set".to_string())?;
// Convert i16 samples to WAV bytes
let wav_data = encode_wav(audio, 16000);
let encoded = base64::Engine::encode(&base64::engine::general_purpose::STANDARD, &wav_data);
let body = serde_json::json!({
"config": {
"encoding": "LINEAR16",
"sampleRateHertz": 16000,
"languageCode": config.stt_language,
},
"audio": {
"content": encoded
}
});
let url = format!("https://speech.googleapis.com/v1/speech:recognize?key={api_key}");
let client = reqwest::blocking::Client::new();
let resp = client.post(&url)
.json(&body)
.send()
.map_err(|e| format!("STT request: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().unwrap_or_default();
return Err(format!("STT API error {status}: {body}"));
}
let json: serde_json::Value = resp.json()
.map_err(|e| format!("STT parse: {e}"))?;
let transcript = json["results"][0]["alternatives"][0]["transcript"]
.as_str()
.unwrap_or("")
.to_string();
Ok(transcript)
}
/// Encode i16 samples as WAV bytes.
fn encode_wav(samples: &[i16], sample_rate: u32) -> Vec<u8> {
let data_len = (samples.len() * 2) as u32;
let file_len = 36 + data_len;
let mut buf = Vec::with_capacity(file_len as usize + 8);
// RIFF header
buf.extend_from_slice(b"RIFF");
buf.extend_from_slice(&file_len.to_le_bytes());
buf.extend_from_slice(b"WAVE");
// fmt chunk
buf.extend_from_slice(b"fmt ");
buf.extend_from_slice(&16u32.to_le_bytes()); // chunk size
buf.extend_from_slice(&1u16.to_le_bytes()); // PCM
buf.extend_from_slice(&1u16.to_le_bytes()); // mono
buf.extend_from_slice(&sample_rate.to_le_bytes());
buf.extend_from_slice(&(sample_rate * 2).to_le_bytes()); // byte rate
buf.extend_from_slice(&2u16.to_le_bytes()); // block align
buf.extend_from_slice(&16u16.to_le_bytes()); // bits per sample
// data chunk
buf.extend_from_slice(b"data");
buf.extend_from_slice(&data_len.to_le_bytes());
for &s in samples {
buf.extend_from_slice(&s.to_le_bytes());
}
buf
}
fn resample(input: &[i16], from_rate: u32, to_rate: u32, output_len: usize) -> Vec<i16> {
if from_rate == to_rate {
return input.to_vec();
}
let ratio = from_rate as f64 / to_rate as f64;
(0..output_len)
.map(|i| {
let src = (i as f64 * ratio) as usize;
input.get(src).copied().unwrap_or(0)
})
.collect()
}

54
src/voice/tts.rs Normal file
View File

@@ -0,0 +1,54 @@
use std::io::{Cursor, Read};
use crate::voice::VoiceConfig;
/// Synthesize text to audio bytes via ElevenLabs API.
pub fn synthesize(config: &VoiceConfig, text: &str) -> Result<Vec<u8>, String> {
let url = format!(
"https://api.elevenlabs.io/v1/text-to-speech/{}",
config.tts_voice_id
);
let body = serde_json::json!({
"text": text,
"model_id": config.tts_model,
"voice_settings": {
"stability": 0.5,
"similarity_boost": 0.75
}
});
let client = reqwest::blocking::Client::new();
let resp = client.post(&url)
.header("xi-api-key", &config.tts_api_key)
.header("Content-Type", "application/json")
.header("Accept", "audio/mpeg")
.json(&body)
.send()
.map_err(|e| format!("TTS request failed: {e}"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().unwrap_or_default();
return Err(format!("TTS API error {status}: {body}"));
}
resp.bytes()
.map(|b| b.to_vec())
.map_err(|e| format!("TTS read error: {e}"))
}
/// Play audio bytes (MP3) using rodio.
pub fn play_audio(data: &[u8]) -> Result<(), String> {
let (_stream, handle) = rodio::OutputStream::try_default()
.map_err(|e| format!("audio output error: {e}"))?;
let sink = rodio::Sink::try_new(&handle)
.map_err(|e| format!("audio sink error: {e}"))?;
let cursor = Cursor::new(data.to_vec());
let source = rodio::Decoder::new(cursor)
.map_err(|e| format!("audio decode error: {e}"))?;
sink.append(source);
sink.sleep_until_end();
Ok(())
}