1
0
Files
chat/src/main.rs

1179 lines
35 KiB
Rust

use axum::{
extract::{Query, RawQuery, State},
http::{HeaderMap, StatusCode},
response::Json,
routing::{get, post},
Router,
};
use chrono::Utc;
use rusqlite::Connection;
use serde::{Deserialize, Serialize};
use std::sync::{Arc, Mutex, OnceLock};
use tower_http::cors::{Any, CorsLayer};
use uuid::Uuid;
// --- Namespace ---
// CHAT_DOMAIN env (default: "bsky.chat") → reversed for XRPC/types: "chat.bsky"
static NS: OnceLock<String> = OnceLock::new();
fn ns() -> &'static str {
NS.get().expect("namespace not initialized")
}
fn ns_type(name: &str) -> String {
format!("{}.convo.defs#{}", ns(), name)
}
fn ns_route(method: &str) -> String {
format!("/xrpc/{}.convo.{}", ns(), method)
}
fn init_namespace() {
let domain = std::env::var("CHAT_DOMAIN").unwrap_or_else(|_| "bsky.chat".into());
let parts: Vec<&str> = domain.split('.').collect();
let reversed = if parts.len() == 2 {
format!("{}.{}", parts[1], parts[0])
} else {
domain
};
NS.set(reversed).expect("namespace already initialized");
}
// --- Types ---
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MessageInput {
text: String,
facets: Option<serde_json::Value>,
embed: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct MessageView {
#[serde(rename = "$type")]
typ: String,
id: String,
rev: String,
text: String,
#[serde(skip_serializing_if = "Option::is_none")]
facets: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
embed: Option<serde_json::Value>,
sender: MessageViewSender,
sent_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct MessageViewSender {
did: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct DeletedMessageView {
#[serde(rename = "$type")]
typ: String,
id: String,
rev: String,
sender: MessageViewSender,
sent_at: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ProfileViewBasic {
did: String,
handle: String,
#[serde(skip_serializing_if = "Option::is_none")]
display_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
avatar: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
chat_disabled: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
struct ConvoView {
id: String,
rev: String,
members: Vec<ProfileViewBasic>,
#[serde(skip_serializing_if = "Option::is_none")]
last_message: Option<serde_json::Value>,
muted: bool,
#[serde(skip_serializing_if = "Option::is_none")]
status: Option<String>,
unread_count: i64,
}
// --- Request/Response types ---
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct SendMessageReq {
convo_id: String,
message: MessageInput,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct SendMessageBatchItem {
convo_id: String,
message: MessageInput,
}
#[derive(Deserialize)]
struct SendMessageBatchReq {
items: Vec<SendMessageBatchItem>,
}
#[derive(Serialize)]
struct SendMessageBatchResp {
items: Vec<serde_json::Value>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GetMessagesParams {
convo_id: String,
limit: Option<i64>,
cursor: Option<String>,
}
#[derive(Serialize)]
struct GetMessagesResp {
#[serde(skip_serializing_if = "Option::is_none")]
cursor: Option<String>,
messages: Vec<serde_json::Value>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct GetConvoParams {
convo_id: String,
}
#[derive(Serialize)]
struct GetConvoResp {
convo: ConvoView,
}
// members=did1&members=did2 (repeated query params)
#[derive(Deserialize)]
struct GetConvoForMembersParams {
members: Vec<String>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ListConvosParams {
limit: Option<i64>,
cursor: Option<String>,
status: Option<String>,
}
#[derive(Serialize)]
struct ListConvosResp {
#[serde(skip_serializing_if = "Option::is_none")]
cursor: Option<String>,
convos: Vec<ConvoView>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct UpdateReadReq {
convo_id: String,
message_id: Option<String>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct UpdateAllReadReq {
status: Option<String>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct UpdateAllReadResp {
updated_count: i64,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct DeleteMessageReq {
convo_id: String,
message_id: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct ConvoIdReq {
convo_id: String,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct LeaveConvoResp {
convo_id: String,
rev: String,
}
#[derive(Deserialize)]
struct GetLogParams {
cursor: Option<String>,
}
#[derive(Serialize)]
struct GetLogResp {
#[serde(skip_serializing_if = "Option::is_none")]
cursor: Option<String>,
logs: Vec<serde_json::Value>,
}
#[derive(Deserialize)]
struct GetConvoAvailabilityParams {
members: Vec<String>,
}
#[derive(Serialize)]
struct GetConvoAvailabilityResp {
members: Vec<MemberAvailability>,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct MemberAvailability {
did: String,
allow_incoming: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct AddReactionReq {
convo_id: String,
message_id: String,
value: String,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct RemoveReactionReq {
convo_id: String,
message_id: String,
value: String,
}
#[derive(Serialize)]
struct ErrorResp {
error: String,
message: String,
}
// --- App State ---
struct AppState {
db: Mutex<Connection>,
}
// --- Auth ---
fn extract_did(headers: &HeaderMap) -> Option<String> {
let auth = headers.get("authorization")?.to_str().ok()?;
let token = auth.strip_prefix("Bearer ")?;
if token.starts_with("did:") {
return Some(token.to_string());
}
// Decode JWT payload to extract iss (issuer = caller DID)
let parts: Vec<&str> = token.split('.').collect();
if parts.len() == 3 {
if let Ok(decoded) = base64::Engine::decode(
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
parts[1],
) {
if let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&decoded) {
if let Some(iss) = payload.get("iss").and_then(|v| v.as_str()) {
return Some(iss.to_string());
}
}
}
}
None
}
fn require_auth(headers: &HeaderMap) -> Result<String, (StatusCode, Json<ErrorResp>)> {
extract_did(headers).ok_or_else(|| {
(
StatusCode::UNAUTHORIZED,
Json(ErrorResp {
error: "AuthMissing".into(),
message: "Authentication Required".into(),
}),
)
})
}
fn new_rev() -> String {
let now = Utc::now().timestamp_micros();
format!("{:x}", now)
}
fn new_id() -> String {
Uuid::new_v4().to_string().replace('-', "")[..13].to_string()
}
/// ISO 8601 format compatible with AT Protocol: "2026-03-22T07:21:00.448Z"
fn now_iso() -> String {
Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
}
// --- DB ---
fn init_db(conn: &Connection) {
conn.execute_batch(
"
CREATE TABLE IF NOT EXISTS accounts (
did TEXT PRIMARY KEY,
handle TEXT NOT NULL DEFAULT '',
display_name TEXT,
avatar TEXT,
created_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS convos (
id TEXT PRIMARY KEY,
rev TEXT NOT NULL,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS convo_members (
convo_id TEXT NOT NULL,
did TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'accepted',
muted INTEGER NOT NULL DEFAULT 0,
last_read_id TEXT,
PRIMARY KEY (convo_id, did),
FOREIGN KEY (convo_id) REFERENCES convos(id)
);
CREATE TABLE IF NOT EXISTS messages (
id TEXT PRIMARY KEY,
convo_id TEXT NOT NULL,
sender_did TEXT NOT NULL,
rev TEXT NOT NULL,
text TEXT NOT NULL,
facets TEXT,
embed TEXT,
sent_at TEXT NOT NULL,
deleted INTEGER NOT NULL DEFAULT 0,
FOREIGN KEY (convo_id) REFERENCES convos(id)
);
CREATE TABLE IF NOT EXISTS reactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
convo_id TEXT NOT NULL,
message_id TEXT NOT NULL,
sender_did TEXT NOT NULL,
value TEXT NOT NULL,
created_at TEXT NOT NULL,
UNIQUE(convo_id, message_id, sender_did, value)
);
CREATE TABLE IF NOT EXISTS logs (
id INTEGER PRIMARY KEY AUTOINCREMENT,
convo_id TEXT NOT NULL,
rev TEXT NOT NULL,
log_type TEXT NOT NULL,
data TEXT NOT NULL,
created_at TEXT NOT NULL
);
",
)
.expect("Failed to initialize database");
}
fn ensure_account(conn: &Connection, did: &str) {
// handle must be a valid handle format, not a DID
let handle = format!("{}.chat.invalid", did.split(':').last().unwrap_or("unknown"));
conn.execute(
"INSERT OR IGNORE INTO accounts (did, handle, created_at) VALUES (?1, ?2, ?3)",
rusqlite::params![did, handle, now_iso()],
)
.ok();
}
fn get_or_create_convo(conn: &Connection, members: &[String]) -> ConvoView {
let mut sorted = members.to_vec();
sorted.sort();
// Look for existing convo with exact same members
let all_convo_ids: Vec<String> = {
let mut stmt = conn
.prepare("SELECT DISTINCT convo_id FROM convo_members WHERE did = ?1")
.unwrap();
stmt.query_map(rusqlite::params![&sorted[0]], |row| row.get(0))
.unwrap()
.filter_map(|r| r.ok())
.collect()
};
for cid in &all_convo_ids {
let mut stmt = conn
.prepare("SELECT did FROM convo_members WHERE convo_id = ?1 ORDER BY did")
.unwrap();
let convo_members: Vec<String> = stmt
.query_map(rusqlite::params![cid], |row| row.get(0))
.unwrap()
.filter_map(|r| r.ok())
.collect();
if convo_members == sorted {
return load_convo(conn, cid, sorted.first().map(|s| s.as_str()));
}
}
// Create new convo
let id = new_id();
let rev = new_rev();
let now = now_iso();
conn.execute(
"INSERT INTO convos (id, rev, created_at, updated_at) VALUES (?1, ?2, ?3, ?4)",
rusqlite::params![id, rev, now, now],
)
.unwrap();
for did in &sorted {
ensure_account(conn, did);
conn.execute(
"INSERT INTO convo_members (convo_id, did) VALUES (?1, ?2)",
rusqlite::params![id, did],
)
.unwrap();
}
add_log(conn, &id, &rev, "logBeginConvo", serde_json::json!({
"$type": ns_type("logBeginConvo"),
"rev": rev,
"convoId": id,
}));
load_convo(conn, &id, sorted.first().map(|s| s.as_str()))
}
fn add_log(conn: &Connection, convo_id: &str, rev: &str, log_type: &str, data: serde_json::Value) {
let now = now_iso();
conn.execute(
"INSERT INTO logs (convo_id, rev, log_type, data, created_at) VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![convo_id, rev, log_type, data.to_string(), now],
)
.unwrap();
}
fn load_convo(conn: &Connection, convo_id: &str, viewer_did: Option<&str>) -> ConvoView {
let (rev, _created_at, _updated_at): (String, String, String) = conn
.query_row(
"SELECT rev, created_at, updated_at FROM convos WHERE id = ?1",
rusqlite::params![convo_id],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
)
.unwrap_or_else(|_| (new_rev(), String::new(), String::new()));
let mut stmt = conn
.prepare(
"SELECT a.did, a.handle, a.display_name, a.avatar
FROM convo_members cm JOIN accounts a ON a.did = cm.did
WHERE cm.convo_id = ?1",
)
.unwrap();
let members: Vec<ProfileViewBasic> = stmt
.query_map(rusqlite::params![convo_id], |row| {
Ok(ProfileViewBasic {
did: row.get(0)?,
handle: row.get(1)?,
display_name: row.get(2)?,
avatar: row.get(3)?,
chat_disabled: None,
})
})
.unwrap()
.filter_map(|r| r.ok())
.collect();
let last_message: Option<serde_json::Value> = conn
.query_row(
"SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted
FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1",
rusqlite::params![convo_id],
|row| build_message_json(row),
)
.ok();
let (muted, status, last_read_id): (bool, String, Option<String>) = viewer_did
.and_then(|did| {
conn.query_row(
"SELECT muted, status, last_read_id FROM convo_members WHERE convo_id = ?1 AND did = ?2",
rusqlite::params![convo_id, did],
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
)
.ok()
})
.unwrap_or((false, "accepted".into(), None));
let unread_count: i64 = if let Some(ref read_id) = last_read_id {
conn.query_row(
"SELECT COUNT(*) FROM messages WHERE convo_id = ?1 AND deleted = 0 AND id > ?2",
rusqlite::params![convo_id, read_id],
|row| row.get(0),
)
.unwrap_or(0)
} else {
conn.query_row(
"SELECT COUNT(*) FROM messages WHERE convo_id = ?1 AND deleted = 0",
rusqlite::params![convo_id],
|row| row.get(0),
)
.unwrap_or(0)
};
ConvoView {
id: convo_id.to_string(),
rev,
members,
last_message,
muted,
status: Some(status),
unread_count,
}
}
fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result<serde_json::Value> {
let deleted: bool = row.get(7)?;
let sender_did: String = row.get(5)?;
if deleted {
Ok(serde_json::json!({
"$type": ns_type("deletedMessageView"),
"id": row.get::<_, String>(0)?,
"rev": row.get::<_, String>(1)?,
"sender": { "did": sender_did },
"sentAt": row.get::<_, String>(6)?,
}))
} else {
let facets: Option<String> = row.get(3)?;
let embed: Option<String> = row.get(4)?;
let mut msg = serde_json::json!({
"$type": ns_type("messageView"),
"id": row.get::<_, String>(0)?,
"rev": row.get::<_, String>(1)?,
"text": row.get::<_, String>(2)?,
"sender": { "did": sender_did },
"sentAt": row.get::<_, String>(6)?,
});
if let Some(f) = facets {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&f) {
msg["facets"] = v;
}
}
if let Some(e) = embed {
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&e) {
msg["embed"] = v;
}
}
Ok(msg)
}
}
fn insert_message(conn: &Connection, convo_id: &str, did: &str, msg: &MessageInput) -> (String, String, String) {
let msg_id = new_id();
let rev = new_rev();
let now = now_iso();
let facets_str = msg.facets.as_ref().map(|v| v.to_string());
let embed_str = msg.embed.as_ref().map(|v| v.to_string());
conn.execute(
"INSERT INTO messages (id, convo_id, sender_did, rev, text, facets, embed, sent_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
rusqlite::params![msg_id, convo_id, did, rev, msg.text, facets_str, embed_str, now],
)
.unwrap();
conn.execute(
"UPDATE convos SET rev = ?1, updated_at = ?2 WHERE id = ?3",
rusqlite::params![rev, now, convo_id],
)
.unwrap();
add_log(conn, convo_id, &rev, "logCreateMessage", serde_json::json!({
"$type": ns_type("logCreateMessage"),
"rev": rev,
"convoId": convo_id,
"message": {
"$type": ns_type("messageView"),
"id": msg_id, "rev": rev, "text": msg.text,
"sender": { "did": did }, "sentAt": now,
}
}));
(msg_id, rev, now)
}
// --- Handlers ---
async fn health() -> &'static str {
"OK"
}
async fn well_known_did() -> Json<serde_json::Value> {
let did_host = std::env::var("DID_HOST").unwrap_or_else(|_| "bsky.syu.is".into());
let service_url = std::env::var("SERVICE_URL").unwrap_or_else(|_| format!("https://{}", did_host));
let did = format!("did:web:{}", did_host);
Json(serde_json::json!({
"@context": ["https://www.w3.org/ns/did/v1"],
"id": did,
"service": [
{
"id": "#bsky_appview",
"type": "BskyAppView",
"serviceEndpoint": &service_url,
},
{
"id": "#bsky_chat",
"type": "BskyChat",
"serviceEndpoint": &service_url,
},
{
"id": "#bsky_notif",
"type": "BskyNotificationService",
"serviceEndpoint": &service_url,
}
]
}))
}
async fn send_message(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<SendMessageReq>,
) -> Result<Json<MessageView>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
ensure_account(&db, &did);
let is_member: bool = db
.query_row(
"SELECT COUNT(*) FROM convo_members WHERE convo_id = ?1 AND did = ?2",
rusqlite::params![req.convo_id, did],
|row| row.get::<_, i64>(0),
)
.unwrap_or(0)
> 0;
if !is_member {
return Err((StatusCode::BAD_REQUEST, Json(ErrorResp {
error: "InvalidRequest".into(),
message: "Not a member of this conversation".into(),
})));
}
let (msg_id, rev, now) = insert_message(&db, &req.convo_id, &did, &req.message);
Ok(Json(MessageView {
typ: ns_type("messageView"),
id: msg_id, rev,
text: req.message.text,
facets: req.message.facets,
embed: req.message.embed,
sender: MessageViewSender { did },
sent_at: now,
}))
}
async fn send_message_batch(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<SendMessageBatchReq>,
) -> Result<Json<SendMessageBatchResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
ensure_account(&db, &did);
let mut items = Vec::new();
for item in &req.items {
let (msg_id, rev, now) = insert_message(&db, &item.convo_id, &did, &item.message);
items.push(serde_json::json!({
"$type": ns_type("messageView"),
"id": msg_id, "rev": rev, "text": item.message.text,
"sender": { "did": did }, "sentAt": now,
}));
}
Ok(Json(SendMessageBatchResp { items }))
}
async fn get_messages(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(params): Query<GetMessagesParams>,
) -> Result<Json<GetMessagesResp>, (StatusCode, Json<ErrorResp>)> {
let _did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let limit = params.limit.unwrap_or(50).min(100);
let messages: Vec<serde_json::Value> = if let Some(ref cursor) = params.cursor {
let mut stmt = db.prepare(
"SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted
FROM messages WHERE convo_id = ?1 AND deleted = 0 AND sent_at < ?2
ORDER BY sent_at DESC LIMIT ?3"
).unwrap();
stmt.query_map(rusqlite::params![params.convo_id, cursor, limit], |row| build_message_json(row))
.unwrap().filter_map(|r| r.ok()).collect()
} else {
let mut stmt = db.prepare(
"SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted
FROM messages WHERE convo_id = ?1 AND deleted = 0
ORDER BY sent_at DESC LIMIT ?2"
).unwrap();
stmt.query_map(rusqlite::params![params.convo_id, limit], |row| build_message_json(row))
.unwrap().filter_map(|r| r.ok()).collect()
};
let cursor = messages.last().and_then(|m| m["sentAt"].as_str().map(String::from));
Ok(Json(GetMessagesResp { cursor, messages }))
}
async fn get_convo(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(params): Query<GetConvoParams>,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
Ok(Json(GetConvoResp { convo: load_convo(&db, &params.convo_id, Some(&did)) }))
}
/// Parse repeated query params: ?members=did1&members=did2
fn parse_members_query(query: &Option<String>) -> Vec<String> {
query.as_deref().unwrap_or("").split('&')
.filter_map(|pair| {
let (key, val) = pair.split_once('=')?;
if key == "members" { Some(urlencoding::decode(val).ok()?.into_owned()) } else { None }
})
.collect()
}
async fn get_convo_for_members(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
RawQuery(query): RawQuery,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let members = parse_members_query(&query);
if !members.contains(&did) {
return Err((StatusCode::BAD_REQUEST, Json(ErrorResp {
error: "InvalidRequest".into(),
message: "Caller must be a member".into(),
})));
}
let db = state.db.lock().unwrap();
for m in &members {
ensure_account(&db, m);
}
let convo = get_or_create_convo(&db, &members);
Ok(Json(GetConvoResp { convo }))
}
async fn get_convo_availability(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
RawQuery(query): RawQuery,
) -> Result<Json<GetConvoAvailabilityResp>, (StatusCode, Json<ErrorResp>)> {
let _did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let members_raw = parse_members_query(&query);
let members: Vec<MemberAvailability> = members_raw.iter().map(|did| {
ensure_account(&db, did);
MemberAvailability {
did: did.clone(),
allow_incoming: "all".into(),
}
}).collect();
Ok(Json(GetConvoAvailabilityResp { members }))
}
async fn list_convos(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(params): Query<ListConvosParams>,
) -> Result<Json<ListConvosResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let limit = params.limit.unwrap_or(20).min(100);
let status_filter = params.status.unwrap_or_else(|| "accepted".into());
let mut stmt = db.prepare(
"SELECT cm.convo_id FROM convo_members cm
JOIN convos c ON c.id = cm.convo_id
WHERE cm.did = ?1 AND cm.status = ?2
ORDER BY c.updated_at DESC LIMIT ?3",
).unwrap();
let convo_ids: Vec<String> = stmt
.query_map(rusqlite::params![did, status_filter, limit], |row| row.get(0))
.unwrap().filter_map(|r| r.ok()).collect();
let convos: Vec<ConvoView> = convo_ids.iter()
.map(|id| load_convo(&db, id, Some(&did))).collect();
Ok(Json(ListConvosResp { cursor: None, convos }))
}
async fn update_read(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<UpdateReadReq>,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let read_id = req.message_id.unwrap_or_else(|| {
db.query_row(
"SELECT id FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1",
rusqlite::params![req.convo_id],
|row| row.get(0),
).unwrap_or_default()
});
db.execute(
"UPDATE convo_members SET last_read_id = ?1 WHERE convo_id = ?2 AND did = ?3",
rusqlite::params![read_id, req.convo_id, did],
).unwrap();
let convo = load_convo(&db, &req.convo_id, Some(&did));
Ok(Json(GetConvoResp { convo }))
}
async fn update_all_read(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<UpdateAllReadReq>,
) -> Result<Json<UpdateAllReadResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let status = req.status.unwrap_or_else(|| "accepted".into());
// Get all convos for this user with given status
let mut stmt = db.prepare(
"SELECT convo_id FROM convo_members WHERE did = ?1 AND status = ?2"
).unwrap();
let convo_ids: Vec<String> = stmt
.query_map(rusqlite::params![did, status], |row| row.get(0))
.unwrap().filter_map(|r| r.ok()).collect();
let mut count = 0i64;
for cid in &convo_ids {
let last_msg_id: Option<String> = db.query_row(
"SELECT id FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1",
rusqlite::params![cid], |row| row.get(0),
).ok();
if let Some(mid) = last_msg_id {
db.execute(
"UPDATE convo_members SET last_read_id = ?1 WHERE convo_id = ?2 AND did = ?3",
rusqlite::params![mid, cid, did],
).unwrap();
count += 1;
}
}
Ok(Json(UpdateAllReadResp { updated_count: count }))
}
async fn accept_convo(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<ConvoIdReq>,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let rev = new_rev();
db.execute(
"UPDATE convo_members SET status = 'accepted' WHERE convo_id = ?1 AND did = ?2",
rusqlite::params![req.convo_id, did],
).unwrap();
db.execute(
"UPDATE convos SET rev = ?1 WHERE id = ?2",
rusqlite::params![rev, req.convo_id],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logAcceptConvo", serde_json::json!({
"$type": ns_type("logAcceptConvo"),
"rev": rev, "convoId": req.convo_id,
}));
let convo = load_convo(&db, &req.convo_id, Some(&did));
Ok(Json(GetConvoResp { convo }))
}
async fn mute_convo(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<ConvoIdReq>,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let rev = new_rev();
db.execute(
"UPDATE convo_members SET muted = 1 WHERE convo_id = ?1 AND did = ?2",
rusqlite::params![req.convo_id, did],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logMuteConvo", serde_json::json!({
"$type": ns_type("logMuteConvo"),
"rev": rev, "convoId": req.convo_id,
}));
let convo = load_convo(&db, &req.convo_id, Some(&did));
Ok(Json(GetConvoResp { convo }))
}
async fn unmute_convo(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<ConvoIdReq>,
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let rev = new_rev();
db.execute(
"UPDATE convo_members SET muted = 0 WHERE convo_id = ?1 AND did = ?2",
rusqlite::params![req.convo_id, did],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logUnmuteConvo", serde_json::json!({
"$type": ns_type("logUnmuteConvo"),
"rev": rev, "convoId": req.convo_id,
}));
let convo = load_convo(&db, &req.convo_id, Some(&did));
Ok(Json(GetConvoResp { convo }))
}
async fn delete_message_for_self(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<DeleteMessageReq>,
) -> Result<Json<DeletedMessageView>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let rev = new_rev();
let now = now_iso();
db.execute(
"UPDATE messages SET deleted = 1, rev = ?1 WHERE id = ?2 AND convo_id = ?3",
rusqlite::params![rev, req.message_id, req.convo_id],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logDeleteMessage", serde_json::json!({
"$type": ns_type("logDeleteMessage"),
"rev": rev, "convoId": req.convo_id,
"message": {
"$type": ns_type("deletedMessageView"),
"id": req.message_id, "rev": rev,
"sender": { "did": did }, "sentAt": now,
}
}));
Ok(Json(DeletedMessageView {
typ: ns_type("deletedMessageView"),
id: req.message_id, rev,
sender: MessageViewSender { did },
sent_at: now,
}))
}
async fn leave_convo(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<ConvoIdReq>,
) -> Result<Json<LeaveConvoResp>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let rev = new_rev();
db.execute(
"DELETE FROM convo_members WHERE convo_id = ?1 AND did = ?2",
rusqlite::params![req.convo_id, did],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logLeaveConvo", serde_json::json!({
"$type": ns_type("logLeaveConvo"),
"rev": rev, "convoId": req.convo_id,
}));
Ok(Json(LeaveConvoResp { convo_id: req.convo_id, rev }))
}
async fn add_reaction(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<AddReactionReq>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let now = now_iso();
let rev = new_rev();
db.execute(
"INSERT OR IGNORE INTO reactions (convo_id, message_id, sender_did, value, created_at)
VALUES (?1, ?2, ?3, ?4, ?5)",
rusqlite::params![req.convo_id, req.message_id, did, req.value, now],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logAddReaction", serde_json::json!({
"$type": ns_type("logAddReaction"),
"rev": rev, "convoId": req.convo_id,
"message": { "$type": ns_type("messageView"), "id": req.message_id },
"reaction": { "value": req.value, "sender": { "did": did }, "createdAt": now },
}));
Ok(Json(serde_json::json!({
"convo": load_convo(&db, &req.convo_id, Some(&did)),
})))
}
async fn remove_reaction(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Json(req): Json<RemoveReactionReq>,
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResp>)> {
let did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let rev = new_rev();
db.execute(
"DELETE FROM reactions WHERE convo_id = ?1 AND message_id = ?2 AND sender_did = ?3 AND value = ?4",
rusqlite::params![req.convo_id, req.message_id, did, req.value],
).unwrap();
add_log(&db, &req.convo_id, &rev, "logRemoveReaction", serde_json::json!({
"$type": ns_type("logRemoveReaction"),
"rev": rev, "convoId": req.convo_id,
}));
Ok(Json(serde_json::json!({
"convo": load_convo(&db, &req.convo_id, Some(&did)),
})))
}
async fn get_log(
State(state): State<Arc<AppState>>,
headers: HeaderMap,
Query(params): Query<GetLogParams>,
) -> Result<Json<GetLogResp>, (StatusCode, Json<ErrorResp>)> {
let _did = require_auth(&headers)?;
let db = state.db.lock().unwrap();
let (logs, last_rev) = if let Some(ref cursor) = params.cursor {
// cursor is a rev (hex timestamp) - find logs after that rev
let mut stmt = db.prepare(
"SELECT rev, data FROM logs WHERE rev > ?1 ORDER BY id ASC LIMIT 100"
).unwrap();
let mut last = None;
let items: Vec<serde_json::Value> = stmt
.query_map(rusqlite::params![cursor], |row| {
let rev: String = row.get(0)?;
let data: String = row.get(1)?;
Ok((rev, data))
})
.unwrap()
.filter_map(|r| r.ok())
.map(|(rev, data)| {
last = Some(rev);
serde_json::from_str(&data).unwrap_or(serde_json::json!({}))
})
.collect();
(items, last)
} else {
// No cursor - return latest rev only
let last: Option<String> = db.query_row(
"SELECT rev FROM logs ORDER BY id DESC LIMIT 1",
[], |row| row.get(0),
).ok();
(vec![], last)
};
Ok(Json(GetLogResp { cursor: last_rev, logs }))
}
// --- Main ---
#[tokio::main]
async fn main() {
let args: Vec<String> = std::env::args().collect();
if args.len() > 1 && (args[1] == "v" || args[1] == "--version" || args[1] == "-v") {
println!("{}", env!("CARGO_PKG_VERSION"));
return;
}
init_namespace();
tracing_subscriber::fmt::init();
let db_path = std::env::var("DB_PATH").unwrap_or_else(|_| "chat.db".into());
let port: u16 = std::env::var("PORT")
.unwrap_or_else(|_| "3100".into())
.parse()
.unwrap_or(3100);
let conn = Connection::open(&db_path).expect("Failed to open database");
init_db(&conn);
let state = Arc::new(AppState { db: Mutex::new(conn) });
let cors = CorsLayer::new()
.allow_origin(Any)
.allow_methods(Any)
.allow_headers(Any);
// Leak route strings so they live for 'static (allocated once at startup)
let r = |method: &str| -> &'static str {
Box::leak(ns_route(method).into_boxed_str())
};
let app = Router::new()
.route("/.well-known/did.json", get(well_known_did))
.route("/xrpc/_health", get(health))
.route(r("sendMessage"), post(send_message))
.route(r("sendMessageBatch"), post(send_message_batch))
.route(r("getMessages"), get(get_messages))
.route(r("getConvo"), get(get_convo))
.route(r("getConvoForMembers"), get(get_convo_for_members))
.route(r("getConvoAvailability"), get(get_convo_availability))
.route(r("listConvos"), get(list_convos))
.route(r("updateRead"), post(update_read))
.route(r("updateAllRead"), post(update_all_read))
.route(r("acceptConvo"), post(accept_convo))
.route(r("muteConvo"), post(mute_convo))
.route(r("unmuteConvo"), post(unmute_convo))
.route(r("deleteMessageForSelf"), post(delete_message_for_self))
.route(r("leaveConvo"), post(leave_convo))
.route(r("addReaction"), post(add_reaction))
.route(r("removeReaction"), post(remove_reaction))
.route(r("getLog"), get(get_log))
.layer(cors)
.with_state(state);
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
let addr = format!("{}:{}", host, port);
tracing::info!("{} listening on {}", ns(), addr);
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
axum::serve(listener, app).await.unwrap();
}