1179 lines
35 KiB
Rust
1179 lines
35 KiB
Rust
use axum::{
|
|
extract::{Query, RawQuery, State},
|
|
http::{HeaderMap, StatusCode},
|
|
response::Json,
|
|
routing::{get, post},
|
|
Router,
|
|
};
|
|
use chrono::Utc;
|
|
use rusqlite::Connection;
|
|
use serde::{Deserialize, Serialize};
|
|
use std::sync::{Arc, Mutex, OnceLock};
|
|
use tower_http::cors::{Any, CorsLayer};
|
|
use uuid::Uuid;
|
|
|
|
// --- Namespace ---
|
|
// CHAT_DOMAIN env (default: "bsky.chat") → reversed for XRPC/types: "chat.bsky"
|
|
|
|
static NS: OnceLock<String> = OnceLock::new();
|
|
|
|
fn ns() -> &'static str {
|
|
NS.get().expect("namespace not initialized")
|
|
}
|
|
|
|
fn ns_type(name: &str) -> String {
|
|
format!("{}.convo.defs#{}", ns(), name)
|
|
}
|
|
|
|
fn ns_route(method: &str) -> String {
|
|
format!("/xrpc/{}.convo.{}", ns(), method)
|
|
}
|
|
|
|
fn init_namespace() {
|
|
let domain = std::env::var("CHAT_DOMAIN").unwrap_or_else(|_| "bsky.chat".into());
|
|
let parts: Vec<&str> = domain.split('.').collect();
|
|
let reversed = if parts.len() == 2 {
|
|
format!("{}.{}", parts[1], parts[0])
|
|
} else {
|
|
domain
|
|
};
|
|
NS.set(reversed).expect("namespace already initialized");
|
|
}
|
|
|
|
// --- Types ---
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct MessageInput {
|
|
text: String,
|
|
facets: Option<serde_json::Value>,
|
|
embed: Option<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct MessageView {
|
|
#[serde(rename = "$type")]
|
|
typ: String,
|
|
id: String,
|
|
rev: String,
|
|
text: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
facets: Option<serde_json::Value>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
embed: Option<serde_json::Value>,
|
|
sender: MessageViewSender,
|
|
sent_at: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
struct MessageViewSender {
|
|
did: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct DeletedMessageView {
|
|
#[serde(rename = "$type")]
|
|
typ: String,
|
|
id: String,
|
|
rev: String,
|
|
sender: MessageViewSender,
|
|
sent_at: String,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct ProfileViewBasic {
|
|
did: String,
|
|
handle: String,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
display_name: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
avatar: Option<String>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
chat_disabled: Option<bool>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct ConvoView {
|
|
id: String,
|
|
rev: String,
|
|
members: Vec<ProfileViewBasic>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
last_message: Option<serde_json::Value>,
|
|
muted: bool,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
status: Option<String>,
|
|
unread_count: i64,
|
|
}
|
|
|
|
// --- Request/Response types ---
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct SendMessageReq {
|
|
convo_id: String,
|
|
message: MessageInput,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct SendMessageBatchItem {
|
|
convo_id: String,
|
|
message: MessageInput,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct SendMessageBatchReq {
|
|
items: Vec<SendMessageBatchItem>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct SendMessageBatchResp {
|
|
items: Vec<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GetMessagesParams {
|
|
convo_id: String,
|
|
limit: Option<i64>,
|
|
cursor: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct GetMessagesResp {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
cursor: Option<String>,
|
|
messages: Vec<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct GetConvoParams {
|
|
convo_id: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct GetConvoResp {
|
|
convo: ConvoView,
|
|
}
|
|
|
|
// members=did1&members=did2 (repeated query params)
|
|
#[derive(Deserialize)]
|
|
struct GetConvoForMembersParams {
|
|
members: Vec<String>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct ListConvosParams {
|
|
limit: Option<i64>,
|
|
cursor: Option<String>,
|
|
status: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ListConvosResp {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
cursor: Option<String>,
|
|
convos: Vec<ConvoView>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct UpdateReadReq {
|
|
convo_id: String,
|
|
message_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct UpdateAllReadReq {
|
|
status: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct UpdateAllReadResp {
|
|
updated_count: i64,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct DeleteMessageReq {
|
|
convo_id: String,
|
|
message_id: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct ConvoIdReq {
|
|
convo_id: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct LeaveConvoResp {
|
|
convo_id: String,
|
|
rev: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct GetLogParams {
|
|
cursor: Option<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct GetLogResp {
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
cursor: Option<String>,
|
|
logs: Vec<serde_json::Value>,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
struct GetConvoAvailabilityParams {
|
|
members: Vec<String>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct GetConvoAvailabilityResp {
|
|
members: Vec<MemberAvailability>,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct MemberAvailability {
|
|
did: String,
|
|
allow_incoming: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct AddReactionReq {
|
|
convo_id: String,
|
|
message_id: String,
|
|
value: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct RemoveReactionReq {
|
|
convo_id: String,
|
|
message_id: String,
|
|
value: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
struct ErrorResp {
|
|
error: String,
|
|
message: String,
|
|
}
|
|
|
|
// --- App State ---
|
|
|
|
struct AppState {
|
|
db: Mutex<Connection>,
|
|
}
|
|
|
|
// --- Auth ---
|
|
|
|
fn extract_did(headers: &HeaderMap) -> Option<String> {
|
|
let auth = headers.get("authorization")?.to_str().ok()?;
|
|
let token = auth.strip_prefix("Bearer ")?;
|
|
|
|
if token.starts_with("did:") {
|
|
return Some(token.to_string());
|
|
}
|
|
|
|
// Decode JWT payload to extract iss (issuer = caller DID)
|
|
let parts: Vec<&str> = token.split('.').collect();
|
|
if parts.len() == 3 {
|
|
if let Ok(decoded) = base64::Engine::decode(
|
|
&base64::engine::general_purpose::URL_SAFE_NO_PAD,
|
|
parts[1],
|
|
) {
|
|
if let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&decoded) {
|
|
if let Some(iss) = payload.get("iss").and_then(|v| v.as_str()) {
|
|
return Some(iss.to_string());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
fn require_auth(headers: &HeaderMap) -> Result<String, (StatusCode, Json<ErrorResp>)> {
|
|
extract_did(headers).ok_or_else(|| {
|
|
(
|
|
StatusCode::UNAUTHORIZED,
|
|
Json(ErrorResp {
|
|
error: "AuthMissing".into(),
|
|
message: "Authentication Required".into(),
|
|
}),
|
|
)
|
|
})
|
|
}
|
|
|
|
fn new_rev() -> String {
|
|
let now = Utc::now().timestamp_micros();
|
|
format!("{:x}", now)
|
|
}
|
|
|
|
fn new_id() -> String {
|
|
Uuid::new_v4().to_string().replace('-', "")[..13].to_string()
|
|
}
|
|
|
|
/// ISO 8601 format compatible with AT Protocol: "2026-03-22T07:21:00.448Z"
|
|
fn now_iso() -> String {
|
|
Utc::now().format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()
|
|
}
|
|
|
|
// --- DB ---
|
|
|
|
fn init_db(conn: &Connection) {
|
|
conn.execute_batch(
|
|
"
|
|
CREATE TABLE IF NOT EXISTS accounts (
|
|
did TEXT PRIMARY KEY,
|
|
handle TEXT NOT NULL DEFAULT '',
|
|
display_name TEXT,
|
|
avatar TEXT,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS convos (
|
|
id TEXT PRIMARY KEY,
|
|
rev TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
updated_at TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS convo_members (
|
|
convo_id TEXT NOT NULL,
|
|
did TEXT NOT NULL,
|
|
status TEXT NOT NULL DEFAULT 'accepted',
|
|
muted INTEGER NOT NULL DEFAULT 0,
|
|
last_read_id TEXT,
|
|
PRIMARY KEY (convo_id, did),
|
|
FOREIGN KEY (convo_id) REFERENCES convos(id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS messages (
|
|
id TEXT PRIMARY KEY,
|
|
convo_id TEXT NOT NULL,
|
|
sender_did TEXT NOT NULL,
|
|
rev TEXT NOT NULL,
|
|
text TEXT NOT NULL,
|
|
facets TEXT,
|
|
embed TEXT,
|
|
sent_at TEXT NOT NULL,
|
|
deleted INTEGER NOT NULL DEFAULT 0,
|
|
FOREIGN KEY (convo_id) REFERENCES convos(id)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS reactions (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
convo_id TEXT NOT NULL,
|
|
message_id TEXT NOT NULL,
|
|
sender_did TEXT NOT NULL,
|
|
value TEXT NOT NULL,
|
|
created_at TEXT NOT NULL,
|
|
UNIQUE(convo_id, message_id, sender_did, value)
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS logs (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
convo_id TEXT NOT NULL,
|
|
rev TEXT NOT NULL,
|
|
log_type TEXT NOT NULL,
|
|
data TEXT NOT NULL,
|
|
created_at TEXT NOT NULL
|
|
);
|
|
",
|
|
)
|
|
.expect("Failed to initialize database");
|
|
}
|
|
|
|
fn ensure_account(conn: &Connection, did: &str) {
|
|
// handle must be a valid handle format, not a DID
|
|
let handle = format!("{}.chat.invalid", did.split(':').last().unwrap_or("unknown"));
|
|
conn.execute(
|
|
"INSERT OR IGNORE INTO accounts (did, handle, created_at) VALUES (?1, ?2, ?3)",
|
|
rusqlite::params![did, handle, now_iso()],
|
|
)
|
|
.ok();
|
|
}
|
|
|
|
fn get_or_create_convo(conn: &Connection, members: &[String]) -> ConvoView {
|
|
let mut sorted = members.to_vec();
|
|
sorted.sort();
|
|
|
|
// Look for existing convo with exact same members
|
|
let all_convo_ids: Vec<String> = {
|
|
let mut stmt = conn
|
|
.prepare("SELECT DISTINCT convo_id FROM convo_members WHERE did = ?1")
|
|
.unwrap();
|
|
stmt.query_map(rusqlite::params![&sorted[0]], |row| row.get(0))
|
|
.unwrap()
|
|
.filter_map(|r| r.ok())
|
|
.collect()
|
|
};
|
|
|
|
for cid in &all_convo_ids {
|
|
let mut stmt = conn
|
|
.prepare("SELECT did FROM convo_members WHERE convo_id = ?1 ORDER BY did")
|
|
.unwrap();
|
|
let convo_members: Vec<String> = stmt
|
|
.query_map(rusqlite::params![cid], |row| row.get(0))
|
|
.unwrap()
|
|
.filter_map(|r| r.ok())
|
|
.collect();
|
|
if convo_members == sorted {
|
|
return load_convo(conn, cid, sorted.first().map(|s| s.as_str()));
|
|
}
|
|
}
|
|
|
|
// Create new convo
|
|
let id = new_id();
|
|
let rev = new_rev();
|
|
let now = now_iso();
|
|
|
|
conn.execute(
|
|
"INSERT INTO convos (id, rev, created_at, updated_at) VALUES (?1, ?2, ?3, ?4)",
|
|
rusqlite::params![id, rev, now, now],
|
|
)
|
|
.unwrap();
|
|
|
|
for did in &sorted {
|
|
ensure_account(conn, did);
|
|
conn.execute(
|
|
"INSERT INTO convo_members (convo_id, did) VALUES (?1, ?2)",
|
|
rusqlite::params![id, did],
|
|
)
|
|
.unwrap();
|
|
}
|
|
|
|
add_log(conn, &id, &rev, "logBeginConvo", serde_json::json!({
|
|
"$type": ns_type("logBeginConvo"),
|
|
"rev": rev,
|
|
"convoId": id,
|
|
}));
|
|
|
|
load_convo(conn, &id, sorted.first().map(|s| s.as_str()))
|
|
}
|
|
|
|
fn add_log(conn: &Connection, convo_id: &str, rev: &str, log_type: &str, data: serde_json::Value) {
|
|
let now = now_iso();
|
|
conn.execute(
|
|
"INSERT INTO logs (convo_id, rev, log_type, data, created_at) VALUES (?1, ?2, ?3, ?4, ?5)",
|
|
rusqlite::params![convo_id, rev, log_type, data.to_string(), now],
|
|
)
|
|
.unwrap();
|
|
}
|
|
|
|
fn load_convo(conn: &Connection, convo_id: &str, viewer_did: Option<&str>) -> ConvoView {
|
|
let (rev, _created_at, _updated_at): (String, String, String) = conn
|
|
.query_row(
|
|
"SELECT rev, created_at, updated_at FROM convos WHERE id = ?1",
|
|
rusqlite::params![convo_id],
|
|
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
|
|
)
|
|
.unwrap_or_else(|_| (new_rev(), String::new(), String::new()));
|
|
|
|
let mut stmt = conn
|
|
.prepare(
|
|
"SELECT a.did, a.handle, a.display_name, a.avatar
|
|
FROM convo_members cm JOIN accounts a ON a.did = cm.did
|
|
WHERE cm.convo_id = ?1",
|
|
)
|
|
.unwrap();
|
|
let members: Vec<ProfileViewBasic> = stmt
|
|
.query_map(rusqlite::params![convo_id], |row| {
|
|
Ok(ProfileViewBasic {
|
|
did: row.get(0)?,
|
|
handle: row.get(1)?,
|
|
display_name: row.get(2)?,
|
|
avatar: row.get(3)?,
|
|
chat_disabled: None,
|
|
})
|
|
})
|
|
.unwrap()
|
|
.filter_map(|r| r.ok())
|
|
.collect();
|
|
|
|
let last_message: Option<serde_json::Value> = conn
|
|
.query_row(
|
|
"SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted
|
|
FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1",
|
|
rusqlite::params![convo_id],
|
|
|row| build_message_json(row),
|
|
)
|
|
.ok();
|
|
|
|
let (muted, status, last_read_id): (bool, String, Option<String>) = viewer_did
|
|
.and_then(|did| {
|
|
conn.query_row(
|
|
"SELECT muted, status, last_read_id FROM convo_members WHERE convo_id = ?1 AND did = ?2",
|
|
rusqlite::params![convo_id, did],
|
|
|row| Ok((row.get(0)?, row.get(1)?, row.get(2)?)),
|
|
)
|
|
.ok()
|
|
})
|
|
.unwrap_or((false, "accepted".into(), None));
|
|
|
|
let unread_count: i64 = if let Some(ref read_id) = last_read_id {
|
|
conn.query_row(
|
|
"SELECT COUNT(*) FROM messages WHERE convo_id = ?1 AND deleted = 0 AND id > ?2",
|
|
rusqlite::params![convo_id, read_id],
|
|
|row| row.get(0),
|
|
)
|
|
.unwrap_or(0)
|
|
} else {
|
|
conn.query_row(
|
|
"SELECT COUNT(*) FROM messages WHERE convo_id = ?1 AND deleted = 0",
|
|
rusqlite::params![convo_id],
|
|
|row| row.get(0),
|
|
)
|
|
.unwrap_or(0)
|
|
};
|
|
|
|
ConvoView {
|
|
id: convo_id.to_string(),
|
|
rev,
|
|
members,
|
|
last_message,
|
|
muted,
|
|
status: Some(status),
|
|
unread_count,
|
|
}
|
|
}
|
|
|
|
fn build_message_json(row: &rusqlite::Row) -> rusqlite::Result<serde_json::Value> {
|
|
let deleted: bool = row.get(7)?;
|
|
let sender_did: String = row.get(5)?;
|
|
if deleted {
|
|
Ok(serde_json::json!({
|
|
"$type": ns_type("deletedMessageView"),
|
|
"id": row.get::<_, String>(0)?,
|
|
"rev": row.get::<_, String>(1)?,
|
|
"sender": { "did": sender_did },
|
|
"sentAt": row.get::<_, String>(6)?,
|
|
}))
|
|
} else {
|
|
let facets: Option<String> = row.get(3)?;
|
|
let embed: Option<String> = row.get(4)?;
|
|
let mut msg = serde_json::json!({
|
|
"$type": ns_type("messageView"),
|
|
"id": row.get::<_, String>(0)?,
|
|
"rev": row.get::<_, String>(1)?,
|
|
"text": row.get::<_, String>(2)?,
|
|
"sender": { "did": sender_did },
|
|
"sentAt": row.get::<_, String>(6)?,
|
|
});
|
|
if let Some(f) = facets {
|
|
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&f) {
|
|
msg["facets"] = v;
|
|
}
|
|
}
|
|
if let Some(e) = embed {
|
|
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&e) {
|
|
msg["embed"] = v;
|
|
}
|
|
}
|
|
Ok(msg)
|
|
}
|
|
}
|
|
|
|
fn insert_message(conn: &Connection, convo_id: &str, did: &str, msg: &MessageInput) -> (String, String, String) {
|
|
let msg_id = new_id();
|
|
let rev = new_rev();
|
|
let now = now_iso();
|
|
let facets_str = msg.facets.as_ref().map(|v| v.to_string());
|
|
let embed_str = msg.embed.as_ref().map(|v| v.to_string());
|
|
|
|
conn.execute(
|
|
"INSERT INTO messages (id, convo_id, sender_did, rev, text, facets, embed, sent_at)
|
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
|
|
rusqlite::params![msg_id, convo_id, did, rev, msg.text, facets_str, embed_str, now],
|
|
)
|
|
.unwrap();
|
|
|
|
conn.execute(
|
|
"UPDATE convos SET rev = ?1, updated_at = ?2 WHERE id = ?3",
|
|
rusqlite::params![rev, now, convo_id],
|
|
)
|
|
.unwrap();
|
|
|
|
add_log(conn, convo_id, &rev, "logCreateMessage", serde_json::json!({
|
|
"$type": ns_type("logCreateMessage"),
|
|
"rev": rev,
|
|
"convoId": convo_id,
|
|
"message": {
|
|
"$type": ns_type("messageView"),
|
|
"id": msg_id, "rev": rev, "text": msg.text,
|
|
"sender": { "did": did }, "sentAt": now,
|
|
}
|
|
}));
|
|
|
|
(msg_id, rev, now)
|
|
}
|
|
|
|
// --- Handlers ---
|
|
|
|
async fn health() -> &'static str {
|
|
"OK"
|
|
}
|
|
|
|
async fn well_known_did() -> Json<serde_json::Value> {
|
|
let did_host = std::env::var("DID_HOST").unwrap_or_else(|_| "bsky.syu.is".into());
|
|
let service_url = std::env::var("SERVICE_URL").unwrap_or_else(|_| format!("https://{}", did_host));
|
|
let did = format!("did:web:{}", did_host);
|
|
|
|
Json(serde_json::json!({
|
|
"@context": ["https://www.w3.org/ns/did/v1"],
|
|
"id": did,
|
|
"service": [
|
|
{
|
|
"id": "#bsky_appview",
|
|
"type": "BskyAppView",
|
|
"serviceEndpoint": &service_url,
|
|
},
|
|
{
|
|
"id": "#bsky_chat",
|
|
"type": "BskyChat",
|
|
"serviceEndpoint": &service_url,
|
|
},
|
|
{
|
|
"id": "#bsky_notif",
|
|
"type": "BskyNotificationService",
|
|
"serviceEndpoint": &service_url,
|
|
}
|
|
]
|
|
}))
|
|
}
|
|
|
|
async fn send_message(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<SendMessageReq>,
|
|
) -> Result<Json<MessageView>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
ensure_account(&db, &did);
|
|
|
|
let is_member: bool = db
|
|
.query_row(
|
|
"SELECT COUNT(*) FROM convo_members WHERE convo_id = ?1 AND did = ?2",
|
|
rusqlite::params![req.convo_id, did],
|
|
|row| row.get::<_, i64>(0),
|
|
)
|
|
.unwrap_or(0)
|
|
> 0;
|
|
|
|
if !is_member {
|
|
return Err((StatusCode::BAD_REQUEST, Json(ErrorResp {
|
|
error: "InvalidRequest".into(),
|
|
message: "Not a member of this conversation".into(),
|
|
})));
|
|
}
|
|
|
|
let (msg_id, rev, now) = insert_message(&db, &req.convo_id, &did, &req.message);
|
|
|
|
Ok(Json(MessageView {
|
|
typ: ns_type("messageView"),
|
|
id: msg_id, rev,
|
|
text: req.message.text,
|
|
facets: req.message.facets,
|
|
embed: req.message.embed,
|
|
sender: MessageViewSender { did },
|
|
sent_at: now,
|
|
}))
|
|
}
|
|
|
|
async fn send_message_batch(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<SendMessageBatchReq>,
|
|
) -> Result<Json<SendMessageBatchResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
ensure_account(&db, &did);
|
|
|
|
let mut items = Vec::new();
|
|
for item in &req.items {
|
|
let (msg_id, rev, now) = insert_message(&db, &item.convo_id, &did, &item.message);
|
|
items.push(serde_json::json!({
|
|
"$type": ns_type("messageView"),
|
|
"id": msg_id, "rev": rev, "text": item.message.text,
|
|
"sender": { "did": did }, "sentAt": now,
|
|
}));
|
|
}
|
|
|
|
Ok(Json(SendMessageBatchResp { items }))
|
|
}
|
|
|
|
async fn get_messages(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Query(params): Query<GetMessagesParams>,
|
|
) -> Result<Json<GetMessagesResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let _did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let limit = params.limit.unwrap_or(50).min(100);
|
|
|
|
let messages: Vec<serde_json::Value> = if let Some(ref cursor) = params.cursor {
|
|
let mut stmt = db.prepare(
|
|
"SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted
|
|
FROM messages WHERE convo_id = ?1 AND deleted = 0 AND sent_at < ?2
|
|
ORDER BY sent_at DESC LIMIT ?3"
|
|
).unwrap();
|
|
stmt.query_map(rusqlite::params![params.convo_id, cursor, limit], |row| build_message_json(row))
|
|
.unwrap().filter_map(|r| r.ok()).collect()
|
|
} else {
|
|
let mut stmt = db.prepare(
|
|
"SELECT id, rev, text, facets, embed, sender_did, sent_at, deleted
|
|
FROM messages WHERE convo_id = ?1 AND deleted = 0
|
|
ORDER BY sent_at DESC LIMIT ?2"
|
|
).unwrap();
|
|
stmt.query_map(rusqlite::params![params.convo_id, limit], |row| build_message_json(row))
|
|
.unwrap().filter_map(|r| r.ok()).collect()
|
|
};
|
|
|
|
let cursor = messages.last().and_then(|m| m["sentAt"].as_str().map(String::from));
|
|
Ok(Json(GetMessagesResp { cursor, messages }))
|
|
}
|
|
|
|
async fn get_convo(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Query(params): Query<GetConvoParams>,
|
|
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
Ok(Json(GetConvoResp { convo: load_convo(&db, ¶ms.convo_id, Some(&did)) }))
|
|
}
|
|
|
|
/// Parse repeated query params: ?members=did1&members=did2
|
|
fn parse_members_query(query: &Option<String>) -> Vec<String> {
|
|
query.as_deref().unwrap_or("").split('&')
|
|
.filter_map(|pair| {
|
|
let (key, val) = pair.split_once('=')?;
|
|
if key == "members" { Some(urlencoding::decode(val).ok()?.into_owned()) } else { None }
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
async fn get_convo_for_members(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
RawQuery(query): RawQuery,
|
|
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let members = parse_members_query(&query);
|
|
if !members.contains(&did) {
|
|
return Err((StatusCode::BAD_REQUEST, Json(ErrorResp {
|
|
error: "InvalidRequest".into(),
|
|
message: "Caller must be a member".into(),
|
|
})));
|
|
}
|
|
let db = state.db.lock().unwrap();
|
|
for m in &members {
|
|
ensure_account(&db, m);
|
|
}
|
|
let convo = get_or_create_convo(&db, &members);
|
|
Ok(Json(GetConvoResp { convo }))
|
|
}
|
|
|
|
async fn get_convo_availability(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
RawQuery(query): RawQuery,
|
|
) -> Result<Json<GetConvoAvailabilityResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let _did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let members_raw = parse_members_query(&query);
|
|
let members: Vec<MemberAvailability> = members_raw.iter().map(|did| {
|
|
ensure_account(&db, did);
|
|
MemberAvailability {
|
|
did: did.clone(),
|
|
allow_incoming: "all".into(),
|
|
}
|
|
}).collect();
|
|
Ok(Json(GetConvoAvailabilityResp { members }))
|
|
}
|
|
|
|
async fn list_convos(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Query(params): Query<ListConvosParams>,
|
|
) -> Result<Json<ListConvosResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let limit = params.limit.unwrap_or(20).min(100);
|
|
let status_filter = params.status.unwrap_or_else(|| "accepted".into());
|
|
|
|
let mut stmt = db.prepare(
|
|
"SELECT cm.convo_id FROM convo_members cm
|
|
JOIN convos c ON c.id = cm.convo_id
|
|
WHERE cm.did = ?1 AND cm.status = ?2
|
|
ORDER BY c.updated_at DESC LIMIT ?3",
|
|
).unwrap();
|
|
|
|
let convo_ids: Vec<String> = stmt
|
|
.query_map(rusqlite::params![did, status_filter, limit], |row| row.get(0))
|
|
.unwrap().filter_map(|r| r.ok()).collect();
|
|
|
|
let convos: Vec<ConvoView> = convo_ids.iter()
|
|
.map(|id| load_convo(&db, id, Some(&did))).collect();
|
|
|
|
Ok(Json(ListConvosResp { cursor: None, convos }))
|
|
}
|
|
|
|
async fn update_read(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<UpdateReadReq>,
|
|
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
|
|
let read_id = req.message_id.unwrap_or_else(|| {
|
|
db.query_row(
|
|
"SELECT id FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1",
|
|
rusqlite::params![req.convo_id],
|
|
|row| row.get(0),
|
|
).unwrap_or_default()
|
|
});
|
|
|
|
db.execute(
|
|
"UPDATE convo_members SET last_read_id = ?1 WHERE convo_id = ?2 AND did = ?3",
|
|
rusqlite::params![read_id, req.convo_id, did],
|
|
).unwrap();
|
|
|
|
let convo = load_convo(&db, &req.convo_id, Some(&did));
|
|
Ok(Json(GetConvoResp { convo }))
|
|
}
|
|
|
|
async fn update_all_read(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<UpdateAllReadReq>,
|
|
) -> Result<Json<UpdateAllReadResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let status = req.status.unwrap_or_else(|| "accepted".into());
|
|
|
|
// Get all convos for this user with given status
|
|
let mut stmt = db.prepare(
|
|
"SELECT convo_id FROM convo_members WHERE did = ?1 AND status = ?2"
|
|
).unwrap();
|
|
let convo_ids: Vec<String> = stmt
|
|
.query_map(rusqlite::params![did, status], |row| row.get(0))
|
|
.unwrap().filter_map(|r| r.ok()).collect();
|
|
|
|
let mut count = 0i64;
|
|
for cid in &convo_ids {
|
|
let last_msg_id: Option<String> = db.query_row(
|
|
"SELECT id FROM messages WHERE convo_id = ?1 ORDER BY sent_at DESC LIMIT 1",
|
|
rusqlite::params![cid], |row| row.get(0),
|
|
).ok();
|
|
if let Some(mid) = last_msg_id {
|
|
db.execute(
|
|
"UPDATE convo_members SET last_read_id = ?1 WHERE convo_id = ?2 AND did = ?3",
|
|
rusqlite::params![mid, cid, did],
|
|
).unwrap();
|
|
count += 1;
|
|
}
|
|
}
|
|
|
|
Ok(Json(UpdateAllReadResp { updated_count: count }))
|
|
}
|
|
|
|
async fn accept_convo(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<ConvoIdReq>,
|
|
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let rev = new_rev();
|
|
|
|
db.execute(
|
|
"UPDATE convo_members SET status = 'accepted' WHERE convo_id = ?1 AND did = ?2",
|
|
rusqlite::params![req.convo_id, did],
|
|
).unwrap();
|
|
db.execute(
|
|
"UPDATE convos SET rev = ?1 WHERE id = ?2",
|
|
rusqlite::params![rev, req.convo_id],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logAcceptConvo", serde_json::json!({
|
|
"$type": ns_type("logAcceptConvo"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
}));
|
|
|
|
let convo = load_convo(&db, &req.convo_id, Some(&did));
|
|
Ok(Json(GetConvoResp { convo }))
|
|
}
|
|
|
|
async fn mute_convo(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<ConvoIdReq>,
|
|
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let rev = new_rev();
|
|
|
|
db.execute(
|
|
"UPDATE convo_members SET muted = 1 WHERE convo_id = ?1 AND did = ?2",
|
|
rusqlite::params![req.convo_id, did],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logMuteConvo", serde_json::json!({
|
|
"$type": ns_type("logMuteConvo"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
}));
|
|
|
|
let convo = load_convo(&db, &req.convo_id, Some(&did));
|
|
Ok(Json(GetConvoResp { convo }))
|
|
}
|
|
|
|
async fn unmute_convo(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<ConvoIdReq>,
|
|
) -> Result<Json<GetConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let rev = new_rev();
|
|
|
|
db.execute(
|
|
"UPDATE convo_members SET muted = 0 WHERE convo_id = ?1 AND did = ?2",
|
|
rusqlite::params![req.convo_id, did],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logUnmuteConvo", serde_json::json!({
|
|
"$type": ns_type("logUnmuteConvo"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
}));
|
|
|
|
let convo = load_convo(&db, &req.convo_id, Some(&did));
|
|
Ok(Json(GetConvoResp { convo }))
|
|
}
|
|
|
|
async fn delete_message_for_self(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<DeleteMessageReq>,
|
|
) -> Result<Json<DeletedMessageView>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let rev = new_rev();
|
|
let now = now_iso();
|
|
|
|
db.execute(
|
|
"UPDATE messages SET deleted = 1, rev = ?1 WHERE id = ?2 AND convo_id = ?3",
|
|
rusqlite::params![rev, req.message_id, req.convo_id],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logDeleteMessage", serde_json::json!({
|
|
"$type": ns_type("logDeleteMessage"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
"message": {
|
|
"$type": ns_type("deletedMessageView"),
|
|
"id": req.message_id, "rev": rev,
|
|
"sender": { "did": did }, "sentAt": now,
|
|
}
|
|
}));
|
|
|
|
Ok(Json(DeletedMessageView {
|
|
typ: ns_type("deletedMessageView"),
|
|
id: req.message_id, rev,
|
|
sender: MessageViewSender { did },
|
|
sent_at: now,
|
|
}))
|
|
}
|
|
|
|
async fn leave_convo(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<ConvoIdReq>,
|
|
) -> Result<Json<LeaveConvoResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let rev = new_rev();
|
|
|
|
db.execute(
|
|
"DELETE FROM convo_members WHERE convo_id = ?1 AND did = ?2",
|
|
rusqlite::params![req.convo_id, did],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logLeaveConvo", serde_json::json!({
|
|
"$type": ns_type("logLeaveConvo"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
}));
|
|
|
|
Ok(Json(LeaveConvoResp { convo_id: req.convo_id, rev }))
|
|
}
|
|
|
|
async fn add_reaction(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<AddReactionReq>,
|
|
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let now = now_iso();
|
|
let rev = new_rev();
|
|
|
|
db.execute(
|
|
"INSERT OR IGNORE INTO reactions (convo_id, message_id, sender_did, value, created_at)
|
|
VALUES (?1, ?2, ?3, ?4, ?5)",
|
|
rusqlite::params![req.convo_id, req.message_id, did, req.value, now],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logAddReaction", serde_json::json!({
|
|
"$type": ns_type("logAddReaction"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
"message": { "$type": ns_type("messageView"), "id": req.message_id },
|
|
"reaction": { "value": req.value, "sender": { "did": did }, "createdAt": now },
|
|
}));
|
|
|
|
Ok(Json(serde_json::json!({
|
|
"convo": load_convo(&db, &req.convo_id, Some(&did)),
|
|
})))
|
|
}
|
|
|
|
async fn remove_reaction(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Json(req): Json<RemoveReactionReq>,
|
|
) -> Result<Json<serde_json::Value>, (StatusCode, Json<ErrorResp>)> {
|
|
let did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
let rev = new_rev();
|
|
|
|
db.execute(
|
|
"DELETE FROM reactions WHERE convo_id = ?1 AND message_id = ?2 AND sender_did = ?3 AND value = ?4",
|
|
rusqlite::params![req.convo_id, req.message_id, did, req.value],
|
|
).unwrap();
|
|
|
|
add_log(&db, &req.convo_id, &rev, "logRemoveReaction", serde_json::json!({
|
|
"$type": ns_type("logRemoveReaction"),
|
|
"rev": rev, "convoId": req.convo_id,
|
|
}));
|
|
|
|
Ok(Json(serde_json::json!({
|
|
"convo": load_convo(&db, &req.convo_id, Some(&did)),
|
|
})))
|
|
}
|
|
|
|
async fn get_log(
|
|
State(state): State<Arc<AppState>>,
|
|
headers: HeaderMap,
|
|
Query(params): Query<GetLogParams>,
|
|
) -> Result<Json<GetLogResp>, (StatusCode, Json<ErrorResp>)> {
|
|
let _did = require_auth(&headers)?;
|
|
let db = state.db.lock().unwrap();
|
|
|
|
let (logs, last_rev) = if let Some(ref cursor) = params.cursor {
|
|
// cursor is a rev (hex timestamp) - find logs after that rev
|
|
let mut stmt = db.prepare(
|
|
"SELECT rev, data FROM logs WHERE rev > ?1 ORDER BY id ASC LIMIT 100"
|
|
).unwrap();
|
|
let mut last = None;
|
|
let items: Vec<serde_json::Value> = stmt
|
|
.query_map(rusqlite::params![cursor], |row| {
|
|
let rev: String = row.get(0)?;
|
|
let data: String = row.get(1)?;
|
|
Ok((rev, data))
|
|
})
|
|
.unwrap()
|
|
.filter_map(|r| r.ok())
|
|
.map(|(rev, data)| {
|
|
last = Some(rev);
|
|
serde_json::from_str(&data).unwrap_or(serde_json::json!({}))
|
|
})
|
|
.collect();
|
|
(items, last)
|
|
} else {
|
|
// No cursor - return latest rev only
|
|
let last: Option<String> = db.query_row(
|
|
"SELECT rev FROM logs ORDER BY id DESC LIMIT 1",
|
|
[], |row| row.get(0),
|
|
).ok();
|
|
(vec![], last)
|
|
};
|
|
|
|
Ok(Json(GetLogResp { cursor: last_rev, logs }))
|
|
}
|
|
|
|
// --- Main ---
|
|
|
|
#[tokio::main]
|
|
async fn main() {
|
|
let args: Vec<String> = std::env::args().collect();
|
|
if args.len() > 1 && (args[1] == "v" || args[1] == "--version" || args[1] == "-v") {
|
|
println!("{}", env!("CARGO_PKG_VERSION"));
|
|
return;
|
|
}
|
|
|
|
init_namespace();
|
|
tracing_subscriber::fmt::init();
|
|
|
|
let db_path = std::env::var("DB_PATH").unwrap_or_else(|_| "chat.db".into());
|
|
let port: u16 = std::env::var("PORT")
|
|
.unwrap_or_else(|_| "3100".into())
|
|
.parse()
|
|
.unwrap_or(3100);
|
|
|
|
let conn = Connection::open(&db_path).expect("Failed to open database");
|
|
init_db(&conn);
|
|
|
|
let state = Arc::new(AppState { db: Mutex::new(conn) });
|
|
|
|
let cors = CorsLayer::new()
|
|
.allow_origin(Any)
|
|
.allow_methods(Any)
|
|
.allow_headers(Any);
|
|
|
|
// Leak route strings so they live for 'static (allocated once at startup)
|
|
let r = |method: &str| -> &'static str {
|
|
Box::leak(ns_route(method).into_boxed_str())
|
|
};
|
|
|
|
let app = Router::new()
|
|
.route("/.well-known/did.json", get(well_known_did))
|
|
.route("/xrpc/_health", get(health))
|
|
.route(r("sendMessage"), post(send_message))
|
|
.route(r("sendMessageBatch"), post(send_message_batch))
|
|
.route(r("getMessages"), get(get_messages))
|
|
.route(r("getConvo"), get(get_convo))
|
|
.route(r("getConvoForMembers"), get(get_convo_for_members))
|
|
.route(r("getConvoAvailability"), get(get_convo_availability))
|
|
.route(r("listConvos"), get(list_convos))
|
|
.route(r("updateRead"), post(update_read))
|
|
.route(r("updateAllRead"), post(update_all_read))
|
|
.route(r("acceptConvo"), post(accept_convo))
|
|
.route(r("muteConvo"), post(mute_convo))
|
|
.route(r("unmuteConvo"), post(unmute_convo))
|
|
.route(r("deleteMessageForSelf"), post(delete_message_for_self))
|
|
.route(r("leaveConvo"), post(leave_convo))
|
|
.route(r("addReaction"), post(add_reaction))
|
|
.route(r("removeReaction"), post(remove_reaction))
|
|
.route(r("getLog"), get(get_log))
|
|
.layer(cors)
|
|
.with_state(state);
|
|
|
|
let host = std::env::var("HOST").unwrap_or_else(|_| "0.0.0.0".into());
|
|
let addr = format!("{}:{}", host, port);
|
|
tracing::info!("{} listening on {}", ns(), addr);
|
|
|
|
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
|
|
axum::serve(listener, app).await.unwrap();
|
|
}
|