2
0

fix login

This commit is contained in:
2026-03-30 14:01:34 +09:00
parent fbb44abad5
commit 9da77980b5
7 changed files with 118 additions and 63 deletions

View File

@@ -69,7 +69,7 @@ pub async fn login(handle: &str, password: &str, pds: &str, is_bot: bool) -> Res
} else { } else {
token::save_session(&session)?; token::save_session(&session)?;
} }
println!("Logged in as {} ({})", session.handle, session.did); eprintln!("Logged in as {} ({})", session.handle, session.did);
Ok(()) Ok(())
} }
@@ -91,14 +91,66 @@ async fn do_refresh(session: &Session, pds: &str) -> Result<Session> {
}) })
} }
/// Refresh access token (OAuth-aware: tries OAuth first, falls back to legacy) /// Check if a JWT access token is still valid (with 60s margin)
pub async fn refresh_session() -> Result<Session> { fn is_token_valid(token: &str) -> bool {
if oauth::has_oauth_session(false) { let parts: Vec<&str> = token.split('.').collect();
match oauth::refresh_oauth_session(false).await { if parts.len() < 2 {
Ok((_oauth, session)) => return Ok(session), return false;
Err(_) => { /* OAuth failed, fall back to legacy */ } }
// Decode JWT payload
use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine};
let payload = match URL_SAFE_NO_PAD.decode(parts[1]) {
Ok(p) => p,
Err(_) => return false,
};
let json: serde_json::Value = match serde_json::from_slice(&payload) {
Ok(v) => v,
Err(_) => return false,
};
let exp = json["exp"].as_i64().unwrap_or(0);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
exp > now + 60 // valid if expires more than 60s from now
}
/// Get OAuth session and return it if access_token is still valid.
/// If expired, refresh it. If refresh fails, remove the OAuth session file
/// so we cleanly fall back to legacy on next call.
async fn try_oauth_session(is_bot: bool) -> Option<Session> {
if !oauth::has_oauth_session(is_bot) {
return None;
}
let oauth_session = match oauth::load_oauth_session(is_bot) {
Ok(s) => s,
Err(_) => return None,
};
// If token is still valid, use it without refreshing
if is_token_valid(&oauth_session.access_token) {
let session = if is_bot {
token::load_bot_session().ok()
} else {
token::load_session().ok()
};
return session;
}
// Token expired, try refresh
match oauth::refresh_oauth_session(is_bot).await {
Ok((_oauth, session)) => Some(session),
Err(_) => {
// Refresh failed — remove broken OAuth session so legacy works
oauth::remove_oauth_session(is_bot);
None
} }
} }
}
/// Refresh access token (OAuth-aware: uses cached token if still valid)
pub async fn refresh_session() -> Result<Session> {
if let Some(session) = try_oauth_session(false).await {
return Ok(session);
}
let session = token::load_session()?; let session = token::load_session()?;
let pds = session.pds.as_deref().unwrap_or("bsky.social"); let pds = session.pds.as_deref().unwrap_or("bsky.social");
@@ -109,13 +161,10 @@ pub async fn refresh_session() -> Result<Session> {
Ok(new_session) Ok(new_session)
} }
/// Refresh bot access token (OAuth-aware) /// Refresh bot access token (OAuth-aware: uses cached token if still valid)
pub async fn refresh_bot_session() -> Result<Session> { pub async fn refresh_bot_session() -> Result<Session> {
if oauth::has_oauth_session(true) { if let Some(session) = try_oauth_session(true).await {
match oauth::refresh_oauth_session(true).await { return Ok(session);
Ok((_oauth, session)) => return Ok(session),
Err(_) => { /* OAuth failed, fall back to legacy */ }
}
} }
let session = token::load_bot_session()?; let session = token::load_bot_session()?;

View File

@@ -89,7 +89,7 @@ pub async fn get_memory(download: bool) -> Result<()> {
} }
} }
println!("Downloaded {} memory records", count); eprintln!("Downloaded {} memory records", count);
} else { } else {
// Show latest only // Show latest only
let result: ListRecordsResponse = client let result: ListRecordsResponse = client
@@ -186,7 +186,7 @@ pub async fn push(collection_name: &str) -> Result<()> {
anyhow::bail!("Collection directory not found: {}", collection_dir.display()); anyhow::bail!("Collection directory not found: {}", collection_dir.display());
} }
println!("Pushing {} records from {}", collection_name, collection_dir.display()); eprintln!("Pushing {} records from {}", collection_name, collection_dir.display());
let mut count = 0; let mut count = 0;
for entry in fs::read_dir(&collection_dir)? { for entry in fs::read_dir(&collection_dir)? {
@@ -218,7 +218,7 @@ pub async fn push(collection_name: &str) -> Result<()> {
record, record,
}; };
println!("Pushing: {}", rkey); eprintln!("Pushing: {}", rkey);
match client match client
.call::<_, PutRecordResponse>( .call::<_, PutRecordResponse>(
@@ -229,16 +229,16 @@ pub async fn push(collection_name: &str) -> Result<()> {
.await .await
{ {
Ok(result) => { Ok(result) => {
println!(" OK: {}", result.uri); eprintln!(" OK: {}", result.uri);
count += 1; count += 1;
} }
Err(e) => { Err(e) => {
println!(" Failed: {}", e); eprintln!(" Failed: {}", e);
} }
} }
} }
println!("Pushed {} records to {}", count, collection); eprintln!("Pushed {} records to {}", count, collection);
Ok(()) Ok(())
} }
@@ -259,7 +259,7 @@ fn save_record(did: &str, collection: &str, rkey: &str, record: &Record) -> Resu
use serde::Serialize; use serde::Serialize;
json.serialize(&mut ser)?; json.serialize(&mut ser)?;
fs::write(&path, String::from_utf8(buf)?)?; fs::write(&path, String::from_utf8(buf)?)?;
println!("Saved: {}", path.display()); eprintln!("Saved: {}", path.display());
Ok(()) Ok(())
} }

