From 9c0cea8b6633ec0808bfa940eed769ec082a8615 Mon Sep 17 00:00:00 2001 From: syui Date: Fri, 27 Feb 2026 12:24:07 +0900 Subject: [PATCH] init --- .gitignore | 23 ++++++ Cargo.toml | 47 ++++++++++++ README.md | 3 + src/cli/mod.rs | 3 + src/cli/repl.rs | 148 ++++++++++++++++++++++++++++++++++++++ src/config/mod.rs | 53 ++++++++++++++ src/lib.rs | 7 ++ src/llm/mod.rs | 18 +++++ src/llm/openai.rs | 126 ++++++++++++++++++++++++++++++++ src/llm/provider.rs | 104 +++++++++++++++++++++++++++ src/main.rs | 74 +++++++++++++++++++ src/mcp/mod.rs | 121 +++++++++++++++++++++++++++++++ src/shell/executor.rs | 112 +++++++++++++++++++++++++++++ src/shell/mod.rs | 5 ++ src/shell/tools.rs | 162 ++++++++++++++++++++++++++++++++++++++++++ 15 files changed, 1006 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.toml create mode 100644 README.md create mode 100644 src/cli/mod.rs create mode 100644 src/cli/repl.rs create mode 100644 src/config/mod.rs create mode 100644 src/lib.rs create mode 100644 src/llm/mod.rs create mode 100644 src/llm/openai.rs create mode 100644 src/llm/provider.rs create mode 100644 src/main.rs create mode 100644 src/mcp/mod.rs create mode 100644 src/shell/executor.rs create mode 100644 src/shell/mod.rs create mode 100644 src/shell/tools.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..90d3745 --- /dev/null +++ b/.gitignore @@ -0,0 +1,23 @@ +# Rust +/target/ +Cargo.lock + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Environment +.env +.env.local +/claude.md +/CLAUDE.md diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..3ba94d2 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "aishell" +version = "0.1.0" +edition = "2021" +authors = ["syui"] +description = "AI-powered shell automation tool - A generic alternative to Claude Code" + +[lib] +name = "aishell" +path = "src/lib.rs" + +[[bin]] +name = "aishell" +path = "src/main.rs" + +[dependencies] +# CLI and async (following aigpt pattern) +clap = { version = "4.5", features = ["derive"] } +tokio = { version = "1.40", features = ["rt", "rt-multi-thread", "macros", "io-std", "process", "fs"] } +async-trait = "0.1" + +# HTTP client for LLM APIs +reqwest = { version = "0.12", features = ["json", "stream"] } + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Error handling +thiserror = "1.0" +anyhow = "1.0" + +# Utilities +dirs = "5.0" + +# Shell execution +duct = "0.13" + +# Configuration +toml = "0.8" + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Interactive REPL +rustyline = "14.0" diff --git a/README.md b/README.md new file mode 100644 index 0000000..3f414d1 --- /dev/null +++ b/README.md @@ -0,0 +1,3 @@ +# aishell + +A single-stream shell where commands and AI coexist type a command, it runs; type anything else, AI responds. diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..402cb58 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,3 @@ +pub mod repl; + +pub use repl::Repl; diff --git a/src/cli/repl.rs b/src/cli/repl.rs new file mode 100644 index 0000000..2ef3fe1 --- /dev/null +++ b/src/cli/repl.rs @@ -0,0 +1,148 @@ +use anyhow::{Context, Result}; +use rustyline::error::ReadlineError; +use rustyline::DefaultEditor; + +use crate::llm::{create_provider, LLMProvider, Message}; +use crate::shell::{execute_tool, get_tool_definitions, ShellExecutor}; + +pub struct Repl { + llm: Box, + executor: ShellExecutor, + messages: Vec, +} + +impl Repl { + pub async fn new(provider: &str, model: Option<&str>) -> Result { + let llm = create_provider(provider, model).await?; + let executor = ShellExecutor::default(); + + let system_prompt = Message::system( + "You are an AI assistant that helps users interact with their system through shell commands. \ + You have access to tools like bash, read, write, and list to help users accomplish their tasks. \ + When a user asks you to do something, use the appropriate tools to complete the task. \ + Always explain what you're doing and show the results to the user." + ); + + Ok(Self { + llm, + executor, + messages: vec![system_prompt], + }) + } + + pub async fn run(&mut self) -> Result<()> { + println!("aishell - AI-powered shell automation"); + println!("Type 'exit' or 'quit' to exit, 'clear' to clear history\n"); + + let mut rl = DefaultEditor::new()?; + + loop { + let readline = rl.readline("aishell> "); + + match readline { + Ok(line) => { + let line = line.trim(); + + if line.is_empty() { + continue; + } + + if line == "exit" || line == "quit" { + println!("Goodbye!"); + break; + } + + if line == "clear" { + self.messages.truncate(1); // Keep only system message + println!("History cleared."); + continue; + } + + rl.add_history_entry(line)?; + + if let Err(e) = self.process_input(line).await { + eprintln!("Error: {}", e); + } + } + Err(ReadlineError::Interrupted) => { + println!("^C"); + continue; + } + Err(ReadlineError::Eof) => { + println!("^D"); + break; + } + Err(err) => { + eprintln!("Error: {:?}", err); + break; + } + } + } + + Ok(()) + } + + pub async fn execute_once(&mut self, prompt: &str) -> Result<()> { + self.process_input(prompt).await + } + + async fn process_input(&mut self, input: &str) -> Result<()> { + // Add user message + self.messages.push(Message::user(input)); + + let tools = get_tool_definitions(); + + // Agent loop: keep calling LLM until it's done (no more tool calls) + let max_iterations = 10; + for iteration in 0..max_iterations { + tracing::debug!("Agent loop iteration {}", iteration + 1); + + let response = self + .llm + .chat(self.messages.clone(), Some(tools.clone())) + .await + .context("Failed to get LLM response")?; + + // If there are tool calls, execute them + if let Some(tool_calls) = response.tool_calls { + tracing::info!("LLM requested {} tool calls", tool_calls.len()); + + // Add assistant message with tool calls + let mut assistant_msg = Message::assistant(response.content.clone()); + assistant_msg.tool_calls = Some(tool_calls.clone()); + self.messages.push(assistant_msg); + + // Execute each tool call + for tool_call in tool_calls { + let tool_name = &tool_call.function.name; + let tool_args = &tool_call.function.arguments; + + println!("\n[Executing tool: {}]", tool_name); + + let result = match execute_tool(tool_name, tool_args, &self.executor) { + Ok(output) => output, + Err(e) => format!("Error executing tool: {}", e), + }; + + println!("{}", result); + + // Add tool result message + self.messages.push(Message::tool(result, tool_call.id.clone())); + } + + // Continue the loop to get the next response + continue; + } + + // No tool calls, so the LLM is done + if !response.content.is_empty() { + println!("\n{}\n", response.content); + self.messages.push(Message::assistant(response.content)); + } + + break; + } + + Ok(()) + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..00efa85 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub llm: LLMConfig, + pub shell: ShellConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LLMConfig { + pub default_provider: String, + pub openai: OpenAIConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAIConfig { + pub model: String, + pub base_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShellConfig { + pub max_execution_time: u64, + pub workdir: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + llm: LLMConfig { + default_provider: "openai".to_string(), + openai: OpenAIConfig { + model: "gpt-4".to_string(), + base_url: None, + }, + }, + shell: ShellConfig { + max_execution_time: 300, + workdir: None, + }, + } + } +} + +impl Config { + pub fn load() -> Result { + // For now, just return default config + // TODO: Load from file in ~/.config/aishell/config.toml + Ok(Self::default()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6c42a7a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +pub mod cli; +pub mod config; +pub mod llm; +pub mod mcp; +pub mod shell; + +pub use config::Config; diff --git a/src/llm/mod.rs b/src/llm/mod.rs new file mode 100644 index 0000000..0b78207 --- /dev/null +++ b/src/llm/mod.rs @@ -0,0 +1,18 @@ +pub mod provider; +pub mod openai; + +pub use provider::{LLMProvider, Message, Role, ToolCall, ToolDefinition, ChatResponse}; +pub use openai::OpenAIProvider; + +use anyhow::Result; + +/// Create an LLM provider based on the provider name +pub async fn create_provider(provider: &str, model: Option<&str>) -> Result> { + match provider.to_lowercase().as_str() { + "openai" => { + let provider = OpenAIProvider::new(model)?; + Ok(Box::new(provider)) + } + _ => anyhow::bail!("Unsupported provider: {}", provider), + } +} diff --git a/src/llm/openai.rs b/src/llm/openai.rs new file mode 100644 index 0000000..352b98b --- /dev/null +++ b/src/llm/openai.rs @@ -0,0 +1,126 @@ +use anyhow::{Context, Result}; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::env; + +use super::provider::{ChatResponse, LLMProvider, Message, ToolCall, ToolDefinition}; + +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, + finish_reason: String, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +pub struct OpenAIProvider { + client: Client, + api_key: String, + base_url: String, + model: String, +} + +impl OpenAIProvider { + pub fn new(model: Option<&str>) -> Result { + let api_key = env::var("OPENAI_API_KEY") + .context("OPENAI_API_KEY environment variable not set")?; + + let base_url = env::var("OPENAI_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); + + let model = model + .map(|s| s.to_string()) + .or_else(|| env::var("OPENAI_MODEL").ok()) + .unwrap_or_else(|| "gpt-4".to_string()); + + Ok(Self { + client: Client::new(), + api_key, + base_url, + model, + }) + } +} + +#[async_trait] +impl LLMProvider for OpenAIProvider { + async fn chat( + &self, + messages: Vec, + tools: Option>, + ) -> Result { + let url = format!("{}/chat/completions", self.base_url); + + let tool_choice = if tools.is_some() { + Some("auto".to_string()) + } else { + None + }; + + let request = ChatRequest { + model: self.model.clone(), + messages, + tools, + tool_choice, + }; + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await + .context("Failed to send request to OpenAI API")?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + anyhow::bail!("OpenAI API error ({}): {}", status, error_text); + } + + let completion: ChatCompletionResponse = response + .json() + .await + .context("Failed to parse OpenAI API response")?; + + let choice = completion + .choices + .into_iter() + .next() + .context("No choices in response")?; + + Ok(ChatResponse { + content: choice.message.content.unwrap_or_default(), + tool_calls: choice.message.tool_calls, + finish_reason: choice.finish_reason, + }) + } + + fn model_name(&self) -> &str { + &self.model + } +} diff --git a/src/llm/provider.rs b/src/llm/provider.rs new file mode 100644 index 0000000..ac70479 --- /dev/null +++ b/src/llm/provider.rs @@ -0,0 +1,104 @@ +use anyhow::Result; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +impl Message { + pub fn system(content: impl Into) -> Self { + Self { + role: Role::System, + content: content.into(), + tool_calls: None, + tool_call_id: None, + } + } + + pub fn user(content: impl Into) -> Self { + Self { + role: Role::User, + content: content.into(), + tool_calls: None, + tool_call_id: None, + } + } + + pub fn assistant(content: impl Into) -> Self { + Self { + role: Role::Assistant, + content: content.into(), + tool_calls: None, + tool_call_id: None, + } + } + + pub fn tool(content: impl Into, tool_call_id: String) -> Self { + Self { + role: Role::Tool, + content: content.into(), + tool_calls: None, + tool_call_id: Some(tool_call_id), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + #[serde(rename = "type")] + pub tool_type: String, + pub function: FunctionDefinition, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug)] +pub struct ChatResponse { + pub content: String, + pub tool_calls: Option>, + pub finish_reason: String, +} + +#[async_trait] +pub trait LLMProvider: Send + Sync { + /// Send a chat completion request + async fn chat(&self, messages: Vec, tools: Option>) -> Result; + + /// Get the model name + fn model_name(&self) -> &str; +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..07e352a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,74 @@ +use anyhow::Result; +use clap::{Parser, Subcommand}; +use tracing_subscriber; + +use aishell::cli::Repl; +use aishell::mcp::MCPServer; + +#[derive(Parser)] +#[command(name = "aishell")] +#[command(about = "AI-powered shell automation - A generic alternative to Claude Code")] +#[command(version)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Start interactive AI shell + Shell { + /// LLM provider (openai, anthropic, ollama) + #[arg(short, long, default_value = "openai")] + provider: String, + + /// Model name + #[arg(short, long)] + model: Option, + }, + + /// Execute a single command via AI + Exec { + /// Command prompt + prompt: String, + + /// LLM provider + #[arg(short = 'p', long, default_value = "openai")] + provider: String, + }, + + /// Start MCP server (for Claude Desktop integration) + Server, +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::INFO.into()), + ) + .init(); + + let cli = Cli::parse(); + + match cli.command { + Commands::Shell { provider, model } => { + let mut repl = Repl::new(&provider, model.as_deref()).await?; + repl.run().await?; + } + + Commands::Exec { prompt, provider } => { + let mut repl = Repl::new(&provider, None).await?; + repl.execute_once(&prompt).await?; + } + + Commands::Server => { + let server = MCPServer::new()?; + server.run().await?; + } + } + + Ok(()) +} diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs new file mode 100644 index 0000000..772789f --- /dev/null +++ b/src/mcp/mod.rs @@ -0,0 +1,121 @@ +use anyhow::Result; +use serde_json::json; +use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader}; + +use crate::shell::{execute_tool, get_tool_definitions, ShellExecutor}; + +pub struct MCPServer { + executor: ShellExecutor, +} + +impl MCPServer { + pub fn new() -> Result { + Ok(Self { + executor: ShellExecutor::default(), + }) + } + + pub async fn run(&self) -> Result<()> { + tracing::info!("Starting MCP server"); + + let stdin = io::stdin(); + let mut stdout = io::stdout(); + let mut reader = BufReader::new(stdin); + let mut line = String::new(); + + loop { + line.clear(); + let n = reader.read_line(&mut line).await?; + + if n == 0 { + break; // EOF + } + + let request: serde_json::Value = match serde_json::from_str(&line) { + Ok(v) => v, + Err(e) => { + tracing::error!("Failed to parse request: {}", e); + continue; + } + }; + + let response = self.handle_request(&request).await; + let response_str = serde_json::to_string(&response)?; + + stdout.write_all(response_str.as_bytes()).await?; + stdout.write_all(b"\n").await?; + stdout.flush().await?; + } + + Ok(()) + } + + async fn handle_request(&self, request: &serde_json::Value) -> serde_json::Value { + let method = request["method"].as_str().unwrap_or(""); + + match method { + "initialize" => { + json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": { + "name": "aishell", + "version": "0.1.0" + } + }) + } + + "tools/list" => { + let tools = get_tool_definitions(); + let tool_list: Vec<_> = tools + .iter() + .map(|t| { + json!({ + "name": t.function.name, + "description": t.function.description, + "inputSchema": t.function.parameters + }) + }) + .collect(); + + json!({ + "tools": tool_list + }) + } + + "tools/call" => { + let tool_name = request["params"]["name"].as_str().unwrap_or(""); + let arguments = request["params"]["arguments"].to_string(); + + let result = match execute_tool(tool_name, &arguments, &self.executor) { + Ok(output) => json!({ + "content": [{ + "type": "text", + "text": output + }] + }), + Err(e) => json!({ + "content": [{ + "type": "text", + "text": format!("Error: {}", e) + }], + "isError": true + }), + }; + + result + } + + _ => { + json!({ + "error": { + "code": -32601, + "message": format!("Method not found: {}", method) + } + }) + } + } + } +} diff --git a/src/shell/executor.rs b/src/shell/executor.rs new file mode 100644 index 0000000..c8dcb9e --- /dev/null +++ b/src/shell/executor.rs @@ -0,0 +1,112 @@ +use anyhow::{Context, Result}; +use duct::cmd; +use std::path::PathBuf; +use std::time::Duration; + +#[derive(Debug)] +pub struct ExecutionResult { + pub stdout: String, + pub stderr: String, + pub exit_code: i32, + pub success: bool, +} + +pub struct ShellExecutor { + workdir: PathBuf, + timeout: Duration, +} + +impl ShellExecutor { + pub fn new(workdir: Option) -> Result { + let workdir = workdir.unwrap_or_else(|| { + std::env::current_dir().expect("Failed to get current directory") + }); + + Ok(Self { + workdir, + timeout: Duration::from_secs(300), // 5 minutes default + }) + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn execute(&self, command: &str) -> Result { + tracing::info!("Executing command: {}", command); + + let output = cmd!("sh", "-c", command) + .dir(&self.workdir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .context("Failed to execute command")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let exit_code = output.status.code().unwrap_or(-1); + let success = output.status.success(); + + tracing::debug!( + "Command result: exit_code={}, stdout_len={}, stderr_len={}", + exit_code, + stdout.len(), + stderr.len() + ); + + Ok(ExecutionResult { + stdout, + stderr, + exit_code, + success, + }) + } + + pub fn read_file(&self, path: &str) -> Result { + let full_path = self.workdir.join(path); + std::fs::read_to_string(&full_path) + .with_context(|| format!("Failed to read file: {}", path)) + } + + pub fn write_file(&self, path: &str, content: &str) -> Result<()> { + let full_path = self.workdir.join(path); + + // Create parent directories if needed + if let Some(parent) = full_path.parent() { + std::fs::create_dir_all(parent)?; + } + + std::fs::write(&full_path, content) + .with_context(|| format!("Failed to write file: {}", path)) + } + + pub fn list_files(&self, pattern: Option<&str>) -> Result> { + let pattern = pattern.unwrap_or("*"); + + let output = cmd!("sh", "-c", format!("ls -1 {}", pattern)) + .dir(&self.workdir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run()?; + + if !output.status.success() { + return Ok(vec![]); + } + + let files = String::from_utf8_lossy(&output.stdout) + .lines() + .map(|s| s.to_string()) + .collect(); + + Ok(files) + } +} + +impl Default for ShellExecutor { + fn default() -> Self { + Self::new(None).expect("Failed to create default ShellExecutor") + } +} diff --git a/src/shell/mod.rs b/src/shell/mod.rs new file mode 100644 index 0000000..12c9d60 --- /dev/null +++ b/src/shell/mod.rs @@ -0,0 +1,5 @@ +pub mod executor; +pub mod tools; + +pub use executor::{ShellExecutor, ExecutionResult}; +pub use tools::{get_tool_definitions, execute_tool, ToolArguments}; diff --git a/src/shell/tools.rs b/src/shell/tools.rs new file mode 100644 index 0000000..620fdf5 --- /dev/null +++ b/src/shell/tools.rs @@ -0,0 +1,162 @@ +use anyhow::{Context, Result}; +use serde::Deserialize; +use serde_json::json; + +use crate::llm::ToolDefinition; +use super::executor::ShellExecutor; + +#[derive(Debug, Deserialize)] +#[serde(tag = "tool", rename_all = "snake_case")] +pub enum ToolArguments { + Bash { command: String }, + Read { path: String }, + Write { path: String, content: String }, + List { pattern: Option }, +} + +/// Get all available tool definitions for the LLM +pub fn get_tool_definitions() -> Vec { + vec![ + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "bash".to_string(), + description: "Execute a bash command and return the output. Use this for running shell commands, git operations, package management, etc.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute" + } + }, + "required": ["command"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "read".to_string(), + description: "Read the contents of a file. Returns the file content as a string.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The path to the file to read" + } + }, + "required": ["path"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "write".to_string(), + description: "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The path to the file to write" + }, + "content": { + "type": "string", + "description": "The content to write to the file" + } + }, + "required": ["path", "content"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "list".to_string(), + description: "List files in the current directory. Optionally filter by pattern.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Optional glob pattern to filter files (e.g., '*.rs')" + } + }, + "required": [] + }), + }, + }, + ] +} + +/// Execute a tool call +pub fn execute_tool( + tool_name: &str, + arguments: &str, + executor: &ShellExecutor, +) -> Result { + tracing::info!("Executing tool: {} with args: {}", tool_name, arguments); + + match tool_name { + "bash" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let command = args["command"] + .as_str() + .context("Missing 'command' argument")?; + + let result = executor.execute(command)?; + + let output = if result.success { + format!("Exit code: {}\n\nStdout:\n{}\n\nStderr:\n{}", + result.exit_code, + result.stdout, + result.stderr + ) + } else { + format!("Command failed with exit code: {}\n\nStdout:\n{}\n\nStderr:\n{}", + result.exit_code, + result.stdout, + result.stderr + ) + }; + + Ok(output) + } + + "read" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let path = args["path"] + .as_str() + .context("Missing 'path' argument")?; + + let content = executor.read_file(path)?; + Ok(content) + } + + "write" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let path = args["path"] + .as_str() + .context("Missing 'path' argument")?; + let content = args["content"] + .as_str() + .context("Missing 'content' argument")?; + + executor.write_file(path, content)?; + Ok(format!("Successfully wrote to file: {}", path)) + } + + "list" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let pattern = args["pattern"].as_str(); + + let files = executor.list_files(pattern)?; + Ok(files.join("\n")) + } + + _ => anyhow::bail!("Unknown tool: {}", tool_name), + } +}