fix memory
This commit is contained in:
@ -53,7 +53,8 @@
|
||||
"Bash(cargo run:*)",
|
||||
"Bash(cargo test:*)",
|
||||
"Bash(diff:*)",
|
||||
"Bash(cargo:*)"
|
||||
"Bash(cargo:*)",
|
||||
"Bash(pkill:*)"
|
||||
],
|
||||
"deny": []
|
||||
}
|
||||
|
@ -78,8 +78,9 @@ where
|
||||
impl Config {
|
||||
pub fn new(data_dir: Option<PathBuf>) -> Result<Self> {
|
||||
let data_dir = data_dir.unwrap_or_else(|| {
|
||||
dirs::config_dir()
|
||||
dirs::home_dir()
|
||||
.unwrap_or_else(|| PathBuf::from("."))
|
||||
.join(".config")
|
||||
.join("syui")
|
||||
.join("ai")
|
||||
.join("gpt")
|
||||
|
@ -12,6 +12,7 @@ use async_openai::{
|
||||
use serde_json::{json, Value};
|
||||
|
||||
use crate::http_client::ServiceClient;
|
||||
use crate::config::Config;
|
||||
|
||||
/// OpenAI provider with MCP tools support (matching Python implementation)
|
||||
pub struct OpenAIProvider {
|
||||
@ -19,6 +20,7 @@ pub struct OpenAIProvider {
|
||||
model: String,
|
||||
service_client: ServiceClient,
|
||||
system_prompt: Option<String>,
|
||||
config: Option<Config>,
|
||||
}
|
||||
|
||||
impl OpenAIProvider {
|
||||
@ -32,22 +34,36 @@ impl OpenAIProvider {
|
||||
model: model.unwrap_or_else(|| "gpt-4".to_string()),
|
||||
service_client: ServiceClient::new(),
|
||||
system_prompt: None,
|
||||
config: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn with_system_prompt(api_key: String, model: Option<String>, system_prompt: Option<String>) -> Self {
|
||||
let config = async_openai::config::OpenAIConfig::new()
|
||||
pub fn with_config(api_key: String, model: Option<String>, system_prompt: Option<String>, config: Config) -> Self {
|
||||
let openai_config = async_openai::config::OpenAIConfig::new()
|
||||
.with_api_key(api_key);
|
||||
let client = Client::with_config(config);
|
||||
let client = Client::with_config(openai_config);
|
||||
|
||||
Self {
|
||||
client,
|
||||
model: model.unwrap_or_else(|| "gpt-4".to_string()),
|
||||
service_client: ServiceClient::new(),
|
||||
system_prompt,
|
||||
config: Some(config),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_mcp_base_url(&self) -> String {
|
||||
if let Some(config) = &self.config {
|
||||
if let Some(mcp) = &config.mcp {
|
||||
if let Some(ai_gpt_server) = mcp.servers.get("ai_gpt") {
|
||||
return ai_gpt_server.base_url.clone();
|
||||
}
|
||||
}
|
||||
}
|
||||
// Fallback to default
|
||||
"http://localhost:8080".to_string()
|
||||
}
|
||||
|
||||
/// Generate OpenAI tools from MCP endpoints (matching Python implementation)
|
||||
fn get_mcp_tools(&self) -> Vec<ChatCompletionTool> {
|
||||
let tools = vec![
|
||||
@ -333,7 +349,8 @@ impl OpenAIProvider {
|
||||
let limit = arguments.get("limit").and_then(|v| v.as_i64()).unwrap_or(5);
|
||||
|
||||
// MCP server call to get memories
|
||||
match self.service_client.get_request(&format!("http://localhost:8080/memories/{}", context_user_id)).await {
|
||||
let base_url = self.get_mcp_base_url();
|
||||
match self.service_client.get_request(&format!("{}/memories/{}", base_url, context_user_id)).await {
|
||||
Ok(result) => {
|
||||
// Extract the actual memory content from MCP response
|
||||
if let Some(content) = result.get("result").and_then(|r| r.get("content")) {
|
||||
@ -381,7 +398,7 @@ impl OpenAIProvider {
|
||||
});
|
||||
|
||||
match self.service_client.post_request(
|
||||
&format!("http://localhost:8080/memories/{}/search", context_user_id),
|
||||
&format!("{}/memories/{}/search", self.get_mcp_base_url(), context_user_id),
|
||||
&search_request
|
||||
).await {
|
||||
Ok(result) => {
|
||||
@ -433,7 +450,7 @@ impl OpenAIProvider {
|
||||
});
|
||||
|
||||
match self.service_client.post_request(
|
||||
&format!("http://localhost:8080/memories/{}/contextual", context_user_id),
|
||||
&format!("{}/memories/{}/contextual", self.get_mcp_base_url(), context_user_id),
|
||||
&contextual_request
|
||||
).await {
|
||||
Ok(result) => {
|
||||
@ -492,7 +509,8 @@ impl OpenAIProvider {
|
||||
let target_user_id = arguments.get("user_id").and_then(|v| v.as_str()).unwrap_or(context_user_id);
|
||||
|
||||
// MCP server call to get relationship status
|
||||
match self.service_client.get_request(&format!("http://localhost:8080/status/{}", target_user_id)).await {
|
||||
let base_url = self.get_mcp_base_url();
|
||||
match self.service_client.get_request(&format!("{}/status/{}", base_url, target_user_id)).await {
|
||||
Ok(result) => {
|
||||
// Extract relationship information from MCP response
|
||||
if let Some(content) = result.get("result").and_then(|r| r.get("content")) {
|
||||
|
@ -130,10 +130,17 @@ impl Persona {
|
||||
.and_then(|p| p.system_prompt.clone());
|
||||
|
||||
|
||||
let openai_provider = OpenAIProvider::with_system_prompt(api_key, Some(openai_model), system_prompt);
|
||||
let openai_provider = OpenAIProvider::with_config(api_key, Some(openai_model), system_prompt, self.config.clone());
|
||||
|
||||
// Use OpenAI with MCP tools support
|
||||
openai_provider.chat_with_mcp(message.to_string(), user_id.to_string()).await?
|
||||
let response = openai_provider.chat_with_mcp(message.to_string(), user_id.to_string()).await?;
|
||||
|
||||
// Add AI response to memory as well
|
||||
if let Some(memory_manager) = &mut self.memory_manager {
|
||||
memory_manager.add_memory(user_id, &format!("AI: {}", response), 0.3)?;
|
||||
}
|
||||
|
||||
response
|
||||
} else {
|
||||
// Use existing AI provider (Ollama)
|
||||
let ai_config = self.config.get_ai_config(provider, model)?;
|
||||
|
Reference in New Issue
Block a user