View File

@@ -620,6 +620,20 @@ pub fn has_oauth_session(is_bot: bool) -> bool {
config_dir.join(filename).exists() config_dir.join(filename).exists()
} }
/// Remove OAuth session file (used when refresh fails, to allow legacy fallback)
pub fn remove_oauth_session(is_bot: bool) {
let config_dir = match dirs::config_dir() {
Some(d) => d.join(BUNDLE_ID),
None => return,
};
let filename = if is_bot {
"oauth_bot_session.json"
} else {
"oauth_session.json"
};
let _ = std::fs::remove_file(config_dir.join(filename));
}
// --- Save --- // --- Save ---
fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> { fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> {

View File

@@ -25,7 +25,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
anyhow::bail!("Collection directory not found: {}", collection_dir); anyhow::bail!("Collection directory not found: {}", collection_dir);
} }
println!("Pushing records from {} to {}", collection_dir, collection); eprintln!("Pushing records from {} to {}", collection_dir, collection);
let mut count = 0; let mut count = 0;
for entry in fs::read_dir(&collection_dir)? { for entry in fs::read_dir(&collection_dir)? {
@@ -59,7 +59,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
record, record,
}; };
println!("Pushing: {}", rkey); eprintln!("Pushing: {}", rkey);
match client match client
.call::<_, PutRecordResponse>( .call::<_, PutRecordResponse>(
@@ -70,16 +70,16 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
.await .await
{ {
Ok(result) => { Ok(result) => {
println!(" OK: {}", result.uri); eprintln!(" OK: {}", result.uri);
count += 1; count += 1;
} }
Err(e) => { Err(e) => {
println!(" Failed: {}", e); eprintln!(" Failed: {}", e);
} }
} }
} }
println!("Pushed {} records to {}", count, collection); eprintln!("Pushed {} records to {}", count, collection);
Ok(()) Ok(())
} }

View File

