diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..37ebe57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,21 @@ +# Rust +/target/ +Cargo.lock + +# IDE +.idea/ +.vscode/ +*.swp +*.swo +*~ + +# OS +.DS_Store +Thumbs.db + +# Logs +*.log + +# Environment +.env +.env.local diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..3ba94d2 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,47 @@ +[package] +name = "aishell" +version = "0.1.0" +edition = "2021" +authors = ["syui"] +description = "AI-powered shell automation tool - A generic alternative to Claude Code" + +[lib] +name = "aishell" +path = "src/lib.rs" + +[[bin]] +name = "aishell" +path = "src/main.rs" + +[dependencies] +# CLI and async (following aigpt pattern) +clap = { version = "4.5", features = ["derive"] } +tokio = { version = "1.40", features = ["rt", "rt-multi-thread", "macros", "io-std", "process", "fs"] } +async-trait = "0.1" + +# HTTP client for LLM APIs +reqwest = { version = "0.12", features = ["json", "stream"] } + +# Serialization +serde = { version = "1.0", features = ["derive"] } +serde_json = "1.0" + +# Error handling +thiserror = "1.0" +anyhow = "1.0" + +# Utilities +dirs = "5.0" + +# Shell execution +duct = "0.13" + +# Configuration +toml = "0.8" + +# Logging +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +# Interactive REPL +rustyline = "14.0" diff --git a/README.md b/README.md index 4f6d81f..43ae160 100644 --- a/README.md +++ b/README.md @@ -1 +1,177 @@ # aishell + +**AI-powered shell automation tool** - A generic alternative to Claude Code + +## Overview + +aishellは、AIがシェルを操作するための汎用的なツールです。Claude Codeのような機能を、より柔軟で拡張可能な形で提供します。 + +**主な特徴:** +- **マルチLLMプロバイダー対応**: OpenAI、Claude、ローカルLLM(gpt-oss等) +- **Function Calling**: LLMがツールを直接呼び出してシェルを操作 +- **MCPサーバー**: Claude Desktopとの連携も可能 +- **AIOS統合**: aigptと組み合わせてAIによるOS管理を実現 + +## Installation + +```bash +# Rust環境が必要 +cargo build --release + +# バイナリをインストール +cargo install --path . +``` + +## Usage + +### 1. 対話型シェル (Interactive Shell) + +```bash +# OpenAI互換APIを使用 +export OPENAI_API_KEY="your-api-key" +aishell shell + +# 別のモデルを指定 +aishell shell -m gpt-4o + +# gpt-ossなどのOpenAI互換サーバーを使用 +export OPENAI_BASE_URL="http://localhost:8080/v1" +aishell shell +``` + +**使用例:** +``` +aishell> List all Rust files in src/ +[Executing tool: list] +src/main.rs +src/lib.rs +... + +aishell> Create a new file hello.txt with "Hello, World!" +[Executing tool: write] +Successfully wrote to file: hello.txt + +aishell> Show me the git status +[Executing tool: bash] +On branch main +... +``` + +### 2. ワンショット実行 (Single Command) + +```bash +aishell exec "Show me the current directory structure" +``` + +### 3. MCPサーバーモード (Claude Desktop Integration) + +```bash +aishell server +``` + +**Claude Desktop設定** (`~/Library/Application Support/Claude/claude_desktop_config.json`): +```json +{ + "mcpServers": { + "aishell": { + "command": "/path/to/aishell", + "args": ["server"] + } + } +} +``` + +## Architecture + +``` +aishell/ +├── src/ +│ ├── cli/ # 対話型インターフェイス (REPL) +│ ├── llm/ # LLMプロバイダー (OpenAI互換) +│ ├── shell/ # シェル実行エンジン +│ ├── mcp/ # MCPサーバー実装 +│ └── config/ # 設定管理 +``` + +**実行フロー:** +``` +User Input → LLM (Function Calling) → Tool Execution → Shell → Result → LLM → User +``` + +## Available Tools + +aishellは以下のツールをLLMに提供します: + +- **bash**: シェルコマンドを実行 +- **read**: ファイルを読み込み +- **write**: ファイルに書き込み +- **list**: ファイル一覧を取得 + +## Environment Variables + +| 変数 | 説明 | デフォルト | +|------|------|----------| +| `OPENAI_API_KEY` | OpenAI APIキー | (必須) | +| `OPENAI_BASE_URL` | APIベースURL | `https://api.openai.com/v1` | +| `OPENAI_MODEL` | 使用するモデル | `gpt-4` | + +## Integration with AIOS + +aishellは[aigpt](https://github.com/syui/aigpt)と組み合わせることで、AIOS(AI Operating System)の一部として機能します: + +- **aigpt**: AIメモリー、パーソナリティ分析 +- **aishell**: シェル操作、自動化 +- **AIOS**: これらを統合したAIによるOS管理システム + +## Comparison with Claude Code + +| 機能 | Claude Code | aishell | +|------|------------|---------| +| LLM | Claude専用 | **マルチプロバイダー** | +| 実行環境 | Electron Desktop | **CLI/MCP** | +| カスタマイズ | 限定的 | **完全制御** | +| ローカルLLM | 非対応 | **対応可能** | +| AIOS統合 | 不可 | **ネイティブ対応** | + +## Development + +```bash +# 開発ビルド +cargo build + +# テスト実行 +cargo test + +# ログ有効化 +RUST_LOG=debug aishell shell +``` + +## Technical Stack + +- **Language**: Rust 2021 +- **CLI**: clap 4.5 +- **Async Runtime**: tokio 1.40 +- **HTTP Client**: reqwest 0.12 +- **Shell Execution**: duct 0.13 +- **REPL**: rustyline 14.0 + +## Roadmap + +- [ ] Anthropic Claude API対応 +- [ ] Ollama対応(ローカルLLM) +- [ ] より高度なツールセット(git統合、ファイル検索等) +- [ ] 設定ファイルサポート +- [ ] セッション履歴の永続化 +- [ ] プラグインシステム + +## License + +MIT License + +## Author + +syui + +## Related Projects + +- [aigpt](https://github.com/syui/aigpt) - AI Memory System diff --git a/claude.md b/claude.md index 7a7f28e..ac92937 100644 --- a/claude.md +++ b/claude.md @@ -1,8 +1,45 @@ # aishell -name: aishell -sid: ai.shell -id: ai.syui.shell +**ID**: ai.syui.shell +**Name**: aishell +**SID**: ai.shell +**Version**: 0.1.0 -claude codeのようなAIがshellを操作するためのツール。 -例えば、gpt-ossのようなllmを使用することを想定。場合によっては、mcpを駆使する。 +## 概要 + +Claude Codeのような、AIがshellを操作するためのツール。 +例えば、gpt-ossのようなllmを使用することを想定。場合によっては、MCPを駆使する。 + +## 主な機能 + +1. **マルチLLMプロバイダー対応** + - OpenAI API互換(OpenAI, gpt-oss, etc.) + - 将来的にClaude API、Ollamaなども対応予定 + +2. **Function Calling (Tool use)** + - LLMが直接ツールを呼び出してシェルを操作 + - bash, read, write, list等のツールを提供 + +3. **MCPサーバーモード** + - Claude Desktopとの連携が可能 + - aigptと同様のMCPプロトコル実装 + +## アーキテクチャ + +``` +User → CLI → LLM Provider → Function Calling → Shell Executor → Result +``` + +## AIOS統合 + +- **aigpt**: メモリー、パーソナリティ分析 +- **aishell**: シェル操作、自動化 +- **統合**: AIによるOS管理の実現 + +## 技術スタック + +- Rust 2021 +- tokio (async runtime) +- reqwest (HTTP client) +- duct (shell execution) +- clap (CLI framework) diff --git a/src/cli/mod.rs b/src/cli/mod.rs new file mode 100644 index 0000000..402cb58 --- /dev/null +++ b/src/cli/mod.rs @@ -0,0 +1,3 @@ +pub mod repl; + +pub use repl::Repl; diff --git a/src/cli/repl.rs b/src/cli/repl.rs new file mode 100644 index 0000000..2ef3fe1 --- /dev/null +++ b/src/cli/repl.rs @@ -0,0 +1,148 @@ +use anyhow::{Context, Result}; +use rustyline::error::ReadlineError; +use rustyline::DefaultEditor; + +use crate::llm::{create_provider, LLMProvider, Message}; +use crate::shell::{execute_tool, get_tool_definitions, ShellExecutor}; + +pub struct Repl { + llm: Box, + executor: ShellExecutor, + messages: Vec, +} + +impl Repl { + pub async fn new(provider: &str, model: Option<&str>) -> Result { + let llm = create_provider(provider, model).await?; + let executor = ShellExecutor::default(); + + let system_prompt = Message::system( + "You are an AI assistant that helps users interact with their system through shell commands. \ + You have access to tools like bash, read, write, and list to help users accomplish their tasks. \ + When a user asks you to do something, use the appropriate tools to complete the task. \ + Always explain what you're doing and show the results to the user." + ); + + Ok(Self { + llm, + executor, + messages: vec![system_prompt], + }) + } + + pub async fn run(&mut self) -> Result<()> { + println!("aishell - AI-powered shell automation"); + println!("Type 'exit' or 'quit' to exit, 'clear' to clear history\n"); + + let mut rl = DefaultEditor::new()?; + + loop { + let readline = rl.readline("aishell> "); + + match readline { + Ok(line) => { + let line = line.trim(); + + if line.is_empty() { + continue; + } + + if line == "exit" || line == "quit" { + println!("Goodbye!"); + break; + } + + if line == "clear" { + self.messages.truncate(1); // Keep only system message + println!("History cleared."); + continue; + } + + rl.add_history_entry(line)?; + + if let Err(e) = self.process_input(line).await { + eprintln!("Error: {}", e); + } + } + Err(ReadlineError::Interrupted) => { + println!("^C"); + continue; + } + Err(ReadlineError::Eof) => { + println!("^D"); + break; + } + Err(err) => { + eprintln!("Error: {:?}", err); + break; + } + } + } + + Ok(()) + } + + pub async fn execute_once(&mut self, prompt: &str) -> Result<()> { + self.process_input(prompt).await + } + + async fn process_input(&mut self, input: &str) -> Result<()> { + // Add user message + self.messages.push(Message::user(input)); + + let tools = get_tool_definitions(); + + // Agent loop: keep calling LLM until it's done (no more tool calls) + let max_iterations = 10; + for iteration in 0..max_iterations { + tracing::debug!("Agent loop iteration {}", iteration + 1); + + let response = self + .llm + .chat(self.messages.clone(), Some(tools.clone())) + .await + .context("Failed to get LLM response")?; + + // If there are tool calls, execute them + if let Some(tool_calls) = response.tool_calls { + tracing::info!("LLM requested {} tool calls", tool_calls.len()); + + // Add assistant message with tool calls + let mut assistant_msg = Message::assistant(response.content.clone()); + assistant_msg.tool_calls = Some(tool_calls.clone()); + self.messages.push(assistant_msg); + + // Execute each tool call + for tool_call in tool_calls { + let tool_name = &tool_call.function.name; + let tool_args = &tool_call.function.arguments; + + println!("\n[Executing tool: {}]", tool_name); + + let result = match execute_tool(tool_name, tool_args, &self.executor) { + Ok(output) => output, + Err(e) => format!("Error executing tool: {}", e), + }; + + println!("{}", result); + + // Add tool result message + self.messages.push(Message::tool(result, tool_call.id.clone())); + } + + // Continue the loop to get the next response + continue; + } + + // No tool calls, so the LLM is done + if !response.content.is_empty() { + println!("\n{}\n", response.content); + self.messages.push(Message::assistant(response.content)); + } + + break; + } + + Ok(()) + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs new file mode 100644 index 0000000..00efa85 --- /dev/null +++ b/src/config/mod.rs @@ -0,0 +1,53 @@ +use anyhow::Result; +use serde::{Deserialize, Serialize}; +use std::path::PathBuf; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Config { + pub llm: LLMConfig, + pub shell: ShellConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct LLMConfig { + pub default_provider: String, + pub openai: OpenAIConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OpenAIConfig { + pub model: String, + pub base_url: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ShellConfig { + pub max_execution_time: u64, + pub workdir: Option, +} + +impl Default for Config { + fn default() -> Self { + Self { + llm: LLMConfig { + default_provider: "openai".to_string(), + openai: OpenAIConfig { + model: "gpt-4".to_string(), + base_url: None, + }, + }, + shell: ShellConfig { + max_execution_time: 300, + workdir: None, + }, + } + } +} + +impl Config { + pub fn load() -> Result { + // For now, just return default config + // TODO: Load from file in ~/.config/aishell/config.toml + Ok(Self::default()) + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..6c42a7a --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,7 @@ +pub mod cli; +pub mod config; +pub mod llm; +pub mod mcp; +pub mod shell; + +pub use config::Config; diff --git a/src/llm/mod.rs b/src/llm/mod.rs new file mode 100644 index 0000000..0b78207 --- /dev/null +++ b/src/llm/mod.rs @@ -0,0 +1,18 @@ +pub mod provider; +pub mod openai; + +pub use provider::{LLMProvider, Message, Role, ToolCall, ToolDefinition, ChatResponse}; +pub use openai::OpenAIProvider; + +use anyhow::Result; + +/// Create an LLM provider based on the provider name +pub async fn create_provider(provider: &str, model: Option<&str>) -> Result> { + match provider.to_lowercase().as_str() { + "openai" => { + let provider = OpenAIProvider::new(model)?; + Ok(Box::new(provider)) + } + _ => anyhow::bail!("Unsupported provider: {}", provider), + } +} diff --git a/src/llm/openai.rs b/src/llm/openai.rs new file mode 100644 index 0000000..352b98b --- /dev/null +++ b/src/llm/openai.rs @@ -0,0 +1,126 @@ +use anyhow::{Context, Result}; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::env; + +use super::provider::{ChatResponse, LLMProvider, Message, ToolCall, ToolDefinition}; + +#[derive(Debug, Serialize)] +struct ChatRequest { + model: String, + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Deserialize)] +struct ChatCompletionResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, + finish_reason: String, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +pub struct OpenAIProvider { + client: Client, + api_key: String, + base_url: String, + model: String, +} + +impl OpenAIProvider { + pub fn new(model: Option<&str>) -> Result { + let api_key = env::var("OPENAI_API_KEY") + .context("OPENAI_API_KEY environment variable not set")?; + + let base_url = env::var("OPENAI_BASE_URL") + .unwrap_or_else(|_| "https://api.openai.com/v1".to_string()); + + let model = model + .map(|s| s.to_string()) + .or_else(|| env::var("OPENAI_MODEL").ok()) + .unwrap_or_else(|| "gpt-4".to_string()); + + Ok(Self { + client: Client::new(), + api_key, + base_url, + model, + }) + } +} + +#[async_trait] +impl LLMProvider for OpenAIProvider { + async fn chat( + &self, + messages: Vec, + tools: Option>, + ) -> Result { + let url = format!("{}/chat/completions", self.base_url); + + let tool_choice = if tools.is_some() { + Some("auto".to_string()) + } else { + None + }; + + let request = ChatRequest { + model: self.model.clone(), + messages, + tools, + tool_choice, + }; + + let response = self + .client + .post(&url) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&request) + .send() + .await + .context("Failed to send request to OpenAI API")?; + + if !response.status().is_success() { + let status = response.status(); + let error_text = response.text().await.unwrap_or_default(); + anyhow::bail!("OpenAI API error ({}): {}", status, error_text); + } + + let completion: ChatCompletionResponse = response + .json() + .await + .context("Failed to parse OpenAI API response")?; + + let choice = completion + .choices + .into_iter() + .next() + .context("No choices in response")?; + + Ok(ChatResponse { + content: choice.message.content.unwrap_or_default(), + tool_calls: choice.message.tool_calls, + finish_reason: choice.finish_reason, + }) + } + + fn model_name(&self) -> &str { + &self.model + } +} diff --git a/src/llm/provider.rs b/src/llm/provider.rs new file mode 100644 index 0000000..ac70479 --- /dev/null +++ b/src/llm/provider.rs @@ -0,0 +1,104 @@ +use anyhow::Result; +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +impl Message { + pub fn system(content: impl Into) -> Self { + Self { + role: Role::System, + content: content.into(), + tool_calls: None, + tool_call_id: None, + } + } + + pub fn user(content: impl Into) -> Self { + Self { + role: Role::User, + content: content.into(), + tool_calls: None, + tool_call_id: None, + } + } + + pub fn assistant(content: impl Into) -> Self { + Self { + role: Role::Assistant, + content: content.into(), + tool_calls: None, + tool_call_id: None, + } + } + + pub fn tool(content: impl Into, tool_call_id: String) -> Self { + Self { + role: Role::Tool, + content: content.into(), + tool_calls: None, + tool_call_id: Some(tool_call_id), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: FunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + #[serde(rename = "type")] + pub tool_type: String, + pub function: FunctionDefinition, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct FunctionDefinition { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug)] +pub struct ChatResponse { + pub content: String, + pub tool_calls: Option>, + pub finish_reason: String, +} + +#[async_trait] +pub trait LLMProvider: Send + Sync { + /// Send a chat completion request + async fn chat(&self, messages: Vec, tools: Option>) -> Result; + + /// Get the model name + fn model_name(&self) -> &str; +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..07e352a --- /dev/null +++ b/src/main.rs @@ -0,0 +1,74 @@ +use anyhow::Result; +use clap::{Parser, Subcommand}; +use tracing_subscriber; + +use aishell::cli::Repl; +use aishell::mcp::MCPServer; + +#[derive(Parser)] +#[command(name = "aishell")] +#[command(about = "AI-powered shell automation - A generic alternative to Claude Code")] +#[command(version)] +struct Cli { + #[command(subcommand)] + command: Commands, +} + +#[derive(Subcommand)] +enum Commands { + /// Start interactive AI shell + Shell { + /// LLM provider (openai, anthropic, ollama) + #[arg(short, long, default_value = "openai")] + provider: String, + + /// Model name + #[arg(short, long)] + model: Option, + }, + + /// Execute a single command via AI + Exec { + /// Command prompt + prompt: String, + + /// LLM provider + #[arg(short = 'p', long, default_value = "openai")] + provider: String, + }, + + /// Start MCP server (for Claude Desktop integration) + Server, +} + +#[tokio::main] +async fn main() -> Result<()> { + // Initialize logging + tracing_subscriber::fmt() + .with_env_filter( + tracing_subscriber::EnvFilter::from_default_env() + .add_directive(tracing::Level::INFO.into()), + ) + .init(); + + let cli = Cli::parse(); + + match cli.command { + Commands::Shell { provider, model } => { + let mut repl = Repl::new(&provider, model.as_deref()).await?; + repl.run().await?; + } + + Commands::Exec { prompt, provider } => { + let mut repl = Repl::new(&provider, None).await?; + repl.execute_once(&prompt).await?; + } + + Commands::Server => { + let server = MCPServer::new()?; + server.run().await?; + } + } + + Ok(()) +} diff --git a/src/mcp/mod.rs b/src/mcp/mod.rs new file mode 100644 index 0000000..772789f --- /dev/null +++ b/src/mcp/mod.rs @@ -0,0 +1,121 @@ +use anyhow::Result; +use serde_json::json; +use tokio::io::{self, AsyncBufReadExt, AsyncWriteExt, BufReader}; + +use crate::shell::{execute_tool, get_tool_definitions, ShellExecutor}; + +pub struct MCPServer { + executor: ShellExecutor, +} + +impl MCPServer { + pub fn new() -> Result { + Ok(Self { + executor: ShellExecutor::default(), + }) + } + + pub async fn run(&self) -> Result<()> { + tracing::info!("Starting MCP server"); + + let stdin = io::stdin(); + let mut stdout = io::stdout(); + let mut reader = BufReader::new(stdin); + let mut line = String::new(); + + loop { + line.clear(); + let n = reader.read_line(&mut line).await?; + + if n == 0 { + break; // EOF + } + + let request: serde_json::Value = match serde_json::from_str(&line) { + Ok(v) => v, + Err(e) => { + tracing::error!("Failed to parse request: {}", e); + continue; + } + }; + + let response = self.handle_request(&request).await; + let response_str = serde_json::to_string(&response)?; + + stdout.write_all(response_str.as_bytes()).await?; + stdout.write_all(b"\n").await?; + stdout.flush().await?; + } + + Ok(()) + } + + async fn handle_request(&self, request: &serde_json::Value) -> serde_json::Value { + let method = request["method"].as_str().unwrap_or(""); + + match method { + "initialize" => { + json!({ + "protocolVersion": "2024-11-05", + "capabilities": { + "tools": {} + }, + "serverInfo": { + "name": "aishell", + "version": "0.1.0" + } + }) + } + + "tools/list" => { + let tools = get_tool_definitions(); + let tool_list: Vec<_> = tools + .iter() + .map(|t| { + json!({ + "name": t.function.name, + "description": t.function.description, + "inputSchema": t.function.parameters + }) + }) + .collect(); + + json!({ + "tools": tool_list + }) + } + + "tools/call" => { + let tool_name = request["params"]["name"].as_str().unwrap_or(""); + let arguments = request["params"]["arguments"].to_string(); + + let result = match execute_tool(tool_name, &arguments, &self.executor) { + Ok(output) => json!({ + "content": [{ + "type": "text", + "text": output + }] + }), + Err(e) => json!({ + "content": [{ + "type": "text", + "text": format!("Error: {}", e) + }], + "isError": true + }), + }; + + result + } + + _ => { + json!({ + "error": { + "code": -32601, + "message": format!("Method not found: {}", method) + } + }) + } + } + } +} diff --git a/src/shell/executor.rs b/src/shell/executor.rs new file mode 100644 index 0000000..c8dcb9e --- /dev/null +++ b/src/shell/executor.rs @@ -0,0 +1,112 @@ +use anyhow::{Context, Result}; +use duct::cmd; +use std::path::PathBuf; +use std::time::Duration; + +#[derive(Debug)] +pub struct ExecutionResult { + pub stdout: String, + pub stderr: String, + pub exit_code: i32, + pub success: bool, +} + +pub struct ShellExecutor { + workdir: PathBuf, + timeout: Duration, +} + +impl ShellExecutor { + pub fn new(workdir: Option) -> Result { + let workdir = workdir.unwrap_or_else(|| { + std::env::current_dir().expect("Failed to get current directory") + }); + + Ok(Self { + workdir, + timeout: Duration::from_secs(300), // 5 minutes default + }) + } + + pub fn with_timeout(mut self, timeout: Duration) -> Self { + self.timeout = timeout; + self + } + + pub fn execute(&self, command: &str) -> Result { + tracing::info!("Executing command: {}", command); + + let output = cmd!("sh", "-c", command) + .dir(&self.workdir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run() + .context("Failed to execute command")?; + + let stdout = String::from_utf8_lossy(&output.stdout).to_string(); + let stderr = String::from_utf8_lossy(&output.stderr).to_string(); + let exit_code = output.status.code().unwrap_or(-1); + let success = output.status.success(); + + tracing::debug!( + "Command result: exit_code={}, stdout_len={}, stderr_len={}", + exit_code, + stdout.len(), + stderr.len() + ); + + Ok(ExecutionResult { + stdout, + stderr, + exit_code, + success, + }) + } + + pub fn read_file(&self, path: &str) -> Result { + let full_path = self.workdir.join(path); + std::fs::read_to_string(&full_path) + .with_context(|| format!("Failed to read file: {}", path)) + } + + pub fn write_file(&self, path: &str, content: &str) -> Result<()> { + let full_path = self.workdir.join(path); + + // Create parent directories if needed + if let Some(parent) = full_path.parent() { + std::fs::create_dir_all(parent)?; + } + + std::fs::write(&full_path, content) + .with_context(|| format!("Failed to write file: {}", path)) + } + + pub fn list_files(&self, pattern: Option<&str>) -> Result> { + let pattern = pattern.unwrap_or("*"); + + let output = cmd!("sh", "-c", format!("ls -1 {}", pattern)) + .dir(&self.workdir) + .stdout_capture() + .stderr_capture() + .unchecked() + .run()?; + + if !output.status.success() { + return Ok(vec![]); + } + + let files = String::from_utf8_lossy(&output.stdout) + .lines() + .map(|s| s.to_string()) + .collect(); + + Ok(files) + } +} + +impl Default for ShellExecutor { + fn default() -> Self { + Self::new(None).expect("Failed to create default ShellExecutor") + } +} diff --git a/src/shell/mod.rs b/src/shell/mod.rs new file mode 100644 index 0000000..12c9d60 --- /dev/null +++ b/src/shell/mod.rs @@ -0,0 +1,5 @@ +pub mod executor; +pub mod tools; + +pub use executor::{ShellExecutor, ExecutionResult}; +pub use tools::{get_tool_definitions, execute_tool, ToolArguments}; diff --git a/src/shell/tools.rs b/src/shell/tools.rs new file mode 100644 index 0000000..620fdf5 --- /dev/null +++ b/src/shell/tools.rs @@ -0,0 +1,162 @@ +use anyhow::{Context, Result}; +use serde::Deserialize; +use serde_json::json; + +use crate::llm::ToolDefinition; +use super::executor::ShellExecutor; + +#[derive(Debug, Deserialize)] +#[serde(tag = "tool", rename_all = "snake_case")] +pub enum ToolArguments { + Bash { command: String }, + Read { path: String }, + Write { path: String, content: String }, + List { pattern: Option }, +} + +/// Get all available tool definitions for the LLM +pub fn get_tool_definitions() -> Vec { + vec![ + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "bash".to_string(), + description: "Execute a bash command and return the output. Use this for running shell commands, git operations, package management, etc.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "command": { + "type": "string", + "description": "The bash command to execute" + } + }, + "required": ["command"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "read".to_string(), + description: "Read the contents of a file. Returns the file content as a string.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The path to the file to read" + } + }, + "required": ["path"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "write".to_string(), + description: "Write content to a file. Creates the file if it doesn't exist, overwrites if it does.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "The path to the file to write" + }, + "content": { + "type": "string", + "description": "The content to write to the file" + } + }, + "required": ["path", "content"] + }), + }, + }, + ToolDefinition { + tool_type: "function".to_string(), + function: crate::llm::provider::FunctionDefinition { + name: "list".to_string(), + description: "List files in the current directory. Optionally filter by pattern.".to_string(), + parameters: json!({ + "type": "object", + "properties": { + "pattern": { + "type": "string", + "description": "Optional glob pattern to filter files (e.g., '*.rs')" + } + }, + "required": [] + }), + }, + }, + ] +} + +/// Execute a tool call +pub fn execute_tool( + tool_name: &str, + arguments: &str, + executor: &ShellExecutor, +) -> Result { + tracing::info!("Executing tool: {} with args: {}", tool_name, arguments); + + match tool_name { + "bash" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let command = args["command"] + .as_str() + .context("Missing 'command' argument")?; + + let result = executor.execute(command)?; + + let output = if result.success { + format!("Exit code: {}\n\nStdout:\n{}\n\nStderr:\n{}", + result.exit_code, + result.stdout, + result.stderr + ) + } else { + format!("Command failed with exit code: {}\n\nStdout:\n{}\n\nStderr:\n{}", + result.exit_code, + result.stdout, + result.stderr + ) + }; + + Ok(output) + } + + "read" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let path = args["path"] + .as_str() + .context("Missing 'path' argument")?; + + let content = executor.read_file(path)?; + Ok(content) + } + + "write" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let path = args["path"] + .as_str() + .context("Missing 'path' argument")?; + let content = args["content"] + .as_str() + .context("Missing 'content' argument")?; + + executor.write_file(path, content)?; + Ok(format!("Successfully wrote to file: {}", path)) + } + + "list" => { + let args: serde_json::Value = serde_json::from_str(arguments)?; + let pattern = args["pattern"].as_str(); + + let files = executor.list_files(pattern)?; + Ok(files.join("\n")) + } + + _ => anyhow::bail!("Unknown tool: {}", tool_name), + } +}