From 9da77980b5823eae1c6beec2f72a70aaaef7e454 Mon Sep 17 00:00:00 2001 From: syui Date: Mon, 30 Mar 2026 14:01:34 +0900 Subject: [PATCH] fix login --- src/commands/auth.rs | 75 ++++++++++++++++++++++++++++++++++-------- src/commands/gpt.rs | 14 ++++---- src/commands/oauth.rs | 14 ++++++++ src/commands/push.rs | 10 +++--- src/commands/record.rs | 36 ++++++++------------ src/commands/sync.rs | 28 ++++++++-------- src/types.rs | 4 +-- 7 files changed, 118 insertions(+), 63 deletions(-) diff --git a/src/commands/auth.rs b/src/commands/auth.rs index 9f9305e..f713c4f 100644 --- a/src/commands/auth.rs +++ b/src/commands/auth.rs @@ -69,7 +69,7 @@ pub async fn login(handle: &str, password: &str, pds: &str, is_bot: bool) -> Res } else { token::save_session(&session)?; } - println!("Logged in as {} ({})", session.handle, session.did); + eprintln!("Logged in as {} ({})", session.handle, session.did); Ok(()) } @@ -91,14 +91,66 @@ async fn do_refresh(session: &Session, pds: &str) -> Result { }) } -/// Refresh access token (OAuth-aware: tries OAuth first, falls back to legacy) -pub async fn refresh_session() -> Result { - if oauth::has_oauth_session(false) { - match oauth::refresh_oauth_session(false).await { - Ok((_oauth, session)) => return Ok(session), - Err(_) => { /* OAuth failed, fall back to legacy */ } +/// Check if a JWT access token is still valid (with 60s margin) +fn is_token_valid(token: &str) -> bool { + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() < 2 { + return false; + } + // Decode JWT payload + use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine}; + let payload = match URL_SAFE_NO_PAD.decode(parts[1]) { + Ok(p) => p, + Err(_) => return false, + }; + let json: serde_json::Value = match serde_json::from_slice(&payload) { + Ok(v) => v, + Err(_) => return false, + }; + let exp = json["exp"].as_i64().unwrap_or(0); + let now = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() as i64; + exp > now + 60 // valid if expires more than 60s from now +} + +/// Get OAuth session and return it if access_token is still valid. +/// If expired, refresh it. If refresh fails, remove the OAuth session file +/// so we cleanly fall back to legacy on next call. +async fn try_oauth_session(is_bot: bool) -> Option { + if !oauth::has_oauth_session(is_bot) { + return None; + } + let oauth_session = match oauth::load_oauth_session(is_bot) { + Ok(s) => s, + Err(_) => return None, + }; + // If token is still valid, use it without refreshing + if is_token_valid(&oauth_session.access_token) { + let session = if is_bot { + token::load_bot_session().ok() + } else { + token::load_session().ok() + }; + return session; + } + // Token expired, try refresh + match oauth::refresh_oauth_session(is_bot).await { + Ok((_oauth, session)) => Some(session), + Err(_) => { + // Refresh failed — remove broken OAuth session so legacy works + oauth::remove_oauth_session(is_bot); + None } } +} + +/// Refresh access token (OAuth-aware: uses cached token if still valid) +pub async fn refresh_session() -> Result { + if let Some(session) = try_oauth_session(false).await { + return Ok(session); + } let session = token::load_session()?; let pds = session.pds.as_deref().unwrap_or("bsky.social"); @@ -109,13 +161,10 @@ pub async fn refresh_session() -> Result { Ok(new_session) } -/// Refresh bot access token (OAuth-aware) +/// Refresh bot access token (OAuth-aware: uses cached token if still valid) pub async fn refresh_bot_session() -> Result { - if oauth::has_oauth_session(true) { - match oauth::refresh_oauth_session(true).await { - Ok((_oauth, session)) => return Ok(session), - Err(_) => { /* OAuth failed, fall back to legacy */ } - } + if let Some(session) = try_oauth_session(true).await { + return Ok(session); } let session = token::load_bot_session()?; diff --git a/src/commands/gpt.rs b/src/commands/gpt.rs index b7357e2..1dca492 100644 --- a/src/commands/gpt.rs +++ b/src/commands/gpt.rs @@ -89,7 +89,7 @@ pub async fn get_memory(download: bool) -> Result<()> { } } - println!("Downloaded {} memory records", count); + eprintln!("Downloaded {} memory records", count); } else { // Show latest only let result: ListRecordsResponse = client @@ -186,7 +186,7 @@ pub async fn push(collection_name: &str) -> Result<()> { anyhow::bail!("Collection directory not found: {}", collection_dir.display()); } - println!("Pushing {} records from {}", collection_name, collection_dir.display()); + eprintln!("Pushing {} records from {}", collection_name, collection_dir.display()); let mut count = 0; for entry in fs::read_dir(&collection_dir)? { @@ -218,7 +218,7 @@ pub async fn push(collection_name: &str) -> Result<()> { record, }; - println!("Pushing: {}", rkey); + eprintln!("Pushing: {}", rkey); match client .call::<_, PutRecordResponse>( @@ -229,16 +229,16 @@ pub async fn push(collection_name: &str) -> Result<()> { .await { Ok(result) => { - println!(" OK: {}", result.uri); + eprintln!(" OK: {}", result.uri); count += 1; } Err(e) => { - println!(" Failed: {}", e); + eprintln!(" Failed: {}", e); } } } - println!("Pushed {} records to {}", count, collection); + eprintln!("Pushed {} records to {}", count, collection); Ok(()) } @@ -259,7 +259,7 @@ fn save_record(did: &str, collection: &str, rkey: &str, record: &Record) -> Resu use serde::Serialize; json.serialize(&mut ser)?; fs::write(&path, String::from_utf8(buf)?)?; - println!("Saved: {}", path.display()); + eprintln!("Saved: {}", path.display()); Ok(()) } diff --git a/src/commands/oauth.rs b/src/commands/oauth.rs index 5d2d67b..347f1ab 100644 --- a/src/commands/oauth.rs +++ b/src/commands/oauth.rs @@ -620,6 +620,20 @@ pub fn has_oauth_session(is_bot: bool) -> bool { config_dir.join(filename).exists() } +/// Remove OAuth session file (used when refresh fails, to allow legacy fallback) +pub fn remove_oauth_session(is_bot: bool) { + let config_dir = match dirs::config_dir() { + Some(d) => d.join(BUNDLE_ID), + None => return, + }; + let filename = if is_bot { + "oauth_bot_session.json" + } else { + "oauth_session.json" + }; + let _ = std::fs::remove_file(config_dir.join(filename)); +} + // --- Save --- fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> { diff --git a/src/commands/push.rs b/src/commands/push.rs index a0444fc..01a5579 100644 --- a/src/commands/push.rs +++ b/src/commands/push.rs @@ -25,7 +25,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu anyhow::bail!("Collection directory not found: {}", collection_dir); } - println!("Pushing records from {} to {}", collection_dir, collection); + eprintln!("Pushing records from {} to {}", collection_dir, collection); let mut count = 0; for entry in fs::read_dir(&collection_dir)? { @@ -59,7 +59,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu record, }; - println!("Pushing: {}", rkey); + eprintln!("Pushing: {}", rkey); match client .call::<_, PutRecordResponse>( @@ -70,16 +70,16 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu .await { Ok(result) => { - println!(" OK: {}", result.uri); + eprintln!(" OK: {}", result.uri); count += 1; } Err(e) => { - println!(" Failed: {}", e); + eprintln!(" Failed: {}", e); } } } - println!("Pushed {} records to {}", count, collection); + eprintln!("Pushed {} records to {}", count, collection); Ok(()) } diff --git a/src/commands/record.rs b/src/commands/record.rs index 2575219..853f75f 100644 --- a/src/commands/record.rs +++ b/src/commands/record.rs @@ -31,16 +31,14 @@ pub async fn put_record(file: &str, collection: &str, rkey: Option<&str>) -> Res record, }; - println!("Posting to {} with rkey: {}", collection, rkey); - println!("{}", serde_json::to_string_pretty(&req)?); - let result: PutRecordResponse = client .call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt) .await?; - println!("Success!"); - println!(" URI: {}", result.uri); - println!(" CID: {}", result.cid); + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "uri": result.uri, + "cid": result.cid, + }))?); Ok(()) } @@ -67,16 +65,14 @@ pub async fn put_lexicon(file: &str) -> Result<()> { record: lexicon, }; - println!("Putting lexicon: {}", lexicon_id); - println!("{}", serde_json::to_string_pretty(&req)?); - let result: PutRecordResponse = client .call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt) .await?; - println!("Success!"); - println!(" URI: {}", result.uri); - println!(" CID: {}", result.cid); + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "uri": result.uri, + "cid": result.cid, + }))?); Ok(()) } @@ -100,13 +96,7 @@ pub async fn get_records(collection: &str, limit: u32) -> Result<()> { ) .await?; - println!("Found {} records in {}", result.records.len(), collection); - for record in &result.records { - println!("---"); - println!("URI: {}", record.uri); - println!("CID: {}", record.cid); - println!("{}", serde_json::to_string_pretty(&record.value)?); - } + println!("{}", serde_json::to_string_pretty(&result)?); Ok(()) } @@ -127,13 +117,15 @@ pub async fn delete_record(collection: &str, rkey: &str, is_bot: bool) -> Result rkey: rkey.to_string(), }; - println!("Deleting {} from {}", rkey, collection); - client .call_no_response(&com_atproto_repo::DELETE_RECORD, &req, &session.access_jwt) .await?; - println!("Deleted successfully"); + println!("{}", serde_json::to_string_pretty(&serde_json::json!({ + "deleted": true, + "collection": collection, + "rkey": rkey, + }))?); Ok(()) } diff --git a/src/commands/sync.rs b/src/commands/sync.rs index 48b3d7c..f517160 100644 --- a/src/commands/sync.rs +++ b/src/commands/sync.rs @@ -18,7 +18,7 @@ pub async fn sync_to_local( let session = token::load_bot_session()?; let pds = session.pds.as_deref().unwrap_or("bsky.social"); let collection = collection_override.unwrap_or("ai.syui.log.chat"); - println!( + eprintln!( "Syncing bot data for {} ({})", session.handle, session.did ); @@ -33,7 +33,7 @@ pub async fn sync_to_local( let config_value = super::token::load_config()?; let config: Config = serde_json::from_value(config_value)?; - println!("Syncing data for {}", config.handle); + eprintln!("Syncing data for {}", config.handle); // Resolve handle to DID let resolve_url = format!( @@ -79,8 +79,8 @@ pub async fn sync_to_local( (did, pds, config.handle.clone(), collection) }; - println!("DID: {}", did); - println!("PDS: {}", pds); + eprintln!("DID: {}", did); + eprintln!("PDS: {}", pds); // Remove https:// prefix for lexicons::url let pds_host = pds.trim_start_matches("https://"); @@ -105,7 +105,7 @@ pub async fn sync_to_local( "collections": describe.collections, }))?; fs::write(&describe_path, &describe_json)?; - println!("Saved: {}", describe_path); + eprintln!("Saved: {}", describe_path); // 2. Sync profile let profile_url = format!( @@ -120,7 +120,7 @@ pub async fn sync_to_local( fs::create_dir_all(&profile_dir)?; let profile_path = format!("{}/self.json", profile_dir); fs::write(&profile_path, serde_json::to_string_pretty(&profile)?)?; - println!("Saved: {}", profile_path); + eprintln!("Saved: {}", profile_path); // Download avatar blob if present if let Some(avatar_cid) = profile["value"]["avatar"]["ref"]["$link"].as_str() { @@ -132,14 +132,14 @@ pub async fn sync_to_local( "{}/xrpc/com.atproto.sync.getBlob?did={}&cid={}", pds, did, avatar_cid ); - println!("Downloading avatar: {}", avatar_cid); + eprintln!("Downloading avatar: {}", avatar_cid); let blob_res = client.get(&blob_url).send().await?; if blob_res.status().is_success() { let blob_bytes = blob_res.bytes().await?; fs::write(&blob_path, &blob_bytes)?; - println!("Saved: {}", blob_path); + eprintln!("Saved: {}", blob_path); } else { - println!("Failed to download avatar: {}", blob_res.status()); + eprintln!("Failed to download avatar: {}", blob_res.status()); } } } @@ -165,7 +165,7 @@ pub async fn sync_to_local( let res = client.get(&records_url).send().await?; if !res.status().is_success() { - println!("Failed to fetch records: {}", res.status()); + eprintln!("Failed to fetch records: {}", res.status()); break; } @@ -182,7 +182,7 @@ pub async fn sync_to_local( "value": record.value, }); fs::write(&record_path, serde_json::to_string_pretty(&record_json)?)?; - println!("Saved: {}", record_path); + eprintln!("Saved: {}", record_path); } total_fetched += count; @@ -211,15 +211,15 @@ pub async fn sync_to_local( } fs::write(&index_path, serde_json::to_string_pretty(&merged_rkeys)?)?; - println!("Saved: {}", index_path); + eprintln!("Saved: {}", index_path); - println!( + eprintln!( "Synced {} records from {}", total_fetched, collection ); } - println!("Sync complete!"); + eprintln!("Sync complete!"); Ok(()) } diff --git a/src/types.rs b/src/types.rs index bf44665..4d223fc 100644 --- a/src/types.rs +++ b/src/types.rs @@ -26,7 +26,7 @@ pub struct PutRecordResponse { } /// ATProto listRecords response -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct ListRecordsResponse { pub records: Vec, #[serde(default)] @@ -34,7 +34,7 @@ pub struct ListRecordsResponse { } /// A single ATProto record (from listRecords / getRecord) -#[derive(Debug, Deserialize)] +#[derive(Debug, Serialize, Deserialize)] pub struct Record { pub uri: String, pub cid: String,