1
0
This commit is contained in:
2026-03-22 13:14:07 +09:00
parent 46a43183ee
commit 0943f2a5c6
2 changed files with 92 additions and 44 deletions

View File

@@ -15,3 +15,4 @@ uuid = { version = "1", features = ["v4"] }
tracing = "0.1"
tracing-subscriber = "0.3"
base64 = "0.22"
urlencoding = "2"

View File

@@ -1,5 +1,5 @@
use axum::{
extract::{Query, State},
extract::{Query, RawQuery, State},
http::{HeaderMap, StatusCode},
response::Json,
routing::{get, post},
@@ -8,10 +8,38 @@ use axum::{
use chrono::Utc;
use rusqlite::Connection;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex};
use std::sync::{Arc, Mutex, OnceLock};
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
// --- Namespace ---
// CHAT_DOMAIN env (default: "bsky.chat") → reversed for XRPC/types: "chat.bsky"
static NS: OnceLock<String> = OnceLock::new();
fn ns() -> &'static str {
NS.get().expect("namespace not initialized")
}
fn ns_type(name: &str) -> String {
format!("{}.convo.defs#{}", ns(), name)
}
fn ns_route(method: &str) -> String {
format!("/xrpc/{}.convo.{}", ns(), method)
}
fn init_namespace() {
let domain = std::env::var("CHAT_DOMAIN").unwrap_or_else(|_| "bsky.chat".into());
let parts: Vec<&str> = domain.split('.').collect();
let reversed = if parts.len() == 2 {
format!("{}.{}", parts[1], parts[0])
} else {
domain
};
NS.set(reversed).expect("namespace already initialized");
}
// --- Types ---
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -419,7 +447,7 @@ fn get_or_create_convo(conn: &Connection, members: &[String]) -> ConvoView {
}
add_log(conn, &id, &rev, "logBeginConvo", serde_json::json!({
"$type": "chat.bsky.convo.defs#logBeginConvo",
"$type": ns_type("logBeginConvo"),
"rev": rev,
"convoId": id,
}));
@@ -518,7 +546,7 @@ fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result<serde_json::Value
let sender_did: String = row.get(5)?;
if deleted {
Ok(serde_json::json!({
"$type": "chat.bsky.convo.defs#deletedMessageView",
"$type": ns_type("deletedMessageView"),
"id": row.get::<_, String>(0)?,
"rev": row.get::<_, String>(1)?,
"sender": { "did": sender_did },
@@ -528,7 +556,7 @@ fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result<serde_json::Value
let facets: Option<String> = row.get(3)?;
let embed: Option<String> = row.get(4)?;
let mut msg = serde_json::json!({
"$type": "chat.bsky.convo.defs#messageView",
"$type": ns_type("messageView"),
"id": row.get::<_, String>(0)?,
"rev": row.get::<_, String>(1)?,
"text": row.get::<_, String>(2)?,
@@ -570,11 +598,11 @@ fn insert_message(conn: &Connection, convo_id: &str, did: &str, msg: &MessageInp
.unwrap();
add_log(conn, convo_id, &rev, "logCreateMessage", serde_json::json!({
"$type": "chat.bsky.convo.defs#logCreateMessage",
"$type": ns_type("logCreateMessage"),
"rev": rev,
"convoId": convo_id,
"message": {
"$type": "chat.bsky.convo.defs#messageView",
"$type": ns_type("messageView"),
"id": msg_id, "rev": rev, "text": msg.text,
"sender": { "did": did }, "sentAt": now,
}
@@ -645,7 +673,7 @@ async fn send_message(
let (msg_id, rev, now) = insert_message(&db, &req.convo_id, &did, &req.message);
Ok(Json(MessageView {
typ: "chat.bsky.convo.defs#messageView".into(),
typ: ns_type("messageView"),
id: msg_id, rev,
text: req.message.text,
facets: req.message.facets,
@@ -668,7 +696,7 @@ async fn send_message_batch(
for item in &req.items {
let (msg_id, rev, now) = insert_message(&db, &item.convo_id, &did, &item.message);
items.push(serde_json::json!({
"$type": "chat.bsky.convo.defs#messageView",
"$type": ns_type("messageView"),
"id": msg_id, "rev": rev, "text": item.message.text,
"sender": { "did": did }, "sentAt": now,
}));
@@ -718,34 +746,46 @@ async fn get_convo(
Ok(Json(GetConvoResp { convo: load_convo(&db, &params.convo_id, Some(&did)) }))
}
/// Parse repeated query params: ?members=did1&members=did2
fn parse_members_query(query: &Option<String>) -> Vec<String> {
query.as_deref().unwrap_or("").split('&')
.filter_map(|pair| {
let (key, val) = pair.split_once('=')?;
if key == "members" { Some(urlencoding::decode(val).ok()?.into_owned()) } else { None }
})
.collect()
}
async fn get_convo_for_members(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(params): Query<GetConvoForMembersParams>,
RawQuery(query): RawQuery,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
if !params.members.contains(&did) {
let members = parse_members_query(&query);
if !members.contains(&did) {
return Err((StatusCode::BAD_REQUEST, Json(ErrorResp {
error: "InvalidRequest".into(),
message: "Caller must be a member".into(),
})));
}
let db = state.db.lock().unwrap();
for m in &params.members {
for m in &members {
ensure_account(&db, m);
}
let convo = get_or_create_convo(&db, &params.members);
let convo = get_or_create_convo(&db, &members);
Ok(Json(GetConvoResp { convo }))
}
async fn get_convo_availability(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(params): Query<GetConvoAvailabilityParams>,
RawQuery(query): RawQuery,
) -> Result<Json<GetConvoAvailabilityResp>, (StatusCode, Json<ErrorResp>)> {
let _did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let members: Vec<MemberAvailability> = params.members.iter().map(|did| {
let members_raw = parse_members_query(&query);
let members: Vec<MemberAvailability> = members_raw.iter().map(|did| {
ensure_account(&db, did);
MemberAvailability {
did: did.clone(),
@@ -861,7 +901,7 @@ async fn accept_convo(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logAcceptConvo", serde_json::json!({
"$type": "chat.bsky.convo.defs#logAcceptConvo",
"$type": ns_type("logAcceptConvo"),
"rev": rev, "convoId": req.convo_id,
}));
@@ -884,7 +924,7 @@ async fn mute_convo(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logMuteConvo", serde_json::json!({
"$type": "chat.bsky.convo.defs#logMuteConvo",
"$type": ns_type("logMuteConvo"),
"rev": rev, "convoId": req.convo_id,
}));
@@ -907,7 +947,7 @@ async fn unmute_convo(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logUnmuteConvo", serde_json::json!({
"$type": "chat.bsky.convo.defs#logUnmuteConvo",
"$type": ns_type("logUnmuteConvo"),
"rev": rev, "convoId": req.convo_id,
}));
@@ -931,17 +971,17 @@ async fn delete_message_for_self(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logDeleteMessage", serde_json::json!({
"$type": "chat.bsky.convo.defs#logDeleteMessage",
"$type": ns_type("logDeleteMessage"),
"rev": rev, "convoId": req.convo_id,
"message": {
"$type": "chat.bsky.convo.defs#deletedMessageView",
"$type": ns_type("deletedMessageView"),
"id": req.message_id, "rev": rev,
"sender": { "did": did }, "sentAt": now,
}
}));
Ok(Json(DeletedMessageView {
typ: "chat.bsky.convo.defs#deletedMessageView".into(),
typ: ns_type("deletedMessageView"),
id: req.message_id, rev,
sender: MessageViewSender { did },
sent_at: now,
@@ -963,7 +1003,7 @@ async fn leave_convo(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logLeaveConvo", serde_json::json!({
"$type": "chat.bsky.convo.defs#logLeaveConvo",
"$type": ns_type("logLeaveConvo"),
"rev": rev, "convoId": req.convo_id,
}));
@@ -987,9 +1027,9 @@ async fn add_reaction(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logAddReaction", serde_json::json!({
"$type": "chat.bsky.convo.defs#logAddReaction",
"$type": ns_type("logAddReaction"),
"rev": rev, "convoId": req.convo_id,
"message": { "$type": "chat.bsky.convo.defs#messageView", "id": req.message_id },
"message": { "$type": ns_type("messageView"), "id": req.message_id },
"reaction": { "value": req.value, "sender": { "did": did }, "createdAt": now },
}));
@@ -1013,7 +1053,7 @@ async fn remove_reaction(
).unwrap();
add_log(&db, &req.convo_id, &rev, "logRemoveReaction", serde_json::json!({
"$type": "chat.bsky.convo.defs#logRemoveReaction",
"$type": ns_type("logRemoveReaction"),
"rev": rev, "convoId": req.convo_id,
}));
@@ -1072,6 +1112,7 @@ async fn main() {
return;
}
init_namespace();
tracing_subscriber::fmt::init();
let db_path = std::env::var("DB_PATH").unwrap_or_else(|_| "chat.db".into());
@@ -1090,31 +1131,37 @@ async fn main() {
.allow_methods(Any)
.allow_headers(Any);
// Leak route strings so they live for 'static (allocated once at startup)
let r = |method: &str| -> &'static str {
Box::leak(ns_route(method).into_boxed_str())
};
let app = Router::new()
.route("/.well-known/did.json", get(well_known_did))
.route("/xrpc/_health", get(health))
.route("/xrpc/chat.bsky.convo.sendMessage", post(send_message))
.route("/xrpc/chat.bsky.convo.sendMessageBatch", post(send_message_batch))
.route("/xrpc/chat.bsky.convo.getMessages", get(get_messages))
.route("/xrpc/chat.bsky.convo.getConvo", get(get_convo))
.route("/xrpc/chat.bsky.convo.getConvoForMembers", get(get_convo_for_members))
.route("/xrpc/chat.bsky.convo.getConvoAvailability", get(get_convo_availability))
.route("/xrpc/chat.bsky.convo.listConvos", get(list_convos))
.route("/xrpc/chat.bsky.convo.updateRead", post(update_read))
.route("/xrpc/chat.bsky.convo.updateAllRead", post(update_all_read))
.route("/xrpc/chat.bsky.convo.acceptConvo", post(accept_convo))
.route("/xrpc/chat.bsky.convo.muteConvo", post(mute_convo))
.route("/xrpc/chat.bsky.convo.unmuteConvo", post(unmute_convo))
.route("/xrpc/chat.bsky.convo.deleteMessageForSelf", post(delete_message_for_self))
.route("/xrpc/chat.bsky.convo.leaveConvo", post(leave_convo))
.route("/xrpc/chat.bsky.convo.addReaction", post(add_reaction))
.route("/xrpc/chat.bsky.convo.removeReaction", post(remove_reaction))
.route("/xrpc/chat.bsky.convo.getLog", get(get_log))
.route(r("sendMessage"), post(send_message))
.route(r("sendMessageBatch"), post(send_message_batch))
.route(r("getMessages"), get(get_messages))
.route(r("getConvo"), get(get_convo))
.route(r("getConvoForMembers"), get(get_convo_for_members))
.route(r("getConvoAvailability"), get(get_convo_availability))
.route(r("listConvos"), get(list_convos))
.route(r("updateRead"), post(update_read))
.route(r("updateAllRead"), post(update_all_read))
.route(r("acceptConvo"), post(accept_convo))
.route(r("muteConvo"), post(mute_convo))
.route(r("unmuteConvo"), post(unmute_convo))
.route(r("deleteMessageForSelf"), post(delete_message_for_self))
.route(r("leaveConvo"), post(leave_convo))
.route(r("addReaction"), post(add_reaction))
.route(r("removeReaction"), post(remove_reaction))
.route(r("getLog"), get(get_log))
.layer(cors)
.with_state(state);
let addr = format!("0.0.0.0:{}", port);
tracing::info!("bsky-chat listening on {}", addr);
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
let addr = format!("{}:{}", host, port);
tracing::info!("{} listening on {}", ns(), addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await.unwrap();