diff --git a/Cargo.toml b/Cargo.toml index 9d6bea4..20d1b33 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,3 +12,4 @@ rusqlite = { version = "0.29", features = ["serde_json"] } shellexpand = "*" fs_extra = "1.3" rand = "0.9.1" +reqwest = { version = "*", features = ["blocking", "json"] } diff --git a/mcp/scripts/ask.py b/mcp/scripts/ask.py index fe3b6aa..f3885e0 100644 --- a/mcp/scripts/ask.py +++ b/mcp/scripts/ask.py @@ -1,28 +1,90 @@ ## scripts/ask.py import sys +import json import requests +from datetime import datetime from config import load_config -def ask(prompt): +def build_payload_mcp(message: str): + return { + "tool": "ask_message", # MCPサーバー側で定義されたツール名 + "input": { + "message": message + } + } + +def build_payload_openai(cfg, message: str): + return { + "model": cfg["model"], + "messages": [ + {"role": "system", "content": "あなたは思いやりのあるAIです。"}, + {"role": "user", "content": message} + ], + "temperature": 0.7 + } + +def call_mcp(cfg, message: str): + payload = build_payload_mcp(message) + headers = {"Content-Type": "application/json"} + response = requests.post(cfg["url"], headers=headers, json=payload) + response.raise_for_status() + return response.json().get("output", {}).get("response", "❓ 応答が取得できませんでした") + +def call_openai(cfg, message: str): + payload = build_payload_openai(cfg, message) + headers = { + "Authorization": f"Bearer {cfg['api_key']}", + "Content-Type": "application/json", + } + response = requests.post(cfg["url"], headers=headers, json=payload) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + +def main(): + if len(sys.argv) < 2: + print("Usage: ask.py 'your message'") + return + + message = sys.argv[1] cfg = load_config() - if cfg["provider"] == "ollama": - payload = {"model": cfg["model"], "prompt": prompt, "stream": False} - response = requests.post(cfg["url"], json=payload) - print(response.json().get("response", "❌ No response")) + + print(f"🔍 使用プロバイダー: {cfg['provider']}") + + try: + if cfg["provider"] == "openai": + response = call_openai(cfg, message) + elif cfg["provider"] == "mcp": + response = call_mcp(cfg, message) + else: + raise ValueError(f"未対応のプロバイダー: {cfg['provider']}") + + print("💬 応答:") + print(response) + + # ログ保存(オプション) + save_log(message, response) + + except Exception as e: + print(f"❌ 実行エラー: {e}") + +def save_log(user_msg, ai_msg): + from config import MEMORY_DIR + date_str = datetime.now().strftime("%Y-%m-%d") + path = MEMORY_DIR / f"{date_str}.json" + path.parent.mkdir(parents=True, exist_ok=True) + + if path.exists(): + with open(path, "r") as f: + logs = json.load(f) else: - headers = { - "Authorization": f"Bearer {cfg['api_key']}", - "Content-Type": "application/json" - } - payload = { - "model": cfg["model"], - "messages": [{"role": "user", "content": prompt}] - } - response = requests.post(cfg["url"], headers=headers, json=payload) - print(response.json().get("choices", [{}])[0].get("message", {}).get("content", "❌ No content")) + logs = [] + + now = datetime.utcnow().isoformat() + "Z" + logs.append({"timestamp": now, "sender": "user", "message": user_msg}) + logs.append({"timestamp": now, "sender": "ai", "message": ai_msg}) + + with open(path, "w") as f: + json.dump(logs, f, indent=2, ensure_ascii=False) if __name__ == "__main__": - if len(sys.argv) < 2: - print("Usage: python ask.py 'your message'") - sys.exit(1) - ask(sys.argv[1]) + main() diff --git a/mcp/scripts/config.py b/mcp/scripts/config.py index d963049..63c64ec 100644 --- a/mcp/scripts/config.py +++ b/mcp/scripts/config.py @@ -1,7 +1,18 @@ # scripts/config.py +# scripts/config.py import os from pathlib import Path +# ディレクトリ設定 +BASE_DIR = Path.home() / ".config" / "aigpt" +MEMORY_DIR = BASE_DIR / "memory" +SUMMARY_DIR = MEMORY_DIR / "summary" + +def init_directories(): + BASE_DIR.mkdir(parents=True, exist_ok=True) + MEMORY_DIR.mkdir(parents=True, exist_ok=True) + SUMMARY_DIR.mkdir(parents=True, exist_ok=True) + def load_config(): provider = os.getenv("PROVIDER", "ollama") model = os.getenv("MODEL", "syui/ai" if provider == "ollama" else "gpt-4o-mini") @@ -20,15 +31,11 @@ def load_config(): "api_key": api_key, "url": f"{os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1')}/chat/completions" } + elif provider == "mcp": + return { + "provider": "mcp", + "model": model, + "url": os.getenv("MCP_URL", "http://localhost:5000/chat") + } else: raise ValueError(f"Unsupported provider: {provider}") - -# ディレクトリ設定 -BASE_DIR = Path.home() / ".config" / "aigpt" -MEMORY_DIR = BASE_DIR / "memory" -SUMMARY_DIR = MEMORY_DIR / "summary" - -# 初期化(必要に応じて作成) -BASE_DIR.mkdir(parents=True, exist_ok=True) -MEMORY_DIR.mkdir(parents=True, exist_ok=True) -SUMMARY_DIR.mkdir(parents=True, exist_ok=True) diff --git a/mcp/scripts/memory_store.py b/mcp/scripts/memory_store.py new file mode 100644 index 0000000..a2ba7d4 --- /dev/null +++ b/mcp/scripts/memory_store.py @@ -0,0 +1,37 @@ +# scripts/memory_store.py +from pathlib import Path +import json +from datetime import datetime + +MEMORY_DIR = Path.home() / ".config" / "aigpt" / "memory" +MEMORY_DIR.mkdir(parents=True, exist_ok=True) + +def get_today_path(): + today = datetime.utcnow().strftime("%Y-%m-%d") + return MEMORY_DIR / f"{today}.json" + +def save_message(sender: str, message: str): + entry = { + "timestamp": datetime.utcnow().isoformat(), + "sender": sender, + "message": message + } + + path = get_today_path() + data = [] + + if path.exists(): + with open(path, "r") as f: + data = json.load(f) + + data.append(entry) + + with open(path, "w") as f: + json.dump(data, f, indent=2, ensure_ascii=False) + +def load_messages(): + path = get_today_path() + if not path.exists(): + return [] + with open(path, "r") as f: + return json.load(f) diff --git a/mcp/scripts/server.py b/mcp/scripts/server.py index e9771e6..d67f9af 100644 --- a/mcp/scripts/server.py +++ b/mcp/scripts/server.py @@ -1,21 +1,41 @@ # server.py from fastapi import FastAPI from fastapi_mcp import FastApiMCP +from pydantic import BaseModel +from memory_store import save_message, load_messages app = FastAPI() +mcp = FastApiMCP(app, name="aigpt-agent", description="MCP Server for AI memory") -@app.get("/items/{item_id}", operation_id="get_item") -async def read_item(item_id: int): - return {"item_id": item_id, "name": f"Item {item_id}"} +# --- モデル定義 --- +class ChatInput(BaseModel): + message: str -# MCPサーバを作成し、FastAPIアプリにマウント -mcp = FastApiMCP( - app, - name="My API MCP", - description="My API description" -) +class MemoryInput(BaseModel): + sender: str + message: str + +# --- ツール(エンドポイント)定義 --- +@app.post("/chat", operation_id="chat") +async def chat(input: ChatInput): + save_message("user", input.message) + response = f"AI: 「{input.message}」を受け取りました!" + save_message("ai", response) + return {"response": response} + +@app.post("/memory", operation_id="save_memory") +async def memory_post(input: MemoryInput): + save_message(input.sender, input.message) + return {"status": "saved"} + +@app.get("/memory", operation_id="get_memory") +async def memory_get(): + return {"messages": load_messages()} + +# --- MCP 初期化 --- mcp.mount() if __name__ == "__main__": import uvicorn - uvicorn.run(app, host="0.0.0.0", port=8000) + print("🚀 Starting MCP server...") + uvicorn.run(app, host="127.0.0.1", port=5000) diff --git a/src/chat.rs b/src/chat.rs index da5792d..be1174f 100644 --- a/src/chat.rs +++ b/src/chat.rs @@ -5,11 +5,16 @@ use serde::Deserialize; use seahorse::Context; use crate::config::ConfigPaths; use crate::metrics::{load_user_data, save_user_data, update_metrics_decay}; +use std::process::Stdio; +use std::io::Write; +use std::time::Duration; +use std::net::TcpStream; #[derive(Debug, Clone, PartialEq)] pub enum Provider { OpenAI, Ollama, + MCP, } impl Provider { @@ -17,6 +22,7 @@ impl Provider { match s.to_lowercase().as_str() { "openai" => Some(Provider::OpenAI), "ollama" => Some(Provider::Ollama), + "mcp" => Some(Provider::MCP), _ => None, } } @@ -25,6 +31,7 @@ impl Provider { match self { Provider::OpenAI => "openai", Provider::Ollama => "ollama", + Provider::MCP => "mcp", } } } @@ -50,13 +57,6 @@ pub fn ask_chat(c: &Context, question: &str) -> Option { let mut user = load_user_data(&user_path); user.metrics = update_metrics_decay(); - // Python 実行パス - let python_path = if cfg!(target_os = "windows") { - base_dir.join(".venv/Scripts/mcp.exe") - } else { - base_dir.join(".venv/bin/mcp") - }; - // 各種オプション let ollama_host = c.string_flag("host").ok(); let ollama_model = c.string_flag("model").ok(); @@ -66,38 +66,75 @@ pub fn ask_chat(c: &Context, question: &str) -> Option { println!("🔍 使用プロバイダー: {}", provider.as_str()); - // Python コマンド準備 - let mut command = Command::new(python_path); - command.arg("ask").arg(question); + match provider { + Provider::MCP => { + let client = reqwest::blocking::Client::new(); + let url = std::env::var("MCP_URL").unwrap_or("http://127.0.0.1:5000/chat".to_string()); + let res = client.post(url) + .json(&serde_json::json!({"message": question})) + .send(); - if let Some(host) = ollama_host { - command.env("OLLAMA_HOST", host); - } - if let Some(model) = ollama_model { - command.env("OLLAMA_MODEL", model.clone()); - command.env("OPENAI_MODEL", model); - } - command.env("PROVIDER", provider.as_str()); + match res { + Ok(resp) => { + if resp.status().is_success() { + let json: serde_json::Value = resp.json().ok()?; + let text = json.get("response")?.as_str()?.to_string(); + user.metrics.intimacy += 0.01; + user.metrics.last_updated = chrono::Utc::now(); + save_user_data(&user_path, &user); + Some(text) + } else { + eprintln!("❌ MCPエラー: HTTP {}", resp.status()); + None + } + } + Err(e) => { + eprintln!("❌ MCP接続失敗: {}", e); + None + } + } + } + _ => { + // Python 実行パス + let python_path = if cfg!(target_os = "windows") { + base_dir.join(".venv/Scripts/mcp.exe") + } else { + base_dir.join(".venv/bin/mcp") + }; - if let Some(key) = api_key { - command.env("OPENAI_API_KEY", key); - } + let mut command = Command::new(python_path); + command.arg("ask").arg(question); - let output = command.output().expect("❌ MCPチャットスクリプトの実行に失敗しました"); + if let Some(host) = ollama_host { + command.env("OLLAMA_HOST", host); + } + if let Some(model) = ollama_model { + command.env("OLLAMA_MODEL", model.clone()); + command.env("OPENAI_MODEL", model); + } + command.env("PROVIDER", provider.as_str()); - if output.status.success() { - let response = String::from_utf8_lossy(&output.stdout).to_string(); - user.metrics.intimacy += 0.01; - user.metrics.last_updated = chrono::Utc::now(); - save_user_data(&user_path, &user); + if let Some(key) = api_key { + command.env("OPENAI_API_KEY", key); + } - Some(response) - } else { - eprintln!( - "❌ 実行エラー: {}\n{}", - String::from_utf8_lossy(&output.stderr), - String::from_utf8_lossy(&output.stdout), - ); - None + let output = command.output().expect("❌ MCPチャットスクリプトの実行に失敗しました"); + + if output.status.success() { + let response = String::from_utf8_lossy(&output.stdout).to_string(); + user.metrics.intimacy += 0.01; + user.metrics.last_updated = chrono::Utc::now(); + save_user_data(&user_path, &user); + + Some(response) + } else { + eprintln!( + "❌ 実行エラー: {}\n{}", + String::from_utf8_lossy(&output.stderr), + String::from_utf8_lossy(&output.stdout), + ); + None + } + } } }