From 45c65e03b3910cdc50e7d2abdca5e6f836b35383 Mon Sep 17 00:00:00 2001 From: syui Date: Thu, 12 Jun 2025 22:03:52 +0900 Subject: [PATCH] fix memory --- .claude/settings.local.json | 3 ++- src/config.rs | 3 ++- src/openai_provider.rs | 32 +++++++++++++++++++++++++------- src/persona.rs | 11 +++++++++-- 4 files changed, 38 insertions(+), 11 deletions(-) diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 36c0a34..8d55a09 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -53,7 +53,8 @@ "Bash(cargo run:*)", "Bash(cargo test:*)", "Bash(diff:*)", - "Bash(cargo:*)" + "Bash(cargo:*)", + "Bash(pkill:*)" ], "deny": [] } diff --git a/src/config.rs b/src/config.rs index 970b0ca..d40f4cb 100644 --- a/src/config.rs +++ b/src/config.rs @@ -78,8 +78,9 @@ where impl Config { pub fn new(data_dir: Option) -> Result { let data_dir = data_dir.unwrap_or_else(|| { - dirs::config_dir() + dirs::home_dir() .unwrap_or_else(|| PathBuf::from(".")) + .join(".config") .join("syui") .join("ai") .join("gpt") diff --git a/src/openai_provider.rs b/src/openai_provider.rs index 0ff0b1d..d26def9 100644 --- a/src/openai_provider.rs +++ b/src/openai_provider.rs @@ -12,6 +12,7 @@ use async_openai::{ use serde_json::{json, Value}; use crate::http_client::ServiceClient; +use crate::config::Config; /// OpenAI provider with MCP tools support (matching Python implementation) pub struct OpenAIProvider { @@ -19,6 +20,7 @@ pub struct OpenAIProvider { model: String, service_client: ServiceClient, system_prompt: Option, + config: Option, } impl OpenAIProvider { @@ -32,22 +34,36 @@ impl OpenAIProvider { model: model.unwrap_or_else(|| "gpt-4".to_string()), service_client: ServiceClient::new(), system_prompt: None, + config: None, } } - pub fn with_system_prompt(api_key: String, model: Option, system_prompt: Option) -> Self { - let config = async_openai::config::OpenAIConfig::new() + pub fn with_config(api_key: String, model: Option, system_prompt: Option, config: Config) -> Self { + let openai_config = async_openai::config::OpenAIConfig::new() .with_api_key(api_key); - let client = Client::with_config(config); + let client = Client::with_config(openai_config); Self { client, model: model.unwrap_or_else(|| "gpt-4".to_string()), service_client: ServiceClient::new(), system_prompt, + config: Some(config), } } + fn get_mcp_base_url(&self) -> String { + if let Some(config) = &self.config { + if let Some(mcp) = &config.mcp { + if let Some(ai_gpt_server) = mcp.servers.get("ai_gpt") { + return ai_gpt_server.base_url.clone(); + } + } + } + // Fallback to default + "http://localhost:8080".to_string() + } + /// Generate OpenAI tools from MCP endpoints (matching Python implementation) fn get_mcp_tools(&self) -> Vec { let tools = vec![ @@ -333,7 +349,8 @@ impl OpenAIProvider { let limit = arguments.get("limit").and_then(|v| v.as_i64()).unwrap_or(5); // MCP server call to get memories - match self.service_client.get_request(&format!("http://localhost:8080/memories/{}", context_user_id)).await { + let base_url = self.get_mcp_base_url(); + match self.service_client.get_request(&format!("{}/memories/{}", base_url, context_user_id)).await { Ok(result) => { // Extract the actual memory content from MCP response if let Some(content) = result.get("result").and_then(|r| r.get("content")) { @@ -381,7 +398,7 @@ impl OpenAIProvider { }); match self.service_client.post_request( - &format!("http://localhost:8080/memories/{}/search", context_user_id), + &format!("{}/memories/{}/search", self.get_mcp_base_url(), context_user_id), &search_request ).await { Ok(result) => { @@ -433,7 +450,7 @@ impl OpenAIProvider { }); match self.service_client.post_request( - &format!("http://localhost:8080/memories/{}/contextual", context_user_id), + &format!("{}/memories/{}/contextual", self.get_mcp_base_url(), context_user_id), &contextual_request ).await { Ok(result) => { @@ -492,7 +509,8 @@ impl OpenAIProvider { let target_user_id = arguments.get("user_id").and_then(|v| v.as_str()).unwrap_or(context_user_id); // MCP server call to get relationship status - match self.service_client.get_request(&format!("http://localhost:8080/status/{}", target_user_id)).await { + let base_url = self.get_mcp_base_url(); + match self.service_client.get_request(&format!("{}/status/{}", base_url, target_user_id)).await { Ok(result) => { // Extract relationship information from MCP response if let Some(content) = result.get("result").and_then(|r| r.get("content")) { diff --git a/src/persona.rs b/src/persona.rs index 239b621..f94ce62 100644 --- a/src/persona.rs +++ b/src/persona.rs @@ -130,10 +130,17 @@ impl Persona { .and_then(|p| p.system_prompt.clone()); - let openai_provider = OpenAIProvider::with_system_prompt(api_key, Some(openai_model), system_prompt); + let openai_provider = OpenAIProvider::with_config(api_key, Some(openai_model), system_prompt, self.config.clone()); // Use OpenAI with MCP tools support - openai_provider.chat_with_mcp(message.to_string(), user_id.to_string()).await? + let response = openai_provider.chat_with_mcp(message.to_string(), user_id.to_string()).await?; + + // Add AI response to memory as well + if let Some(memory_manager) = &mut self.memory_manager { + memory_manager.add_memory(user_id, &format!("AI: {}", response), 0.3)?; + } + + response } else { // Use existing AI provider (Ollama) let ai_config = self.config.get_ai_config(provider, model)?;