From 5b2379716b156aec1ccd57069b6667f3b8eca2e8 Mon Sep 17 00:00:00 2001 From: syui Date: Tue, 3 Jun 2025 05:00:51 +0900 Subject: [PATCH] fix gpt --- api/app/ai_provider.py | 290 +++++++++++++++++++++++++++++++++++++++++ api/app/mcp_server.py | 4 + 2 files changed, 294 insertions(+) create mode 100644 api/app/ai_provider.py diff --git a/api/app/ai_provider.py b/api/app/ai_provider.py new file mode 100644 index 0000000..070e0eb --- /dev/null +++ b/api/app/ai_provider.py @@ -0,0 +1,290 @@ +"""AI Provider integration for ai.card""" + +import os +import json +from typing import Optional, Dict, List, Any +from abc import ABC, abstractmethod +import logging +import httpx +from openai import OpenAI +import ollama + + +class AIProvider(ABC): + """Base class for AI providers""" + + @abstractmethod + async def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str: + """Generate a response based on prompt""" + pass + + +class OllamaProvider(AIProvider): + """Ollama AI provider for ai.card""" + + def __init__(self, model: str = "qwen3", host: Optional[str] = None): + self.model = model + self.host = host or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434') + if not self.host.startswith('http'): + self.host = f'http://{self.host}' + self.client = ollama.Client(host=self.host, timeout=60.0) + self.logger = logging.getLogger(__name__) + self.logger.info(f"OllamaProvider initialized with host: {self.host}, model: {self.model}") + + async def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str: + """Simple chat interface""" + try: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = self.client.chat( + model=self.model, + messages=messages, + options={ + "num_predict": 2000, + "temperature": 0.7, + "top_p": 0.9, + }, + stream=False + ) + return response['message']['content'] + except Exception as e: + self.logger.error(f"Ollama chat failed: {e}") + return "I'm having trouble connecting to the AI model." + + +class OpenAIProvider(AIProvider): + """OpenAI API provider with MCP function calling support""" + + def __init__(self, model: str = "gpt-4o-mini", api_key: Optional[str] = None, mcp_client=None): + self.model = model + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + if not self.api_key: + raise ValueError("OpenAI API key not provided") + self.client = OpenAI(api_key=self.api_key) + self.logger = logging.getLogger(__name__) + self.mcp_client = mcp_client + + def _get_mcp_tools(self) -> List[Dict[str, Any]]: + """Generate OpenAI tools from MCP endpoints""" + if not self.mcp_client: + return [] + + tools = [ + { + "type": "function", + "function": { + "name": "get_user_cards", + "description": "ユーザーが所有するカードの一覧を取得します", + "parameters": { + "type": "object", + "properties": { + "did": { + "type": "string", + "description": "ユーザーのDID" + }, + "limit": { + "type": "integer", + "description": "取得するカード数の上限", + "default": 10 + } + }, + "required": ["did"] + } + } + }, + { + "type": "function", + "function": { + "name": "draw_card", + "description": "ガチャを引いてカードを取得します", + "parameters": { + "type": "object", + "properties": { + "did": { + "type": "string", + "description": "ユーザーのDID" + }, + "is_paid": { + "type": "boolean", + "description": "有料ガチャかどうか", + "default": False + } + }, + "required": ["did"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_card_details", + "description": "特定のカードの詳細情報を取得します", + "parameters": { + "type": "object", + "properties": { + "card_id": { + "type": "integer", + "description": "カードID" + } + }, + "required": ["card_id"] + } + } + }, + { + "type": "function", + "function": { + "name": "analyze_card_collection", + "description": "ユーザーのカードコレクションを分析します", + "parameters": { + "type": "object", + "properties": { + "did": { + "type": "string", + "description": "ユーザーのDID" + } + }, + "required": ["did"] + } + } + }, + { + "type": "function", + "function": { + "name": "get_gacha_stats", + "description": "ガチャの統計情報を取得します", + "parameters": { + "type": "object", + "properties": {} + } + } + } + ] + return tools + + async def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str: + """Simple chat interface without MCP tools""" + try: + messages = [] + if system_prompt: + messages.append({"role": "system", "content": system_prompt}) + messages.append({"role": "user", "content": prompt}) + + response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=2000, + temperature=0.7 + ) + return response.choices[0].message.content + except Exception as e: + self.logger.error(f"OpenAI chat failed: {e}") + return "I'm having trouble connecting to the AI model." + + async def chat_with_mcp(self, prompt: str, did: str = "user") -> str: + """Chat interface with MCP function calling support""" + if not self.mcp_client: + return await self.chat(prompt) + + try: + tools = self._get_mcp_tools() + + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": "あなたはai.cardシステムのアシスタントです。カードゲームの情報、ガチャ、コレクション分析などについて質問されたら、必要に応じてツールを使用して正確な情報を提供してください。"}, + {"role": "user", "content": prompt} + ], + tools=tools, + tool_choice="auto", + max_tokens=2000, + temperature=0.7 + ) + + message = response.choices[0].message + + # Handle tool calls + if message.tool_calls: + messages = [ + {"role": "system", "content": "カードゲームシステムのツールを使って正確な情報を提供してください。"}, + {"role": "user", "content": prompt}, + { + "role": "assistant", + "content": message.content, + "tool_calls": [tc.model_dump() for tc in message.tool_calls] + } + ] + + # Execute each tool call + for tool_call in message.tool_calls: + tool_result = await self._execute_mcp_tool(tool_call, did) + messages.append({ + "role": "tool", + "tool_call_id": tool_call.id, + "name": tool_call.function.name, + "content": json.dumps(tool_result, ensure_ascii=False) + }) + + # Get final response + final_response = self.client.chat.completions.create( + model=self.model, + messages=messages, + max_tokens=2000, + temperature=0.7 + ) + + return final_response.choices[0].message.content + else: + return message.content + + except Exception as e: + self.logger.error(f"OpenAI MCP chat failed: {e}") + return f"申し訳ありません。エラーが発生しました: {e}" + + async def _execute_mcp_tool(self, tool_call, default_did: str = "user") -> Dict[str, Any]: + """Execute MCP tool call""" + try: + function_name = tool_call.function.name + arguments = json.loads(tool_call.function.arguments) + + if function_name == "get_user_cards": + did = arguments.get("did", default_did) + limit = arguments.get("limit", 10) + return await self.mcp_client.get_user_cards(did, limit) + + elif function_name == "draw_card": + did = arguments.get("did", default_did) + is_paid = arguments.get("is_paid", False) + return await self.mcp_client.draw_card(did, is_paid) + + elif function_name == "get_card_details": + card_id = arguments.get("card_id") + return await self.mcp_client.get_card_details(card_id) + + elif function_name == "analyze_card_collection": + did = arguments.get("did", default_did) + return await self.mcp_client.analyze_card_collection(did) + + elif function_name == "get_gacha_stats": + return await self.mcp_client.get_gacha_stats() + + else: + return {"error": f"未知のツール: {function_name}"} + + except Exception as e: + return {"error": f"ツール実行エラー: {str(e)}"} + + +def create_ai_provider(provider: str = "ollama", model: Optional[str] = None, mcp_client=None, **kwargs) -> AIProvider: + """Factory function to create AI providers""" + if provider == "ollama": + model = model or "qwen3" + return OllamaProvider(model=model, **kwargs) + elif provider == "openai": + model = model or "gpt-4o-mini" + return OpenAIProvider(model=model, mcp_client=mcp_client, **kwargs) + else: + raise ValueError(f"Unknown provider: {provider}") \ No newline at end of file diff --git a/api/app/mcp_server.py b/api/app/mcp_server.py index 4faf181..974a176 100644 --- a/api/app/mcp_server.py +++ b/api/app/mcp_server.py @@ -37,6 +37,10 @@ class AICardMcpServer: self.server = FastMCP("aicard") self._register_mcp_tools() + def get_app(self) -> FastAPI: + """Get the FastAPI app instance""" + return self.app + def _register_mcp_tools(self): """Register all MCP tools"""