fix login
This commit is contained in:
@@ -69,7 +69,7 @@ pub async fn login(handle: &str, password: &str, pds: &str, is_bot: bool) -> Res
|
|||||||
} else {
|
} else {
|
||||||
token::save_session(&session)?;
|
token::save_session(&session)?;
|
||||||
}
|
}
|
||||||
println!("Logged in as {} ({})", session.handle, session.did);
|
eprintln!("Logged in as {} ({})", session.handle, session.did);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -91,14 +91,66 @@ async fn do_refresh(session: &Session, pds: &str) -> Result<Session> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Refresh access token (OAuth-aware: tries OAuth first, falls back to legacy)
|
/// Check if a JWT access token is still valid (with 60s margin)
|
||||||
pub async fn refresh_session() -> Result<Session> {
|
fn is_token_valid(token: &str) -> bool {
|
||||||
if oauth::has_oauth_session(false) {
|
let parts: Vec<&str> = token.split('.').collect();
|
||||||
match oauth::refresh_oauth_session(false).await {
|
if parts.len() < 2 {
|
||||||
Ok((_oauth, session)) => return Ok(session),
|
return false;
|
||||||
Err(_) => { /* OAuth failed, fall back to legacy */ }
|
}
|
||||||
|
// Decode JWT payload
|
||||||
|
use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine};
|
||||||
|
let payload = match URL_SAFE_NO_PAD.decode(parts[1]) {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(_) => return false,
|
||||||
|
};
|
||||||
|
let json: serde_json::Value = match serde_json::from_slice(&payload) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => return false,
|
||||||
|
};
|
||||||
|
let exp = json["exp"].as_i64().unwrap_or(0);
|
||||||
|
let now = std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as i64;
|
||||||
|
exp > now + 60 // valid if expires more than 60s from now
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get OAuth session and return it if access_token is still valid.
|
||||||
|
/// If expired, refresh it. If refresh fails, remove the OAuth session file
|
||||||
|
/// so we cleanly fall back to legacy on next call.
|
||||||
|
async fn try_oauth_session(is_bot: bool) -> Option<Session> {
|
||||||
|
if !oauth::has_oauth_session(is_bot) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
let oauth_session = match oauth::load_oauth_session(is_bot) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => return None,
|
||||||
|
};
|
||||||
|
// If token is still valid, use it without refreshing
|
||||||
|
if is_token_valid(&oauth_session.access_token) {
|
||||||
|
let session = if is_bot {
|
||||||
|
token::load_bot_session().ok()
|
||||||
|
} else {
|
||||||
|
token::load_session().ok()
|
||||||
|
};
|
||||||
|
return session;
|
||||||
|
}
|
||||||
|
// Token expired, try refresh
|
||||||
|
match oauth::refresh_oauth_session(is_bot).await {
|
||||||
|
Ok((_oauth, session)) => Some(session),
|
||||||
|
Err(_) => {
|
||||||
|
// Refresh failed — remove broken OAuth session so legacy works
|
||||||
|
oauth::remove_oauth_session(is_bot);
|
||||||
|
None
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Refresh access token (OAuth-aware: uses cached token if still valid)
|
||||||
|
pub async fn refresh_session() -> Result<Session> {
|
||||||
|
if let Some(session) = try_oauth_session(false).await {
|
||||||
|
return Ok(session);
|
||||||
|
}
|
||||||
|
|
||||||
let session = token::load_session()?;
|
let session = token::load_session()?;
|
||||||
let pds = session.pds.as_deref().unwrap_or("bsky.social");
|
let pds = session.pds.as_deref().unwrap_or("bsky.social");
|
||||||
@@ -109,13 +161,10 @@ pub async fn refresh_session() -> Result<Session> {
|
|||||||
Ok(new_session)
|
Ok(new_session)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Refresh bot access token (OAuth-aware)
|
/// Refresh bot access token (OAuth-aware: uses cached token if still valid)
|
||||||
pub async fn refresh_bot_session() -> Result<Session> {
|
pub async fn refresh_bot_session() -> Result<Session> {
|
||||||
if oauth::has_oauth_session(true) {
|
if let Some(session) = try_oauth_session(true).await {
|
||||||
match oauth::refresh_oauth_session(true).await {
|
return Ok(session);
|
||||||
Ok((_oauth, session)) => return Ok(session),
|
|
||||||
Err(_) => { /* OAuth failed, fall back to legacy */ }
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let session = token::load_bot_session()?;
|
let session = token::load_bot_session()?;
|
||||||
|
|||||||
@@ -89,7 +89,7 @@ pub async fn get_memory(download: bool) -> Result<()> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Downloaded {} memory records", count);
|
eprintln!("Downloaded {} memory records", count);
|
||||||
} else {
|
} else {
|
||||||
// Show latest only
|
// Show latest only
|
||||||
let result: ListRecordsResponse = client
|
let result: ListRecordsResponse = client
|
||||||
@@ -186,7 +186,7 @@ pub async fn push(collection_name: &str) -> Result<()> {
|
|||||||
anyhow::bail!("Collection directory not found: {}", collection_dir.display());
|
anyhow::bail!("Collection directory not found: {}", collection_dir.display());
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Pushing {} records from {}", collection_name, collection_dir.display());
|
eprintln!("Pushing {} records from {}", collection_name, collection_dir.display());
|
||||||
|
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
for entry in fs::read_dir(&collection_dir)? {
|
for entry in fs::read_dir(&collection_dir)? {
|
||||||
@@ -218,7 +218,7 @@ pub async fn push(collection_name: &str) -> Result<()> {
|
|||||||
record,
|
record,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("Pushing: {}", rkey);
|
eprintln!("Pushing: {}", rkey);
|
||||||
|
|
||||||
match client
|
match client
|
||||||
.call::<_, PutRecordResponse>(
|
.call::<_, PutRecordResponse>(
|
||||||
@@ -229,16 +229,16 @@ pub async fn push(collection_name: &str) -> Result<()> {
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
println!(" OK: {}", result.uri);
|
eprintln!(" OK: {}", result.uri);
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!(" Failed: {}", e);
|
eprintln!(" Failed: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Pushed {} records to {}", count, collection);
|
eprintln!("Pushed {} records to {}", count, collection);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,7 +259,7 @@ fn save_record(did: &str, collection: &str, rkey: &str, record: &Record) -> Resu
|
|||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
json.serialize(&mut ser)?;
|
json.serialize(&mut ser)?;
|
||||||
fs::write(&path, String::from_utf8(buf)?)?;
|
fs::write(&path, String::from_utf8(buf)?)?;
|
||||||
println!("Saved: {}", path.display());
|
eprintln!("Saved: {}", path.display());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -620,6 +620,20 @@ pub fn has_oauth_session(is_bot: bool) -> bool {
|
|||||||
config_dir.join(filename).exists()
|
config_dir.join(filename).exists()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Remove OAuth session file (used when refresh fails, to allow legacy fallback)
|
||||||
|
pub fn remove_oauth_session(is_bot: bool) {
|
||||||
|
let config_dir = match dirs::config_dir() {
|
||||||
|
Some(d) => d.join(BUNDLE_ID),
|
||||||
|
None => return,
|
||||||
|
};
|
||||||
|
let filename = if is_bot {
|
||||||
|
"oauth_bot_session.json"
|
||||||
|
} else {
|
||||||
|
"oauth_session.json"
|
||||||
|
};
|
||||||
|
let _ = std::fs::remove_file(config_dir.join(filename));
|
||||||
|
}
|
||||||
|
|
||||||
// --- Save ---
|
// --- Save ---
|
||||||
|
|
||||||
fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> {
|
fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> {
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
|
|||||||
anyhow::bail!("Collection directory not found: {}", collection_dir);
|
anyhow::bail!("Collection directory not found: {}", collection_dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Pushing records from {} to {}", collection_dir, collection);
|
eprintln!("Pushing records from {} to {}", collection_dir, collection);
|
||||||
|
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
for entry in fs::read_dir(&collection_dir)? {
|
for entry in fs::read_dir(&collection_dir)? {
|
||||||
@@ -59,7 +59,7 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
|
|||||||
record,
|
record,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("Pushing: {}", rkey);
|
eprintln!("Pushing: {}", rkey);
|
||||||
|
|
||||||
match client
|
match client
|
||||||
.call::<_, PutRecordResponse>(
|
.call::<_, PutRecordResponse>(
|
||||||
@@ -70,16 +70,16 @@ pub async fn push_to_remote(input: &str, collection: &str, is_bot: bool) -> Resu
|
|||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
println!(" OK: {}", result.uri);
|
eprintln!(" OK: {}", result.uri);
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
println!(" Failed: {}", e);
|
eprintln!(" Failed: {}", e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Pushed {} records to {}", count, collection);
|
eprintln!("Pushed {} records to {}", count, collection);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -31,16 +31,14 @@ pub async fn put_record(file: &str, collection: &str, rkey: Option<&str>) -> Res
|
|||||||
record,
|
record,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("Posting to {} with rkey: {}", collection, rkey);
|
|
||||||
println!("{}", serde_json::to_string_pretty(&req)?);
|
|
||||||
|
|
||||||
let result: PutRecordResponse = client
|
let result: PutRecordResponse = client
|
||||||
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
|
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("Success!");
|
println!("{}", serde_json::to_string_pretty(&serde_json::json!({
|
||||||
println!(" URI: {}", result.uri);
|
"uri": result.uri,
|
||||||
println!(" CID: {}", result.cid);
|
"cid": result.cid,
|
||||||
|
}))?);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -67,16 +65,14 @@ pub async fn put_lexicon(file: &str) -> Result<()> {
|
|||||||
record: lexicon,
|
record: lexicon,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("Putting lexicon: {}", lexicon_id);
|
|
||||||
println!("{}", serde_json::to_string_pretty(&req)?);
|
|
||||||
|
|
||||||
let result: PutRecordResponse = client
|
let result: PutRecordResponse = client
|
||||||
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
|
.call(&com_atproto_repo::PUT_RECORD, &req, &session.access_jwt)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("Success!");
|
println!("{}", serde_json::to_string_pretty(&serde_json::json!({
|
||||||
println!(" URI: {}", result.uri);
|
"uri": result.uri,
|
||||||
println!(" CID: {}", result.cid);
|
"cid": result.cid,
|
||||||
|
}))?);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -100,13 +96,7 @@ pub async fn get_records(collection: &str, limit: u32) -> Result<()> {
|
|||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("Found {} records in {}", result.records.len(), collection);
|
println!("{}", serde_json::to_string_pretty(&result)?);
|
||||||
for record in &result.records {
|
|
||||||
println!("---");
|
|
||||||
println!("URI: {}", record.uri);
|
|
||||||
println!("CID: {}", record.cid);
|
|
||||||
println!("{}", serde_json::to_string_pretty(&record.value)?);
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
@@ -127,13 +117,15 @@ pub async fn delete_record(collection: &str, rkey: &str, is_bot: bool) -> Result
|
|||||||
rkey: rkey.to_string(),
|
rkey: rkey.to_string(),
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("Deleting {} from {}", rkey, collection);
|
|
||||||
|
|
||||||
client
|
client
|
||||||
.call_no_response(&com_atproto_repo::DELETE_RECORD, &req, &session.access_jwt)
|
.call_no_response(&com_atproto_repo::DELETE_RECORD, &req, &session.access_jwt)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
println!("Deleted successfully");
|
println!("{}", serde_json::to_string_pretty(&serde_json::json!({
|
||||||
|
"deleted": true,
|
||||||
|
"collection": collection,
|
||||||
|
"rkey": rkey,
|
||||||
|
}))?);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ pub async fn sync_to_local(
|
|||||||
let session = token::load_bot_session()?;
|
let session = token::load_bot_session()?;
|
||||||
let pds = session.pds.as_deref().unwrap_or("bsky.social");
|
let pds = session.pds.as_deref().unwrap_or("bsky.social");
|
||||||
let collection = collection_override.unwrap_or("ai.syui.log.chat");
|
let collection = collection_override.unwrap_or("ai.syui.log.chat");
|
||||||
println!(
|
eprintln!(
|
||||||
"Syncing bot data for {} ({})",
|
"Syncing bot data for {} ({})",
|
||||||
session.handle, session.did
|
session.handle, session.did
|
||||||
);
|
);
|
||||||
@@ -33,7 +33,7 @@ pub async fn sync_to_local(
|
|||||||
let config_value = super::token::load_config()?;
|
let config_value = super::token::load_config()?;
|
||||||
let config: Config = serde_json::from_value(config_value)?;
|
let config: Config = serde_json::from_value(config_value)?;
|
||||||
|
|
||||||
println!("Syncing data for {}", config.handle);
|
eprintln!("Syncing data for {}", config.handle);
|
||||||
|
|
||||||
// Resolve handle to DID
|
// Resolve handle to DID
|
||||||
let resolve_url = format!(
|
let resolve_url = format!(
|
||||||
@@ -79,8 +79,8 @@ pub async fn sync_to_local(
|
|||||||
(did, pds, config.handle.clone(), collection)
|
(did, pds, config.handle.clone(), collection)
|
||||||
};
|
};
|
||||||
|
|
||||||
println!("DID: {}", did);
|
eprintln!("DID: {}", did);
|
||||||
println!("PDS: {}", pds);
|
eprintln!("PDS: {}", pds);
|
||||||
|
|
||||||
// Remove https:// prefix for lexicons::url
|
// Remove https:// prefix for lexicons::url
|
||||||
let pds_host = pds.trim_start_matches("https://");
|
let pds_host = pds.trim_start_matches("https://");
|
||||||
@@ -105,7 +105,7 @@ pub async fn sync_to_local(
|
|||||||
"collections": describe.collections,
|
"collections": describe.collections,
|
||||||
}))?;
|
}))?;
|
||||||
fs::write(&describe_path, &describe_json)?;
|
fs::write(&describe_path, &describe_json)?;
|
||||||
println!("Saved: {}", describe_path);
|
eprintln!("Saved: {}", describe_path);
|
||||||
|
|
||||||
// 2. Sync profile
|
// 2. Sync profile
|
||||||
let profile_url = format!(
|
let profile_url = format!(
|
||||||
@@ -120,7 +120,7 @@ pub async fn sync_to_local(
|
|||||||
fs::create_dir_all(&profile_dir)?;
|
fs::create_dir_all(&profile_dir)?;
|
||||||
let profile_path = format!("{}/self.json", profile_dir);
|
let profile_path = format!("{}/self.json", profile_dir);
|
||||||
fs::write(&profile_path, serde_json::to_string_pretty(&profile)?)?;
|
fs::write(&profile_path, serde_json::to_string_pretty(&profile)?)?;
|
||||||
println!("Saved: {}", profile_path);
|
eprintln!("Saved: {}", profile_path);
|
||||||
|
|
||||||
// Download avatar blob if present
|
// Download avatar blob if present
|
||||||
if let Some(avatar_cid) = profile["value"]["avatar"]["ref"]["$link"].as_str() {
|
if let Some(avatar_cid) = profile["value"]["avatar"]["ref"]["$link"].as_str() {
|
||||||
@@ -132,14 +132,14 @@ pub async fn sync_to_local(
|
|||||||
"{}/xrpc/com.atproto.sync.getBlob?did={}&cid={}",
|
"{}/xrpc/com.atproto.sync.getBlob?did={}&cid={}",
|
||||||
pds, did, avatar_cid
|
pds, did, avatar_cid
|
||||||
);
|
);
|
||||||
println!("Downloading avatar: {}", avatar_cid);
|
eprintln!("Downloading avatar: {}", avatar_cid);
|
||||||
let blob_res = client.get(&blob_url).send().await?;
|
let blob_res = client.get(&blob_url).send().await?;
|
||||||
if blob_res.status().is_success() {
|
if blob_res.status().is_success() {
|
||||||
let blob_bytes = blob_res.bytes().await?;
|
let blob_bytes = blob_res.bytes().await?;
|
||||||
fs::write(&blob_path, &blob_bytes)?;
|
fs::write(&blob_path, &blob_bytes)?;
|
||||||
println!("Saved: {}", blob_path);
|
eprintln!("Saved: {}", blob_path);
|
||||||
} else {
|
} else {
|
||||||
println!("Failed to download avatar: {}", blob_res.status());
|
eprintln!("Failed to download avatar: {}", blob_res.status());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -165,7 +165,7 @@ pub async fn sync_to_local(
|
|||||||
|
|
||||||
let res = client.get(&records_url).send().await?;
|
let res = client.get(&records_url).send().await?;
|
||||||
if !res.status().is_success() {
|
if !res.status().is_success() {
|
||||||
println!("Failed to fetch records: {}", res.status());
|
eprintln!("Failed to fetch records: {}", res.status());
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -182,7 +182,7 @@ pub async fn sync_to_local(
|
|||||||
"value": record.value,
|
"value": record.value,
|
||||||
});
|
});
|
||||||
fs::write(&record_path, serde_json::to_string_pretty(&record_json)?)?;
|
fs::write(&record_path, serde_json::to_string_pretty(&record_json)?)?;
|
||||||
println!("Saved: {}", record_path);
|
eprintln!("Saved: {}", record_path);
|
||||||
}
|
}
|
||||||
|
|
||||||
total_fetched += count;
|
total_fetched += count;
|
||||||
@@ -211,15 +211,15 @@ pub async fn sync_to_local(
|
|||||||
}
|
}
|
||||||
|
|
||||||
fs::write(&index_path, serde_json::to_string_pretty(&merged_rkeys)?)?;
|
fs::write(&index_path, serde_json::to_string_pretty(&merged_rkeys)?)?;
|
||||||
println!("Saved: {}", index_path);
|
eprintln!("Saved: {}", index_path);
|
||||||
|
|
||||||
println!(
|
eprintln!(
|
||||||
"Synced {} records from {}",
|
"Synced {} records from {}",
|
||||||
total_fetched, collection
|
total_fetched, collection
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("Sync complete!");
|
eprintln!("Sync complete!");
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ pub struct PutRecordResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// ATProto listRecords response
|
/// ATProto listRecords response
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct ListRecordsResponse {
|
pub struct ListRecordsResponse {
|
||||||
pub records: Vec<Record>,
|
pub records: Vec<Record>,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
@@ -34,7 +34,7 @@ pub struct ListRecordsResponse {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// A single ATProto record (from listRecords / getRecord)
|
/// A single ATProto record (from listRecords / getRecord)
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
pub struct Record {
|
pub struct Record {
|
||||||
pub uri: String,
|
pub uri: String,
|
||||||
pub cid: String,
|
pub cid: String,
|
||||||
|
|||||||
Reference in New Issue
Block a user