From 957dfb4074497aacad37e9f08179e05d911fa4e4 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sun, 27 Jul 2025 00:18:47 +0000 Subject: [PATCH 1/5] feat(stats): add metric for basebackup age --- Cargo.lock | 1 + ci/rust-docker-target.sh | 7 ++- crates/scheduled_tasks/Cargo.toml | 1 + crates/scheduled_tasks/src/main.rs | 14 ++++- crates/scheduled_tasks/src/tasks.rs | 91 ++++++++++++++++++++++++++++- 5 files changed, 111 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 96bb2e24..9309d1a9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3311,6 +3311,7 @@ dependencies = [ "chrono", "croner", "fred", + "lazy_static", "libpk", "metrics", "num-format", diff --git a/ci/rust-docker-target.sh b/ci/rust-docker-target.sh index ba6df5e9..d7d5f6fc 100755 --- a/ci/rust-docker-target.sh +++ b/ci/rust-docker-target.sh @@ -42,5 +42,10 @@ build api build dispatch build gateway build avatars "COPY .docker-bin/avatar_cleanup /bin/avatar_cleanup" -build scheduled_tasks +build scheduled_tasks "$(cat < anyhow::Result<()> { update_db_message_meta ); doforever!("* * * * *", "discord stats updater", update_discord_stats); - // on :00 and :30 + // on hh:00 and hh:30 doforever!( "0,30 * * * *", "queue deleted image cleanup job", queue_deleted_image_cleanup ); + // non-standard cron: at hh:mm:00, hh:mm:30 doforever!("0,30 * * * * *", "stats api updater", update_stats_api); + // every hour (could probably even be less frequent, basebackups are taken rarely) + doforever!( + "* * * * *", + "data basebackup info updater", + update_data_basebackup_prometheus + ); + doforever!( + "* * * * *", + "messages basebackup info updater", + update_messages_basebackup_prometheus + ); set.join_next() .await diff --git a/crates/scheduled_tasks/src/tasks.rs b/crates/scheduled_tasks/src/tasks.rs index 64246fc9..84773149 100644 --- a/crates/scheduled_tasks/src/tasks.rs +++ b/crates/scheduled_tasks/src/tasks.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{collections::HashMap, time::Duration}; use anyhow::anyhow; use fred::prelude::KeysInterface; @@ -10,10 +10,22 @@ use metrics::gauge; use num_format::{Locale, ToFormattedString}; use reqwest::ClientBuilder; use sqlx::Executor; +use tokio::{process::Command, sync::Mutex}; use crate::AppCtx; pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> { + let data_ts = *BASEBACKUP_TS.lock().await.get("data").unwrap_or(&0) as f64; + let messages_ts = *BASEBACKUP_TS.lock().await.get("messages").unwrap_or(&0) as f64; + + let now_ts = chrono::Utc::now().timestamp() as f64; + + gauge!("pluralkit_latest_backup_ts", "repo" => "data").set(data_ts); + gauge!("pluralkit_latest_backup_ts", "repo" => "messages").set(messages_ts); + + gauge!("pluralkit_latest_backup_age", "repo" => "data").set(now_ts - data_ts); + gauge!("pluralkit_latest_backup_age", "repo" => "messages").set(now_ts - messages_ts); + #[derive(sqlx::FromRow)] struct Count { count: i64, @@ -41,6 +53,83 @@ pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> { Ok(()) } +lazy_static::lazy_static! { + static ref BASEBACKUP_TS: Mutex> = Mutex::new(HashMap::new()); +} + +pub async fn update_data_basebackup_prometheus(_: AppCtx) -> anyhow::Result<()> { + update_basebackup_ts("data".to_string()).await +} + +pub async fn update_messages_basebackup_prometheus(_: AppCtx) -> anyhow::Result<()> { + update_basebackup_ts("messages".to_string()).await +} + +async fn update_basebackup_ts(repo: String) -> anyhow::Result<()> { + let mut env = HashMap::new(); + + for (key, value) in std::env::vars() { + if key.starts_with("AWS") { + env.insert(key, value); + } + } + + env.insert( + "WALG_S3_PREFIX".to_string(), + format!("s3://pluralkit-backups/{repo}/"), + ); + + let output = Command::new("wal-g") + .arg("backup-list") + .arg("--json") + .envs(env) + .output() + .await?; + + if !output.status.success() { + // todo: we should return error here + tracing::error!( + status = output.status.code(), + "failed to execute wal-g command" + ); + return Ok(()); + } + + #[derive(serde::Deserialize)] + struct WalgBackupInfo { + backup_name: String, + time: String, + ts_parsed: Option, + } + + let mut info = + serde_json::from_str::>(&String::from_utf8_lossy(&output.stdout))? + .into_iter() + .filter(|v| v.backup_name.contains("base")) + .filter_map(|mut v| { + chrono::DateTime::parse_from_rfc3339(&v.time) + .ok() + .map(|dt| { + v.ts_parsed = Some(dt.with_timezone(&chrono::Utc).timestamp()); + v + }) + }) + .collect::>(); + + info.sort_by(|a, b| b.ts_parsed.cmp(&a.ts_parsed)); + + let Some(info) = info.first() else { + anyhow::bail!("could not find any basebackups in repo {repo}"); + }; + + BASEBACKUP_TS + .lock() + .await + .insert(repo, info.ts_parsed.unwrap()); + + Ok(()) +} + pub async fn update_db_meta(ctx: AppCtx) -> anyhow::Result<()> { ctx.data .execute( From e5a169d8a3e24d6013642535cc7677be407572ef Mon Sep 17 00:00:00 2001 From: alyssa Date: Fri, 8 Aug 2025 20:36:51 +0000 Subject: [PATCH 2/5] fix(api): use constant time comparison for tokens --- Cargo.lock | 1 + crates/api/Cargo.toml | 1 + crates/api/src/middleware/auth.rs | 9 ++++++--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 9309d1a9..ba35d84b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -95,6 +95,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sqlx", + "subtle", "tokio", "tower 0.4.13", "tower-http", diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index d2f883d7..e1f99425 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -26,3 +26,4 @@ reverse-proxy-service = { version = "0.2.1", features = ["axum"] } serde_urlencoded = "0.7.1" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["catch-panic"] } +subtle = "2.6.1" diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 08981c3a..3d1d813b 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -5,6 +5,8 @@ use axum::{ response::Response, }; +use subtle::ConstantTimeEq; + use tracing::error; use crate::auth::AuthState; @@ -48,9 +50,10 @@ pub async fn auth(State(ctx): State, mut req: Request, next: Next) - .expect("missing api config") .temp_token2 .as_ref() - // this is NOT how you validate tokens - // but this is low abuse risk so we're keeping it for now - && app_auth_header == config_token2 + && app_auth_header + .as_bytes() + .ct_eq(config_token2.as_bytes()) + .into() { authed_app_id = Some(1); } From fe6085f3d04350adac6ace6e84a1b0181e6764bd Mon Sep 17 00:00:00 2001 From: alyssa Date: Fri, 8 Aug 2025 20:57:38 +0000 Subject: [PATCH 3/5] feat(api): add internal auth --- crates/api/src/auth.rs | 13 +++++++++++-- crates/api/src/middleware/auth.rs | 17 ++++++++++++++++- crates/libpk/src/_config.rs | 3 +++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/crates/api/src/auth.rs b/crates/api/src/auth.rs index c084eafe..4e12a287 100644 --- a/crates/api/src/auth.rs +++ b/crates/api/src/auth.rs @@ -7,11 +7,16 @@ pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid"; pub struct AuthState { system_id: Option, app_id: Option, + internal: bool, } impl AuthState { - pub fn new(system_id: Option, app_id: Option) -> Self { - Self { system_id, app_id } + pub fn new(system_id: Option, app_id: Option, internal: bool) -> Self { + Self { + system_id, + app_id, + internal, + } } pub fn system_id(&self) -> Option { @@ -22,6 +27,10 @@ impl AuthState { self.app_id } + pub fn internal(&self) -> bool { + self.internal + } + pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel { if self .system_id diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 3d1d813b..0992757f 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -58,8 +58,23 @@ pub async fn auth(State(ctx): State, mut req: Request, next: Next) - authed_app_id = Some(1); } + // todo: fix syntax + let internal = if req.headers().get("x-pluralkit-client-ip").is_none() + && let Some(auth_header) = req + .headers() + .get("x-pluralkit-internalauth") + .map(|h| h.to_str().ok()) + .flatten() + && let Some(real_token) = libpk::config.internal_auth.clone() + && auth_header.as_bytes().ct_eq(real_token.as_bytes()).into() + { + true + } else { + false + }; + req.extensions_mut() - .insert(AuthState::new(authed_system_id, authed_app_id)); + .insert(AuthState::new(authed_system_id, authed_app_id, internal)); next.run(req).await } diff --git a/crates/libpk/src/_config.rs b/crates/libpk/src/_config.rs index 8358440b..7f992d95 100644 --- a/crates/libpk/src/_config.rs +++ b/crates/libpk/src/_config.rs @@ -128,6 +128,9 @@ pub struct PKConfig { #[serde(default)] pub sentry_url: Option, + + #[serde(default)] + pub internal_auth: Option, } impl PKConfig { From c1f5327bccc1fc2548760628f7ec3039fe8fd49a Mon Sep 17 00:00:00 2001 From: alyssa Date: Sat, 9 Aug 2025 14:50:57 +0000 Subject: [PATCH 4/5] fix(api): automatically reload ratelimit script on redis server restart --- crates/api/src/middleware/ratelimit.rs | 43 +++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/crates/api/src/middleware/ratelimit.rs b/crates/api/src/middleware/ratelimit.rs index f4a63f7e..1638ecc9 100644 --- a/crates/api/src/middleware/ratelimit.rs +++ b/crates/api/src/middleware/ratelimit.rs @@ -45,21 +45,6 @@ pub fn ratelimiter(f: F) -> FromFnLayer, T> { tokio::spawn(async move { handle }); - let rscript = r.clone(); - tokio::spawn(async move { - if let Ok(()) = rscript.wait_for_connect().await { - match rscript - .script_load::(LUA_SCRIPT.to_string()) - .await - { - Ok(_) => info!("connected to redis for request rate limiting"), - Err(error) => error!(?error, "could not load redis script"), - } - } else { - error!("could not wait for connection to load redis script!"); - } - }); - r }); @@ -152,12 +137,34 @@ pub async fn do_request_ratelimited( let period = 1; // seconds let cost = 1; // todo: update this for group member endpoints + let script_exists: Vec = + match redis.script_exists(vec![LUA_SCRIPT_SHA.to_string()]).await { + Ok(exists) => exists, + Err(error) => { + error!(?error, "failed to check ratelimit script"); + return json_err( + StatusCode::INTERNAL_SERVER_ERROR, + r#"{"message": "500: internal server error", "code": 0}"#.to_string(), + ); + } + }; + + if script_exists[0] != 1 { + match redis + .script_load::(LUA_SCRIPT.to_string()) + .await + { + Ok(_) => info!("successfully loaded ratelimit script to redis"), + Err(error) => { + error!(?error, "could not load redis script") + } + } + } + // local rate_limit_key = KEYS[1] // local rate = ARGV[1] // local period = ARGV[2] // return {remaining, tostring(retry_after), reset_after} - - // todo: check if error is script not found and reload script let resp = redis .evalsha::<(i32, String, u64), String, Vec, Vec>( LUA_SCRIPT_SHA.to_string(), @@ -219,7 +226,7 @@ pub async fn do_request_ratelimited( return response; } Err(error) => { - tracing::error!(?error, "error getting ratelimit info"); + error!(?error, "error getting ratelimit info"); return json_err( StatusCode::INTERNAL_SERVER_ERROR, r#"{"message": "500: internal server error", "code": 0}"#.to_string(), From d4d4ed7feb499bacf83ac24dedea100fbb95e105 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sun, 10 Aug 2025 00:25:29 +0000 Subject: [PATCH 5/5] feat(api): implement PKError in rust-api --- Cargo.lock | 10 +++-- crates/api/Cargo.toml | 1 + crates/api/src/endpoints/private.rs | 18 ++++---- crates/api/src/endpoints/system.rs | 43 ++++++------------ crates/api/src/error.rs | 68 ++++++++++++++++++++++++++--- crates/api/src/main.rs | 21 ++++----- crates/macros/Cargo.toml | 1 + crates/macros/src/api.rs | 52 ++++++++++++++++++++++ crates/macros/src/lib.rs | 6 +++ 9 files changed, 157 insertions(+), 63 deletions(-) create mode 100644 crates/macros/src/api.rs diff --git a/Cargo.lock b/Cargo.lock index ba35d84b..52ab48dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,7 @@ dependencies = [ "lazy_static", "libpk", "metrics", + "pk_macros", "pluralkit_models", "reqwest 0.12.15", "reverse-proxy-service", @@ -2530,6 +2531,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" name = "pk_macros" version = "0.1.0" dependencies = [ + "prettyplease", "proc-macro2", "quote", "syn", @@ -2611,9 +2613,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn", @@ -3965,9 +3967,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index e1f99425..9413d62c 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] pluralkit_models = { path = "../models" } +pk_macros = { path = "../macros" } libpk = { path = "../libpk" } anyhow = { workspace = true } diff --git a/crates/api/src/endpoints/private.rs b/crates/api/src/endpoints/private.rs index df67421c..067fef57 100644 --- a/crates/api/src/endpoints/private.rs +++ b/crates/api/src/endpoints/private.rs @@ -2,6 +2,7 @@ use crate::ApiContext; use axum::{extract::State, response::Json}; use fred::interfaces::*; use libpk::state::ShardState; +use pk_macros::api_endpoint; use serde::Deserialize; use serde_json::{json, Value}; use std::collections::HashMap; @@ -13,34 +14,33 @@ struct ClusterStats { pub channel_count: i32, } +#[api_endpoint] pub async fn discord_state(State(ctx): State) -> Json { let mut shard_status = ctx .redis .hgetall::, &str>("pluralkit:shardstatus") - .await - .unwrap() + .await? .values() .map(|v| serde_json::from_str(v).expect("could not deserialize shard")) .collect::>(); shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id)); - Json(json!({ + Ok(Json(json!({ "shards": shard_status, - })) + }))) } +#[api_endpoint] pub async fn meta(State(ctx): State) -> Json { let stats = serde_json::from_str::( ctx.redis .get::("statsapi") - .await - .unwrap() + .await? .as_str(), - ) - .unwrap(); + )?; - Json(stats) + Ok(Json(stats)) } use std::time::Duration; diff --git a/crates/api/src/endpoints/system.rs b/crates/api/src/endpoints/system.rs index e510f7c5..7b919df5 100644 --- a/crates/api/src/endpoints/system.rs +++ b/crates/api/src/endpoints/system.rs @@ -1,22 +1,18 @@ -use axum::{ - extract::State, - http::StatusCode, - response::{IntoResponse, Response}, - Extension, Json, -}; -use serde_json::json; +use axum::{extract::State, response::IntoResponse, Extension, Json}; +use pk_macros::api_endpoint; +use serde_json::{json, Value}; use sqlx::Postgres; -use tracing::error; use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel}; -use crate::{auth::AuthState, util::json_err, ApiContext}; +use crate::{auth::AuthState, error::fail, ApiContext}; +#[api_endpoint] pub async fn get_system_settings( Extension(auth): Extension, Extension(system): Extension, State(ctx): State, -) -> Response { +) -> Json { let access_level = auth.access_level_for(&system); let mut config = match sqlx::query_as::( @@ -27,23 +23,11 @@ pub async fn get_system_settings( .await { Ok(Some(config)) => config, - Ok(None) => { - error!( - system = system.id, - "failed to find system config for existing system" - ); - return json_err( - StatusCode::INTERNAL_SERVER_ERROR, - r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), - ); - } - Err(err) => { - error!(?err, "failed to query system config"); - return json_err( - StatusCode::INTERNAL_SERVER_ERROR, - r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), - ); - } + Ok(None) => fail!( + system = system.id, + "failed to find system config for existing system" + ), + Err(err) => fail!(?err, "failed to query system config"), }; // fix this @@ -51,7 +35,7 @@ pub async fn get_system_settings( config.name_format = Some("{name} {tag}".to_string()); } - Json(&match access_level { + Ok(Json(match access_level { PrivacyLevel::Private => config.to_json(), PrivacyLevel::Public => json!({ "pings_enabled": config.pings_enabled, @@ -64,6 +48,5 @@ pub async fn get_system_settings( "proxy_switch": config.proxy_switch, "name_format": config.name_format, }), - }) - .into_response() + })) } diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index 464534d8..fc481d0c 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -1,13 +1,17 @@ -use axum::http::StatusCode; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; use std::fmt; -// todo -#[allow(dead_code)] +// todo: model parse errors #[derive(Debug)] pub struct PKError { pub response_code: StatusCode, pub json_code: i32, pub message: &'static str, + + pub inner: Option, } impl fmt::Display for PKError { @@ -16,17 +20,67 @@ impl fmt::Display for PKError { } } -impl std::error::Error for PKError {} +impl Clone for PKError { + fn clone(&self) -> PKError { + if self.inner.is_some() { + panic!("cannot clone PKError with inner error"); + } + PKError { + response_code: self.response_code, + json_code: self.json_code, + message: self.message, + inner: None, + } + } +} + +impl From for PKError +where + E: std::fmt::Display + Into, +{ + fn from(err: E) -> Self { + let mut res = GENERIC_SERVER_ERROR.clone(); + res.inner = Some(err.into()); + res + } +} + +impl IntoResponse for PKError { + fn into_response(self) -> Response { + if let Some(inner) = self.inner { + tracing::error!(?inner, "error returned from handler"); + } + crate::util::json_err( + self.response_code, + serde_json::to_string(&serde_json::json!({ + "message": self.message, + "code": self.json_code, + })) + .unwrap(), + ) + } +} + +macro_rules! fail { + ($($stuff:tt)+) => {{ + tracing::error!($($stuff)+); + return Err(crate::error::GENERIC_SERVER_ERROR); + }}; +} + +pub(crate) use fail; -#[allow(unused_macros)] macro_rules! define_error { ( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => { - const $name: PKError = PKError { + #[allow(dead_code)] + pub const $name: PKError = PKError { response_code: $response_code, json_code: $json_code, message: $message, + inner: None, }; }; } -// define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } +define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } +define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index e3a201cb..07e47f89 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -4,8 +4,8 @@ use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}; use axum::{ body::Body, extract::{Request as ExtractRequest, State}, - http::{Response, StatusCode, Uri}, - response::IntoResponse, + http::Uri, + response::{IntoResponse, Response}, routing::{delete, get, patch, post}, Extension, Router, }; @@ -13,7 +13,9 @@ use hyper_util::{ client::legacy::{connect::HttpConnector, Client}, rt::TokioExecutor, }; -use tracing::{error, info}; +use tracing::info; + +use pk_macros::api_endpoint; mod auth; mod endpoints; @@ -30,11 +32,12 @@ pub struct ApiContext { rproxy_client: Client, } +#[api_endpoint] async fn rproxy( Extension(auth): Extension, State(ctx): State, mut req: ExtractRequest, -) -> Result, StatusCode> { +) -> Response { let path = req.uri().path(); let path_query = req .uri() @@ -59,15 +62,7 @@ async fn rproxy( headers.append(INTERNAL_APPID_HEADER, aid.into()); } - Ok(ctx - .rproxy_client - .request(req) - .await - .map_err(|error| { - error!(?error, "failed to serve reverse proxy to dotnet-api"); - StatusCode::BAD_GATEWAY - })? - .into_response()) + Ok(ctx.rproxy_client.request(req).await?.into_response()) } // this function is manually formatted for easier legibility of route_services diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index 8090798f..10feaf88 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -10,4 +10,5 @@ proc-macro = true quote = "1.0" proc-macro2 = "1.0" syn = "2.0" +prettyplease = "0.2.36" diff --git a/crates/macros/src/api.rs b/crates/macros/src/api.rs new file mode 100644 index 00000000..7f797b8c --- /dev/null +++ b/crates/macros/src/api.rs @@ -0,0 +1,52 @@ +use quote::quote; +use syn::{parse_macro_input, FnArg, ItemFn, Pat}; + +fn pretty_print(ts: &proc_macro2::TokenStream) -> String { + let file = syn::parse_file(&ts.to_string()).unwrap(); + prettyplease::unparse(&file) +} + +pub fn macro_impl( + _args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as ItemFn); + + let fn_name = &input.sig.ident; + let fn_params = &input.sig.inputs; + let fn_body = &input.block; + let syn::ReturnType::Type(_, fn_return_type) = &input.sig.output else { + panic!("handler return type must not be nothing"); + }; + let pms: Vec = fn_params + .iter() + .map(|v| { + let FnArg::Typed(pat) = v else { + panic!("must not have self param in handler"); + }; + let mut pat = pat.pat.clone(); + if let Pat::Ident(ident) = *pat { + let mut ident = ident.clone(); + ident.mutability = None; + pat = Box::new(Pat::Ident(ident)); + } + quote! { #pat } + }) + .collect(); + + let res = quote! { + #[allow(unused_mut)] + pub async fn #fn_name(#fn_params) -> axum::response::Response { + async fn inner(#fn_params) -> Result<#fn_return_type, crate::error::PKError> { + #fn_body + } + + match inner(#(#pms),*).await { + Ok(res) => res.into_response(), + Err(err) => err.into_response(), + } + } + }; + + res.into() +} diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index db5a55b7..ad3c1064 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -1,8 +1,14 @@ use proc_macro::TokenStream; +mod api; mod entrypoint; mod model; +#[proc_macro_attribute] +pub fn api_endpoint(args: TokenStream, input: TokenStream) -> TokenStream { + api::macro_impl(args, input) +} + #[proc_macro_attribute] pub fn main(args: TokenStream, input: TokenStream) -> TokenStream { entrypoint::macro_impl(args, input)