@@ -31,16 +31,14 @@ pub async fn put_record(file: &str, collection: &str, rkey: Option<&str>) -> Res
record, record,
}; };
println!("Posting to {} with rkey: {}", collection, rkey);
println!("{}", serde_json::to_string_pretty(&req)?);
let result: PutRecordResponse = client let result: PutRecordResponse = client
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt) .call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
.await?; .await?;
println!("Success!"); println!("{}", serde_json::to_string_pretty(&serde_json::json!({
println!(" URI: {}", result.uri); "uri": result.uri,
println!(" CID: {}", result.cid); "cid": result.cid,
}))?);
Ok(()) Ok(())
} }
@@ -67,16 +65,14 @@ pub async fn put_lexicon(file: &str) -> Result<()> {
record: lexicon, record: lexicon,
}; };
println!("Putting lexicon: {}", lexicon_id);
println!("{}", serde_json::to_string_pretty(&req)?);
let result: PutRecordResponse = client let result: PutRecordResponse = client
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt) .call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
.await?; .await?;
println!("Success!"); println!("{}", serde_json::to_string_pretty(&serde_json::json!({
println!(" URI: {}", result.uri); "uri": result.uri,
println!(" CID: {}", result.cid); "cid": result.cid,
}))?);
Ok(()) Ok(())
} }
@@ -100,13 +96,7 @@ pub async fn get_records(collection: &str, limit: u32) -> Result<()> {
) )
.await?; .await?;
println!("Found {} records in {}", result.records.len(), collection); println!("{}", serde_json::to_string_pretty(&result)?);
for record in &result.records {
println!("---");
println!("URI: {}", record.uri);
println!("CID: {}", record.cid);
println!("{}", serde_json::to_string_pretty(&record.value)?);
}
Ok(()) Ok(())
} }
@@ -127,13 +117,15 @@ pub async fn delete_record(collection: &str, rkey: &str, is_bot: bool) -> Result
rkey: rkey.to_string(), rkey: rkey.to_string(),
}; };
println!("Deleting {} from {}", rkey, collection);
client client
.call_no_response(&com_atproto_repo::DELETE_RECORD, &req, &session.access_jwt) .call_no_response(&com_atproto_repo::DELETE_RECORD, &req, &session.access_jwt)
.await?; .await?;
println!("Deleted successfully"); println!("{}", serde_json::to_string_pretty(&serde_json::json!({
"deleted": true,
"collection": collection,
"rkey": rkey,
}))?);
Ok(()) Ok(())
} }

View File

