use anyhow::{Context, Result}; use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine}; use ring::rand::SecureRandom; use ring::signature::{EcdsaKeyPair, ECDSA_P256_SHA256_FIXED_SIGNING, KeyPair}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::io::{self, Write}; use super::token::{self, Session, BUNDLE_ID}; #[derive(Debug, Deserialize)] struct SiteConfig { #[serde(rename = "siteUrl")] site_url: Option, } fn load_site_url() -> Result { // 1. Try public/config.json in current directory let local_path = std::path::Path::new("public/config.json"); if local_path.exists() { let content = std::fs::read_to_string(local_path)?; let config: SiteConfig = serde_json::from_str(&content)?; if let Some(url) = config.site_url { return Ok(url.trim_end_matches('/').to_string()); } } // 2. Fallback to ~/.config/ai.syui.log/config.json if let Some(cfg_dir) = dirs::config_dir() { let cfg_path = cfg_dir.join(BUNDLE_ID).join("config.json"); if cfg_path.exists() { let content = std::fs::read_to_string(&cfg_path)?; let config: SiteConfig = serde_json::from_str(&content)?; if let Some(url) = config.site_url { return Ok(url.trim_end_matches('/').to_string()); } } } anyhow::bail!( "No siteUrl found. Create public/config.json or run ailog oauth with --client-id" ); } fn percent_encode(s: &str) -> String { let mut result = String::with_capacity(s.len() * 2); for b in s.bytes() { match b { b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => { result.push(b as char); } _ => { result.push_str(&format!("%{:02X}", b)); } } } result } // --- Data types --- #[derive(Debug, Deserialize)] struct AuthServerMetadata { issuer: String, authorization_endpoint: String, token_endpoint: String, pushed_authorization_request_endpoint: String, } #[derive(Debug, Deserialize)] struct ParResponse { request_uri: String, #[allow(dead_code)] expires_in: Option, } #[derive(Debug, Deserialize)] struct TokenResponse { access_token: String, refresh_token: Option, token_type: String, #[allow(dead_code)] expires_in: Option, sub: Option, } #[derive(Debug, Deserialize)] struct TokenErrorResponse { error: String, #[allow(dead_code)] error_description: Option, } #[derive(Debug, Deserialize)] struct DidDocument { id: String, #[serde(default)] service: Vec, } #[derive(Debug, Deserialize)] struct DidService { id: String, #[serde(rename = "type")] service_type: String, #[serde(rename = "serviceEndpoint")] service_endpoint: String, } #[derive(Debug, Serialize, Deserialize)] struct DpopJwk { kty: String, crv: String, x: String, y: String, } #[derive(Debug, Serialize, Deserialize)] pub struct OAuthSession { pub did: String, pub handle: String, pub access_token: String, pub refresh_token: Option, pub pds: String, pub token_endpoint: String, pub issuer: String, dpop_pkcs8: String, dpop_jwk: DpopJwk, } // --- Handle / DID / PDS resolution --- async fn resolve_handle_to_did(handle: &str) -> Result { let client = reqwest::Client::new(); // Try DNS TXT _atproto.{handle} first is complex; use HTTPS resolution let url = format!( "https://public.api.bsky.app/xrpc/com.atproto.identity.resolveHandle?handle={}", handle ); let res = client.get(&url).send().await?; if !res.status().is_success() { anyhow::bail!("Failed to resolve handle '{}': {}", handle, res.status()); } #[derive(Deserialize)] struct R { did: String, } let r: R = res.json().await?; Ok(r.did) } async fn resolve_did_to_pds(did: &str) -> Result<(String, String)> { let client = reqwest::Client::new(); let doc: DidDocument = if did.starts_with("did:plc:") { let url = format!("https://plc.directory/{}", did); let res = client.get(&url).send().await?; if !res.status().is_success() { anyhow::bail!("Failed to resolve DID '{}': {}", did, res.status()); } res.json().await? } else if did.starts_with("did:web:") { let domain = did.strip_prefix("did:web:").unwrap(); let url = format!("https://{}/.well-known/did.json", domain); let res = client.get(&url).send().await?; res.json().await? } else { anyhow::bail!("Unsupported DID method: {}", did); }; // Find AtprotoPersonalDataServer service let pds_endpoint = doc .service .iter() .find(|s| s.id == "#atproto_pds" || s.service_type == "AtprotoPersonalDataServer") .map(|s| s.service_endpoint.clone()) .context("No PDS service found in DID document")?; Ok((doc.id, pds_endpoint)) } // --- OAuth metadata --- async fn fetch_auth_server_metadata(pds_url: &str) -> Result { let client = reqwest::Client::new(); let base = pds_url.trim_end_matches('/'); // Try PDS's own .well-known/oauth-authorization-server first let url = format!("{}/.well-known/oauth-authorization-server", base); let res = client.get(&url).send().await?; if res.status().is_success() { if let Ok(meta) = res.json::().await { return Ok(meta); } } // Fallback: check oauth-protected-resource for authorization_servers let pr_url = format!("{}/.well-known/oauth-protected-resource", base); let pr_res = client.get(&pr_url).send().await?; if pr_res.status().is_success() { #[derive(Deserialize)] struct ProtectedResource { authorization_servers: Vec, } if let Ok(pr) = pr_res.json::().await { for auth_server in &pr.authorization_servers { let as_url = format!( "{}/.well-known/oauth-authorization-server", auth_server.trim_end_matches('/') ); let as_res = client.get(&as_url).send().await?; if as_res.status().is_success() { return Ok(as_res.json().await?); } } } } anyhow::bail!( "Failed to fetch OAuth authorization server metadata for {}", base ); } // --- PKCE --- fn generate_pkce() -> Result<(String, String)> { let rng = ring::rand::SystemRandom::new(); let mut verifier_bytes = [0u8; 32]; rng.fill(&mut verifier_bytes) .map_err(|_| anyhow::anyhow!("Failed to generate random bytes"))?; let code_verifier = URL_SAFE_NO_PAD.encode(verifier_bytes); let challenge_hash = ring::digest::digest(&ring::digest::SHA256, code_verifier.as_bytes()); let code_challenge = URL_SAFE_NO_PAD.encode(challenge_hash.as_ref()); Ok((code_verifier, code_challenge)) } // --- DPoP key pair --- struct DpopKey { pkcs8_bytes: Vec, jwk: DpopJwk, } fn generate_dpop_keypair() -> Result { let rng = ring::rand::SystemRandom::new(); let pkcs8 = EcdsaKeyPair::generate_pkcs8(&ECDSA_P256_SHA256_FIXED_SIGNING, &rng) .map_err(|e| anyhow::anyhow!("Failed to generate ECDSA key: {}", e))?; let key_pair = EcdsaKeyPair::from_pkcs8(&ECDSA_P256_SHA256_FIXED_SIGNING, pkcs8.as_ref(), &rng) .map_err(|e| anyhow::anyhow!("Failed to parse generated key: {}", e))?; // Extract public key (uncompressed: 0x04 || x(32) || y(32)) let pub_key = key_pair.public_key().as_ref(); assert!(pub_key.len() == 65 && pub_key[0] == 0x04); let x = URL_SAFE_NO_PAD.encode(&pub_key[1..33]); let y = URL_SAFE_NO_PAD.encode(&pub_key[33..65]); Ok(DpopKey { pkcs8_bytes: pkcs8.as_ref().to_vec(), jwk: DpopJwk { kty: "EC".to_string(), crv: "P-256".to_string(), x, y, }, }) } // --- DPoP proof JWT --- fn create_dpop_proof( pkcs8_bytes: &[u8], jwk: &DpopJwk, method: &str, url: &str, nonce: Option<&str>, ath: Option<&str>, ) -> Result { let rng = ring::rand::SystemRandom::new(); let key_pair = EcdsaKeyPair::from_pkcs8(&ECDSA_P256_SHA256_FIXED_SIGNING, pkcs8_bytes, &rng) .map_err(|e| anyhow::anyhow!("Failed to load DPoP key: {}", e))?; // Header let header = serde_json::json!({ "typ": "dpop+jwt", "alg": "ES256", "jwk": { "kty": &jwk.kty, "crv": &jwk.crv, "x": &jwk.x, "y": &jwk.y, } }); // Generate jti let mut jti_bytes = [0u8; 16]; rng.fill(&mut jti_bytes) .map_err(|_| anyhow::anyhow!("Failed to generate jti"))?; let jti = URL_SAFE_NO_PAD.encode(jti_bytes); let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH)? .as_secs(); let mut payload = serde_json::json!({ "jti": jti, "htm": method, "htu": url, "iat": now, "exp": now + 120, }); if let Some(n) = nonce { payload["nonce"] = serde_json::Value::String(n.to_string()); } if let Some(a) = ath { payload["ath"] = serde_json::Value::String(a.to_string()); } let header_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&header)?.as_bytes()); let payload_b64 = URL_SAFE_NO_PAD.encode(serde_json::to_string(&payload)?.as_bytes()); let signing_input = format!("{}.{}", header_b64, payload_b64); let signature = key_pair .sign(&rng, signing_input.as_bytes()) .map_err(|e| anyhow::anyhow!("Failed to sign DPoP proof: {}", e))?; // ECDSA_P256_SHA256_FIXED_SIGNING produces r||s (64 bytes), which is exactly what JWS ES256 needs let sig_b64 = URL_SAFE_NO_PAD.encode(signature.as_ref()); Ok(format!("{}.{}", signing_input, sig_b64)) } // --- PAR --- async fn pushed_authorization_request( par_endpoint: &str, client_id: &str, redirect_uri: &str, code_challenge: &str, scope: &str, login_hint: &str, dpop_key: &DpopKey, ) -> Result { let client = reqwest::Client::new(); let mut dpop_nonce: Option = None; // Try up to 2 times (initial + nonce retry) for attempt in 0..2 { let dpop_proof = create_dpop_proof( &dpop_key.pkcs8_bytes, &dpop_key.jwk, "POST", par_endpoint, dpop_nonce.as_deref(), None, )?; let params = [ ("client_id", client_id), ("redirect_uri", redirect_uri), ("code_challenge", code_challenge), ("code_challenge_method", "S256"), ("response_type", "code"), ("scope", scope), ("login_hint", login_hint), ("state", "cli"), ]; let res = client .post(par_endpoint) .header("DPoP", &dpop_proof) .form(¶ms) .send() .await?; if res.status().is_success() { return Ok(res.json().await?); } // Check for use_dpop_nonce error let nonce_header = res .headers() .get("dpop-nonce") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); let body = res.text().await?; if attempt == 0 { if let Some(nonce) = nonce_header { // Check if it's a dpop nonce error if body.contains("use_dpop_nonce") { dpop_nonce = Some(nonce); continue; } } } anyhow::bail!("PAR request failed: {}", body); } anyhow::bail!("PAR request failed after nonce retry"); } // --- Token exchange --- async fn exchange_code( token_endpoint: &str, client_id: &str, redirect_uri: &str, code: &str, code_verifier: &str, dpop_key: &DpopKey, initial_nonce: Option<&str>, ) -> Result<(TokenResponse, Option)> { let client = reqwest::Client::new(); let mut dpop_nonce = initial_nonce.map(|s| s.to_string()); for attempt in 0..2 { let dpop_proof = create_dpop_proof( &dpop_key.pkcs8_bytes, &dpop_key.jwk, "POST", token_endpoint, dpop_nonce.as_deref(), None, )?; let mut params = HashMap::new(); params.insert("grant_type", "authorization_code"); params.insert("client_id", client_id); params.insert("redirect_uri", redirect_uri); params.insert("code", code); params.insert("code_verifier", code_verifier); let res = client .post(token_endpoint) .header("DPoP", &dpop_proof) .form(¶ms) .send() .await?; let new_nonce = res .headers() .get("dpop-nonce") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); if res.status().is_success() { let token_res: TokenResponse = res.json().await?; return Ok((token_res, new_nonce.or(dpop_nonce))); } let body = res.text().await?; if attempt == 0 { if let Some(nonce) = new_nonce { if body.contains("use_dpop_nonce") { dpop_nonce = Some(nonce); continue; } } } // Try to parse error for better message if let Ok(err) = serde_json::from_str::(&body) { anyhow::bail!("Token exchange failed: {}", err.error); } anyhow::bail!("Token exchange failed: {}", body); } anyhow::bail!("Token exchange failed after nonce retry"); } // --- Token refresh --- pub async fn refresh_oauth_session(is_bot: bool) -> Result<(OAuthSession, Session)> { let oauth = load_oauth_session(is_bot)?; let pkcs8_bytes = URL_SAFE_NO_PAD .decode(&oauth.dpop_pkcs8) .context("Failed to decode DPoP key")?; let client = reqwest::Client::new(); let mut dpop_nonce: Option = None; for attempt in 0..2 { let dpop_proof = create_dpop_proof( &pkcs8_bytes, &oauth.dpop_jwk, "POST", &oauth.token_endpoint, dpop_nonce.as_deref(), None, )?; let params = [ ("grant_type", "refresh_token"), ("refresh_token", oauth.refresh_token.as_deref().unwrap_or("")), ("client_id", &oauth.issuer), ]; let site_url = load_site_url()?; let client_id_url = format!("{}/client-metadata.json", site_url); let form_params = [ ("grant_type", "refresh_token"), ( "refresh_token", oauth.refresh_token.as_deref().unwrap_or(""), ), ("client_id", &client_id_url), ]; let res = client .post(&oauth.token_endpoint) .header("DPoP", &dpop_proof) .form(&form_params) .send() .await?; let new_nonce = res .headers() .get("dpop-nonce") .and_then(|v| v.to_str().ok()) .map(|s| s.to_string()); if res.status().is_success() { let token_res: TokenResponse = res.json().await?; let sub = token_res.sub.as_deref().unwrap_or(&oauth.did); let new_oauth = OAuthSession { did: sub.to_string(), handle: oauth.handle.clone(), access_token: token_res.access_token.clone(), refresh_token: token_res.refresh_token.or(oauth.refresh_token), pds: oauth.pds.clone(), token_endpoint: oauth.token_endpoint.clone(), issuer: oauth.issuer.clone(), dpop_pkcs8: oauth.dpop_pkcs8.clone(), dpop_jwk: DpopJwk { kty: oauth.dpop_jwk.kty.clone(), crv: oauth.dpop_jwk.crv.clone(), x: oauth.dpop_jwk.x.clone(), y: oauth.dpop_jwk.y.clone(), }, }; save_oauth_session(&new_oauth, is_bot)?; let pds_host = oauth .pds .strip_prefix("https://") .unwrap_or(&oauth.pds) .trim_end_matches('/'); let compat = Session { did: sub.to_string(), handle: oauth.handle.clone(), access_jwt: token_res.access_token, refresh_jwt: new_oauth.refresh_token.clone().unwrap_or_default(), pds: Some(pds_host.to_string()), }; if is_bot { token::save_bot_session(&compat)?; } else { token::save_session(&compat)?; } return Ok((new_oauth, compat)); } let body = res.text().await?; if attempt == 0 { if let Some(nonce) = new_nonce { if body.contains("use_dpop_nonce") { dpop_nonce = Some(nonce); let _ = params; continue; } } } anyhow::bail!("OAuth token refresh failed: {}", body); } anyhow::bail!("OAuth token refresh failed after nonce retry"); } /// Create a DPoP proof for an API request with optional nonce pub fn create_dpop_proof_for_request_with_nonce( oauth: &OAuthSession, method: &str, url: &str, nonce: Option<&str>, ) -> Result { let pkcs8_bytes = URL_SAFE_NO_PAD .decode(&oauth.dpop_pkcs8) .context("Failed to decode DPoP key")?; // Compute ath (access token hash) let ath_hash = ring::digest::digest(&ring::digest::SHA256, oauth.access_token.as_bytes()); let ath = URL_SAFE_NO_PAD.encode(ath_hash.as_ref()); create_dpop_proof(&pkcs8_bytes, &oauth.dpop_jwk, method, url, nonce, Some(&ath)) } // --- Load OAuth session --- pub fn load_oauth_session(is_bot: bool) -> Result { let config_dir = dirs::config_dir() .context("Could not find config directory")? .join(BUNDLE_ID); let filename = if is_bot { "oauth_bot_session.json" } else { "oauth_session.json" }; let path = config_dir.join(filename); let content = std::fs::read_to_string(&path) .with_context(|| format!("OAuth session not found: {:?}", path))?; let session: OAuthSession = serde_json::from_str(&content)?; Ok(session) } /// Check if OAuth session exists pub fn has_oauth_session(is_bot: bool) -> bool { let config_dir = match dirs::config_dir() { Some(d) => d.join(BUNDLE_ID), None => return false, }; let filename = if is_bot { "oauth_bot_session.json" } else { "oauth_session.json" }; config_dir.join(filename).exists() } // --- Save --- fn save_oauth_session(session: &OAuthSession, is_bot: bool) -> Result<()> { let config_dir = dirs::config_dir() .context("Could not find config directory")? .join(BUNDLE_ID); std::fs::create_dir_all(&config_dir)?; let filename = if is_bot { "oauth_bot_session.json" } else { "oauth_session.json" }; let path = config_dir.join(filename); let content = serde_json::to_string_pretty(session)?; std::fs::write(&path, content)?; println!("OAuth session saved to {:?}", path); Ok(()) } // --- Main entry --- pub async fn oauth_login(handle: &str, is_bot: bool) -> Result<()> { let account_type = if is_bot { "bot" } else { "user" }; println!("Starting OAuth login for {} ({})...", handle, account_type); // 1. Resolve handle → DID → PDS println!("Resolving handle..."); let did = resolve_handle_to_did(handle).await?; println!("DID: {}", did); let (_, pds_url) = resolve_did_to_pds(&did).await?; println!("PDS: {}", pds_url); // 2. Fetch OAuth metadata println!("Fetching OAuth metadata..."); let meta = fetch_auth_server_metadata(&pds_url).await?; // 3. Generate PKCE let (code_verifier, code_challenge) = generate_pkce()?; // 4. Generate DPoP key pair let dpop_key = generate_dpop_keypair()?; // 5. Client metadata (derived from config.json siteUrl) let site_url = load_site_url()?; let client_id = format!("{}/client-metadata.json", site_url); let scope = "atproto transition:generic"; // Try /oauth/cli first, fallback to /oauth/callback let redirect_candidates = [ format!("{}/oauth/cli", site_url), format!("{}/oauth/callback", site_url), ]; let mut redirect_uri = String::new(); let mut par_res: Option = None; // 6. PAR (try each redirect_uri) println!("Sending authorization request..."); for candidate in &redirect_candidates { match pushed_authorization_request( &meta.pushed_authorization_request_endpoint, &client_id, candidate, &code_challenge, scope, &did, &dpop_key, ) .await { Ok(res) => { redirect_uri = candidate.clone(); par_res = Some(res); break; } Err(e) => { let msg = e.to_string(); if msg.contains("Invalid redirect_uri") && candidate != redirect_candidates.last().unwrap() { println!(" {} not accepted, trying fallback...", candidate); // Regenerate DPoP key for retry (nonce may have changed) continue; } return Err(e); } } } let par_res = par_res.context("All redirect_uri candidates rejected by PDS")?; // 7. Build authorize URL let auth_url = format!( "{}?client_id={}&request_uri={}", meta.authorization_endpoint, percent_encode(&client_id), percent_encode(&par_res.request_uri), ); println!("\nOpen this URL in your browser to authorize:\n"); println!(" {}\n", auth_url); println!("After authorizing, paste the code from the browser here."); print!("Code: "); io::stdout().flush()?; let mut code = String::new(); io::stdin().read_line(&mut code)?; let code = code.trim(); if code.is_empty() { anyhow::bail!("No authorization code provided"); } // 8. Exchange code for tokens println!("Exchanging code for tokens..."); let (token_res, _dpop_nonce) = exchange_code( &meta.token_endpoint, &client_id, &redirect_uri, code, &code_verifier, &dpop_key, None, ) .await?; if token_res.token_type.to_lowercase() != "dpop" { println!( "Warning: Expected DPoP token type, got '{}'", token_res.token_type ); } let resolved_did = token_res.sub.as_deref().unwrap_or(&did); // 9. Save OAuth session (DPoP keys + tokens) let oauth_session = OAuthSession { did: resolved_did.to_string(), handle: handle.to_string(), access_token: token_res.access_token.clone(), refresh_token: token_res.refresh_token.clone(), pds: pds_url.clone(), token_endpoint: meta.token_endpoint.clone(), issuer: meta.issuer.clone(), dpop_pkcs8: URL_SAFE_NO_PAD.encode(&dpop_key.pkcs8_bytes), dpop_jwk: dpop_key.jwk, }; save_oauth_session(&oauth_session, is_bot)?; // 10. Save compatible Session (for existing commands) let pds_host = pds_url .strip_prefix("https://") .unwrap_or(&pds_url) .trim_end_matches('/'); let compat_session = Session { did: resolved_did.to_string(), handle: handle.to_string(), access_jwt: token_res.access_token, refresh_jwt: token_res.refresh_token.unwrap_or_default(), pds: Some(pds_host.to_string()), }; if is_bot { token::save_bot_session(&compat_session)?; println!("Bot session saved."); } else { token::save_session(&compat_session)?; println!("Session saved."); } println!( "Logged in as {} ({}) via OAuth", compat_session.handle, compat_session.did ); Ok(()) }