2
0

fix login

This commit is contained in:
2026-03-30 14:01:34 +09:00
parent fbb44abad5
commit 9da77980b5
7 changed files with 118 additions and 63 deletions

View File

@@ -69,7 +69,7 @@ pub async fn login(handle: &str, password: &str, pds: &str, is_bot: bool) -> Res
} else {
token::save_session(&session)?;
}
println!("Logged in as {} ({})", session.handle, session.did);
eprintln!("Logged in as {} ({})", session.handle, session.did);
Ok(())
}
@@ -91,14 +91,66 @@ async fn do_refresh(session: &Session, pds: &str) -> Result<Session> {
})
}
/// Refresh access token (OAuth-aware: tries OAuth first, falls back to legacy)
pub async fn refresh_session() -> Result<Session> {
if oauth::has_oauth_session(false) {
match oauth::refresh_oauth_session(false).await {
Ok((_oauth, session)) => return Ok(session),
Err(_) => { /* OAuth failed, fall back to legacy */ }
/// Check if a JWT access token is still valid (with 60s margin)
fn is_token_valid(token: &str) -> bool {
let parts: Vec<&str> = token.split('.').collect();
if parts.len() < 2 {
return false;
}
// Decode JWT payload
use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine};
let payload = match URL_SAFE_NO_PAD.decode(parts[1]) {
Ok(p) => p,
Err(_) => return false,
};
let json: serde_json::Value = match serde_json::from_slice(&payload) {
Ok(v) => v,
Err(_) => return false,
};
let exp = json["exp"].as_i64().unwrap_or(0);
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs() as i64;
exp > now + 60 // valid if expires more than 60s from now
}
/// Get OAuth session and return it if access_token is still valid.
/// If expired, refresh it. If refresh fails, remove the OAuth session file
/// so we cleanly fall back to legacy on next call.
async fn try_oauth_session(is_bot: bool) -> Option<Session> {
if !oauth::has_oauth_session(is_bot) {
return None;
}
let oauth_session = match oauth::load_oauth_session(is_bot) {
Ok(s) => s,
Err(_) => return None,
};
// If token is still valid, use it without refreshing
if is_token_valid(&oauth_session.access_token) {
let session = if is_bot {
token::load_bot_session().ok()
} else {
token::load_session().ok()
};
return session;
}
// Token expired, try refresh
match oauth::refresh_oauth_session(is_bot).await {
Ok((_oauth, session)) => Some(session),
Err(_) => {
// Refresh failed — remove broken OAuth session so legacy works
oauth::remove_oauth_session(is_bot);
None
}
}
}
/// Refresh access token (OAuth-aware: uses cached token if still valid)
pub async fn refresh_session() -> Result<Session> {
if let Some(session) = try_oauth_session(false).await {
return Ok(session);
}
let session = token::load_session()?;
let pds = session.pds.as_deref().unwrap_or("bsky.social");
@@ -109,13 +161,10 @@ pub async fn refresh_session() -> Result<Session> {
Ok(new_session)
}
/// Refresh bot access token (OAuth-aware)
/// Refresh bot access token (OAuth-aware: uses cached token if still valid)
pub async fn refresh_bot_session() -> Result<Session> {
if oauth::has_oauth_session(true) {
match oauth::refresh_oauth_session(true).await {
Ok((_oauth, session)) => return Ok(session),
Err(_) => { /* OAuth failed, fall back to legacy */ }
}
if let Some(session) = try_oauth_session(true).await {
return Ok(session);
}
let session = token::load_bot_session()?;

View File

@@ -89,7 +89,7 @@ pub async fn get_memory(download: bool) -> Result<()> {
}
}
println!("Downloaded {} memory records", count);
eprintln!("Downloaded {} memory records", count);
} else {
// Show latest only
let result: ListRecordsResponse = client
@@ -186,7 +186,7 @@ pub async fn push(collection_name: &str) -> Result<()> {
anyhow::bail!("Collection directory not found: {}", collection_dir.display());
}
println!("Pushing {} records from {}", collection_name, collection_dir.display());
eprintln!("Pushing {} records from {}", collection_name, collection_dir.display());
let mut count = 0;
for entry in fs::read_dir(&collection_dir)? {
@@ -218,7 +218,7 @@ pub async fn push(collection_name: &str) -> Result<()> {
record,
};
println!("Pushing: {}", rkey);
eprintln!("Pushing: {}", rkey);
match client
.call::<_, PutRecordResponse>(
@@ -229,16 +229,16 @@ pub async fn push(collection_name: &str) -> Result<()> {
.await
{
Ok(result) => {
println!(" OK: {}", result.uri);
eprintln!(" OK: {}", result.uri);
count += 1;
}
Err(e) => {
println!(" Failed: {}", e);
eprintln!(" Failed: {}", e);
}
}
}
println!("Pushed {} records to {}", count, collection);
eprintln!("Pushed {} records to {}", count, collection);
Ok(())
}
@@ -259,7 +259,7 @@ fn save_record(did: &str, collection: &str, rkey: &str, record: &Record) -> Resu
use serde::Serialize;
json.serialize(&mut ser)?;
fs::write(&path, String::from_utf8(buf)?)?;
println!("Saved: {}", path.display());
eprintln!("Saved: {}", path.display());
Ok(())
}

View File

@@ -620,6 +620,20 @@ pub fn has_oauth_session(is_bot: bool) -> bool {
config_dir.join(filename).exists()
}
/// Remove OAuth session file (used when refresh fails, to allow legacy fallback)
pub fn remove_oauth_session(is_bot: bool) {
let config_dir = match dirs::config_dir() {
Some(d) => d.join(BUNDLE_ID),
None => return,
};
let filename = if is_bot {
"oauth_bot_session.json"
} else {
"oauth_session.json"
};
let _ = std::fs::remove_file(config_dir.join(filename));
}
// --- Save ---
fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> {

View File

@@ -25,7 +25,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
anyhow::bail!("Collection directory not found: {}", collection_dir);
}
println!("Pushing records from {} to {}", collection_dir, collection);
eprintln!("Pushing records from {} to {}", collection_dir, collection);
let mut count = 0;
for entry in fs::read_dir(&collection_dir)? {
@@ -59,7 +59,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
record,
};
println!("Pushing: {}", rkey);
eprintln!("Pushing: {}", rkey);
match client
.call::<_, PutRecordResponse>(
@@ -70,16 +70,16 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
.await
{
Ok(result) => {
println!(" OK: {}", result.uri);
eprintln!(" OK: {}", result.uri);
count += 1;
}
Err(e) => {
println!(" Failed: {}", e);
eprintln!(" Failed: {}", e);
}
}
}
println!("Pushed {} records to {}", count, collection);
eprintln!("Pushed {} records to {}", count, collection);
Ok(())
}

View File

@@ -31,16 +31,14 @@ pub async fn put_record(file: &str, collection: &str, rkey: Option<&str>) -> Res
record,
};
println!("Posting to {} with rkey: {}", collection, rkey);
println!("{}", serde_json::to_string_pretty(&req)?);
let result: PutRecordResponse = client
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
.await?;
println!("Success!");
println!(" URI: {}", result.uri);
println!(" CID: {}", result.cid);
println!("{}", serde_json::to_string_pretty(&serde_json::json!({
"uri": result.uri,
"cid": result.cid,
}))?);
Ok(())
}
@@ -67,16 +65,14 @@ pub async fn put_lexicon(file: &str) -> Result<()> {
record: lexicon,
};
println!("Putting lexicon: {}", lexicon_id);
println!("{}", serde_json::to_string_pretty(&req)?);
let result: PutRecordResponse = client
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
.await?;
println!("Success!");
println!(" URI: {}", result.uri);
println!(" CID: {}", result.cid);
println!("{}", serde_json::to_string_pretty(&serde_json::json!({
"uri": result.uri,
"cid": result.cid,
}))?);
Ok(())
}
@@ -100,13 +96,7 @@ pub async fn get_records(collection: &str, limit: u32) -> Result<()> {
)
.await?;
println!("Found {} records in {}", result.records.len(), collection);
for record in &result.records {
println!("---");
println!("URI: {}", record.uri);
println!("CID: {}", record.cid);
println!("{}", serde_json::to_string_pretty(&record.value)?);
}
println!("{}", serde_json::to_string_pretty(&result)?);
Ok(())
}
@@ -127,13 +117,15 @@ pub async fn delete_record(collection: &str, rkey: &str, is_bot: bool) -> Result
rkey: rkey.to_string(),
};
println!("Deleting {} from {}", rkey, collection);
client
.call_no_response(&com_atproto_repo::DELETE_RECORD, &req, &session.access_jwt)
.await?;
println!("Deleted successfully");
println!("{}", serde_json::to_string_pretty(&serde_json::json!({
"deleted": true,
"collection": collection,
"rkey": rkey,
}))?);
Ok(())
}

