diff --git a/Cargo.toml b/Cargo.toml index 5471892..a90413c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,3 +15,4 @@ uuid = { version = "1", features = ["v4"] } tracing = "0.1" tracing-subscriber = "0.3" base64 = "0.22" +urlencoding = "2" diff --git a/src/main.rs b/src/main.rs index 6da66c4..f607aad 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{Query, State}, + extract::{Query, RawQuery, State}, http::{HeaderMap, StatusCode}, response::Json, routing::{get, post}, @@ -8,10 +8,38 @@ use axum::{ use chrono::Utc; use rusqlite::Connection; use serde::{Deserialize, Serialize}; -use std::sync::{Arc, Mutex}; +use std::sync::{Arc, Mutex, OnceLock}; use tower_http::cors::{Any, CorsLayer}; use uuid::Uuid; +// --- Namespace --- +// CHAT_DOMAIN env (default: "bsky.chat") → reversed for XRPC/types: "chat.bsky" + +static NS: OnceLock = OnceLock::new(); + +fn ns() -> &'static str { + NS.get().expect("namespace not initialized") +} + +fn ns_type(name: &str) -> String { + format!("{}.convo.defs#{}", ns(), name) +} + +fn ns_route(method: &str) -> String { + format!("/xrpc/{}.convo.{}", ns(), method) +} + +fn init_namespace() { + let domain = std::env::var("CHAT_DOMAIN").unwrap_or_else(|_| "bsky.chat".into()); + let parts: Vec<&str> = domain.split('.').collect(); + let reversed = if parts.len() == 2 { + format!("{}.{}", parts[1], parts[0]) + } else { + domain + }; + NS.set(reversed).expect("namespace already initialized"); +} + // --- Types --- #[derive(Debug, Clone, Serialize, Deserialize)] @@ -419,7 +447,7 @@ fn get_or_create_convo(conn: &Connection, members: &[String]) -> ConvoView { } add_log(conn, &id, &rev, "logBeginConvo", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logBeginConvo", + "$type": ns_type("logBeginConvo"), "rev": rev, "convoId": id, })); @@ -518,7 +546,7 @@ fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result(0)?, "rev": row.get::<_, String>(1)?, "sender": { "did": sender_did }, @@ -528,7 +556,7 @@ fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result = row.get(3)?; let embed: Option = row.get(4)?; let mut msg = serde_json::json!({ - "$type": "chat.bsky.convo.defs#messageView", + "$type": ns_type("messageView"), "id": row.get::<_, String>(0)?, "rev": row.get::<_, String>(1)?, "text": row.get::<_, String>(2)?, @@ -570,11 +598,11 @@ fn insert_message(conn: &Connection, convo_id: &str, did: &str, msg: &MessageInp .unwrap(); add_log(conn, convo_id, &rev, "logCreateMessage", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logCreateMessage", + "$type": ns_type("logCreateMessage"), "rev": rev, "convoId": convo_id, "message": { - "$type": "chat.bsky.convo.defs#messageView", + "$type": ns_type("messageView"), "id": msg_id, "rev": rev, "text": msg.text, "sender": { "did": did }, "sentAt": now, } @@ -645,7 +673,7 @@ async fn send_message( let (msg_id, rev, now) = insert_message(&db, &req.convo_id, &did, &req.message); Ok(Json(MessageView { - typ: "chat.bsky.convo.defs#messageView".into(), + typ: ns_type("messageView"), id: msg_id, rev, text: req.message.text, facets: req.message.facets, @@ -668,7 +696,7 @@ async fn send_message_batch( for item in &req.items { let (msg_id, rev, now) = insert_message(&db, &item.convo_id, &did, &item.message); items.push(serde_json::json!({ - "$type": "chat.bsky.convo.defs#messageView", + "$type": ns_type("messageView"), "id": msg_id, "rev": rev, "text": item.message.text, "sender": { "did": did }, "sentAt": now, })); @@ -718,34 +746,46 @@ async fn get_convo( Ok(Json(GetConvoResp { convo: load_convo(&db, ¶ms.convo_id, Some(&did)) })) } +/// Parse repeated query params: ?members=did1&members=did2 +fn parse_members_query(query: &Option) -> Vec { + query.as_deref().unwrap_or("").split('&') + .filter_map(|pair| { + let (key, val) = pair.split_once('=')?; + if key == "members" { Some(urlencoding::decode(val).ok()?.into_owned()) } else { None } + }) + .collect() +} + async fn get_convo_for_members( State(state): State>, headers: HeaderMap, - Query(params): Query, + RawQuery(query): RawQuery, ) -> Result, (StatusCode, Json)> { let did = require_auth(&headers)?; - if !params.members.contains(&did) { + let members = parse_members_query(&query); + if !members.contains(&did) { return Err((StatusCode::BAD_REQUEST, Json(ErrorResp { error: "InvalidRequest".into(), message: "Caller must be a member".into(), }))); } let db = state.db.lock().unwrap(); - for m in ¶ms.members { + for m in &members { ensure_account(&db, m); } - let convo = get_or_create_convo(&db, ¶ms.members); + let convo = get_or_create_convo(&db, &members); Ok(Json(GetConvoResp { convo })) } async fn get_convo_availability( State(state): State>, headers: HeaderMap, - Query(params): Query, + RawQuery(query): RawQuery, ) -> Result, (StatusCode, Json)> { let _did = require_auth(&headers)?; let db = state.db.lock().unwrap(); - let members: Vec = params.members.iter().map(|did| { + let members_raw = parse_members_query(&query); + let members: Vec = members_raw.iter().map(|did| { ensure_account(&db, did); MemberAvailability { did: did.clone(), @@ -861,7 +901,7 @@ async fn accept_convo( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logAcceptConvo", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logAcceptConvo", + "$type": ns_type("logAcceptConvo"), "rev": rev, "convoId": req.convo_id, })); @@ -884,7 +924,7 @@ async fn mute_convo( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logMuteConvo", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logMuteConvo", + "$type": ns_type("logMuteConvo"), "rev": rev, "convoId": req.convo_id, })); @@ -907,7 +947,7 @@ async fn unmute_convo( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logUnmuteConvo", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logUnmuteConvo", + "$type": ns_type("logUnmuteConvo"), "rev": rev, "convoId": req.convo_id, })); @@ -931,17 +971,17 @@ async fn delete_message_for_self( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logDeleteMessage", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logDeleteMessage", + "$type": ns_type("logDeleteMessage"), "rev": rev, "convoId": req.convo_id, "message": { - "$type": "chat.bsky.convo.defs#deletedMessageView", + "$type": ns_type("deletedMessageView"), "id": req.message_id, "rev": rev, "sender": { "did": did }, "sentAt": now, } })); Ok(Json(DeletedMessageView { - typ: "chat.bsky.convo.defs#deletedMessageView".into(), + typ: ns_type("deletedMessageView"), id: req.message_id, rev, sender: MessageViewSender { did }, sent_at: now, @@ -963,7 +1003,7 @@ async fn leave_convo( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logLeaveConvo", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logLeaveConvo", + "$type": ns_type("logLeaveConvo"), "rev": rev, "convoId": req.convo_id, })); @@ -987,9 +1027,9 @@ async fn add_reaction( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logAddReaction", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logAddReaction", + "$type": ns_type("logAddReaction"), "rev": rev, "convoId": req.convo_id, - "message": { "$type": "chat.bsky.convo.defs#messageView", "id": req.message_id }, + "message": { "$type": ns_type("messageView"), "id": req.message_id }, "reaction": { "value": req.value, "sender": { "did": did }, "createdAt": now }, })); @@ -1013,7 +1053,7 @@ async fn remove_reaction( ).unwrap(); add_log(&db, &req.convo_id, &rev, "logRemoveReaction", serde_json::json!({ - "$type": "chat.bsky.convo.defs#logRemoveReaction", + "$type": ns_type("logRemoveReaction"), "rev": rev, "convoId": req.convo_id, })); @@ -1072,6 +1112,7 @@ async fn main() { return; } + init_namespace(); tracing_subscriber::fmt::init(); let db_path = std::env::var("DB_PATH").unwrap_or_else(|_| "chat.db".into()); @@ -1090,31 +1131,37 @@ async fn main() { .allow_methods(Any) .allow_headers(Any); + // Leak route strings so they live for 'static (allocated once at startup) + let r = |method: &str| -> &'static str { + Box::leak(ns_route(method).into_boxed_str()) + }; + let app = Router::new() .route("/.well-known/did.json", get(well_known_did)) .route("/xrpc/_health", get(health)) - .route("/xrpc/chat.bsky.convo.sendMessage", post(send_message)) - .route("/xrpc/chat.bsky.convo.sendMessageBatch", post(send_message_batch)) - .route("/xrpc/chat.bsky.convo.getMessages", get(get_messages)) - .route("/xrpc/chat.bsky.convo.getConvo", get(get_convo)) - .route("/xrpc/chat.bsky.convo.getConvoForMembers", get(get_convo_for_members)) - .route("/xrpc/chat.bsky.convo.getConvoAvailability", get(get_convo_availability)) - .route("/xrpc/chat.bsky.convo.listConvos", get(list_convos)) - .route("/xrpc/chat.bsky.convo.updateRead", post(update_read)) - .route("/xrpc/chat.bsky.convo.updateAllRead", post(update_all_read)) - .route("/xrpc/chat.bsky.convo.acceptConvo", post(accept_convo)) - .route("/xrpc/chat.bsky.convo.muteConvo", post(mute_convo)) - .route("/xrpc/chat.bsky.convo.unmuteConvo", post(unmute_convo)) - .route("/xrpc/chat.bsky.convo.deleteMessageForSelf", post(delete_message_for_self)) - .route("/xrpc/chat.bsky.convo.leaveConvo", post(leave_convo)) - .route("/xrpc/chat.bsky.convo.addReaction", post(add_reaction)) - .route("/xrpc/chat.bsky.convo.removeReaction", post(remove_reaction)) - .route("/xrpc/chat.bsky.convo.getLog", get(get_log)) + .route(r("sendMessage"), post(send_message)) + .route(r("sendMessageBatch"), post(send_message_batch)) + .route(r("getMessages"), get(get_messages)) + .route(r("getConvo"), get(get_convo)) + .route(r("getConvoForMembers"), get(get_convo_for_members)) + .route(r("getConvoAvailability"), get(get_convo_availability)) + .route(r("listConvos"), get(list_convos)) + .route(r("updateRead"), post(update_read)) + .route(r("updateAllRead"), post(update_all_read)) + .route(r("acceptConvo"), post(accept_convo)) + .route(r("muteConvo"), post(mute_convo)) + .route(r("unmuteConvo"), post(unmute_convo)) + .route(r("deleteMessageForSelf"), post(delete_message_for_self)) + .route(r("leaveConvo"), post(leave_convo)) + .route(r("addReaction"), post(add_reaction)) + .route(r("removeReaction"), post(remove_reaction)) + .route(r("getLog"), get(get_log)) .layer(cors) .with_state(state); - let addr = format!("0.0.0.0:{}", port); - tracing::info!("bsky-chat listening on {}", addr); + let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into()); + let addr = format!("{}:{}", host, port); + tracing::info!("{} listening on {}", ns(), addr); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); axum::serve(listener, app).await.unwrap();