add fastapi_mcp
This commit is contained in:
parent
9cbf5da3fd
commit
4f55138306
2
.gitignore
vendored
2
.gitignore
vendored
@ -3,3 +3,5 @@
|
||||
output.json
|
||||
config/*.db
|
||||
aigpt
|
||||
mcp/scripts/__*
|
||||
data
|
||||
|
@ -12,3 +12,4 @@ rusqlite = { version = "0.29", features = ["serde_json"] }
|
||||
shellexpand = "*"
|
||||
fs_extra = "1.3"
|
||||
rand = "0.9.1"
|
||||
reqwest = { version = "*", features = ["blocking", "json"] }
|
||||
|
@ -33,7 +33,7 @@ $ ./aigpt mcp chat "hello world!" --host http://localhost:11434 --model syui/ai
|
||||
|
||||
---
|
||||
# openai api
|
||||
$ ./aigpt mcp set-api -api sk-abc123
|
||||
$ ./aigpt mcp set-api --api sk-abc123
|
||||
$ ./aigpt mcp chat "こんにちは" -p openai -m gpt-4o-mini
|
||||
|
||||
---
|
||||
|
27
mcp/cli.py
27
mcp/cli.py
@ -1,3 +1,28 @@
|
||||
# cli.py
|
||||
import sys
|
||||
import subprocess
|
||||
from pathlib import Path
|
||||
|
||||
SCRIPT_DIR = Path.home() / ".config" / "aigpt" / "mcp" / "scripts"
|
||||
def run_script(name):
|
||||
script_path = SCRIPT_DIR / f"{name}.py"
|
||||
if not script_path.exists():
|
||||
print(f"❌ スクリプトが見つかりません: {script_path}")
|
||||
sys.exit(1)
|
||||
|
||||
args = sys.argv[2:] # ← "ask" の後の引数を取り出す
|
||||
result = subprocess.run(["python", str(script_path)] + args, capture_output=True, text=True)
|
||||
print(result.stdout)
|
||||
if result.stderr:
|
||||
print(result.stderr)
|
||||
def main():
|
||||
print("Hello MCP!")
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: mcp <script>")
|
||||
return
|
||||
|
||||
command = sys.argv[1]
|
||||
|
||||
if command in {"summarize", "ask", "setup", "server"}:
|
||||
run_script(command)
|
||||
else:
|
||||
print(f"❓ 未知のコマンド: {command}")
|
||||
|
@ -1,55 +1,198 @@
|
||||
import os
|
||||
## scripts/ask.py
|
||||
import sys
|
||||
import json
|
||||
import httpx
|
||||
import openai
|
||||
import requests
|
||||
from config import load_config
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from context_loader import load_context_from_repo
|
||||
from prompt_template import PROMPT_TEMPLATE
|
||||
def build_payload_openai(cfg, message: str):
|
||||
return {
|
||||
"model": cfg["model"],
|
||||
"tools": [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "ask_message",
|
||||
"description": "過去の記憶を検索します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "検索したい語句"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
],
|
||||
"tool_choice": "auto",
|
||||
"messages": [
|
||||
{"role": "system", "content": "あなたは親しみやすいAIで、必要に応じて記憶から情報を検索して応答します。"},
|
||||
{"role": "user", "content": message}
|
||||
]
|
||||
}
|
||||
|
||||
PROVIDER = os.getenv("PROVIDER", "ollama") # "ollama" or "openai"
|
||||
|
||||
# Ollama用
|
||||
OLLAMA_HOST = os.getenv("OLLAMA_HOST", "http://localhost:11434")
|
||||
OLLAMA_URL = f"{OLLAMA_HOST}/api/generate"
|
||||
OLLAMA_MODEL = os.getenv("MODEL", "syui/ai")
|
||||
|
||||
# OpenAI用
|
||||
OPENAI_BASE = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
|
||||
OPENAI_KEY = os.getenv("OPENAI_API_KEY", "")
|
||||
OPENAI_MODEL = os.getenv("MODEL", "gpt-4o-mini")
|
||||
|
||||
def ask_question(question, repo_path="."):
|
||||
context = load_context_from_repo(repo_path)
|
||||
prompt = PROMPT_TEMPLATE.format(context=context[:10000], question=question)
|
||||
|
||||
if PROVIDER == "ollama":
|
||||
payload = {
|
||||
"model": OLLAMA_MODEL,
|
||||
"prompt": prompt,
|
||||
"stream": False
|
||||
def build_payload_mcp(message: str):
|
||||
return {
|
||||
"tool": "ask_message", # MCPサーバー側で定義されたツール名
|
||||
"input": {
|
||||
"message": message
|
||||
}
|
||||
response = httpx.post(OLLAMA_URL, json=payload, timeout=60.0)
|
||||
result = response.json()
|
||||
return result.get("response", "返答がありませんでした。")
|
||||
}
|
||||
|
||||
elif PROVIDER == "openai":
|
||||
import openai
|
||||
openai.api_key = OPENAI_KEY
|
||||
openai.api_base = OPENAI_BASE
|
||||
def build_payload_openai(cfg, message: str):
|
||||
return {
|
||||
"model": cfg["model"],
|
||||
"messages": [
|
||||
{"role": "system", "content": "あなたは思いやりのあるAIです。"},
|
||||
{"role": "user", "content": message}
|
||||
],
|
||||
"temperature": 0.7
|
||||
}
|
||||
|
||||
client = openai.OpenAI(api_key=os.getenv("OPENAI_API_KEY"))
|
||||
response = client.chat.completions.create(
|
||||
model=OPENAI_MODEL,
|
||||
messages=[{"role": "user", "content": prompt}]
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
def call_mcp(cfg, message: str):
|
||||
payload = build_payload_mcp(message)
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(cfg["url"], headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json().get("output", {}).get("response", "❓ 応答が取得できませんでした")
|
||||
|
||||
def call_openai(cfg, message: str):
|
||||
# ツール定義
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "memory",
|
||||
"description": "記憶を検索する",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "検索する語句"
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# 最初のメッセージ送信
|
||||
payload = {
|
||||
"model": cfg["model"],
|
||||
"messages": [
|
||||
{"role": "system", "content": "あなたはAIで、必要に応じてツールmemoryを使って記憶を検索します。"},
|
||||
{"role": "user", "content": message}
|
||||
],
|
||||
"tools": tools,
|
||||
"tool_choice": "auto"
|
||||
}
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {cfg['api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
res1 = requests.post(cfg["url"], headers=headers, json=payload)
|
||||
res1.raise_for_status()
|
||||
result = res1.json()
|
||||
|
||||
# 🧠 tool_call されたか確認
|
||||
if "tool_calls" in result["choices"][0]["message"]:
|
||||
tool_call = result["choices"][0]["message"]["tool_calls"][0]
|
||||
if tool_call["function"]["name"] == "memory":
|
||||
args = json.loads(tool_call["function"]["arguments"])
|
||||
query = args.get("query", "")
|
||||
print(f"🛠️ ツール実行: memory(query='{query}')")
|
||||
|
||||
# MCPエンドポイントにPOST
|
||||
memory_res = requests.post("http://127.0.0.1:5000/memory/search", json={"query": query})
|
||||
memory_json = memory_res.json()
|
||||
tool_output = memory_json.get("result", "なし")
|
||||
|
||||
# tool_outputをAIに返す
|
||||
followup = {
|
||||
"model": cfg["model"],
|
||||
"messages": [
|
||||
{"role": "system", "content": "あなたはAIで、必要に応じてツールmemoryを使って記憶を検索します。"},
|
||||
{"role": "user", "content": message},
|
||||
{"role": "assistant", "tool_calls": result["choices"][0]["message"]["tool_calls"]},
|
||||
{"role": "tool", "tool_call_id": tool_call["id"], "name": "memory", "content": tool_output}
|
||||
]
|
||||
}
|
||||
|
||||
res2 = requests.post(cfg["url"], headers=headers, json=followup)
|
||||
res2.raise_for_status()
|
||||
final_response = res2.json()
|
||||
return final_response["choices"][0]["message"]["content"]
|
||||
#print(tool_output)
|
||||
#print(cfg["model"])
|
||||
#print(final_response)
|
||||
|
||||
# ツール未使用 or 通常応答
|
||||
return result["choices"][0]["message"]["content"]
|
||||
|
||||
def call_ollama(cfg, message: str):
|
||||
payload = {
|
||||
"model": cfg["model"],
|
||||
"prompt": message, # `prompt` → `message` にすべき(変数未定義エラー回避)
|
||||
"stream": False
|
||||
}
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(cfg["url"], headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json().get("response", "❌ 応答が取得できませんでした")
|
||||
def main():
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: ask.py 'your message'")
|
||||
return
|
||||
|
||||
message = sys.argv[1]
|
||||
cfg = load_config()
|
||||
|
||||
print(f"🔍 使用プロバイダー: {cfg['provider']}")
|
||||
|
||||
try:
|
||||
if cfg["provider"] == "openai":
|
||||
response = call_openai(cfg, message)
|
||||
elif cfg["provider"] == "mcp":
|
||||
response = call_mcp(cfg, message)
|
||||
elif cfg["provider"] == "ollama":
|
||||
response = call_ollama(cfg, message)
|
||||
else:
|
||||
raise ValueError(f"未対応のプロバイダー: {cfg['provider']}")
|
||||
|
||||
print("💬 応答:")
|
||||
print(response)
|
||||
|
||||
# ログ保存(オプション)
|
||||
save_log(message, response)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ 実行エラー: {e}")
|
||||
|
||||
def save_log(user_msg, ai_msg):
|
||||
from config import MEMORY_DIR
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
path = MEMORY_DIR / f"{date_str}.json"
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
if path.exists():
|
||||
with open(path, "r") as f:
|
||||
logs = json.load(f)
|
||||
else:
|
||||
return f"❌ 未知のプロバイダです: {PROVIDER}"
|
||||
logs = []
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
logs.append({"timestamp": now, "sender": "user", "message": user_msg})
|
||||
logs.append({"timestamp": now, "sender": "ai", "message": ai_msg})
|
||||
|
||||
with open(path, "w") as f:
|
||||
json.dump(logs, f, indent=2, ensure_ascii=False)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
question = " ".join(sys.argv[1:])
|
||||
answer = ask_question(question)
|
||||
print("\n🧠 回答:\n", answer)
|
||||
main()
|
||||
|
41
mcp/scripts/config.py
Normal file
41
mcp/scripts/config.py
Normal file
@ -0,0 +1,41 @@
|
||||
# scripts/config.py
|
||||
# scripts/config.py
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
# ディレクトリ設定
|
||||
BASE_DIR = Path.home() / ".config" / "aigpt"
|
||||
MEMORY_DIR = BASE_DIR / "memory"
|
||||
SUMMARY_DIR = MEMORY_DIR / "summary"
|
||||
|
||||
def init_directories():
|
||||
BASE_DIR.mkdir(parents=True, exist_ok=True)
|
||||
MEMORY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
SUMMARY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def load_config():
|
||||
provider = os.getenv("PROVIDER", "ollama")
|
||||
model = os.getenv("MODEL", "syui/ai" if provider == "ollama" else "gpt-4o-mini")
|
||||
api_key = os.getenv("OPENAI_API_KEY", "")
|
||||
|
||||
if provider == "ollama":
|
||||
return {
|
||||
"provider": "ollama",
|
||||
"model": model,
|
||||
"url": f"{os.getenv('OLLAMA_HOST', 'http://localhost:11434')}/api/generate"
|
||||
}
|
||||
elif provider == "openai":
|
||||
return {
|
||||
"provider": "openai",
|
||||
"model": model,
|
||||
"api_key": api_key,
|
||||
"url": f"{os.getenv('OPENAI_API_BASE', 'https://api.openai.com/v1')}/chat/completions"
|
||||
}
|
||||
elif provider == "mcp":
|
||||
return {
|
||||
"provider": "mcp",
|
||||
"model": model,
|
||||
"url": os.getenv("MCP_URL", "http://localhost:5000/chat")
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {provider}")
|
92
mcp/scripts/memory_store.py
Normal file
92
mcp/scripts/memory_store.py
Normal file
@ -0,0 +1,92 @@
|
||||
# scripts/memory_store.py
|
||||
import json
|
||||
from pathlib import Path
|
||||
from config import MEMORY_DIR
|
||||
from datetime import datetime, timezone
|
||||
|
||||
def load_logs(date_str=None):
|
||||
if date_str is None:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
path = MEMORY_DIR / f"{date_str}.json"
|
||||
if path.exists():
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
return []
|
||||
|
||||
def save_message(sender, message):
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
path = MEMORY_DIR / f"{date_str}.json"
|
||||
logs = load_logs(date_str)
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
logs.append({"timestamp": now, "sender": sender, "message": message})
|
||||
with open(path, "w") as f:
|
||||
json.dump(logs, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def search_memory(query: str):
|
||||
from glob import glob
|
||||
all_logs = []
|
||||
pattern = re.compile(re.escape(query), re.IGNORECASE)
|
||||
|
||||
for file_path in sorted(MEMORY_DIR.glob("*.json")):
|
||||
with open(file_path, "r") as f:
|
||||
logs = json.load(f)
|
||||
matched = [entry for entry in logs if pattern.search(entry["message"])]
|
||||
all_logs.extend(matched)
|
||||
|
||||
return all_logs[-5:]
|
||||
|
||||
# scripts/memory_store.py
|
||||
import json
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from config import MEMORY_DIR
|
||||
|
||||
# ログを読み込む(指定日または当日)
|
||||
def load_logs(date_str=None):
|
||||
if date_str is None:
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
path = MEMORY_DIR / f"{date_str}.json"
|
||||
if path.exists():
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
return []
|
||||
|
||||
# メッセージを保存する
|
||||
def save_message(sender, message):
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
path = MEMORY_DIR / f"{date_str}.json"
|
||||
logs = load_logs(date_str)
|
||||
#now = datetime.utcnow().isoformat() + "Z"
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
logs.append({"timestamp": now, "sender": sender, "message": message})
|
||||
with open(path, "w") as f:
|
||||
json.dump(logs, f, indent=2, ensure_ascii=False)
|
||||
|
||||
def search_memory(query: str):
|
||||
from glob import glob
|
||||
all_logs = []
|
||||
for file_path in sorted(MEMORY_DIR.glob("*.json")):
|
||||
with open(file_path, "r") as f:
|
||||
logs = json.load(f)
|
||||
matched = [
|
||||
entry for entry in logs
|
||||
if entry["sender"] == "user" and query in entry["message"]
|
||||
]
|
||||
all_logs.extend(matched)
|
||||
return all_logs[-5:] # 最新5件だけ返す
|
||||
def search_memory(query: str):
|
||||
from glob import glob
|
||||
all_logs = []
|
||||
seen_messages = set() # すでに見たメッセージを保持
|
||||
|
||||
for file_path in sorted(MEMORY_DIR.glob("*.json")):
|
||||
with open(file_path, "r") as f:
|
||||
logs = json.load(f)
|
||||
for entry in logs:
|
||||
if entry["sender"] == "user" and query in entry["message"]:
|
||||
# すでに同じメッセージが結果に含まれていなければ追加
|
||||
if entry["message"] not in seen_messages:
|
||||
all_logs.append(entry)
|
||||
seen_messages.add(entry["message"])
|
||||
|
||||
return all_logs[-5:] # 最新5件だけ返す
|
56
mcp/scripts/server.py
Normal file
56
mcp/scripts/server.py
Normal file
@ -0,0 +1,56 @@
|
||||
# server.py
|
||||
from fastapi import FastAPI, Body
|
||||
from fastapi_mcp import FastApiMCP
|
||||
from pydantic import BaseModel
|
||||
from memory_store import save_message, load_logs, search_memory as do_search_memory
|
||||
|
||||
app = FastAPI()
|
||||
mcp = FastApiMCP(app, name="aigpt-agent", description="MCP Server for AI memory")
|
||||
|
||||
class ChatInput(BaseModel):
|
||||
message: str
|
||||
|
||||
class MemoryInput(BaseModel):
|
||||
sender: str
|
||||
message: str
|
||||
|
||||
class MemoryQuery(BaseModel):
|
||||
query: str
|
||||
|
||||
@app.post("/chat", operation_id="chat")
|
||||
async def chat(input: ChatInput):
|
||||
save_message("user", input.message)
|
||||
response = f"AI: 「{input.message}」を受け取りました!"
|
||||
save_message("ai", response)
|
||||
return {"response": response}
|
||||
|
||||
@app.post("/memory", operation_id="save_memory")
|
||||
async def memory_post(input: MemoryInput):
|
||||
save_message(input.sender, input.message)
|
||||
return {"status": "saved"}
|
||||
|
||||
@app.get("/memory", operation_id="get_memory")
|
||||
async def memory_get():
|
||||
return {"messages": load_messages()}
|
||||
|
||||
@app.post("/ask_message", operation_id="ask_message")
|
||||
async def ask_message(input: MemoryQuery):
|
||||
results = search_memory(input.query)
|
||||
return {
|
||||
"response": f"🔎 記憶から {len(results)} 件ヒット:\n" + "\n".join([f"{r['sender']}: {r['message']}" for r in results])
|
||||
}
|
||||
|
||||
@app.post("/memory/search", operation_id="memory")
|
||||
async def memory_search(query: MemoryQuery):
|
||||
hits = do_search_memory(query.query)
|
||||
if not hits:
|
||||
return {"result": "🔍 記憶の中に該当する内容は見つかりませんでした。"}
|
||||
summary = "\n".join([f"{e['sender']}: {e['message']}" for e in hits])
|
||||
return {"result": f"🔎 見つかった記憶:\n{summary}"}
|
||||
|
||||
mcp.mount()
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
print("🚀 Starting MCP server...")
|
||||
uvicorn.run(app, host="127.0.0.1", port=5000)
|
76
mcp/scripts/summarize.py
Normal file
76
mcp/scripts/summarize.py
Normal file
@ -0,0 +1,76 @@
|
||||
# scripts/summarize.py
|
||||
import json
|
||||
from datetime import datetime
|
||||
from config import MEMORY_DIR, SUMMARY_DIR, load_config
|
||||
import requests
|
||||
|
||||
def load_memory(date_str):
|
||||
path = MEMORY_DIR / f"{date_str}.json"
|
||||
if not path.exists():
|
||||
print(f"⚠️ メモリファイルが見つかりません: {path}")
|
||||
return None
|
||||
with open(path, "r") as f:
|
||||
return json.load(f)
|
||||
|
||||
def save_summary(date_str, content):
|
||||
SUMMARY_DIR.mkdir(parents=True, exist_ok=True)
|
||||
path = SUMMARY_DIR / f"{date_str}_summary.json"
|
||||
with open(path, "w") as f:
|
||||
json.dump(content, f, indent=2, ensure_ascii=False)
|
||||
print(f"✅ 要約を保存しました: {path}")
|
||||
|
||||
def build_prompt(logs):
|
||||
messages = [
|
||||
{"role": "system", "content": "あなたは要約AIです。以下の会話ログを要約してください。"},
|
||||
{"role": "user", "content": "\n".join(f"{entry['sender']}: {entry['message']}" for entry in logs)}
|
||||
]
|
||||
return messages
|
||||
|
||||
def summarize_with_llm(messages):
|
||||
cfg = load_config()
|
||||
if cfg["provider"] == "openai":
|
||||
headers = {
|
||||
"Authorization": f"Bearer {cfg['api_key']}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
payload = {
|
||||
"model": cfg["model"],
|
||||
"messages": messages,
|
||||
"temperature": 0.7
|
||||
}
|
||||
response = requests.post(cfg["url"], headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["choices"][0]["message"]["content"]
|
||||
|
||||
elif cfg["provider"] == "ollama":
|
||||
payload = {
|
||||
"model": cfg["model"],
|
||||
"prompt": "\n".join(m["content"] for m in messages),
|
||||
"stream": False,
|
||||
}
|
||||
response = requests.post(cfg["url"], json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()["response"]
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {cfg['provider']}")
|
||||
|
||||
def main():
|
||||
date_str = datetime.now().strftime("%Y-%m-%d")
|
||||
logs = load_memory(date_str)
|
||||
if not logs:
|
||||
return
|
||||
|
||||
prompt_messages = build_prompt(logs)
|
||||
summary_text = summarize_with_llm(prompt_messages)
|
||||
|
||||
summary = {
|
||||
"date": date_str,
|
||||
"summary": summary_text,
|
||||
"total_messages": len(logs)
|
||||
}
|
||||
|
||||
save_summary(date_str, summary)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@ -1,8 +1,8 @@
|
||||
# setup.py
|
||||
from setuptools import setup
|
||||
|
||||
setup(
|
||||
name='mcp',
|
||||
version='0.1.0',
|
||||
name='aigpt-mcp',
|
||||
py_modules=['cli'],
|
||||
entry_points={
|
||||
'console_scripts': [
|
||||
|
108
src/chat.rs
108
src/chat.rs
@ -5,11 +5,16 @@ use serde::Deserialize;
|
||||
use seahorse::Context;
|
||||
use crate::config::ConfigPaths;
|
||||
use crate::metrics::{load_user_data, save_user_data, update_metrics_decay};
|
||||
//use std::process::Stdio;
|
||||
//use std::io::Write;
|
||||
//use std::time::Duration;
|
||||
//use std::net::TcpStream;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Provider {
|
||||
OpenAI,
|
||||
Ollama,
|
||||
MCP,
|
||||
}
|
||||
|
||||
impl Provider {
|
||||
@ -17,6 +22,7 @@ impl Provider {
|
||||
match s.to_lowercase().as_str() {
|
||||
"openai" => Some(Provider::OpenAI),
|
||||
"ollama" => Some(Provider::Ollama),
|
||||
"mcp" => Some(Provider::MCP),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
@ -25,6 +31,7 @@ impl Provider {
|
||||
match self {
|
||||
Provider::OpenAI => "openai",
|
||||
Provider::Ollama => "ollama",
|
||||
Provider::MCP => "mcp",
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -45,19 +52,11 @@ fn load_openai_api_key() -> Option<String> {
|
||||
pub fn ask_chat(c: &Context, question: &str) -> Option<String> {
|
||||
let config = ConfigPaths::new();
|
||||
let base_dir = config.base_dir.join("mcp");
|
||||
let script_path = base_dir.join("scripts/ask.py");
|
||||
let user_path = config.base_dir.join("user.json");
|
||||
|
||||
let mut user = load_user_data(&user_path);
|
||||
user.metrics = update_metrics_decay();
|
||||
|
||||
// Python 実行パス
|
||||
let python_path = if cfg!(target_os = "windows") {
|
||||
base_dir.join(".venv/Scripts/python.exe")
|
||||
} else {
|
||||
base_dir.join(".venv/bin/python")
|
||||
};
|
||||
|
||||
// 各種オプション
|
||||
let ollama_host = c.string_flag("host").ok();
|
||||
let ollama_model = c.string_flag("model").ok();
|
||||
@ -67,38 +66,75 @@ pub fn ask_chat(c: &Context, question: &str) -> Option<String> {
|
||||
|
||||
println!("🔍 使用プロバイダー: {}", provider.as_str());
|
||||
|
||||
// Python コマンド準備
|
||||
let mut command = Command::new(python_path);
|
||||
command.arg(script_path).arg(question);
|
||||
match provider {
|
||||
Provider::MCP => {
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let url = std::env::var("MCP_URL").unwrap_or("http://127.0.0.1:5000/chat".to_string());
|
||||
let res = client.post(url)
|
||||
.json(&serde_json::json!({"message": question}))
|
||||
.send();
|
||||
|
||||
if let Some(host) = ollama_host {
|
||||
command.env("OLLAMA_HOST", host);
|
||||
}
|
||||
if let Some(model) = ollama_model {
|
||||
command.env("OLLAMA_MODEL", model.clone());
|
||||
command.env("OPENAI_MODEL", model);
|
||||
}
|
||||
command.env("PROVIDER", provider.as_str());
|
||||
match res {
|
||||
Ok(resp) => {
|
||||
if resp.status().is_success() {
|
||||
let json: serde_json::Value = resp.json().ok()?;
|
||||
let text = json.get("response")?.as_str()?.to_string();
|
||||
user.metrics.intimacy += 0.01;
|
||||
user.metrics.last_updated = chrono::Utc::now();
|
||||
save_user_data(&user_path, &user);
|
||||
Some(text)
|
||||
} else {
|
||||
eprintln!("❌ MCPエラー: HTTP {}", resp.status());
|
||||
None
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
eprintln!("❌ MCP接続失敗: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Python 実行パス
|
||||
let python_path = if cfg!(target_os = "windows") {
|
||||
base_dir.join(".venv/Scripts/mcp.exe")
|
||||
} else {
|
||||
base_dir.join(".venv/bin/mcp")
|
||||
};
|
||||
|
||||
if let Some(key) = api_key {
|
||||
command.env("OPENAI_API_KEY", key);
|
||||
}
|
||||
let mut command = Command::new(python_path);
|
||||
command.arg("ask").arg(question);
|
||||
|
||||
let output = command.output().expect("❌ MCPチャットスクリプトの実行に失敗しました");
|
||||
if let Some(host) = ollama_host {
|
||||
command.env("OLLAMA_HOST", host);
|
||||
}
|
||||
if let Some(model) = ollama_model {
|
||||
command.env("OLLAMA_MODEL", model.clone());
|
||||
command.env("OPENAI_MODEL", model);
|
||||
}
|
||||
command.env("PROVIDER", provider.as_str());
|
||||
|
||||
if output.status.success() {
|
||||
let response = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
user.metrics.intimacy += 0.01;
|
||||
user.metrics.last_updated = chrono::Utc::now();
|
||||
save_user_data(&user_path, &user);
|
||||
if let Some(key) = api_key {
|
||||
command.env("OPENAI_API_KEY", key);
|
||||
}
|
||||
|
||||
Some(response)
|
||||
} else {
|
||||
eprintln!(
|
||||
"❌ 実行エラー: {}\n{}",
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
None
|
||||
let output = command.output().expect("❌ MCPチャットスクリプトの実行に失敗しました");
|
||||
|
||||
if output.status.success() {
|
||||
let response = String::from_utf8_lossy(&output.stdout).to_string();
|
||||
user.metrics.intimacy += 0.01;
|
||||
user.metrics.last_updated = chrono::Utc::now();
|
||||
save_user_data(&user_path, &user);
|
||||
|
||||
Some(response)
|
||||
} else {
|
||||
eprintln!(
|
||||
"❌ 実行エラー: {}\n{}",
|
||||
String::from_utf8_lossy(&output.stderr),
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -32,8 +32,12 @@ pub fn mcp_setup() {
|
||||
"cli.py",
|
||||
"setup.py",
|
||||
"scripts/ask.py",
|
||||
"scripts/server.py",
|
||||
"scripts/config.py",
|
||||
"scripts/summarize.py",
|
||||
"scripts/context_loader.py",
|
||||
"scripts/prompt_template.py",
|
||||
"scripts/memory_store.py",
|
||||
];
|
||||
|
||||
for rel_path in files_to_copy {
|
||||
@ -76,6 +80,12 @@ pub fn mcp_setup() {
|
||||
let output = OtherCommand::new(&pip_path)
|
||||
.arg("install")
|
||||
.arg("openai")
|
||||
.arg("requests")
|
||||
.arg("fastmcp")
|
||||
.arg("uvicorn")
|
||||
.arg("fastapi")
|
||||
.arg("fastapi_mcp")
|
||||
.arg("mcp")
|
||||
.current_dir(&dest_dir)
|
||||
.output()
|
||||
.expect("pip install に失敗しました");
|
||||
|
Loading…
x
Reference in New Issue
Block a user