add cmd lexicon

This commit is contained in:
2026-01-18 12:15:43 +09:00
parent 28eb463b74
commit 1fd619e32b
15 changed files with 1317 additions and 50 deletions

192
src/commands/gen.rs Normal file
View File

@@ -0,0 +1,192 @@
use anyhow::{Context, Result};
use serde::Deserialize;
use std::collections::BTreeMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Deserialize)]
struct Lexicon {
id: String,
defs: BTreeMap<String, LexiconDef>,
}
#[derive(Debug, Deserialize)]
struct LexiconDef {
#[serde(rename = "type")]
def_type: Option<String>,
}
struct EndpointInfo {
nsid: String,
method: String, // GET or POST
}
/// Generate lexicon Rust code from ATProto lexicon JSON files
pub fn generate(input: &str, output: &str) -> Result<()> {
let input_path = Path::new(input);
if !input_path.exists() {
anyhow::bail!("Input directory does not exist: {}", input);
}
println!("Scanning lexicons from: {}", input);
// Collect all endpoints grouped by namespace
let mut namespaces: BTreeMap<String, Vec<EndpointInfo>> = BTreeMap::new();
// Scan com/atproto directory
let atproto_path = input_path.join("com/atproto");
if atproto_path.exists() {
scan_namespace(&atproto_path, "com.atproto", &mut namespaces)?;
}
// Scan app/bsky directory
let bsky_path = input_path.join("app/bsky");
if bsky_path.exists() {
scan_namespace(&bsky_path, "app.bsky", &mut namespaces)?;
}
// Generate Rust code
let code = generate_rust_code(&namespaces);
// Write output
let output_path = Path::new(output).join("mod.rs");
fs::create_dir_all(output)?;
fs::write(&output_path, &code)?;
println!("Generated: {}", output_path.display());
println!("Total namespaces: {}", namespaces.len());
let total_endpoints: usize = namespaces.values().map(|v| v.len()).sum();
println!("Total endpoints: {}", total_endpoints);
Ok(())
}
fn scan_namespace(
base_path: &Path,
prefix: &str,
namespaces: &mut BTreeMap<String, Vec<EndpointInfo>>,
) -> Result<()> {
for entry in fs::read_dir(base_path)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
let ns_name = path.file_name()
.and_then(|n| n.to_str())
.context("Invalid directory name")?;
let full_ns = format!("{}.{}", prefix, ns_name);
let mut endpoints = Vec::new();
// Scan JSON files in this namespace
for file_entry in fs::read_dir(&path)? {
let file_entry = file_entry?;
let file_path = file_entry.path();
if file_path.extension().map(|e| e == "json").unwrap_or(false) {
if let Some(endpoint) = parse_lexicon_file(&file_path)? {
endpoints.push(endpoint);
}
}
}
if !endpoints.is_empty() {
endpoints.sort_by(|a, b| a.nsid.cmp(&b.nsid));
namespaces.insert(full_ns, endpoints);
}
}
}
Ok(())
}
fn parse_lexicon_file(path: &Path) -> Result<Option<EndpointInfo>> {
let content = fs::read_to_string(path)
.with_context(|| format!("Failed to read: {}", path.display()))?;
let lexicon: Lexicon = serde_json::from_str(&content)
.with_context(|| format!("Failed to parse: {}", path.display()))?;
// Get the main definition type
let main_def = match lexicon.defs.get("main") {
Some(def) => def,
None => return Ok(None),
};
let method = match main_def.def_type.as_deref() {
Some("query") => "GET",
Some("procedure") => "POST",
Some("subscription") => return Ok(None), // Skip websocket subscriptions
_ => return Ok(None), // Skip records, tokens, etc.
};
Ok(Some(EndpointInfo {
nsid: lexicon.id,
method: method.to_string(),
}))
}
fn generate_rust_code(namespaces: &BTreeMap<String, Vec<EndpointInfo>>) -> String {
let mut code = String::new();
// Header
code.push_str("//! Auto-generated from ATProto lexicons\n");
code.push_str("//! Run `ailog gen` to regenerate\n");
code.push_str("//! Do not edit manually\n\n");
// Endpoint struct
code.push_str("#[derive(Debug, Clone, Copy)]\n");
code.push_str("pub struct Endpoint {\n");
code.push_str(" pub nsid: &'static str,\n");
code.push_str(" pub method: &'static str,\n");
code.push_str("}\n\n");
// URL helper function
code.push_str("/// Build XRPC URL for an endpoint\n");
code.push_str("pub fn url(pds: &str, endpoint: &Endpoint) -> String {\n");
code.push_str(" format!(\"https://{}/xrpc/{}\", pds, endpoint.nsid)\n");
code.push_str("}\n\n");
// Generate modules for each namespace
for (ns, endpoints) in namespaces {
// Convert namespace to module name: com.atproto.repo -> com_atproto_repo
let mod_name = ns.replace('.', "_");
code.push_str(&format!("pub mod {} {{\n", mod_name));
code.push_str(" use super::Endpoint;\n\n");
for endpoint in endpoints {
// Extract the method name from NSID: com.atproto.repo.listRecords -> LIST_RECORDS
let method_name = endpoint.nsid
.rsplit('.')
.next()
.unwrap_or(&endpoint.nsid);
// Convert camelCase to SCREAMING_SNAKE_CASE
let const_name = to_screaming_snake_case(method_name);
code.push_str(&format!(
" pub const {}: Endpoint = Endpoint {{ nsid: \"{}\", method: \"{}\" }};\n",
const_name, endpoint.nsid, endpoint.method
));
}
code.push_str("}\n\n");
}
code
}
fn to_screaming_snake_case(s: &str) -> String {
let mut result = String::new();
for (i, c) in s.chars().enumerate() {
if c.is_uppercase() && i > 0 {
result.push('_');
}
result.push(c.to_ascii_uppercase());
}
result
}