diff --git a/Cargo.lock b/Cargo.lock index ba35d84b..52ab48dc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,7 @@ dependencies = [ "lazy_static", "libpk", "metrics", + "pk_macros", "pluralkit_models", "reqwest 0.12.15", "reverse-proxy-service", @@ -2530,6 +2531,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" name = "pk_macros" version = "0.1.0" dependencies = [ + "prettyplease", "proc-macro2", "quote", "syn", @@ -2611,9 +2613,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn", @@ -3965,9 +3967,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index e1f99425..9413d62c 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] pluralkit_models = { path = "../models" } +pk_macros = { path = "../macros" } libpk = { path = "../libpk" } anyhow = { workspace = true } diff --git a/crates/api/src/endpoints/private.rs b/crates/api/src/endpoints/private.rs index df67421c..067fef57 100644 --- a/crates/api/src/endpoints/private.rs +++ b/crates/api/src/endpoints/private.rs @@ -2,6 +2,7 @@ use crate::ApiContext; use axum::{extract::State, response::Json}; use fred::interfaces::*; use libpk::state::ShardState; +use pk_macros::api_endpoint; use serde::Deserialize; use serde_json::{json, Value}; use std::collections::HashMap; @@ -13,34 +14,33 @@ struct ClusterStats { pub channel_count: i32, } +#[api_endpoint] pub async fn discord_state(State(ctx): State) -> Json { let mut shard_status = ctx .redis .hgetall::, &str>("pluralkit:shardstatus") - .await - .unwrap() + .await? .values() .map(|v| serde_json::from_str(v).expect("could not deserialize shard")) .collect::>(); shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id)); - Json(json!({ + Ok(Json(json!({ "shards": shard_status, - })) + }))) } +#[api_endpoint] pub async fn meta(State(ctx): State) -> Json { let stats = serde_json::from_str::( ctx.redis .get::("statsapi") - .await - .unwrap() + .await? .as_str(), - ) - .unwrap(); + )?; - Json(stats) + Ok(Json(stats)) } use std::time::Duration; diff --git a/crates/api/src/endpoints/system.rs b/crates/api/src/endpoints/system.rs index e510f7c5..7b919df5 100644 --- a/crates/api/src/endpoints/system.rs +++ b/crates/api/src/endpoints/system.rs @@ -1,22 +1,18 @@ -use axum::{ - extract::State, - http::StatusCode, - response::{IntoResponse, Response}, - Extension, Json, -}; -use serde_json::json; +use axum::{extract::State, response::IntoResponse, Extension, Json}; +use pk_macros::api_endpoint; +use serde_json::{json, Value}; use sqlx::Postgres; -use tracing::error; use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel}; -use crate::{auth::AuthState, util::json_err, ApiContext}; +use crate::{auth::AuthState, error::fail, ApiContext}; +#[api_endpoint] pub async fn get_system_settings( Extension(auth): Extension, Extension(system): Extension, State(ctx): State, -) -> Response { +) -> Json { let access_level = auth.access_level_for(&system); let mut config = match sqlx::query_as::( @@ -27,23 +23,11 @@ pub async fn get_system_settings( .await { Ok(Some(config)) => config, - Ok(None) => { - error!( - system = system.id, - "failed to find system config for existing system" - ); - return json_err( - StatusCode::INTERNAL_SERVER_ERROR, - r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), - ); - } - Err(err) => { - error!(?err, "failed to query system config"); - return json_err( - StatusCode::INTERNAL_SERVER_ERROR, - r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), - ); - } + Ok(None) => fail!( + system = system.id, + "failed to find system config for existing system" + ), + Err(err) => fail!(?err, "failed to query system config"), }; // fix this @@ -51,7 +35,7 @@ pub async fn get_system_settings( config.name_format = Some("{name} {tag}".to_string()); } - Json(&match access_level { + Ok(Json(match access_level { PrivacyLevel::Private => config.to_json(), PrivacyLevel::Public => json!({ "pings_enabled": config.pings_enabled, @@ -64,6 +48,5 @@ pub async fn get_system_settings( "proxy_switch": config.proxy_switch, "name_format": config.name_format, }), - }) - .into_response() + })) } diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index 464534d8..fc481d0c 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -1,13 +1,17 @@ -use axum::http::StatusCode; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; use std::fmt; -// todo -#[allow(dead_code)] +// todo: model parse errors #[derive(Debug)] pub struct PKError { pub response_code: StatusCode, pub json_code: i32, pub message: &'static str, + + pub inner: Option, } impl fmt::Display for PKError { @@ -16,17 +20,67 @@ impl fmt::Display for PKError { } } -impl std::error::Error for PKError {} +impl Clone for PKError { + fn clone(&self) -> PKError { + if self.inner.is_some() { + panic!("cannot clone PKError with inner error"); + } + PKError { + response_code: self.response_code, + json_code: self.json_code, + message: self.message, + inner: None, + } + } +} + +impl From for PKError +where + E: std::fmt::Display + Into, +{ + fn from(err: E) -> Self { + let mut res = GENERIC_SERVER_ERROR.clone(); + res.inner = Some(err.into()); + res + } +} + +impl IntoResponse for PKError { + fn into_response(self) -> Response { + if let Some(inner) = self.inner { + tracing::error!(?inner, "error returned from handler"); + } + crate::util::json_err( + self.response_code, + serde_json::to_string(&serde_json::json!({ + "message": self.message, + "code": self.json_code, + })) + .unwrap(), + ) + } +} + +macro_rules! fail { + ($($stuff:tt)+) => {{ + tracing::error!($($stuff)+); + return Err(crate::error::GENERIC_SERVER_ERROR); + }}; +} + +pub(crate) use fail; -#[allow(unused_macros)] macro_rules! define_error { ( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => { - const $name: PKError = PKError { + #[allow(dead_code)] + pub const $name: PKError = PKError { response_code: $response_code, json_code: $json_code, message: $message, + inner: None, }; }; } -// define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } +define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } +define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index e3a201cb..07e47f89 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -4,8 +4,8 @@ use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}; use axum::{ body::Body, extract::{Request as ExtractRequest, State}, - http::{Response, StatusCode, Uri}, - response::IntoResponse, + http::Uri, + response::{IntoResponse, Response}, routing::{delete, get, patch, post}, Extension, Router, }; @@ -13,7 +13,9 @@ use hyper_util::{ client::legacy::{connect::HttpConnector, Client}, rt::TokioExecutor, }; -use tracing::{error, info}; +use tracing::info; + +use pk_macros::api_endpoint; mod auth; mod endpoints; @@ -30,11 +32,12 @@ pub struct ApiContext { rproxy_client: Client, } +#[api_endpoint] async fn rproxy( Extension(auth): Extension, State(ctx): State, mut req: ExtractRequest, -) -> Result, StatusCode> { +) -> Response { let path = req.uri().path(); let path_query = req .uri() @@ -59,15 +62,7 @@ async fn rproxy( headers.append(INTERNAL_APPID_HEADER, aid.into()); } - Ok(ctx - .rproxy_client - .request(req) - .await - .map_err(|error| { - error!(?error, "failed to serve reverse proxy to dotnet-api"); - StatusCode::BAD_GATEWAY - })? - .into_response()) + Ok(ctx.rproxy_client.request(req).await?.into_response()) } // this function is manually formatted for easier legibility of route_services diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index 8090798f..10feaf88 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -10,4 +10,5 @@ proc-macro = true quote = "1.0" proc-macro2 = "1.0" syn = "2.0" +prettyplease = "0.2.36" diff --git a/crates/macros/src/api.rs b/crates/macros/src/api.rs new file mode 100644 index 00000000..7f797b8c --- /dev/null +++ b/crates/macros/src/api.rs @@ -0,0 +1,52 @@ +use quote::quote; +use syn::{parse_macro_input, FnArg, ItemFn, Pat}; + +fn pretty_print(ts: &proc_macro2::TokenStream) -> String { + let file = syn::parse_file(&ts.to_string()).unwrap(); + prettyplease::unparse(&file) +} + +pub fn macro_impl( + _args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as ItemFn); + + let fn_name = &input.sig.ident; + let fn_params = &input.sig.inputs; + let fn_body = &input.block; + let syn::ReturnType::Type(_, fn_return_type) = &input.sig.output else { + panic!("handler return type must not be nothing"); + }; + let pms: Vec = fn_params + .iter() + .map(|v| { + let FnArg::Typed(pat) = v else { + panic!("must not have self param in handler"); + }; + let mut pat = pat.pat.clone(); + if let Pat::Ident(ident) = *pat { + let mut ident = ident.clone(); + ident.mutability = None; + pat = Box::new(Pat::Ident(ident)); + } + quote! { #pat } + }) + .collect(); + + let res = quote! { + #[allow(unused_mut)] + pub async fn #fn_name(#fn_params) -> axum::response::Response { + async fn inner(#fn_params) -> Result<#fn_return_type, crate::error::PKError> { + #fn_body + } + + match inner(#(#pms),*).await { + Ok(res) => res.into_response(), + Err(err) => err.into_response(), + } + } + }; + + res.into() +} diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index db5a55b7..ad3c1064 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -1,8 +1,14 @@ use proc_macro::TokenStream; +mod api; mod entrypoint; mod model; +#[proc_macro_attribute] +pub fn api_endpoint(args: TokenStream, input: TokenStream) -> TokenStream { + api::macro_impl(args, input) +} + #[proc_macro_attribute] pub fn main(args: TokenStream, input: TokenStream) -> TokenStream { entrypoint::macro_impl(args, input)