diff --git a/.gitignore b/.gitignore index d088031..f87909d 100644 --- a/.gitignore +++ b/.gitignore @@ -2,3 +2,4 @@ **.lock output.json config/*.db +aigpt diff --git a/README.md b/README.md index 162b0b4..a1d4faa 100644 --- a/README.md +++ b/README.md @@ -27,7 +27,13 @@ $ ollama run syui/ai ```sh $ cargo build -$ ./target/debug/aigpt mcp setup -$ ./target/debug/aigpt mcp chat "hello world!" --host http://localhost:11434 --model syui/ai +$ ./aigpt mcp setup +$ ./aigpt mcp chat "hello world!" +$ ./aigpt mcp chat "hello world!" --host http://localhost:11434 --model syui/ai + +--- +# openai api +$ ./aigpt mcp set-api -api sk-abc123 +$ ./aigpt mcp chat "こんにちは" -p openai -m gpt-4o-mini ``` diff --git a/mcp/scripts/ask.py b/mcp/scripts/ask.py index c2eab45..4181ec2 100644 --- a/mcp/scripts/ask.py +++ b/mcp/scripts/ask.py @@ -1,27 +1,52 @@ -import httpx import os import json +import httpx +import openai + from context_loader import load_context_from_repo from prompt_template import PROMPT_TEMPLATE +PROVIDER = os.getenv("PROVIDER", "ollama") # "ollama" or "openai" + +# Ollama用 OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434") OLLAMA_URL = f"{OLLAMA_HOST}/api/generate" -OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "syui/ai") +OLLAMA_MODEL = os.getenv("MODEL", "syui/ai") + +# OpenAI用 +OPENAI_BASE = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1") +OPENAI_KEY = os.getenv("OPENAI_API_KEY", "") +OPENAI_MODEL = os.getenv("MODEL", "gpt-4o-mini") def ask_question(question, repo_path="."): context = load_context_from_repo(repo_path) prompt = PROMPT_TEMPLATE.format(context=context[:10000], question=question) - payload = { - "model": OLLAMA_MODEL, - "prompt": prompt, - "stream": False - } + if PROVIDER == "ollama": + payload = { + "model": OLLAMA_MODEL, + "prompt": prompt, + "stream": False + } + response = httpx.post(OLLAMA_URL, json=payload, timeout=60.0) + result = response.json() + return result.get("response", "返答がありませんでした。") + + elif PROVIDER == "openai": + import openai + openai.api_key = OPENAI_KEY + openai.api_base = OPENAI_BASE + + client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY")) + response = client.chat.completions.create( + model=OPENAI_MODEL, + messages=[{"role": "user", "content": prompt}] + ) + return response.choices[0].message.content + + else: + return f"❌ 未知のプロバイダです: {PROVIDER}" - #response = httpx.post(OLLAMA_URL, json=payload) - response = httpx.post(OLLAMA_URL, json=payload, timeout=60.0) - result = response.json() - return result.get("response", "返答がありませんでした。") if __name__ == "__main__": import sys diff --git a/src/chat.rs b/src/chat.rs index 5eff0c9..12b4e9c 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -2,9 +2,47 @@ use seahorse::Context; use std::process::Command; -//use std::env; use crate::config::ConfigPaths; +#[derive(Debug, Clone, PartialEq)] +pub enum Provider { + OpenAI, + Ollama, +} + +impl Provider { + pub fn from_str(s: &str) -> Option { + match s.to_lowercase().as_str() { + "openai" => Some(Provider::OpenAI), + "ollama" => Some(Provider::Ollama), + _ => None, + } + } + + pub fn as_str(&self) -> &'static str { + match self { + Provider::OpenAI => "openai", + Provider::Ollama => "ollama", + } + } +} + +use std::fs; +use serde::Deserialize; + +#[derive(Deserialize)] +struct OpenAIKey { + token: String, +} + +fn load_openai_api_key() -> Option { + let config = ConfigPaths::new(); + let path = config.base_dir.join("openai.json"); + let data = fs::read_to_string(path).ok()?; + let parsed: OpenAIKey = serde_json::from_str(&data).ok()?; + Some(parsed.token) +} + pub fn ask_chat(c: &Context, question: &str) { let config = ConfigPaths::new(); let base_dir = config.base_dir.join("mcp"); @@ -18,17 +56,34 @@ pub fn ask_chat(c: &Context, question: &str) { let ollama_host = c.string_flag("host").ok(); let ollama_model = c.string_flag("model").ok(); + let api_key = c.string_flag("api-key").ok() + .or_else(|| load_openai_api_key()); + use crate::chat::Provider; + + let provider_str = c.string_flag("provider").unwrap_or_else(|_| "ollama".to_string()); + let provider = Provider::from_str(&provider_str).unwrap_or(Provider::Ollama); + + println!("🔍 使用プロバイダー: {}", provider.as_str()); + + // 🛠️ command の定義をここで行う let mut command = Command::new(python_path); command.arg(script_path).arg(question); + // ✨ 環境変数をセット + command.env("PROVIDER", provider.as_str()); + if let Some(host) = ollama_host { command.env("OLLAMA_HOST", host); } if let Some(model) = ollama_model { command.env("OLLAMA_MODEL", model); } + if let Some(api_key) = api_key { + command.env("OPENAI_API_KEY", api_key); + } + // 🔁 実行 let output = command .output() .expect("❌ MCPチャットスクリプトの実行に失敗しました"); diff --git a/src/commands/mcp.rs b/src/commands/mcp.rs index ee2d93e..c18b987 100644 --- a/src/commands/mcp.rs +++ b/src/commands/mcp.rs @@ -1,13 +1,14 @@ // src/commands/mcp.rs -use seahorse::{Command, Context, Flag, FlagType}; -use crate::chat::ask_chat; -use crate::git::{git_init, git_status}; - use std::fs; use std::path::{PathBuf}; -use crate::config::ConfigPaths; use std::process::Command as OtherCommand; +use serde_json::json; +use seahorse::{Command, Context, Flag, FlagType}; + +use crate::chat::ask_chat; +use crate::git::{git_init, git_status}; +use crate::config::ConfigPaths; pub fn mcp_setup() { let config = ConfigPaths::new(); @@ -106,12 +107,52 @@ pub fn mcp_setup() { } } +fn set_api_key_cmd() -> Command { + Command::new("set-api") + .description("OpenAI APIキーを設定") + .usage("mcp set-api --api ") + .flag(Flag::new("api", FlagType::String).description("OpenAI APIキー").alias("a")) + .action(|c: &Context| { + if let Ok(api_key) = c.string_flag("api") { + let config = ConfigPaths::new(); + let path = config.base_dir.join("openai.json"); + let json_data = json!({ "token": api_key }); + + if let Err(e) = fs::write(&path, serde_json::to_string_pretty(&json_data).unwrap()) { + eprintln!("❌ ファイル書き込み失敗: {}", e); + } else { + println!("✅ APIキーを保存しました: {}", path.display()); + } + } else { + eprintln!("❗ APIキーを --api で指定してください"); + } + }) +} + fn chat_cmd() -> Command { Command::new("chat") .description("チャットで質問を送る") - .usage("mcp chat '質問内容' --host --model ") - .flag(Flag::new("host", FlagType::String).description("OLLAMAホストのURL")) - .flag(Flag::new("model", FlagType::String).description("OLLAMAモデル名")) + .usage("mcp chat '質問内容' --host --model [--provider ] [--api-key ]") + .flag( + Flag::new("host", FlagType::String) + .description("OLLAMAホストのURL") + .alias("H"), + ) + .flag( + Flag::new("model", FlagType::String) + .description("モデル名 (OLLAMA_MODEL / OPENAI_MODEL)") + .alias("m"), + ) + .flag( + Flag::new("provider", FlagType::String) + .description("使用するプロバイダ (ollama / openai)") + .alias("p"), + ) + .flag( + Flag::new("api-key", FlagType::String) + .description("OpenAI APIキー") + .alias("k"), + ) .action(|c: &Context| { if let Some(question) = c.args.get(0) { ask_chat(c, question); @@ -157,4 +198,5 @@ pub fn mcp_cmd() -> Command { .command(init_cmd()) .command(status_cmd()) .command(setup_cmd()) + .command(set_api_key_cmd()) }