fix env
This commit is contained in:
135
src/main.rs
135
src/main.rs
@@ -1,5 +1,5 @@
|
||||
use axum::{
|
||||
extract::{Query, State},
|
||||
extract::{Query, RawQuery, State},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::Json,
|
||||
routing::{get, post},
|
||||
@@ -8,10 +8,38 @@ use axum::{
|
||||
use chrono::Utc;
|
||||
use rusqlite::Connection;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
use tower_http::cors::{Any, CorsLayer};
|
||||
use uuid::Uuid;
|
||||
|
||||
// --- Namespace ---
|
||||
// CHAT_DOMAIN env (default: "bsky.chat") → reversed for XRPC/types: "chat.bsky"
|
||||
|
||||
static NS: OnceLock<String> = OnceLock::new();
|
||||
|
||||
fn ns() -> &'static str {
|
||||
NS.get().expect("namespace not initialized")
|
||||
}
|
||||
|
||||
fn ns_type(name: &str) -> String {
|
||||
format!("{}.convo.defs#{}", ns(), name)
|
||||
}
|
||||
|
||||
fn ns_route(method: &str) -> String {
|
||||
format!("/xrpc/{}.convo.{}", ns(), method)
|
||||
}
|
||||
|
||||
fn init_namespace() {
|
||||
let domain = std::env::var("CHAT_DOMAIN").unwrap_or_else(|_| "bsky.chat".into());
|
||||
let parts: Vec<&str> = domain.split('.').collect();
|
||||
let reversed = if parts.len() == 2 {
|
||||
format!("{}.{}", parts[1], parts[0])
|
||||
} else {
|
||||
domain
|
||||
};
|
||||
NS.set(reversed).expect("namespace already initialized");
|
||||
}
|
||||
|
||||
// --- Types ---
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
@@ -419,7 +447,7 @@ fn get_or_create_convo(conn: &Connection, members: &[String]) -> ConvoView {
|
||||
}
|
||||
|
||||
add_log(conn, &id, &rev, "logBeginConvo", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logBeginConvo",
|
||||
"$type": ns_type("logBeginConvo"),
|
||||
"rev": rev,
|
||||
"convoId": id,
|
||||
}));
|
||||
@@ -518,7 +546,7 @@ fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result<serde_json::Value
|
||||
let sender_did: String = row.get(5)?;
|
||||
if deleted {
|
||||
Ok(serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#deletedMessageView",
|
||||
"$type": ns_type("deletedMessageView"),
|
||||
"id": row.get::<_, String>(0)?,
|
||||
"rev": row.get::<_, String>(1)?,
|
||||
"sender": { "did": sender_did },
|
||||
@@ -528,7 +556,7 @@ fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result<serde_json::Value
|
||||
let facets: Option<String> = row.get(3)?;
|
||||
let embed: Option<String> = row.get(4)?;
|
||||
let mut msg = serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#messageView",
|
||||
"$type": ns_type("messageView"),
|
||||
"id": row.get::<_, String>(0)?,
|
||||
"rev": row.get::<_, String>(1)?,
|
||||
"text": row.get::<_, String>(2)?,
|
||||
@@ -570,11 +598,11 @@ fn insert_message(conn: &Connection, convo_id: &str, did: &str, msg: &MessageInp
|
||||
.unwrap();
|
||||
|
||||
add_log(conn, convo_id, &rev, "logCreateMessage", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logCreateMessage",
|
||||
"$type": ns_type("logCreateMessage"),
|
||||
"rev": rev,
|
||||
"convoId": convo_id,
|
||||
"message": {
|
||||
"$type": "chat.bsky.convo.defs#messageView",
|
||||
"$type": ns_type("messageView"),
|
||||
"id": msg_id, "rev": rev, "text": msg.text,
|
||||
"sender": { "did": did }, "sentAt": now,
|
||||
}
|
||||
@@ -645,7 +673,7 @@ async fn send_message(
|
||||
let (msg_id, rev, now) = insert_message(&db, &req.convo_id, &did, &req.message);
|
||||
|
||||
Ok(Json(MessageView {
|
||||
typ: "chat.bsky.convo.defs#messageView".into(),
|
||||
typ: ns_type("messageView"),
|
||||
id: msg_id, rev,
|
||||
text: req.message.text,
|
||||
facets: req.message.facets,
|
||||
@@ -668,7 +696,7 @@ async fn send_message_batch(
|
||||
for item in &req.items {
|
||||
let (msg_id, rev, now) = insert_message(&db, &item.convo_id, &did, &item.message);
|
||||
items.push(serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#messageView",
|
||||
"$type": ns_type("messageView"),
|
||||
"id": msg_id, "rev": rev, "text": item.message.text,
|
||||
"sender": { "did": did }, "sentAt": now,
|
||||
}));
|
||||
@@ -718,34 +746,46 @@ async fn get_convo(
|
||||
Ok(Json(GetConvoResp { convo: load_convo(&db, ¶ms.convo_id, Some(&did)) }))
|
||||
}
|
||||
|
||||
/// Parse repeated query params: ?members=did1&members=did2
|
||||
fn parse_members_query(query: &Option<String>) -> Vec<String> {
|
||||
query.as_deref().unwrap_or("").split('&')
|
||||
.filter_map(|pair| {
|
||||
let (key, val) = pair.split_once('=')?;
|
||||
if key == "members" { Some(urlencoding::decode(val).ok()?.into_owned()) } else { None }
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
async fn get_convo_for_members(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(params): Query<GetConvoForMembersParams>,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
||||
let did = require_auth(&headers)?;
|
||||
if !params.members.contains(&did) {
|
||||
let members = parse_members_query(&query);
|
||||
if !members.contains(&did) {
|
||||
return Err((StatusCode::BAD_REQUEST, Json(ErrorResp {
|
||||
error: "InvalidRequest".into(),
|
||||
message: "Caller must be a member".into(),
|
||||
})));
|
||||
}
|
||||
let db = state.db.lock().unwrap();
|
||||
for m in ¶ms.members {
|
||||
for m in &members {
|
||||
ensure_account(&db, m);
|
||||
}
|
||||
let convo = get_or_create_convo(&db, ¶ms.members);
|
||||
let convo = get_or_create_convo(&db, &members);
|
||||
Ok(Json(GetConvoResp { convo }))
|
||||
}
|
||||
|
||||
async fn get_convo_availability(
|
||||
State(state): State<Arc<AppState>>,
|
||||
headers: HeaderMap,
|
||||
Query(params): Query<GetConvoAvailabilityParams>,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Result<Json<GetConvoAvailabilityResp>, (StatusCode, Json<ErrorResp>)> {
|
||||
let _did = require_auth(&headers)?;
|
||||
let db = state.db.lock().unwrap();
|
||||
let members: Vec<MemberAvailability> = params.members.iter().map(|did| {
|
||||
let members_raw = parse_members_query(&query);
|
||||
let members: Vec<MemberAvailability> = members_raw.iter().map(|did| {
|
||||
ensure_account(&db, did);
|
||||
MemberAvailability {
|
||||
did: did.clone(),
|
||||
@@ -861,7 +901,7 @@ async fn accept_convo(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logAcceptConvo", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logAcceptConvo",
|
||||
"$type": ns_type("logAcceptConvo"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
}));
|
||||
|
||||
@@ -884,7 +924,7 @@ async fn mute_convo(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logMuteConvo", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logMuteConvo",
|
||||
"$type": ns_type("logMuteConvo"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
}));
|
||||
|
||||
@@ -907,7 +947,7 @@ async fn unmute_convo(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logUnmuteConvo", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logUnmuteConvo",
|
||||
"$type": ns_type("logUnmuteConvo"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
}));
|
||||
|
||||
@@ -931,17 +971,17 @@ async fn delete_message_for_self(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logDeleteMessage", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logDeleteMessage",
|
||||
"$type": ns_type("logDeleteMessage"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
"message": {
|
||||
"$type": "chat.bsky.convo.defs#deletedMessageView",
|
||||
"$type": ns_type("deletedMessageView"),
|
||||
"id": req.message_id, "rev": rev,
|
||||
"sender": { "did": did }, "sentAt": now,
|
||||
}
|
||||
}));
|
||||
|
||||
Ok(Json(DeletedMessageView {
|
||||
typ: "chat.bsky.convo.defs#deletedMessageView".into(),
|
||||
typ: ns_type("deletedMessageView"),
|
||||
id: req.message_id, rev,
|
||||
sender: MessageViewSender { did },
|
||||
sent_at: now,
|
||||
@@ -963,7 +1003,7 @@ async fn leave_convo(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logLeaveConvo", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logLeaveConvo",
|
||||
"$type": ns_type("logLeaveConvo"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
}));
|
||||
|
||||
@@ -987,9 +1027,9 @@ async fn add_reaction(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logAddReaction", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logAddReaction",
|
||||
"$type": ns_type("logAddReaction"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
"message": { "$type": "chat.bsky.convo.defs#messageView", "id": req.message_id },
|
||||
"message": { "$type": ns_type("messageView"), "id": req.message_id },
|
||||
"reaction": { "value": req.value, "sender": { "did": did }, "createdAt": now },
|
||||
}));
|
||||
|
||||
@@ -1013,7 +1053,7 @@ async fn remove_reaction(
|
||||
).unwrap();
|
||||
|
||||
add_log(&db, &req.convo_id, &rev, "logRemoveReaction", serde_json::json!({
|
||||
"$type": "chat.bsky.convo.defs#logRemoveReaction",
|
||||
"$type": ns_type("logRemoveReaction"),
|
||||
"rev": rev, "convoId": req.convo_id,
|
||||
}));
|
||||
|
||||
@@ -1072,6 +1112,7 @@ async fn main() {
|
||||
return;
|
||||
}
|
||||
|
||||
init_namespace();
|
||||
tracing_subscriber::fmt::init();
|
||||
|
||||
let db_path = std::env::var("DB_PATH").unwrap_or_else(|_| "chat.db".into());
|
||||
@@ -1090,31 +1131,37 @@ async fn main() {
|
||||
.allow_methods(Any)
|
||||
.allow_headers(Any);
|
||||
|
||||
// Leak route strings so they live for 'static (allocated once at startup)
|
||||
let r = |method: &str| -> &'static str {
|
||||
Box::leak(ns_route(method).into_boxed_str())
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/.well-known/did.json", get(well_known_did))
|
||||
.route("/xrpc/_health", get(health))
|
||||
.route("/xrpc/chat.bsky.convo.sendMessage", post(send_message))
|
||||
.route("/xrpc/chat.bsky.convo.sendMessageBatch", post(send_message_batch))
|
||||
.route("/xrpc/chat.bsky.convo.getMessages", get(get_messages))
|
||||
.route("/xrpc/chat.bsky.convo.getConvo", get(get_convo))
|
||||
.route("/xrpc/chat.bsky.convo.getConvoForMembers", get(get_convo_for_members))
|
||||
.route("/xrpc/chat.bsky.convo.getConvoAvailability", get(get_convo_availability))
|
||||
.route("/xrpc/chat.bsky.convo.listConvos", get(list_convos))
|
||||
.route("/xrpc/chat.bsky.convo.updateRead", post(update_read))
|
||||
.route("/xrpc/chat.bsky.convo.updateAllRead", post(update_all_read))
|
||||
.route("/xrpc/chat.bsky.convo.acceptConvo", post(accept_convo))
|
||||
.route("/xrpc/chat.bsky.convo.muteConvo", post(mute_convo))
|
||||
.route("/xrpc/chat.bsky.convo.unmuteConvo", post(unmute_convo))
|
||||
.route("/xrpc/chat.bsky.convo.deleteMessageForSelf", post(delete_message_for_self))
|
||||
.route("/xrpc/chat.bsky.convo.leaveConvo", post(leave_convo))
|
||||
.route("/xrpc/chat.bsky.convo.addReaction", post(add_reaction))
|
||||
.route("/xrpc/chat.bsky.convo.removeReaction", post(remove_reaction))
|
||||
.route("/xrpc/chat.bsky.convo.getLog", get(get_log))
|
||||
.route(r("sendMessage"), post(send_message))
|
||||
.route(r("sendMessageBatch"), post(send_message_batch))
|
||||
.route(r("getMessages"), get(get_messages))
|
||||
.route(r("getConvo"), get(get_convo))
|
||||
.route(r("getConvoForMembers"), get(get_convo_for_members))
|
||||
.route(r("getConvoAvailability"), get(get_convo_availability))
|
||||
.route(r("listConvos"), get(list_convos))
|
||||
.route(r("updateRead"), post(update_read))
|
||||
.route(r("updateAllRead"), post(update_all_read))
|
||||
.route(r("acceptConvo"), post(accept_convo))
|
||||
.route(r("muteConvo"), post(mute_convo))
|
||||
.route(r("unmuteConvo"), post(unmute_convo))
|
||||
.route(r("deleteMessageForSelf"), post(delete_message_for_self))
|
||||
.route(r("leaveConvo"), post(leave_convo))
|
||||
.route(r("addReaction"), post(add_reaction))
|
||||
.route(r("removeReaction"), post(remove_reaction))
|
||||
.route(r("getLog"), get(get_log))
|
||||
.layer(cors)
|
||||
.with_state(state);
|
||||
|
||||
let addr = format!("0.0.0.0:{}", port);
|
||||
tracing::info!("bsky-chat listening on {}", addr);
|
||||
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
|
||||
let addr = format!("{}:{}", host, port);
|
||||
tracing::info!("{} listening on {}", ns(), addr);
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
|
||||
Reference in New Issue
Block a user