View File

@@ -18,7 +18,7 @@ pub async fn sync_to_local(
let session = token::load_bot_session()?;
let pds = session.pds.as_deref().unwrap_or("bsky.social");
let collection = collection_override.unwrap_or("ai.syui.log.chat");
println!(
eprintln!(
"Syncing bot data for {} ({})",
session.handle, session.did
);
@@ -33,7 +33,7 @@ pub async fn sync_to_local(
let config_value = super::token::load_config()?;
let config: Config = serde_json::from_value(config_value)?;
println!("Syncing data for {}", config.handle);
eprintln!("Syncing data for {}", config.handle);
// Resolve handle to DID
let resolve_url = format!(
@@ -79,8 +79,8 @@ pub async fn sync_to_local(
(did, pds, config.handle.clone(), collection)
};
println!("DID: {}", did);
println!("PDS: {}", pds);
eprintln!("DID: {}", did);
eprintln!("PDS: {}", pds);
// Remove https:// prefix for lexicons::url
let pds_host = pds.trim_start_matches("https://");
@@ -105,7 +105,7 @@ pub async fn sync_to_local(
"collections": describe.collections,
}))?;
fs::write(&describe_path, &describe_json)?;
println!("Saved: {}", describe_path);
eprintln!("Saved: {}", describe_path);
// 2. Sync profile
let profile_url = format!(
@@ -120,7 +120,7 @@ pub async fn sync_to_local(
fs::create_dir_all(&profile_dir)?;
let profile_path = format!("{}/self.json", profile_dir);
fs::write(&profile_path, serde_json::to_string_pretty(&profile)?)?;
println!("Saved: {}", profile_path);
eprintln!("Saved: {}", profile_path);
// Download avatar blob if present
if let Some(avatar_cid) = profile["value"]["avatar"]["ref"]["$link"].as_str() {
@@ -132,14 +132,14 @@ pub async fn sync_to_local(
"{}/xrpc/com.atproto.sync.getBlob?did={}&cid={}",
pds, did, avatar_cid
);
println!("Downloading avatar: {}", avatar_cid);
eprintln!("Downloading avatar: {}", avatar_cid);
let blob_res = client.get(&blob_url).send().await?;
if blob_res.status().is_success() {
let blob_bytes = blob_res.bytes().await?;
fs::write(&blob_path, &blob_bytes)?;
println!("Saved: {}", blob_path);
eprintln!("Saved: {}", blob_path);
} else {
println!("Failed to download avatar: {}", blob_res.status());
eprintln!("Failed to download avatar: {}", blob_res.status());
}
}
}
@@ -165,7 +165,7 @@ pub async fn sync_to_local(
let res = client.get(&records_url).send().await?;
if !res.status().is_success() {
println!("Failed to fetch records: {}", res.status());
eprintln!("Failed to fetch records: {}", res.status());
break;
}
@@ -182,7 +182,7 @@ pub async fn sync_to_local(
"value": record.value,
});
fs::write(&record_path, serde_json::to_string_pretty(&record_json)?)?;
println!("Saved: {}", record_path);
eprintln!("Saved: {}", record_path);
}
total_fetched += count;
@@ -211,15 +211,15 @@ pub async fn sync_to_local(
}
fs::write(&index_path, serde_json::to_string_pretty(&merged_rkeys)?)?;
println!("Saved: {}", index_path);
eprintln!("Saved: {}", index_path);
println!(
eprintln!(
"Synced {} records from {}",
total_fetched, collection
);
}
println!("Sync complete!");
eprintln!("Sync complete!");
Ok(())
}

View File

@@ -26,7 +26,7 @@ pub struct PutRecordResponse {
}
/// ATProto listRecords response
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct ListRecordsResponse {
pub records: Vec<Record>,
#[serde(default)]
@@ -34,7 +34,7 @@ pub struct ListRecordsResponse {
}
/// A single ATProto record (from listRecords / getRecord)
#[derive(Debug, Deserialize)]
#[derive(Debug, Serialize, Deserialize)]
pub struct Record {
pub uri: String,
pub cid: String,