diff --git a/src/data.rs b/src/data.rs index 0a49911..8eb3efb 100644 --- a/src/data.rs +++ b/src/data.rs @@ -5,7 +5,7 @@ use std::fs; use std::fs::OpenOptions; use std::io::Read; use std::io::Write; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::env; /// ホームディレクトリパスを展開するユーティリティ関数 @@ -20,35 +20,51 @@ fn expand_home_path(path: &str) -> PathBuf { } } -pub fn data_file(s: &str) -> String { - let path = expand_home_path("~/.config/ai"); +/// 設定ディレクトリのベースパスを取得し、必要に応じて作成する +fn get_config_base_path() -> PathBuf { + let path = expand_home_path("~/.config/syui/ai/bot"); + if !path.is_dir() { + let _ = fs::create_dir_all(&path); + } + path +} + +/// サブディレクトリを含む設定パスを取得し、必要に応じて作成する +fn get_config_path(subdir: &str) -> PathBuf { + let base_path = get_config_base_path(); + let path = if subdir.is_empty() { + base_path + } else { + base_path.join(subdir) + }; if !path.is_dir() { let _ = fs::create_dir_all(&path); } + path +} + +pub fn data_file(s: &str) -> String { + let path = get_config_base_path(); + let path_str = path.to_string_lossy(); - let mut path_str = path.to_string_lossy().to_string(); - match &*s { - "toml" => path_str + "/token.toml", - "json" => path_str + "/token.json", - "refresh" => path_str + "/refresh.toml", - _ => path_str + "/." + &s, + match s { + "toml" => format!("{}/token.toml", path_str), + "json" => format!("{}/token.json", path_str), + "refresh" => format!("{}/refresh.toml", path_str), + _ => format!("{}/.{}", path_str, s), } } pub fn log_file(s: &str) -> String { - let path = expand_home_path("~/.config/ai/txt"); + let path = get_config_path("txt"); + let path_str = path.to_string_lossy(); - if !path.is_dir() { - let _ = fs::create_dir_all(&path); - } - - let mut path_str = path.to_string_lossy().to_string(); - match &*s { - "n1" => path_str + "/notify_cid.txt", - "n2" => path_str + "/notify_cid_run.txt", - "c1" => path_str + "/comment_cid.txt", - _ => path_str + "/" + &s, + match s { + "n1" => format!("{}/notify_cid.txt", path_str), + "n2" => format!("{}/notify_cid_run.txt", path_str), + "c1" => format!("{}/comment_cid.txt", path_str), + _ => format!("{}/{}", path_str, s), } } @@ -275,7 +291,7 @@ pub fn data_refresh(s: &str) -> String { } pub fn data_scpt(s: &str) -> String { - let mut path = expand_home_path("~/.config/ai/scpt"); + let mut path = expand_home_path("~/.config/syui/ai/bot/scpt"); path.push(format!("{}.zsh", s)); path.to_string_lossy().to_string() } @@ -545,7 +561,20 @@ pub fn w_cfg(h: &str, res: &str, password: &str) { let mut f = fs::File::create(f.clone()).unwrap(); let mut ff = fs::File::create(ff.clone()).unwrap(); f.write_all(&res.as_bytes()).unwrap(); - let json: Token = serde_json::from_str(&res).unwrap(); + // Check if response contains an error + if res.contains("\"error\"") { + eprintln!("Authentication error: {}", res); + return; + } + + let json: Token = match serde_json::from_str(&res) { + Ok(token) => token, + Err(e) => { + eprintln!("JSON parse error: {}", e); + eprintln!("Response: {}", res); + return; + } + }; let datas = Data { host: h.to_string(), password: password.to_string(), @@ -614,7 +643,7 @@ pub fn w_cid(cid: String, file: String, t: bool) -> bool { } pub fn c_follow_all() { - let path = expand_home_path("~/.config/ai/scpt/follow_all.zsh"); + let path = expand_home_path("~/.config/syui/ai/bot/scpt/follow_all.zsh"); use std::process::Command; let output = Command::new(path.to_str().unwrap()).output().expect("zsh"); @@ -628,7 +657,7 @@ pub fn c_openai_key(c: &Context) { let o = "api='".to_owned() + &api.to_string() + &"'".to_owned(); let o = o.to_string(); - let path = expand_home_path("~/.config/ai/openai.toml"); + let path = expand_home_path("~/.config/syui/ai/bot/openai.toml"); let mut l = fs::File::create(&path).unwrap(); if o != "" { @@ -639,7 +668,7 @@ pub fn c_openai_key(c: &Context) { impl Open { pub fn new() -> Result { - let path = expand_home_path("~/.config/ai/openai.toml"); + let path = expand_home_path("~/.config/syui/ai/bot/openai.toml"); let s = Config::builder() .add_source(File::with_name(path.to_str().unwrap())) diff --git a/src/main.rs b/src/main.rs index ec02541..3749623 100644 --- a/src/main.rs +++ b/src/main.rs @@ -105,7 +105,7 @@ fn main() { .command( Command::new("login") .alias("l") - .description("l -p \n\t\t\tl -p -s ") + .description("l -p \n\t\t\tl -p -s \n\t\t\tl -p -c <2fa_code>") .action(token) .flag( Flag::new("password", FlagType::String) @@ -117,6 +117,11 @@ fn main() { .description("server flag") .alias("s"), ) + .flag( + Flag::new("code", FlagType::String) + .description("2FA authentication code") + .alias("c"), + ) ) .command( Command::new("refresh") @@ -506,15 +511,11 @@ fn token(c: &Context) { let m = c.args[0].to_string(); let h = async { if let Ok(p) = c.string_flag("password") { - if let Ok(s) = c.string_flag("server") { - let res = token::post_request(m.to_string(), p.to_string(), s.to_string()).await; - w_cfg(&s, &res, &p); - } else { - let res = - token::post_request(m.to_string(), p.to_string(), "bsky.social".to_string()) - .await; - w_cfg(&"bsky.social", &res, &p); - } + let server = c.string_flag("server").unwrap_or_else(|_| "bsky.social".to_string()); + let code = c.string_flag("code").ok(); + + let res = token::post_request(m.to_string(), p.to_string(), server.to_string(), code).await; + w_cfg(&server, &res, &p); } }; let res = tokio::runtime::Runtime::new().unwrap().block_on(h); @@ -530,7 +531,7 @@ fn refresh(_c: &Context) { let m = data_toml(&"handle"); let p = data_toml(&"password"); let s = data_toml(&"host"); - let res = token::post_request(m.to_string(), p.to_string(), s.to_string()).await; + let res = token::post_request(m.to_string(), p.to_string(), s.to_string(), None).await; w_cfg(&s, &res, &p); } else { w_refresh(&res); diff --git a/src/token.rs b/src/token.rs index fa09b75..5547063 100644 --- a/src/token.rs +++ b/src/token.rs @@ -1,12 +1,17 @@ use crate::http_client::HttpClient; use std::collections::HashMap; -pub async fn post_request(handle: String, pass: String, host: String) -> String { +pub async fn post_request(handle: String, pass: String, host: String, auth_factor_token: Option) -> String { let url = format!("https://{}/xrpc/com.atproto.server.createSession", host); let mut map = HashMap::new(); map.insert("identifier", &handle); map.insert("password", &pass); + + // Add 2FA code if provided + if let Some(code) = &auth_factor_token { + map.insert("authFactorToken", code); + } let client = HttpClient::new();