fix card
This commit is contained in:
@ -239,6 +239,85 @@ class OpenAIProvider:
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Add ai.card tools if available
|
||||
if hasattr(self.mcp_client, 'has_card_tools') and self.mcp_client.has_card_tools:
|
||||
card_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "card_get_user_cards",
|
||||
"description": "ユーザーが所有するカードの一覧を取得します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"did": {
|
||||
"type": "string",
|
||||
"description": "ユーザーのDID"
|
||||
},
|
||||
"limit": {
|
||||
"type": "integer",
|
||||
"description": "取得するカード数の上限",
|
||||
"default": 10
|
||||
}
|
||||
},
|
||||
"required": ["did"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "card_draw_card",
|
||||
"description": "ガチャを引いてカードを取得します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"did": {
|
||||
"type": "string",
|
||||
"description": "ユーザーのDID"
|
||||
},
|
||||
"is_paid": {
|
||||
"type": "boolean",
|
||||
"description": "有料ガチャかどうか",
|
||||
"default": False
|
||||
}
|
||||
},
|
||||
"required": ["did"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "card_analyze_collection",
|
||||
"description": "ユーザーのカードコレクションを分析します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"did": {
|
||||
"type": "string",
|
||||
"description": "ユーザーのDID"
|
||||
}
|
||||
},
|
||||
"required": ["did"]
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "card_get_gacha_stats",
|
||||
"description": "ガチャの統計情報を取得します",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
tools.extend(card_tools)
|
||||
|
||||
return tools
|
||||
|
||||
async def generate_response(
|
||||
@ -298,7 +377,7 @@ Recent memories:
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=[
|
||||
{"role": "system", "content": self.config_system_prompt or "あなたは記憶システムと関係性データにアクセスできます。過去の会話、記憶、関係性について質問された時は、必ずツールを使用して正確な情報を取得してください。「覚えている」「前回」「以前」「について話した」「関係」などのキーワードがあれば積極的にツールを使用してください。"},
|
||||
{"role": "system", "content": self.config_system_prompt or "あなたは記憶システムと関係性データ、カードゲームシステムにアクセスできます。過去の会話、記憶、関係性について質問された時は、必ずツールを使用して正確な情報を取得してください。「覚えている」「前回」「以前」「について話した」「関係」などのキーワードがあれば積極的にツールを使用してください。カード関連の質問(「カード」「コレクション」「ガチャ」「見せて」「持っている」など)では、必ずcard_get_user_cardsやcard_analyze_collectionなどのツールを使用してください。didパラメータには現在会話しているユーザーのID(例:'syui')を使用してください。"},
|
||||
{"role": "user", "content": prompt}
|
||||
],
|
||||
tools=tools,
|
||||
@ -384,6 +463,49 @@ Recent memories:
|
||||
print(f"🔍 [DEBUG] MCP result: {result}")
|
||||
return result or {"error": "関係性の取得に失敗しました"}
|
||||
|
||||
# ai.card tools
|
||||
elif function_name == "card_get_user_cards":
|
||||
did = arguments.get("did", context_user_id)
|
||||
limit = arguments.get("limit", 10)
|
||||
result = await self.mcp_client.card_get_user_cards(did, limit)
|
||||
# Check if ai.card server is not running
|
||||
if result and result.get("error") == "ai.card server is not running":
|
||||
return {
|
||||
"error": "ai.cardサーバーが起動していません",
|
||||
"message": "カードシステムを使用するには、別のターミナルで以下のコマンドを実行してください:\ncd card && ./start_server.sh"
|
||||
}
|
||||
return result or {"error": "カード一覧の取得に失敗しました"}
|
||||
|
||||
elif function_name == "card_draw_card":
|
||||
did = arguments.get("did", context_user_id)
|
||||
is_paid = arguments.get("is_paid", False)
|
||||
result = await self.mcp_client.card_draw_card(did, is_paid)
|
||||
if result and result.get("error") == "ai.card server is not running":
|
||||
return {
|
||||
"error": "ai.cardサーバーが起動していません",
|
||||
"message": "カードシステムを使用するには、別のターミナルで以下のコマンドを実行してください:\ncd card && ./start_server.sh"
|
||||
}
|
||||
return result or {"error": "ガチャに失敗しました"}
|
||||
|
||||
elif function_name == "card_analyze_collection":
|
||||
did = arguments.get("did", context_user_id)
|
||||
result = await self.mcp_client.card_analyze_collection(did)
|
||||
if result and result.get("error") == "ai.card server is not running":
|
||||
return {
|
||||
"error": "ai.cardサーバーが起動していません",
|
||||
"message": "カードシステムを使用するには、別のターミナルで以下のコマンドを実行してください:\ncd card && ./start_server.sh"
|
||||
}
|
||||
return result or {"error": "コレクション分析に失敗しました"}
|
||||
|
||||
elif function_name == "card_get_gacha_stats":
|
||||
result = await self.mcp_client.card_get_gacha_stats()
|
||||
if result and result.get("error") == "ai.card server is not running":
|
||||
return {
|
||||
"error": "ai.cardサーバーが起動していません",
|
||||
"message": "カードシステムを使用するには、別のターミナルで以下のコマンドを実行してください:\ncd card && ./start_server.sh"
|
||||
}
|
||||
return result or {"error": "ガチャ統計の取得に失敗しました"}
|
||||
|
||||
else:
|
||||
return {"error": f"未知のツール: {function_name}"}
|
||||
|
||||
|
134
src/aigpt/cli.py
134
src/aigpt/cli.py
@ -41,6 +41,7 @@ class MCPClient:
|
||||
self.auto_detect = self.config.get("mcp.auto_detect", True)
|
||||
self.servers = self.config.get("mcp.servers", {})
|
||||
self.available = False
|
||||
self.has_card_tools = False
|
||||
|
||||
if self.enabled:
|
||||
self._check_availability()
|
||||
@ -75,6 +76,16 @@ class MCPClient:
|
||||
self.available = True
|
||||
self.active_server = "ai_gpt"
|
||||
print(f"✅ [MCP Client] ai_gpt server connected successfully")
|
||||
|
||||
# Check if card tools are available
|
||||
try:
|
||||
card_status = client.get(f"{base_url}/card_system_status")
|
||||
if card_status.status_code == 200:
|
||||
self.has_card_tools = True
|
||||
print(f"✅ [MCP Client] ai.card tools detected and available")
|
||||
except:
|
||||
print(f"🔍 [MCP Client] ai.card tools not available")
|
||||
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"🚨 [MCP Client] ai_gpt connection failed: {e}")
|
||||
@ -224,8 +235,70 @@ class MCPClient:
|
||||
"display_name": server_config.get("name", self.active_server),
|
||||
"base_url": server_config.get("base_url", ""),
|
||||
"timeout": server_config.get("timeout", 5.0),
|
||||
"endpoints": len(server_config.get("endpoints", {}))
|
||||
"endpoints": len(server_config.get("endpoints", {})),
|
||||
"has_card_tools": self.has_card_tools
|
||||
}
|
||||
|
||||
# ai.card MCP methods
|
||||
async def card_get_user_cards(self, did: str, limit: int = 10) -> Optional[Dict[str, Any]]:
|
||||
"""Get user's card collection via MCP"""
|
||||
if not self.has_card_tools:
|
||||
return {"error": "ai.card tools not available"}
|
||||
|
||||
url = self._get_url("card_get_user_cards")
|
||||
if not url:
|
||||
return None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._get_timeout()) as client:
|
||||
response = await client.get(f"{url}?did={did}&limit={limit}")
|
||||
return response.json() if response.status_code == 200 else None
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to get cards: {str(e)}"}
|
||||
|
||||
async def card_draw_card(self, did: str, is_paid: bool = False) -> Optional[Dict[str, Any]]:
|
||||
"""Draw a card from gacha system via MCP"""
|
||||
if not self.has_card_tools:
|
||||
return {"error": "ai.card tools not available"}
|
||||
|
||||
url = self._get_url("card_draw_card")
|
||||
if not url:
|
||||
return None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._get_timeout()) as client:
|
||||
response = await client.post(url, json={"did": did, "is_paid": is_paid})
|
||||
return response.json() if response.status_code == 200 else None
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to draw card: {str(e)}"}
|
||||
|
||||
async def card_analyze_collection(self, did: str) -> Optional[Dict[str, Any]]:
|
||||
"""Analyze card collection via MCP"""
|
||||
if not self.has_card_tools:
|
||||
return {"error": "ai.card tools not available"}
|
||||
|
||||
url = self._get_url("card_analyze_collection")
|
||||
if not url:
|
||||
return None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._get_timeout()) as client:
|
||||
response = await client.get(f"{url}?did={did}")
|
||||
return response.json() if response.status_code == 200 else None
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to analyze collection: {str(e)}"}
|
||||
|
||||
async def card_get_gacha_stats(self) -> Optional[Dict[str, Any]]:
|
||||
"""Get gacha statistics via MCP"""
|
||||
if not self.has_card_tools:
|
||||
return {"error": "ai.card tools not available"}
|
||||
|
||||
url = self._get_url("card_get_gacha_stats")
|
||||
if not url:
|
||||
return None
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self._get_timeout()) as client:
|
||||
response = await client.get(url)
|
||||
return response.json() if response.status_code == 200 else None
|
||||
except Exception as e:
|
||||
return {"error": f"Failed to get gacha stats: {str(e)}"}
|
||||
|
||||
|
||||
def get_persona(data_dir: Optional[Path] = None) -> Persona:
|
||||
@ -248,15 +321,34 @@ def chat(
|
||||
"""Chat with the AI"""
|
||||
persona = get_persona(data_dir)
|
||||
|
||||
# Create AI provider if specified
|
||||
# Get config instance
|
||||
config_instance = Config()
|
||||
|
||||
# Get defaults from config if not provided
|
||||
if not provider:
|
||||
provider = config_instance.get("default_provider", "ollama")
|
||||
if not model:
|
||||
if provider == "ollama":
|
||||
model = config_instance.get("providers.ollama.default_model", "qwen2.5")
|
||||
else:
|
||||
model = config_instance.get("providers.openai.default_model", "gpt-4o-mini")
|
||||
|
||||
# Create AI provider with MCP client if needed
|
||||
ai_provider = None
|
||||
if provider and model:
|
||||
try:
|
||||
ai_provider = create_ai_provider(provider=provider, model=model)
|
||||
console.print(f"[dim]Using {provider} with model {model}[/dim]\n")
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not create AI provider: {e}[/yellow]")
|
||||
console.print("[yellow]Falling back to simple responses[/yellow]\n")
|
||||
mcp_client = None
|
||||
|
||||
try:
|
||||
# Create MCP client for OpenAI provider
|
||||
if provider == "openai":
|
||||
mcp_client = MCPClient(config_instance)
|
||||
if mcp_client.available:
|
||||
console.print(f"[dim]MCP client connected to {mcp_client.active_server}[/dim]")
|
||||
|
||||
ai_provider = create_ai_provider(provider=provider, model=model, mcp_client=mcp_client)
|
||||
console.print(f"[dim]Using {provider} with model {model}[/dim]\n")
|
||||
except Exception as e:
|
||||
console.print(f"[yellow]Warning: Could not create AI provider: {e}[/yellow]")
|
||||
console.print("[yellow]Falling back to simple responses[/yellow]\n")
|
||||
|
||||
# Process interaction
|
||||
response, relationship_delta = persona.process_interaction(user_id, message, ai_provider)
|
||||
@ -465,6 +557,10 @@ def server(
|
||||
system_endpoints = ["get_persona_state", "get_fortune", "run_maintenance"]
|
||||
shell_endpoints = ["execute_command", "analyze_file", "write_file", "list_files", "read_project_file"]
|
||||
remote_endpoints = ["remote_shell", "ai_bot_status", "isolated_python", "isolated_analysis"]
|
||||
card_endpoints = ["card_get_user_cards", "card_draw_card", "card_get_card_details", "card_analyze_collection", "card_get_gacha_stats", "card_system_status"]
|
||||
|
||||
# Check if ai.card tools are available
|
||||
has_card_tools = mcp_server.has_card
|
||||
|
||||
# Build endpoint summary
|
||||
endpoint_summary = f"""🧠 Memory System: {len(memory_endpoints)} tools
|
||||
@ -473,10 +569,18 @@ def server(
|
||||
💻 Shell Integration: {len(shell_endpoints)} tools
|
||||
🔒 Remote Execution: {len(remote_endpoints)} tools"""
|
||||
|
||||
if has_card_tools:
|
||||
endpoint_summary += f"\n🎴 Card Game System: {len(card_endpoints)} tools"
|
||||
|
||||
# Check MCP client connectivity
|
||||
mcp_client = MCPClient(config_instance)
|
||||
mcp_status = "✅ MCP Client Ready" if mcp_client.available else "⚠️ MCP Client Disabled"
|
||||
|
||||
# Add ai.card status if available
|
||||
card_status = ""
|
||||
if has_card_tools:
|
||||
card_status = "\n🎴 ai.card: ./card directory detected"
|
||||
|
||||
# Provider configuration check
|
||||
provider_status = "✅ Ready"
|
||||
if provider == "openai":
|
||||
@ -500,7 +604,7 @@ def server(
|
||||
f"{endpoint_summary}\n\n"
|
||||
f"[green]Integration Status:[/green]\n"
|
||||
f"{mcp_status}\n"
|
||||
f"🔗 Config: {config_instance.config_file}\n\n"
|
||||
f"🔗 Config: {config_instance.config_file}{card_status}\n\n"
|
||||
f"[dim]Press Ctrl+C to stop server[/dim]",
|
||||
title="🔧 MCP Server Startup",
|
||||
border_style="green",
|
||||
@ -1367,7 +1471,15 @@ def conversation(
|
||||
console.print(" /search <keywords> - Search memories")
|
||||
console.print(" /context <query> - Get contextual memories")
|
||||
console.print(" /relationship - Show relationship via MCP")
|
||||
console.print(" <message> - Chat with AI\n")
|
||||
|
||||
if mcp_client.has_card_tools:
|
||||
console.print(f"\n[cyan]Card Commands:[/cyan]")
|
||||
console.print(" AI can answer questions about cards:")
|
||||
console.print(" - 'Show my cards'")
|
||||
console.print(" - 'Draw a card' / 'Gacha'")
|
||||
console.print(" - 'Analyze my collection'")
|
||||
console.print(" - 'Show gacha stats'")
|
||||
console.print("\n <message> - Chat with AI\n")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == '/clear':
|
||||
|
@ -34,7 +34,15 @@ class AIGptMcpServer:
|
||||
# Create MCP server with FastAPI app
|
||||
self.server = FastApiMCP(self.app)
|
||||
|
||||
# Check if ai.card exists
|
||||
self.card_dir = Path("./card")
|
||||
self.has_card = self.card_dir.exists() and self.card_dir.is_dir()
|
||||
|
||||
self._register_tools()
|
||||
|
||||
# Register ai.card tools if available
|
||||
if self.has_card:
|
||||
self._register_card_tools()
|
||||
|
||||
def _register_tools(self):
|
||||
"""Register all MCP tools"""
|
||||
@ -484,6 +492,148 @@ class AIGptMcpServer:
|
||||
# Python コードを /sh 経由で実行
|
||||
python_command = f'python3 -c "{code.replace('"', '\\"')}"'
|
||||
return await remote_shell(python_command, ai_bot_url)
|
||||
|
||||
def _register_card_tools(self):
|
||||
"""Register ai.card MCP tools when card directory exists"""
|
||||
logger.info("Registering ai.card tools...")
|
||||
|
||||
@self.app.get("/card_get_user_cards", operation_id="card_get_user_cards")
|
||||
async def card_get_user_cards(did: str, limit: int = 10) -> Dict[str, Any]:
|
||||
"""Get user's card collection from ai.card system"""
|
||||
logger.info(f"🎴 [ai.card] Getting cards for did: {did}, limit: {limit}")
|
||||
try:
|
||||
url = "http://localhost:8000/get_user_cards"
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
logger.info(f"🎴 [ai.card] Calling: {url}")
|
||||
response = await client.get(
|
||||
url,
|
||||
params={"did": did, "limit": limit}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
cards = response.json()
|
||||
return {
|
||||
"cards": cards,
|
||||
"count": len(cards),
|
||||
"did": did
|
||||
}
|
||||
else:
|
||||
return {"error": f"Failed to get cards: {response.status_code}"}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"error": "ai.card server is not running",
|
||||
"hint": "Please start ai.card server: cd card && ./start_server.sh",
|
||||
"details": "Connection refused to http://localhost:8000"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"ai.card connection failed: {str(e)}"}
|
||||
|
||||
@self.app.post("/card_draw_card", operation_id="card_draw_card")
|
||||
async def card_draw_card(did: str, is_paid: bool = False) -> Dict[str, Any]:
|
||||
"""Draw a card from gacha system"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.post(
|
||||
f"http://localhost:8000/draw_card?did={did}&is_paid={is_paid}"
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Failed to draw card: {response.status_code}"}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"error": "ai.card server is not running",
|
||||
"hint": "Please start ai.card server: cd card && ./start_server.sh",
|
||||
"details": "Connection refused to http://localhost:8000"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"ai.card connection failed: {str(e)}"}
|
||||
|
||||
@self.app.get("/card_get_card_details", operation_id="card_get_card_details")
|
||||
async def card_get_card_details(card_id: int) -> Dict[str, Any]:
|
||||
"""Get detailed information about a specific card"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"http://localhost:8000/get_card_details",
|
||||
params={"card_id": card_id}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Failed to get card details: {response.status_code}"}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"error": "ai.card server is not running",
|
||||
"hint": "Please start ai.card server: cd card && ./start_server.sh",
|
||||
"details": "Connection refused to http://localhost:8000"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"ai.card connection failed: {str(e)}"}
|
||||
|
||||
@self.app.get("/card_analyze_collection", operation_id="card_analyze_collection")
|
||||
async def card_analyze_collection(did: str) -> Dict[str, Any]:
|
||||
"""Analyze user's card collection statistics"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(
|
||||
"http://localhost:8000/analyze_card_collection",
|
||||
params={"did": did}
|
||||
)
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Failed to analyze collection: {response.status_code}"}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"error": "ai.card server is not running",
|
||||
"hint": "Please start ai.card server: cd card && ./start_server.sh",
|
||||
"details": "Connection refused to http://localhost:8000"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"ai.card connection failed: {str(e)}"}
|
||||
|
||||
@self.app.get("/card_get_gacha_stats", operation_id="card_get_gacha_stats")
|
||||
async def card_get_gacha_stats() -> Dict[str, Any]:
|
||||
"""Get gacha system statistics"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get("http://localhost:8000/get_gacha_stats")
|
||||
if response.status_code == 200:
|
||||
return response.json()
|
||||
else:
|
||||
return {"error": f"Failed to get gacha stats: {response.status_code}"}
|
||||
except httpx.ConnectError:
|
||||
return {
|
||||
"error": "ai.card server is not running",
|
||||
"hint": "Please start ai.card server: cd card && ./start_server.sh",
|
||||
"details": "Connection refused to http://localhost:8000"
|
||||
}
|
||||
except Exception as e:
|
||||
return {"error": f"ai.card connection failed: {str(e)}"}
|
||||
|
||||
@self.app.get("/card_system_status", operation_id="card_system_status")
|
||||
async def card_system_status() -> Dict[str, Any]:
|
||||
"""Check ai.card system status"""
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get("http://localhost:8000/health")
|
||||
if response.status_code == 200:
|
||||
return {
|
||||
"status": "online",
|
||||
"health": response.json(),
|
||||
"card_dir": str(self.card_dir)
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "error",
|
||||
"error": f"Health check failed: {response.status_code}"
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "offline",
|
||||
"error": f"ai.card is not running: {str(e)}",
|
||||
"hint": "Start ai.card with: cd card && ./start_server.sh"
|
||||
}
|
||||
|
||||
@self.app.post("/isolated_analysis", operation_id="isolated_analysis")
|
||||
async def isolated_analysis(file_path: str, analysis_type: str = "structure", ai_bot_url: str = "http://localhost:8080") -> Dict[str, Any]:
|
||||
|
@ -133,7 +133,15 @@ FORTUNE: {state.fortune.fortune_value}/10
|
||||
if context_parts:
|
||||
context_prompt += "RELEVANT CONTEXT:\n" + "\n\n".join(context_parts) + "\n\n"
|
||||
|
||||
context_prompt += f"""Respond to this message while staying true to your personality and the established relationship context:
|
||||
context_prompt += f"""IMPORTANT: You have access to the following tools:
|
||||
- Memory tools: get_memories, search_memories, get_contextual_memories
|
||||
- Relationship tools: get_relationship
|
||||
- Card game tools: card_get_user_cards, card_draw_card, card_analyze_collection
|
||||
|
||||
When asked about cards, collections, or anything card-related, YOU MUST use the card tools.
|
||||
For "カードコレクションを見せて" or similar requests, use card_get_user_cards with did='{user_id}'.
|
||||
|
||||
Respond to this message while staying true to your personality and the established relationship context:
|
||||
|
||||
User: {current_message}
|
||||
|
||||
|
15
src/aigpt/shared/__init__.py
Normal file
15
src/aigpt/shared/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
"""Shared modules for AI ecosystem"""
|
||||
|
||||
from .ai_provider import (
|
||||
AIProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
create_ai_provider
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
'AIProvider',
|
||||
'OllamaProvider',
|
||||
'OpenAIProvider',
|
||||
'create_ai_provider'
|
||||
]
|
139
src/aigpt/shared/ai_provider.py
Normal file
139
src/aigpt/shared/ai_provider.py
Normal file
@ -0,0 +1,139 @@
|
||||
"""Shared AI Provider implementation for ai ecosystem"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Optional, Dict, List, Any, Protocol
|
||||
from abc import abstractmethod
|
||||
import httpx
|
||||
from openai import OpenAI
|
||||
import ollama
|
||||
|
||||
|
||||
class AIProvider(Protocol):
|
||||
"""Protocol for AI providers"""
|
||||
|
||||
@abstractmethod
|
||||
async def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
"""Generate a response based on prompt"""
|
||||
pass
|
||||
|
||||
|
||||
class OllamaProvider:
|
||||
"""Ollama AI provider - shared implementation"""
|
||||
|
||||
def __init__(self, model: str = "qwen3", host: Optional[str] = None, config_system_prompt: Optional[str] = None):
|
||||
self.model = model
|
||||
# Use environment variable OLLAMA_HOST if available
|
||||
self.host = host or os.getenv('OLLAMA_HOST', 'http://127.0.0.1:11434')
|
||||
# Ensure proper URL format
|
||||
if not self.host.startswith('http'):
|
||||
self.host = f'http://{self.host}'
|
||||
self.client = ollama.Client(host=self.host, timeout=60.0)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.logger.info(f"OllamaProvider initialized with host: {self.host}, model: {self.model}")
|
||||
self.config_system_prompt = config_system_prompt
|
||||
|
||||
async def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
"""Simple chat interface"""
|
||||
try:
|
||||
messages = []
|
||||
# Use provided system_prompt, fall back to config_system_prompt
|
||||
final_system_prompt = system_prompt or self.config_system_prompt
|
||||
if final_system_prompt:
|
||||
messages.append({"role": "system", "content": final_system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = self.client.chat(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
options={
|
||||
"num_predict": 2000,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
stream=False
|
||||
)
|
||||
return self._clean_response(response['message']['content'])
|
||||
except Exception as e:
|
||||
self.logger.error(f"Ollama chat failed (host: {self.host}): {e}")
|
||||
return "I'm having trouble connecting to the AI model."
|
||||
|
||||
def _clean_response(self, response: str) -> str:
|
||||
"""Clean response by removing think tags and other unwanted content"""
|
||||
import re
|
||||
# Remove <think></think> tags and their content
|
||||
response = re.sub(r'<think>.*?</think>', '', response, flags=re.DOTALL)
|
||||
# Remove any remaining whitespace at the beginning/end
|
||||
response = response.strip()
|
||||
return response
|
||||
|
||||
|
||||
class OpenAIProvider:
|
||||
"""OpenAI API provider - shared implementation"""
|
||||
|
||||
def __init__(self, model: str = "gpt-4o-mini", api_key: Optional[str] = None,
|
||||
config_system_prompt: Optional[str] = None, mcp_client=None):
|
||||
self.model = model
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
if not self.api_key:
|
||||
raise ValueError("OpenAI API key not provided")
|
||||
self.client = OpenAI(api_key=self.api_key)
|
||||
self.logger = logging.getLogger(__name__)
|
||||
self.config_system_prompt = config_system_prompt
|
||||
self.mcp_client = mcp_client
|
||||
|
||||
async def chat(self, prompt: str, system_prompt: Optional[str] = None) -> str:
|
||||
"""Simple chat interface without MCP tools"""
|
||||
try:
|
||||
messages = []
|
||||
# Use provided system_prompt, fall back to config_system_prompt
|
||||
final_system_prompt = system_prompt or self.config_system_prompt
|
||||
if final_system_prompt:
|
||||
messages.append({"role": "system", "content": final_system_prompt})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
|
||||
response = self.client.chat.completions.create(
|
||||
model=self.model,
|
||||
messages=messages,
|
||||
max_tokens=2000,
|
||||
temperature=0.7
|
||||
)
|
||||
return response.choices[0].message.content
|
||||
except Exception as e:
|
||||
self.logger.error(f"OpenAI chat failed: {e}")
|
||||
return "I'm having trouble connecting to the AI model."
|
||||
|
||||
def _get_mcp_tools(self) -> List[Dict[str, Any]]:
|
||||
"""Override this method in subclasses to provide MCP tools"""
|
||||
return []
|
||||
|
||||
async def chat_with_mcp(self, prompt: str, **kwargs) -> str:
|
||||
"""Chat interface with MCP function calling support
|
||||
|
||||
This method should be overridden in subclasses to provide
|
||||
specific MCP functionality.
|
||||
"""
|
||||
if not self.mcp_client:
|
||||
return await self.chat(prompt)
|
||||
|
||||
# Default implementation - subclasses should override
|
||||
return await self.chat(prompt)
|
||||
|
||||
async def _execute_mcp_tool(self, tool_call, **kwargs) -> Dict[str, Any]:
|
||||
"""Execute MCP tool call - override in subclasses"""
|
||||
return {"error": "MCP tool execution not implemented"}
|
||||
|
||||
|
||||
def create_ai_provider(provider: str = "ollama", model: Optional[str] = None,
|
||||
config_system_prompt: Optional[str] = None, mcp_client=None, **kwargs) -> AIProvider:
|
||||
"""Factory function to create AI providers"""
|
||||
if provider == "ollama":
|
||||
model = model or "qwen3"
|
||||
return OllamaProvider(model=model, config_system_prompt=config_system_prompt, **kwargs)
|
||||
elif provider == "openai":
|
||||
model = model or "gpt-4o-mini"
|
||||
return OpenAIProvider(model=model, config_system_prompt=config_system_prompt,
|
||||
mcp_client=mcp_client, **kwargs)
|
||||
else:
|
||||
raise ValueError(f"Unknown provider: {provider}")
|
Reference in New Issue
Block a user