fix mcp
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
"""AI Provider integration for response generation"""
|
||||
|
||||
import os
|
||||
import json
|
||||
from typing import Optional, Dict, List, Any, Protocol
|
||||
from abc import abstractmethod
|
||||
import logging
|
||||
@ -128,9 +129,9 @@ Recent memories:
|
||||
|
||||
|
||||
class OpenAIProvider:
|
||||
"""OpenAI API provider"""
|
||||
"""OpenAI API provider with MCP function calling support"""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o-mini", api_key: Optional[str] = None):
|
||||
def __init__(self, model: str = "gpt-4o-mini", api_key: Optional[str] = None, mcp_client=None):
|
||||
self.model = model
|
||||
# Try to get API key from config first
|
||||
config = Config()
|
||||
@ -139,6 +140,90 @@ class OpenAIProvider:
|
||||
raise ValueError("OpenAI API key not provided. Set it with: aigpt config set providers.openai.api_key YOUR_KEY")
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.mcp_client = mcp_client # For MCP function calling
|
||||
|
||||
def _get_mcp_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Generate OpenAI tools from MCP endpoints"""
|
||||
if not self.mcp_client or not self.mcp_client.available:
|
||||
return []
|
||||
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_memories",
|
||||
"description": "過去の会話記憶を取得します。「覚えている」「前回」「以前」などの質問で必ず使用してください",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "取得する記憶の数",
|
||||
"default": 5
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "search_memories",
|
||||
"description": "特定のトピックについて話した記憶を検索します。「プログラミングについて」「○○について話した」などの質問で使用してください",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "検索キーワードの配列"
|
||||
}
|
||||
},
|
||||
"required": ["keywords"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_contextual_memories",
|
||||
"description": "クエリに関連する文脈的記憶を取得します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "検索クエリ"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "取得する記憶の数",
|
||||
"default": 5
|
||||
}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_relationship",
|
||||
"description": "特定ユーザーとの関係性情報を取得します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"description": "ユーザーID"
|
||||
}
|
||||
},
|
||||
"required": ["user_id"]
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
return tools
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
@ -184,6 +269,127 @@ Recent memories:
|
||||
self.logger.error(f"OpenAI generation failed: {e}")
|
||||
return self._fallback_response(persona_state)
|
||||
|
||||
async def chat_with_mcp(self, prompt: str, max_tokens: int = 2000, user_id: str = "user") -> str:
|
||||
"""Chat interface with MCP function calling support"""
|
||||
if not self.mcp_client or not self.mcp_client.available:
|
||||
return self.chat(prompt, max_tokens)
|
||||
|
||||
try:
|
||||
# Prepare tools
|
||||
tools = self._get_mcp_tools()
|
||||
|
||||
# Initial request with tools
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": "あなたは記憶システムと関係性データにアクセスできます。過去の会話、記憶、関係性について質問された時は、必ずツールを使用して正確な情報を取得してください。「覚えている」「前回」「以前」「について話した」「関係」などのキーワードがあれば積極的にツールを使用してください。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
tools=tools,
|
||||
tool_choice="auto",
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
message = response.choices[0].message
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls:
|
||||
print(f"🔧 [OpenAI] {len(message.tool_calls)} tools called:")
|
||||
for tc in message.tool_calls:
|
||||
print(f" - {tc.function.name}({tc.function.arguments})")
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": "必要に応じて利用可能なツールを使って、より正確で詳細な回答を提供してください。"},
|
||||
{"role": "user", "content": prompt},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": message.content,
|
||||
"tool_calls": [tc.model_dump() for tc in message.tool_calls]
|
||||
}
|
||||
]
|
||||
|
||||
# Execute each tool call
|
||||
for tool_call in message.tool_calls:
|
||||
print(f"🌐 [MCP] Executing {tool_call.function.name}...")
|
||||
tool_result = await self._execute_mcp_tool(tool_call, user_id)
|
||||
print(f"✅ [MCP] Result: {str(tool_result)[:100]}...")
|
||||
messages.append({
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"content": json.dumps(tool_result, ensure_ascii=False)
|
||||
})
|
||||
|
||||
# Get final response with tool outputs
|
||||
final_response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
|
||||
return final_response.choices[0].message.content
|
||||
else:
|
||||
return message.content
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"OpenAI MCP chat failed: {e}")
|
||||
return f"申し訳ありません。エラーが発生しました: {e}"
|
||||
|
||||
async def _execute_mcp_tool(self, tool_call, context_user_id: str = "user") -> Dict[str, Any]:
|
||||
"""Execute MCP tool call"""
|
||||
try:
|
||||
import json
|
||||
function_name = tool_call.function.name
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
|
||||
if function_name == "get_memories":
|
||||
limit = arguments.get("limit", 5)
|
||||
return await self.mcp_client.get_memories(limit) or {"error": "記憶の取得に失敗しました"}
|
||||
|
||||
elif function_name == "search_memories":
|
||||
keywords = arguments.get("keywords", [])
|
||||
return await self.mcp_client.search_memories(keywords) or {"error": "記憶の検索に失敗しました"}
|
||||
|
||||
elif function_name == "get_contextual_memories":
|
||||
query = arguments.get("query", "")
|
||||
limit = arguments.get("limit", 5)
|
||||
return await self.mcp_client.get_contextual_memories(query, limit) or {"error": "文脈記憶の取得に失敗しました"}
|
||||
|
||||
elif function_name == "get_relationship":
|
||||
# 引数のuser_idがない場合はコンテキストから取得
|
||||
user_id = arguments.get("user_id", context_user_id)
|
||||
if not user_id or user_id == "user":
|
||||
user_id = context_user_id
|
||||
# デバッグ用ログ
|
||||
print(f"🔍 [DEBUG] get_relationship called with user_id: '{user_id}' (context: '{context_user_id}')")
|
||||
result = await self.mcp_client.get_relationship(user_id)
|
||||
print(f"🔍 [DEBUG] MCP result: {result}")
|
||||
return result or {"error": "関係性の取得に失敗しました"}
|
||||
|
||||
else:
|
||||
return {"error": f"未知のツール: {function_name}"}
|
||||
|
||||
except Exception as e:
|
||||
return {"error": f"ツール実行エラー: {str(e)}"}
|
||||
|
||||
def chat(self, prompt: str, max_tokens: int = 2000) -> str:
|
||||
"""Simple chat interface without MCP tools"""
|
||||
try:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
max_tokens=max_tokens,
|
||||
temperature=0.7
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.logger.error(f"OpenAI chat failed: {e}")
|
||||
return "I'm having trouble connecting to the AI model."
|
||||
|
||||
def _fallback_response(self, persona_state: PersonaState) -> str:
|
||||
"""Fallback response based on mood"""
|
||||
mood_responses = {
|
||||
@ -196,7 +402,7 @@ Recent memories:
|
||||
return mood_responses.get(persona_state.current_mood, "I see.")
|
||||
|
||||
|
||||
def create_ai_provider(provider: str = "ollama", model: Optional[str] = None, **kwargs) -> AIProvider:
|
||||
def create_ai_provider(provider: str = "ollama", model: Optional[str] = None, mcp_client=None, **kwargs) -> AIProvider:
|
||||
"""Factory function to create AI providers"""
|
||||
if provider == "ollama":
|
||||
# Get model from config if not provided
|
||||
@ -228,6 +434,6 @@ def create_ai_provider(provider: str = "ollama", model: Optional[str] = None, **
|
||||
model = config.get('providers.openai.default_model', 'gpt-4o-mini')
|
||||
except:
|
||||
model = 'gpt-4o-mini' # Fallback to default
|
||||
return OpenAIProvider(model=model, **kwargs)
|
||||
return OpenAIProvider(model=model, mcp_client=mcp_client, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
||||
|
Reference in New Issue
Block a user