@@ -18,7 +18,7 @@ pub async fn sync_to_local(
let session = token::load_bot_session()?; let session = token::load_bot_session()?;
let pds = session.pds.as_deref().unwrap_or("bsky.social"); let pds = session.pds.as_deref().unwrap_or("bsky.social");
let collection = collection_override.unwrap_or("ai.syui.log.chat"); let collection = collection_override.unwrap_or("ai.syui.log.chat");
println!( eprintln!(
"Syncing bot data for {} ({})", "Syncing bot data for {} ({})",
session.handle, session.did session.handle, session.did
); );
@@ -33,7 +33,7 @@ pub async fn sync_to_local(
let config_value = super::token::load_config()?; let config_value = super::token::load_config()?;
let config: Config = serde_json::from_value(config_value)?; let config: Config = serde_json::from_value(config_value)?;
println!("Syncing data for {}", config.handle); eprintln!("Syncing data for {}", config.handle);
// Resolve handle to DID // Resolve handle to DID
let resolve_url = format!( let resolve_url = format!(
@@ -79,8 +79,8 @@ pub async fn sync_to_local(
(did, pds, config.handle.clone(), collection) (did, pds, config.handle.clone(), collection)
}; };
println!("DID: {}", did); eprintln!("DID: {}", did);
println!("PDS: {}", pds); eprintln!("PDS: {}", pds);
// Remove https:// prefix for lexicons::url // Remove https:// prefix for lexicons::url
let pds_host = pds.trim_start_matches("https://"); let pds_host = pds.trim_start_matches("https://");
@@ -105,7 +105,7 @@ pub async fn sync_to_local(
"collections": describe.collections, "collections": describe.collections,
}))?; }))?;
fs::write(&describe_path, &describe_json)?; fs::write(&describe_path, &describe_json)?;
println!("Saved: {}", describe_path); eprintln!("Saved: {}", describe_path);
// 2. Sync profile // 2. Sync profile
let profile_url = format!( let profile_url = format!(
@@ -120,7 +120,7 @@ pub async fn sync_to_local(
fs::create_dir_all(&profile_dir)?; fs::create_dir_all(&profile_dir)?;
let profile_path = format!("{}/self.json", profile_dir); let profile_path = format!("{}/self.json", profile_dir);
fs::write(&profile_path, serde_json::to_string_pretty(&profile)?)?; fs::write(&profile_path, serde_json::to_string_pretty(&profile)?)?;
println!("Saved: {}", profile_path); eprintln!("Saved: {}", profile_path);
// Download avatar blob if present // Download avatar blob if present
if let Some(avatar_cid) = profile["value"]["avatar"]["ref"]["$link"].as_str() { if let Some(avatar_cid) = profile["value"]["avatar"]["ref"]["$link"].as_str() {
@@ -132,14 +132,14 @@ pub async fn sync_to_local(
"{}/xrpc/com.atproto.sync.getBlob?did={}&cid={}", "{}/xrpc/com.atproto.sync.getBlob?did={}&cid={}",
pds, did, avatar_cid pds, did, avatar_cid
); );
println!("Downloading avatar: {}", avatar_cid); eprintln!("Downloading avatar: {}", avatar_cid);
let blob_res = client.get(&blob_url).send().await?; let blob_res = client.get(&blob_url).send().await?;
if blob_res.status().is_success() { if blob_res.status().is_success() {
let blob_bytes = blob_res.bytes().await?; let blob_bytes = blob_res.bytes().await?;
fs::write(&blob_path, &blob_bytes)?; fs::write(&blob_path, &blob_bytes)?;
println!("Saved: {}", blob_path); eprintln!("Saved: {}", blob_path);
} else { } else {
println!("Failed to download avatar: {}", blob_res.status()); eprintln!("Failed to download avatar: {}", blob_res.status());
} }
} }
} }
@@ -165,7 +165,7 @@ pub async fn sync_to_local(
let res = client.get(&records_url).send().await?; let res = client.get(&records_url).send().await?;
if !res.status().is_success() { if !res.status().is_success() {
println!("Failed to fetch records: {}", res.status()); eprintln!("Failed to fetch records: {}", res.status());
break; break;
} }
@@ -182,7 +182,7 @@ pub async fn sync_to_local(
"value": record.value, "value": record.value,
}); });
fs::write(&record_path, serde_json::to_string_pretty(&record_json)?)?; fs::write(&record_path, serde_json::to_string_pretty(&record_json)?)?;
println!("Saved: {}", record_path); eprintln!("Saved: {}", record_path);
} }
total_fetched += count; total_fetched += count;
@@ -211,15 +211,15 @@ pub async fn sync_to_local(
} }
fs::write(&index_path, serde_json::to_string_pretty(&merged_rkeys)?)?; fs::write(&index_path, serde_json::to_string_pretty(&merged_rkeys)?)?;
println!("Saved: {}", index_path); eprintln!("Saved: {}", index_path);
println!( eprintln!(
"Synced {} records from {}", "Synced {} records from {}",
total_fetched, collection total_fetched, collection
); );
} }
println!("Sync complete!"); eprintln!("Sync complete!");
Ok(()) Ok(())
} }

View File

@@ -26,7 +26,7 @@ pub struct PutRecordResponse {
} }
/// ATProto listRecords response /// ATProto listRecords response
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct ListRecordsResponse { pub struct ListRecordsResponse {
pub records: Vec<Record>, pub records: Vec<Record>,
#[serde(default)] #[serde(default)]
@@ -34,7 +34,7 @@ pub struct ListRecordsResponse {
} }
/// A single ATProto record (from listRecords / getRecord) /// A single ATProto record (from listRecords / getRecord)
#[derive(Debug, Deserialize)] #[derive(Debug, Serialize, Deserialize)]
pub struct Record { pub struct Record {
pub uri: String, pub uri: String,
pub cid: String, pub cid: String,