diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2222e6e --- /dev/null +++ b/.gitignore @@ -0,0 +1,17 @@ +/dist +/repos +/target +/CLAUDE.md +/.claude +node_modules +package-lock.json +Cargo.lock +.env +.mcp.json +bot +/public/at/*/ai.syui.ue.* +/public/at/*/ai.syui.note.* +/wiki +/src/rules/note.md +/src/rules/manga.md +/src/rules/novel.md diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..957987e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "aichat" +version = "0.1.0" +edition = "2021" + +[dependencies] +axum = "0.7" +tokio = { version = "1", features = ["full"] } +serde = { version = "1", features = ["derive"] } +serde_json = "1" +rusqlite = { version = "0.32", features = ["bundled"] } +tower-http = { version = "0.5", features = ["cors"] } +chrono = { version = "0.4", features = ["serde"] } +uuid = { version = "1", features = ["v4"] } +tracing = "0.1" +tracing-subscriber = "0.3" +base64 = "0.22" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..ba261f2 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,14 @@ +FROM rust:1.90 AS builder +WORKDIR /app +COPY Cargo.toml Cargo.lock* ./ +COPY src ./src +RUN cargo build --release + +FROM debian:bookworm-slim +RUN apt-get update && apt-get install -y ca-certificates && rm -rf /var/lib/apt/lists/* +COPY --from=builder /app/target/release/aichat /usr/local/bin/aichat +VOLUME /data +ENV DB_PATH=/data/chat.db +ENV PORT=3100 +EXPOSE 3100 +CMD ["aichat"] diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..0b151a6 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,1115 @@ +use axum::{ + extract::{Query, State}, + http::{HeaderMap, StatusCode}, + response::Json, + routing::{get, post}, + Router, +}; +use chrono::Utc; +use rusqlite::Connection; +use serde::{Deserialize, Serialize}; +use std::sync::{Arc, Mutex}; +use tower_http::cors::{Any, CorsLayer}; +use uuid::Uuid; + +// --- Types --- + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MessageInput { + text: String, + facets: Option, + embed: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct MessageView { + #[serde(rename = "$type")] + typ: String, + id: String, + rev: String, + text: String, + #[serde(skip_serializing_if = "Option::is_none")] + facets: Option, + #[serde(skip_serializing_if = "Option::is_none")] + embed: Option, + sender: MessageViewSender, + sent_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +struct MessageViewSender { + did: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct DeletedMessageView { + #[serde(rename = "$type")] + typ: String, + id: String, + rev: String, + sender: MessageViewSender, + sent_at: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ProfileViewBasic { + did: String, + handle: String, + #[serde(skip_serializing_if = "Option::is_none")] + display_name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + avatar: Option, + #[serde(skip_serializing_if = "Option::is_none")] + chat_disabled: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ConvoView { + id: String, + rev: String, + members: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + last_message: Option, + muted: bool, + #[serde(skip_serializing_if = "Option::is_none")] + status: Option, + unread_count: i64, +} + +// --- Request/Response types --- + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct SendMessageReq { + convo_id: String, + message: MessageInput, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct SendMessageBatchItem { + convo_id: String, + message: MessageInput, +} + +#[derive(Deserialize)] +struct SendMessageBatchReq { + items: Vec, +} + +#[derive(Serialize)] +struct SendMessageBatchResp { + items: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct GetMessagesParams { + convo_id: String, + limit: Option, + cursor: Option, +} + +#[derive(Serialize)] +struct GetMessagesResp { + cursor: Option, + messages: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct GetConvoParams { + convo_id: String, +} + +#[derive(Serialize)] +struct GetConvoResp { + convo: ConvoView, +} + +// members=did1&members=did2 (repeated query params) +#[derive(Deserialize)] +struct GetConvoForMembersParams { + members: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct ListConvosParams { + limit: Option, + cursor: Option, + status: Option, +} + +#[derive(Serialize)] +struct ListConvosResp { + cursor: Option, + convos: Vec, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct UpdateReadReq { + convo_id: String, + message_id: Option, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct UpdateAllReadReq { + status: Option, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct UpdateAllReadResp { + updated_count: i64, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct DeleteMessageReq { + convo_id: String, + message_id: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct ConvoIdReq { + convo_id: String, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct LeaveConvoResp { + convo_id: String, + rev: String, +} + +#[derive(Deserialize)] +struct GetLogParams { + cursor: Option, +} + +#[derive(Serialize)] +struct GetLogResp { + cursor: Option, + logs: Vec, +} + +#[derive(Deserialize)] +struct GetConvoAvailabilityParams { + members: Vec, +} + +#[derive(Serialize)] +struct GetConvoAvailabilityResp { + members: Vec, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +struct MemberAvailability { + did: String, + allow_incoming: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct AddReactionReq { + convo_id: String, + message_id: String, + value: String, +} + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +struct RemoveReactionReq { + convo_id: String, + message_id: String, + value: String, +} + +#[derive(Serialize)] +struct ErrorResp { + error: String, + message: String, +} + +// --- App State --- + +struct AppState { + db: Mutex, +} + +// --- Auth --- + +fn extract_did(headers: &HeaderMap) -> Option { + let auth = headers.get("authorization")?.to_str().ok()?; + let token = auth.strip_prefix("Bearer ")?; + + if token.starts_with("did:") { + return Some(token.to_string()); + } + + // Decode JWT payload to extract iss (issuer = caller DID) + let parts: Vec<&str> = token.split('.').collect(); + if parts.len() == 3 { + if let Ok(decoded) = base64::Engine::decode( + &base64::engine::general_purpose::URL_SAFE_NO_PAD, + parts[1], + ) { + if let Ok(payload) = serde_json::from_slice::(&decoded) { + if let Some(iss) = payload.get("iss").and_then(|v| v.as_str()) { + return Some(iss.to_string()); + } + } + } + } + + None +} + +fn require_auth(headers: &HeaderMap) -> Result)> { + extract_did(headers).ok_or_else(|| { + ( + StatusCode::UNAUTHORIZED, + Json(ErrorResp { + error: "AuthMissing".into(), + message: "Authentication Required".into(), + }), + ) + }) +} + +fn new_rev() -> String { + let now = Utc::now().timestamp_micros(); + format!("{:x}", now) +} + +fn new_id() -> String { + Uuid::new_v4().to_string().replace('-', "")[..13].to_string() +} + +// --- DB --- + +fn init_db(conn: &Connection) { + conn.execute_batch( + " + CREATE TABLE IF NOT EXISTS accounts ( + did TEXT PRIMARY KEY, + handle TEXT NOT NULL DEFAULT '', + display_name TEXT, + avatar TEXT, + created_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS convos ( + id TEXT PRIMARY KEY, + rev TEXT NOT NULL, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + + CREATE TABLE IF NOT EXISTS convo_members ( + convo_id TEXT NOT NULL, + did TEXT NOT NULL, + status TEXT NOT NULL DEFAULT 'accepted', + muted INTEGER NOT NULL DEFAULT 0, + last_read_id TEXT, + PRIMARY KEY (convo_id, did), + FOREIGN KEY (convo_id) REFERENCES convos(id) + ); + + CREATE TABLE IF NOT EXISTS messages ( + id TEXT PRIMARY KEY, + convo_id TEXT NOT NULL, + sender_did TEXT NOT NULL, + rev TEXT NOT NULL, + text TEXT NOT NULL, + facets TEXT, + embed TEXT, + sent_at TEXT NOT NULL, + deleted INTEGER NOT NULL DEFAULT 0, + FOREIGN KEY (convo_id) REFERENCES convos(id) + ); + + CREATE TABLE IF NOT EXISTS reactions ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + convo_id TEXT NOT NULL, + message_id TEXT NOT NULL, + sender_did TEXT NOT NULL, + value TEXT NOT NULL, + created_at TEXT NOT NULL, + UNIQUE(convo_id, message_id, sender_did, value) + ); + + CREATE TABLE IF NOT EXISTS logs ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + convo_id TEXT NOT NULL, + rev TEXT NOT NULL, + log_type TEXT NOT NULL, + data TEXT NOT NULL, + created_at TEXT NOT NULL + ); + ", + ) + .expect("Failed to initialize database"); +} + +fn ensure_account(conn: &Connection, did: &str) { + conn.execute( + "INSERT OR IGNORE INTO accounts (did, handle, created_at) VALUES (?1, ?2, ?3)", + rusqlite::params![did, did, Utc::now().to_rfc3339()], + ) + .ok(); +} + +fn get_or_create_convo(conn: &Connection, members: &[String]) -> ConvoView { + let mut sorted = members.to_vec(); + sorted.sort(); + + // Look for existing convo with exact same members + let all_convo_ids: Vec = { + let mut stmt = conn + .prepare("SELECT DISTINCT convo_id FROM convo_members WHERE did = ?1") + .unwrap(); + stmt.query_map(rusqlite::params![&sorted[0]], |row| row.get(0)) + .unwrap() + .filter_map(|r| r.ok()) + .collect() + }; + + for cid in &all_convo_ids { + let mut stmt = conn + .prepare("SELECT did FROM convo_members WHERE convo_id = ?1 ORDER BY did") + .unwrap(); + let convo_members: Vec = stmt + .query_map(rusqlite::params![cid], |row| row.get(0)) + .unwrap() + .filter_map(|r| r.ok()) + .collect(); + if convo_members == sorted { + return load_convo(conn, cid, sorted.first().map(|s| s.as_str())); + } + } + + // Create new convo + let id = new_id(); + let rev = new_rev(); + let now = Utc::now().to_rfc3339(); + + conn.execute( + "INSERT INTO convos (id, rev, created_at, updated_at) VALUES (?1, ?2, ?3, ?4)", + rusqlite::params![id, rev, now, now], + ) + .unwrap(); + + for did in &sorted { + ensure_account(conn, did); + conn.execute( + "INSERT INTO convo_members (convo_id, did) VALUES (?1, ?2)", + rusqlite::params![id, did], + ) + .unwrap(); + } + + add_log(conn, &id, &rev, "logBeginConvo", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logBeginConvo", + "rev": rev, + "convoId": id, + })); + + load_convo(conn, &id, sorted.first().map(|s| s.as_str())) +} + +fn add_log(conn: &Connection, convo_id: &str, rev: &str, log_type: &str, data: serde_json::Value) { + let now = Utc::now().to_rfc3339(); + conn.execute( + "INSERT INTO logs (convo_id, rev, log_type, data, created_at) VALUES (?1, ?2, ?3, ?4, ?5)", + rusqlite::params![convo_id, rev, log_type, data.to_string(), now], + ) + .unwrap(); +} + +fn load_convo(conn: &Connection, convo_id: &str, viewer_did: Option<&str>) -> ConvoView { + let (rev, _created_at, _updated_at): (String, String, String) = conn + .query_row( + "SELECT rev, created_at, updated_at FROM convos WHERE id = ?1", + rusqlite::params![convo_id], + |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)), + ) + .unwrap_or_else(|_| (new_rev(), String::new(), String::new())); + + let mut stmt = conn + .prepare( + "SELECT a.did, a.handle, a.display_name, a.avatar + FROM convo_members cm JOIN accounts a ON a.did = cm.did + WHERE cm.convo_id = ?1", + ) + .unwrap(); + let members: Vec = stmt + .query_map(rusqlite::params![convo_id], |row| { + Ok(ProfileViewBasic { + did: row.get(0)?, + handle: row.get(1)?, + display_name: row.get(2)?, + avatar: row.get(3)?, + chat_disabled: None, + }) + }) + .unwrap() + .filter_map(|r| r.ok()) + .collect(); + + let last_message: Option = conn + .query_row( + "SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted + FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1", + rusqlite::params![convo_id], + |row| build_message_json(row), + ) + .ok(); + + let (muted, status, last_read_id): (bool, String, Option) = viewer_did + .and_then(|did| { + conn.query_row( + "SELECT muted, status, last_read_id FROM convo_members WHERE convo_id = ?1 AND did = ?2", + rusqlite::params![convo_id, did], + |row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)), + ) + .ok() + }) + .unwrap_or((false, "accepted".into(), None)); + + let unread_count: i64 = if let Some(ref read_id) = last_read_id { + conn.query_row( + "SELECT COUNT(*) FROM messages WHERE convo_id = ?1 AND deleted = 0 AND id > ?2", + rusqlite::params![convo_id, read_id], + |row| row.get(0), + ) + .unwrap_or(0) + } else { + conn.query_row( + "SELECT COUNT(*) FROM messages WHERE convo_id = ?1 AND deleted = 0", + rusqlite::params![convo_id], + |row| row.get(0), + ) + .unwrap_or(0) + }; + + ConvoView { + id: convo_id.to_string(), + rev, + members, + last_message, + muted, + status: Some(status), + unread_count, + } +} + +fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result { + let deleted: bool = row.get(7)?; + let sender_did: String = row.get(5)?; + if deleted { + Ok(serde_json::json!({ + "$type": "chat.bsky.convo.defs#deletedMessageView", + "id": row.get::<_, String>(0)?, + "rev": row.get::<_, String>(1)?, + "sender": { "did": sender_did }, + "sentAt": row.get::<_, String>(6)?, + })) + } else { + let facets: Option = row.get(3)?; + let embed: Option = row.get(4)?; + let mut msg = serde_json::json!({ + "$type": "chat.bsky.convo.defs#messageView", + "id": row.get::<_, String>(0)?, + "rev": row.get::<_, String>(1)?, + "text": row.get::<_, String>(2)?, + "sender": { "did": sender_did }, + "sentAt": row.get::<_, String>(6)?, + }); + if let Some(f) = facets { + if let Ok(v) = serde_json::from_str::(&f) { + msg["facets"] = v; + } + } + if let Some(e) = embed { + if let Ok(v) = serde_json::from_str::(&e) { + msg["embed"] = v; + } + } + Ok(msg) + } +} + +fn insert_message(conn: &Connection, convo_id: &str, did: &str, msg: &MessageInput) -> (String, String, String) { + let msg_id = new_id(); + let rev = new_rev(); + let now = Utc::now().to_rfc3339(); + let facets_str = msg.facets.as_ref().map(|v| v.to_string()); + let embed_str = msg.embed.as_ref().map(|v| v.to_string()); + + conn.execute( + "INSERT INTO messages (id, convo_id, sender_did, rev, text, facets, embed, sent_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)", + rusqlite::params![msg_id, convo_id, did, rev, msg.text, facets_str, embed_str, now], + ) + .unwrap(); + + conn.execute( + "UPDATE convos SET rev = ?1, updated_at = ?2 WHERE id = ?3", + rusqlite::params![rev, now, convo_id], + ) + .unwrap(); + + add_log(conn, convo_id, &rev, "logCreateMessage", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logCreateMessage", + "rev": rev, + "convoId": convo_id, + "message": { + "$type": "chat.bsky.convo.defs#messageView", + "id": msg_id, "rev": rev, "text": msg.text, + "sender": { "did": did }, "sentAt": now, + } + })); + + (msg_id, rev, now) +} + +// --- Handlers --- + +async fn health() -> &'static str { + "OK" +} + +async fn well_known_did() -> Json { + let did_host = std::env::var("DID_HOST").unwrap_or_else(|_| "bsky.syu.is".into()); + let service_url = std::env::var("SERVICE_URL").unwrap_or_else(|_| format!("https://{}", did_host)); + let did = format!("did:web:{}", did_host); + + Json(serde_json::json!({ + "@context": ["https://www.w3.org/ns/did/v1"], + "id": did, + "service": [ + { + "id": "#bsky_appview", + "type": "BskyAppView", + "serviceEndpoint": &service_url, + }, + { + "id": "#bsky_chat", + "type": "BskyChat", + "serviceEndpoint": &service_url, + }, + { + "id": "#bsky_notif", + "type": "BskyNotificationService", + "serviceEndpoint": &service_url, + } + ] + })) +} + +async fn send_message( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + ensure_account(&db, &did); + + let is_member: bool = db + .query_row( + "SELECT COUNT(*) FROM convo_members WHERE convo_id = ?1 AND did = ?2", + rusqlite::params![req.convo_id, did], + |row| row.get::<_, i64>(0), + ) + .unwrap_or(0) + > 0; + + if !is_member { + return Err((StatusCode::BAD_REQUEST, Json(ErrorResp { + error: "InvalidRequest".into(), + message: "Not a member of this conversation".into(), + }))); + } + + let (msg_id, rev, now) = insert_message(&db, &req.convo_id, &did, &req.message); + + Ok(Json(MessageView { + typ: "chat.bsky.convo.defs#messageView".into(), + id: msg_id, rev, + text: req.message.text, + facets: req.message.facets, + embed: req.message.embed, + sender: MessageViewSender { did }, + sent_at: now, + })) +} + +async fn send_message_batch( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + ensure_account(&db, &did); + + let mut items = Vec::new(); + for item in &req.items { + let (msg_id, rev, now) = insert_message(&db, &item.convo_id, &did, &item.message); + items.push(serde_json::json!({ + "$type": "chat.bsky.convo.defs#messageView", + "id": msg_id, "rev": rev, "text": item.message.text, + "sender": { "did": did }, "sentAt": now, + })); + } + + Ok(Json(SendMessageBatchResp { items })) +} + +async fn get_messages( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> Result, (StatusCode, Json)> { + let _did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let limit = params.limit.unwrap_or(50).min(100); + + let messages: Vec = if let Some(ref cursor) = params.cursor { + let mut stmt = db.prepare( + "SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted + FROM messages WHERE convo_id = ?1 AND deleted = 0 AND sent_at < ?2 + ORDER BY sent_at DESC LIMIT ?3" + ).unwrap(); + stmt.query_map(rusqlite::params![params.convo_id, cursor, limit], |row| build_message_json(row)) + .unwrap().filter_map(|r| r.ok()).collect() + } else { + let mut stmt = db.prepare( + "SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted + FROM messages WHERE convo_id = ?1 AND deleted = 0 + ORDER BY sent_at DESC LIMIT ?2" + ).unwrap(); + stmt.query_map(rusqlite::params![params.convo_id, limit], |row| build_message_json(row)) + .unwrap().filter_map(|r| r.ok()).collect() + }; + + let cursor = messages.last().and_then(|m| m["sentAt"].as_str().map(String::from)); + Ok(Json(GetMessagesResp { cursor, messages })) +} + +async fn get_convo( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + Ok(Json(GetConvoResp { convo: load_convo(&db, ¶ms.convo_id, Some(&did)) })) +} + +async fn get_convo_for_members( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + if !params.members.contains(&did) { + return Err((StatusCode::BAD_REQUEST, Json(ErrorResp { + error: "InvalidRequest".into(), + message: "Caller must be a member".into(), + }))); + } + let db = state.db.lock().unwrap(); + for m in ¶ms.members { + ensure_account(&db, m); + } + let convo = get_or_create_convo(&db, ¶ms.members); + Ok(Json(GetConvoResp { convo })) +} + +async fn get_convo_availability( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> Result, (StatusCode, Json)> { + let _did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let members: Vec = params.members.iter().map(|did| { + ensure_account(&db, did); + MemberAvailability { + did: did.clone(), + allow_incoming: "all".into(), + } + }).collect(); + Ok(Json(GetConvoAvailabilityResp { members })) +} + +async fn list_convos( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let limit = params.limit.unwrap_or(20).min(100); + let status_filter = params.status.unwrap_or_else(|| "accepted".into()); + + let mut stmt = db.prepare( + "SELECT cm.convo_id FROM convo_members cm + JOIN convos c ON c.id = cm.convo_id + WHERE cm.did = ?1 AND cm.status = ?2 + ORDER BY c.updated_at DESC LIMIT ?3", + ).unwrap(); + + let convo_ids: Vec = stmt + .query_map(rusqlite::params![did, status_filter, limit], |row| row.get(0)) + .unwrap().filter_map(|r| r.ok()).collect(); + + let convos: Vec = convo_ids.iter() + .map(|id| load_convo(&db, id, Some(&did))).collect(); + + Ok(Json(ListConvosResp { cursor: None, convos })) +} + +async fn update_read( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + + let read_id = req.message_id.unwrap_or_else(|| { + db.query_row( + "SELECT id FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1", + rusqlite::params![req.convo_id], + |row| row.get(0), + ).unwrap_or_default() + }); + + db.execute( + "UPDATE convo_members SET last_read_id = ?1 WHERE convo_id = ?2 AND did = ?3", + rusqlite::params![read_id, req.convo_id, did], + ).unwrap(); + + let convo = load_convo(&db, &req.convo_id, Some(&did)); + Ok(Json(GetConvoResp { convo })) +} + +async fn update_all_read( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let status = req.status.unwrap_or_else(|| "accepted".into()); + + // Get all convos for this user with given status + let mut stmt = db.prepare( + "SELECT convo_id FROM convo_members WHERE did = ?1 AND status = ?2" + ).unwrap(); + let convo_ids: Vec = stmt + .query_map(rusqlite::params![did, status], |row| row.get(0)) + .unwrap().filter_map(|r| r.ok()).collect(); + + let mut count = 0i64; + for cid in &convo_ids { + let last_msg_id: Option = db.query_row( + "SELECT id FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1", + rusqlite::params![cid], |row| row.get(0), + ).ok(); + if let Some(mid) = last_msg_id { + db.execute( + "UPDATE convo_members SET last_read_id = ?1 WHERE convo_id = ?2 AND did = ?3", + rusqlite::params![mid, cid, did], + ).unwrap(); + count += 1; + } + } + + Ok(Json(UpdateAllReadResp { updated_count: count })) +} + +async fn accept_convo( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let rev = new_rev(); + + db.execute( + "UPDATE convo_members SET status = 'accepted' WHERE convo_id = ?1 AND did = ?2", + rusqlite::params![req.convo_id, did], + ).unwrap(); + db.execute( + "UPDATE convos SET rev = ?1 WHERE id = ?2", + rusqlite::params![rev, req.convo_id], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logAcceptConvo", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logAcceptConvo", + "rev": rev, "convoId": req.convo_id, + })); + + let convo = load_convo(&db, &req.convo_id, Some(&did)); + Ok(Json(GetConvoResp { convo })) +} + +async fn mute_convo( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let rev = new_rev(); + + db.execute( + "UPDATE convo_members SET muted = 1 WHERE convo_id = ?1 AND did = ?2", + rusqlite::params![req.convo_id, did], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logMuteConvo", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logMuteConvo", + "rev": rev, "convoId": req.convo_id, + })); + + let convo = load_convo(&db, &req.convo_id, Some(&did)); + Ok(Json(GetConvoResp { convo })) +} + +async fn unmute_convo( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let rev = new_rev(); + + db.execute( + "UPDATE convo_members SET muted = 0 WHERE convo_id = ?1 AND did = ?2", + rusqlite::params![req.convo_id, did], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logUnmuteConvo", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logUnmuteConvo", + "rev": rev, "convoId": req.convo_id, + })); + + let convo = load_convo(&db, &req.convo_id, Some(&did)); + Ok(Json(GetConvoResp { convo })) +} + +async fn delete_message_for_self( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let rev = new_rev(); + let now = Utc::now().to_rfc3339(); + + db.execute( + "UPDATE messages SET deleted = 1, rev = ?1 WHERE id = ?2 AND convo_id = ?3", + rusqlite::params![rev, req.message_id, req.convo_id], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logDeleteMessage", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logDeleteMessage", + "rev": rev, "convoId": req.convo_id, + "message": { + "$type": "chat.bsky.convo.defs#deletedMessageView", + "id": req.message_id, "rev": rev, + "sender": { "did": did }, "sentAt": now, + } + })); + + Ok(Json(DeletedMessageView { + typ: "chat.bsky.convo.defs#deletedMessageView".into(), + id: req.message_id, rev, + sender: MessageViewSender { did }, + sent_at: now, + })) +} + +async fn leave_convo( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let rev = new_rev(); + + db.execute( + "DELETE FROM convo_members WHERE convo_id = ?1 AND did = ?2", + rusqlite::params![req.convo_id, did], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logLeaveConvo", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logLeaveConvo", + "rev": rev, "convoId": req.convo_id, + })); + + Ok(Json(LeaveConvoResp { convo_id: req.convo_id, rev })) +} + +async fn add_reaction( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let now = Utc::now().to_rfc3339(); + let rev = new_rev(); + + db.execute( + "INSERT OR IGNORE INTO reactions (convo_id, message_id, sender_did, value, created_at) + VALUES (?1, ?2, ?3, ?4, ?5)", + rusqlite::params![req.convo_id, req.message_id, did, req.value, now], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logAddReaction", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logAddReaction", + "rev": rev, "convoId": req.convo_id, + "message": { "$type": "chat.bsky.convo.defs#messageView", "id": req.message_id }, + "reaction": { "value": req.value, "sender": { "did": did }, "createdAt": now }, + })); + + Ok(Json(serde_json::json!({ + "convo": load_convo(&db, &req.convo_id, Some(&did)), + }))) +} + +async fn remove_reaction( + State(state): State>, + headers: HeaderMap, + Json(req): Json, +) -> Result, (StatusCode, Json)> { + let did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + let rev = new_rev(); + + db.execute( + "DELETE FROM reactions WHERE convo_id = ?1 AND message_id = ?2 AND sender_did = ?3 AND value = ?4", + rusqlite::params![req.convo_id, req.message_id, did, req.value], + ).unwrap(); + + add_log(&db, &req.convo_id, &rev, "logRemoveReaction", serde_json::json!({ + "$type": "chat.bsky.convo.defs#logRemoveReaction", + "rev": rev, "convoId": req.convo_id, + })); + + Ok(Json(serde_json::json!({ + "convo": load_convo(&db, &req.convo_id, Some(&did)), + }))) +} + +async fn get_log( + State(state): State>, + headers: HeaderMap, + Query(params): Query, +) -> Result, (StatusCode, Json)> { + let _did = require_auth(&headers)?; + let db = state.db.lock().unwrap(); + + let (logs, last_rev) = if let Some(ref cursor) = params.cursor { + // cursor is a rev (hex timestamp) - find logs after that rev + let mut stmt = db.prepare( + "SELECT rev, data FROM logs WHERE rev > ?1 ORDER BY id ASC LIMIT 100" + ).unwrap(); + let mut last = None; + let items: Vec = stmt + .query_map(rusqlite::params![cursor], |row| { + let rev: String = row.get(0)?; + let data: String = row.get(1)?; + Ok((rev, data)) + }) + .unwrap() + .filter_map(|r| r.ok()) + .map(|(rev, data)| { + last = Some(rev); + serde_json::from_str(&data).unwrap_or(serde_json::json!({})) + }) + .collect(); + (items, last) + } else { + // No cursor - return latest rev only + let last: Option = db.query_row( + "SELECT rev FROM logs ORDER BY id DESC LIMIT 1", + [], |row| row.get(0), + ).ok(); + (vec![], last) + }; + + Ok(Json(GetLogResp { cursor: last_rev, logs })) +} + +// --- Main --- + +#[tokio::main] +async fn main() { + tracing_subscriber::fmt::init(); + + let db_path = std::env::var("DB_PATH").unwrap_or_else(|_| "chat.db".into()); + let port: u16 = std::env::var("PORT") + .unwrap_or_else(|_| "3100".into()) + .parse() + .unwrap_or(3100); + + let conn = Connection::open(&db_path).expect("Failed to open database"); + init_db(&conn); + + let state = Arc::new(AppState { db: Mutex::new(conn) }); + + let cors = CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any); + + let app = Router::new() + .route("/.well-known/did.json", get(well_known_did)) + .route("/xrpc/_health", get(health)) + .route("/xrpc/chat.bsky.convo.sendMessage", post(send_message)) + .route("/xrpc/chat.bsky.convo.sendMessageBatch", post(send_message_batch)) + .route("/xrpc/chat.bsky.convo.getMessages", get(get_messages)) + .route("/xrpc/chat.bsky.convo.getConvo", get(get_convo)) + .route("/xrpc/chat.bsky.convo.getConvoForMembers", get(get_convo_for_members)) + .route("/xrpc/chat.bsky.convo.getConvoAvailability", get(get_convo_availability)) + .route("/xrpc/chat.bsky.convo.listConvos", get(list_convos)) + .route("/xrpc/chat.bsky.convo.updateRead", post(update_read)) + .route("/xrpc/chat.bsky.convo.updateAllRead", post(update_all_read)) + .route("/xrpc/chat.bsky.convo.acceptConvo", post(accept_convo)) + .route("/xrpc/chat.bsky.convo.muteConvo", post(mute_convo)) + .route("/xrpc/chat.bsky.convo.unmuteConvo", post(unmute_convo)) + .route("/xrpc/chat.bsky.convo.deleteMessageForSelf", post(delete_message_for_self)) + .route("/xrpc/chat.bsky.convo.leaveConvo", post(leave_convo)) + .route("/xrpc/chat.bsky.convo.addReaction", post(add_reaction)) + .route("/xrpc/chat.bsky.convo.removeReaction", post(remove_reaction)) + .route("/xrpc/chat.bsky.convo.getLog", get(get_log)) + .layer(cors) + .with_state(state); + + let addr = format!("0.0.0.0:{}", port); + tracing::info!("bsky-chat listening on {}", addr); + + let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); + axum::serve(listener, app).await.unwrap(); +}