diff --git a/Cargo.toml b/Cargo.toml index 3a39fc9..c3a32b5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -31,3 +31,7 @@ axum = "0.7" tower = "0.4" tower-http = { version = "0.5", features = ["cors"] } hyper = "1.0" + +# OpenAI API client +async-openai = "0.23" +openai_api_rust = "0.1" diff --git a/src/http_client.rs b/src/http_client.rs index 82bd62a..b11efc1 100644 --- a/src/http_client.rs +++ b/src/http_client.rs @@ -65,6 +65,22 @@ impl ServiceClient { let json: Value = response.json().await?; Ok(json) } + + /// Get user's card collection from ai.card service + pub async fn get_user_cards(&self, user_did: &str) -> Result { + let url = format!("http://localhost:8000/api/v1/cards/user/{}", user_did); + self.get_request(&url).await + } + + /// Draw a card for user from ai.card service + pub async fn draw_card(&self, user_did: &str, is_paid: bool) -> Result { + let payload = serde_json::json!({ + "user_did": user_did, + "is_paid": is_paid + }); + + self.post_request("http://localhost:8000/api/v1/cards/draw", &payload).await + } } /// Service status enum diff --git a/src/lib.rs b/src/lib.rs index d772e66..542bd53 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -9,6 +9,7 @@ pub mod http_client; pub mod import; pub mod mcp_server; pub mod memory; +pub mod openai_provider; pub mod persona; pub mod relationship; pub mod scheduler; diff --git a/src/main.rs b/src/main.rs index cd94f1a..82c870d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -13,6 +13,7 @@ mod http_client; mod import; mod mcp_server; mod memory; +mod openai_provider; mod persona; mod relationship; mod scheduler; diff --git a/src/mcp_server.rs b/src/mcp_server.rs index 0fd1b7e..9bdd7e1 100644 --- a/src/mcp_server.rs +++ b/src/mcp_server.rs @@ -378,6 +378,52 @@ impl MCPServer { } }), }, + MCPTool { + name: "get_user_cards".to_string(), + description: "Get user's card collection from ai.card service".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "user_did": { + "type": "string", + "description": "User DID to get cards for" + } + }, + "required": ["user_did"] + }), + }, + MCPTool { + name: "draw_card".to_string(), + description: "Draw a card from ai.card gacha system".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "user_did": { + "type": "string", + "description": "User DID to draw card for" + }, + "is_paid": { + "type": "boolean", + "description": "Whether this is a premium draw (default: false)" + } + }, + "required": ["user_did"] + }), + }, + MCPTool { + name: "get_draw_status".to_string(), + description: "Check if user can draw cards (daily limit check)".to_string(), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "user_did": { + "type": "string", + "description": "User DID to check draw status for" + } + }, + "required": ["user_did"] + }), + }, ] } @@ -437,6 +483,9 @@ impl MCPServer { "run_scheduler" => self.tool_run_scheduler(arguments).await, "get_scheduler_status" => self.tool_get_scheduler_status(arguments).await, "get_transmission_history" => self.tool_get_transmission_history(arguments).await, + "get_user_cards" => self.tool_get_user_cards(arguments).await, + "draw_card" => self.tool_draw_card(arguments).await, + "get_draw_status" => self.tool_get_draw_status(arguments).await, _ => Err(anyhow::anyhow!("Unknown tool: {}", tool_name)), } } @@ -461,6 +510,9 @@ impl MCPServer { "run_scheduler" => self.tool_run_scheduler(params).await, "get_scheduler_status" => self.tool_get_scheduler_status(params).await, "get_transmission_history" => self.tool_get_transmission_history(params).await, + "get_user_cards" => self.tool_get_user_cards(params).await, + "draw_card" => self.tool_draw_card(params).await, + "get_draw_status" => self.tool_get_draw_status(params).await, _ => Err(anyhow::anyhow!("Unknown tool: {}", tool_name)), } } @@ -1146,6 +1198,115 @@ impl MCPServer { })) } + // MARK: - ai.card Integration Tools + + async fn tool_get_user_cards(&self, args: Value) -> Result { + let user_did = args["user_did"].as_str() + .ok_or_else(|| anyhow::anyhow!("Missing user_did"))?; + + // Use ServiceClient to call ai.card API + match self.service_client.get_user_cards(user_did).await { + Ok(cards) => { + let card_count = cards.as_array().map(|arr| arr.len()).unwrap_or(0); + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": format!("ユーザー {} のカード一覧 ({}枚):\n\n{}", + user_did, + card_count, + serde_json::to_string_pretty(&cards)?) + } + ] + })) + } + Err(e) => { + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": format!("カード取得エラー: {}. ai.cardサーバーが起動していることを確認してください。", e) + } + ] + })) + } + } + } + + async fn tool_draw_card(&self, args: Value) -> Result { + let user_did = args["user_did"].as_str() + .ok_or_else(|| anyhow::anyhow!("Missing user_did"))?; + let is_paid = args["is_paid"].as_bool().unwrap_or(false); + + // Use ServiceClient to call ai.card API + match self.service_client.draw_card(user_did, is_paid).await { + Ok(draw_result) => { + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": format!("🎉 カードドロー結果:\n\n{}", + serde_json::to_string_pretty(&draw_result)?) + } + ] + })) + } + Err(e) => { + let error_msg = e.to_string(); + if error_msg.contains("429") { + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": "⏰ カードドロー制限中です。日別制限により、現在カードを引くことができません。時間を置いてから再度お試しください。" + } + ] + })) + } else { + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": format!("カードドローエラー: {}. ai.cardサーバーが起動していることを確認してください。", e) + } + ] + })) + } + } + } + } + + async fn tool_get_draw_status(&self, args: Value) -> Result { + let user_did = args["user_did"].as_str() + .ok_or_else(|| anyhow::anyhow!("Missing user_did"))?; + + // Use ServiceClient to call ai.card API + match self.service_client.get_request(&format!("http://localhost:8000/api/v1/cards/draw-status/{}", user_did)).await { + Ok(status) => { + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": format!("ユーザー {} のドロー状況:\n\n{}", + user_did, + serde_json::to_string_pretty(&status)?) + } + ] + })) + } + Err(e) => { + Ok(serde_json::json!({ + "content": [ + { + "type": "text", + "text": format!("ドロー状況取得エラー: {}. ai.cardサーバーが起動していることを確認してください。", e) + } + ] + })) + } + } + } + pub async fn start_server(&mut self, port: u16) -> Result<()> { println!("🚀 Starting MCP Server on port {}", port); println!("📋 Available tools: {}", self.get_tools().len()); diff --git a/src/openai_provider.rs b/src/openai_provider.rs new file mode 100644 index 0000000..87f58e4 --- /dev/null +++ b/src/openai_provider.rs @@ -0,0 +1,390 @@ +use anyhow::Result; +use async_openai::{ + types::{ + ChatCompletionRequestMessage, + CreateChatCompletionRequestArgs, ChatCompletionTool, ChatCompletionToolType, + FunctionObject, ChatCompletionRequestToolMessage, + ChatCompletionRequestAssistantMessage, ChatCompletionRequestUserMessage, + ChatCompletionRequestSystemMessage, ChatCompletionToolChoiceOption + }, + Client, +}; +use serde_json::{json, Value}; + +use crate::http_client::ServiceClient; + +/// OpenAI provider with MCP tools support (matching Python implementation) +pub struct OpenAIProvider { + client: Client, + model: String, + service_client: ServiceClient, + system_prompt: Option, +} + +impl OpenAIProvider { + pub fn new(api_key: String, model: Option) -> Self { + let config = async_openai::config::OpenAIConfig::new() + .with_api_key(api_key); + let client = Client::with_config(config); + + Self { + client, + model: model.unwrap_or_else(|| "gpt-4".to_string()), + service_client: ServiceClient::new(), + system_prompt: None, + } + } + + pub fn with_system_prompt(mut self, prompt: String) -> Self { + self.system_prompt = Some(prompt); + self + } + + /// Generate OpenAI tools from MCP endpoints (matching Python implementation) + fn get_mcp_tools(&self) -> Vec { + let tools = vec![ + // Memory tools + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "get_memories".to_string(), + description: Some("過去の会話記憶を取得します。「覚えている」「前回」「以前」などの質問で必ず使用してください".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "limit": { + "type": "integer", + "description": "取得する記憶の数", + "default": 5 + } + } + })), + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "search_memories".to_string(), + description: Some("特定のトピックについて話した記憶を検索します。「プログラミングについて」「○○について話した」などの質問で使用してください".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "keywords": { + "type": "array", + "items": {"type": "string"}, + "description": "検索キーワードの配列" + } + }, + "required": ["keywords"] + })), + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "get_contextual_memories".to_string(), + description: Some("クエリに関連する文脈的記憶を取得します".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "検索クエリ" + }, + "limit": { + "type": "integer", + "description": "取得する記憶の数", + "default": 5 + } + }, + "required": ["query"] + })), + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "get_relationship".to_string(), + description: Some("特定ユーザーとの関係性情報を取得します".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "user_id": { + "type": "string", + "description": "ユーザーID" + } + }, + "required": ["user_id"] + })), + }, + }, + // ai.card tools + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "card_get_user_cards".to_string(), + description: Some("ユーザーが所有するカードの一覧を取得します".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "did": { + "type": "string", + "description": "ユーザーのDID" + }, + "limit": { + "type": "integer", + "description": "取得するカード数の上限", + "default": 10 + } + }, + "required": ["did"] + })), + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "card_draw_card".to_string(), + description: Some("ガチャを引いてカードを取得します".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "did": { + "type": "string", + "description": "ユーザーのDID" + }, + "is_paid": { + "type": "boolean", + "description": "有料ガチャかどうか", + "default": false + } + }, + "required": ["did"] + })), + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "card_analyze_collection".to_string(), + description: Some("ユーザーのカードコレクションを分析します".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": { + "did": { + "type": "string", + "description": "ユーザーのDID" + } + }, + "required": ["did"] + })), + }, + }, + ChatCompletionTool { + r#type: ChatCompletionToolType::Function, + function: FunctionObject { + name: "card_get_gacha_stats".to_string(), + description: Some("ガチャの統計情報を取得します".to_string()), + parameters: Some(json!({ + "type": "object", + "properties": {} + })), + }, + }, + ]; + + tools + } + + /// Chat interface with MCP function calling support (matching Python implementation) + pub async fn chat_with_mcp(&self, prompt: String, user_id: String) -> Result { + let tools = self.get_mcp_tools(); + + let system_content = self.system_prompt.as_deref().unwrap_or( + "あなたは記憶システムと関係性データ、カードゲームシステムにアクセスできるAIです。\n\n【重要】以下の場合は必ずツールを使用してください:\n\n1. カード関連の質問:\n- 「カード」「コレクション」「ガチャ」「見せて」「持っている」「状況」「どんなカード」などのキーワードがある場合\n- card_get_user_cardsツールを使用してユーザーのカード情報を取得\n\n2. 記憶・関係性の質問:\n- 「覚えている」「前回」「以前」「について話した」「関係」などのキーワードがある場合\n- 適切なメモリツールを使用\n\n3. パラメータの設定:\n- didパラメータには現在会話しているユーザーのID(例:'syui')を使用\n- ツールを積極的に使用して正確な情報を提供してください\n\nユーザーが何かを尋ねた時は、まず関連するツールがあるかを考え、適切なツールを使用してから回答してください。" + ); + + let request = CreateChatCompletionRequestArgs::default() + .model(&self.model) + .messages(vec![ + ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: system_content.to_string().into(), + name: None, + } + ), + ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: prompt.clone().into(), + name: None, + } + ), + ]) + .tools(tools) + .tool_choice(ChatCompletionToolChoiceOption::Auto) + .max_tokens(2000u16) + .temperature(0.7) + .build()?; + + let response = self.client.chat().create(request).await?; + let message = &response.choices[0].message; + + // Handle tool calls + if let Some(tool_calls) = &message.tool_calls { + if tool_calls.is_empty() { + println!("🔧 [OpenAI] No tools called"); + } else { + println!("🔧 [OpenAI] {} tools called:", tool_calls.len()); + for tc in tool_calls { + println!(" - {}({})", tc.function.name, tc.function.arguments); + } + } + } else { + println!("🔧 [OpenAI] No tools called"); + } + + // Process tool calls if any + if let Some(tool_calls) = &message.tool_calls { + if !tool_calls.is_empty() { + + let mut messages = vec![ + ChatCompletionRequestMessage::System( + ChatCompletionRequestSystemMessage { + content: system_content.to_string().into(), + name: None, + } + ), + ChatCompletionRequestMessage::User( + ChatCompletionRequestUserMessage { + content: prompt.into(), + name: None, + } + ), + ChatCompletionRequestMessage::Assistant( + ChatCompletionRequestAssistantMessage { + content: message.content.clone(), + name: None, + tool_calls: message.tool_calls.clone(), + function_call: None, + } + ), + ]; + + // Execute each tool call + for tool_call in tool_calls { + println!("🌐 [MCP] Executing {}...", tool_call.function.name); + let tool_result = self.execute_mcp_tool(tool_call, &user_id).await?; + let result_preview = serde_json::to_string(&tool_result)?; + let preview = if result_preview.chars().count() > 100 { + format!("{}...", result_preview.chars().take(100).collect::()) + } else { + result_preview.clone() + }; + println!("✅ [MCP] Result: {}", preview); + + messages.push(ChatCompletionRequestMessage::Tool( + ChatCompletionRequestToolMessage { + content: serde_json::to_string(&tool_result)?, + tool_call_id: tool_call.id.clone(), + } + )); + } + + // Get final response with tool outputs + let final_request = CreateChatCompletionRequestArgs::default() + .model(&self.model) + .messages(messages) + .max_tokens(2000u16) + .temperature(0.7) + .build()?; + + let final_response = self.client.chat().create(final_request).await?; + Ok(final_response.choices[0].message.content.as_ref().unwrap_or(&"".to_string()).clone()) + } else { + // No tools were called + Ok(message.content.as_ref().unwrap_or(&"".to_string()).clone()) + } + } else { + // No tool_calls field at all + Ok(message.content.as_ref().unwrap_or(&"".to_string()).clone()) + } + } + + /// Execute MCP tool call (matching Python implementation) + async fn execute_mcp_tool(&self, tool_call: &async_openai::types::ChatCompletionMessageToolCall, context_user_id: &str) -> Result { + let function_name = &tool_call.function.name; + let arguments: Value = serde_json::from_str(&tool_call.function.arguments)?; + + match function_name.as_str() { + "get_memories" => { + let limit = arguments.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); + // TODO: Implement actual MCP call + Ok(json!({"info": "記憶機能は実装中です"})) + } + "search_memories" => { + let _keywords = arguments.get("keywords").and_then(|v| v.as_array()); + // TODO: Implement actual MCP call + Ok(json!({"info": "記憶検索機能は実装中です"})) + } + "get_contextual_memories" => { + let _query = arguments.get("query").and_then(|v| v.as_str()).unwrap_or(""); + let _limit = arguments.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); + // TODO: Implement actual MCP call + Ok(json!({"info": "文脈記憶機能は実装中です"})) + } + "get_relationship" => { + let _user_id = arguments.get("user_id").and_then(|v| v.as_str()).unwrap_or(context_user_id); + // TODO: Implement actual MCP call + Ok(json!({"info": "関係性機能は実装中です"})) + } + // ai.card tools + "card_get_user_cards" => { + let did = arguments.get("did").and_then(|v| v.as_str()).unwrap_or(context_user_id); + let _limit = arguments.get("limit").and_then(|v| v.as_i64()).unwrap_or(10); + + match self.service_client.get_user_cards(did).await { + Ok(result) => Ok(result), + Err(e) => { + println!("❌ ai.card API error: {}", e); + Ok(json!({ + "error": "ai.cardサーバーが起動していません", + "message": "カードシステムを使用するには、ai.cardサーバーを起動してください" + })) + } + } + } + "card_draw_card" => { + let did = arguments.get("did").and_then(|v| v.as_str()).unwrap_or(context_user_id); + let is_paid = arguments.get("is_paid").and_then(|v| v.as_bool()).unwrap_or(false); + + match self.service_client.draw_card(did, is_paid).await { + Ok(result) => Ok(result), + Err(e) => { + println!("❌ ai.card API error: {}", e); + Ok(json!({ + "error": "ai.cardサーバーが起動していません", + "message": "カードシステムを使用するには、ai.cardサーバーを起動してください" + })) + } + } + } + "card_analyze_collection" => { + let did = arguments.get("did").and_then(|v| v.as_str()).unwrap_or(context_user_id); + // TODO: Implement collection analysis endpoint + Ok(json!({ + "info": "コレクション分析機能は実装中です", + "user_did": did + })) + } + "card_get_gacha_stats" => { + // TODO: Implement gacha stats endpoint + Ok(json!({"info": "ガチャ統計機能は実装中です"})) + } + _ => { + Ok(json!({ + "error": format!("Unknown tool: {}", function_name) + })) + } + } + } +} \ No newline at end of file diff --git a/src/persona.rs b/src/persona.rs index 96a92b9..4f3549c 100644 --- a/src/persona.rs +++ b/src/persona.rs @@ -109,37 +109,58 @@ impl Persona { 0.0 }; - // Generate AI response - let ai_config = self.config.get_ai_config(provider, model)?; - let ai_client = AIProviderClient::new(ai_config); - - // Build conversation context - let mut messages = Vec::new(); - - // Get recent memories for context - if let Some(memory_manager) = &mut self.memory_manager { - let recent_memories = memory_manager.get_memories(user_id, 5); - if !recent_memories.is_empty() { - let context = recent_memories.iter() - .map(|m| m.content.clone()) - .collect::>() - .join("\n"); - messages.push(ChatMessage::system(format!("Previous conversation context:\n{}", context))); + // Check provider type and use appropriate client + let response = if provider.as_deref() == Some("openai") { + // Use OpenAI provider with MCP tools + use crate::openai_provider::OpenAIProvider; + + // Get OpenAI API key from config or environment + let api_key = std::env::var("OPENAI_API_KEY") + .or_else(|_| { + self.config.providers.get("openai") + .and_then(|p| p.api_key.clone()) + .ok_or_else(|| std::env::VarError::NotPresent) + }) + .map_err(|_| anyhow::anyhow!("OpenAI API key not found. Set OPENAI_API_KEY environment variable or add to config."))?; + + let openai_model = model.unwrap_or_else(|| "gpt-4".to_string()); + let openai_provider = OpenAIProvider::new(api_key, Some(openai_model)); + + // Use OpenAI with MCP tools support + openai_provider.chat_with_mcp(message.to_string(), user_id.to_string()).await? + } else { + // Use existing AI provider (Ollama) + let ai_config = self.config.get_ai_config(provider, model)?; + let ai_client = AIProviderClient::new(ai_config); + + // Build conversation context + let mut messages = Vec::new(); + + // Get recent memories for context + if let Some(memory_manager) = &mut self.memory_manager { + let recent_memories = memory_manager.get_memories(user_id, 5); + if !recent_memories.is_empty() { + let context = recent_memories.iter() + .map(|m| m.content.clone()) + .collect::>() + .join("\n"); + messages.push(ChatMessage::system(format!("Previous conversation context:\n{}", context))); + } } - } - - // Add current message - messages.push(ChatMessage::user(message)); - - // Generate system prompt based on personality and relationship - let system_prompt = self.generate_system_prompt(user_id); - - // Get AI response - let response = match ai_client.chat(messages, Some(system_prompt)).await { - Ok(chat_response) => chat_response.content, - Err(_) => { - // Fallback to simple response if AI fails - format!("I understand your message: '{}'", message) + + // Add current message + messages.push(ChatMessage::user(message)); + + // Generate system prompt based on personality and relationship + let system_prompt = self.generate_system_prompt(user_id); + + // Get AI response + match ai_client.chat(messages, Some(system_prompt)).await { + Ok(chat_response) => chat_response.content, + Err(_) => { + // Fallback to simple response if AI fails + format!("I understand your message: '{}'", message) + } } };