From b3eb108a13f50e61fb4be654ed815008b554f385 Mon Sep 17 00:00:00 2001 From: asleepyskye Date: Mon, 1 Sep 2025 21:23:24 -0400 Subject: [PATCH 01/10] chore(bot): update wording on error message --- PluralKit.Bot/Services/ErrorMessageService.cs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/PluralKit.Bot/Services/ErrorMessageService.cs b/PluralKit.Bot/Services/ErrorMessageService.cs index 50f47d1c..efab9d02 100644 --- a/PluralKit.Bot/Services/ErrorMessageService.cs +++ b/PluralKit.Bot/Services/ErrorMessageService.cs @@ -116,7 +116,7 @@ public class ErrorMessageService return new EmbedBuilder() .Color(0xE74C3C) .Title("Internal error occurred") - .Description($"For support, please send the error code above as text in {channelInfo} with a description of what you were doing at the time.") + .Description($"**If you need support,** please send/forward the error code above **as text** in {channelInfo} with a description of what you were doing at the time.") .Footer(new Embed.EmbedFooter(errorId)) .Timestamp(now.ToDateTimeOffset().ToString("O")) .Build(); From 2248403140031ecd7fa3ed14e4826763e6558d28 Mon Sep 17 00:00:00 2001 From: asleepyskye Date: Mon, 1 Sep 2025 21:33:16 -0400 Subject: [PATCH 02/10] fix(flake): change systems url to default --- flake.lock | 10 +++++----- flake.nix | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/flake.lock b/flake.lock index 7298bd18..4024bfd9 100644 --- a/flake.lock +++ b/flake.lock @@ -297,16 +297,16 @@ }, "systems": { "locked": { - "lastModified": 1680978846, - "narHash": "sha256-Gtqg8b/v49BFDpDetjclCYXm8mAnTrUzR0JnE2nv5aw=", + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", "owner": "nix-systems", - "repo": "x86_64-linux", - "rev": "2ecfcac5e15790ba6ce360ceccddb15ad16d08a8", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", "type": "github" }, "original": { "owner": "nix-systems", - "repo": "x86_64-linux", + "repo": "default", "type": "github" } }, diff --git a/flake.nix b/flake.nix index 8fd2ed6b..85793415 100644 --- a/flake.nix +++ b/flake.nix @@ -4,7 +4,7 @@ inputs = { nixpkgs.url = "nixpkgs/nixpkgs-unstable"; parts.url = "github:hercules-ci/flake-parts"; - systems.url = "github:nix-systems/x86_64-linux"; + systems.url = "github:nix-systems/default"; # process compose process-compose.url = "github:Platonic-Systems/process-compose-flake"; services.url = "github:juspay/services-flake"; From 2d40a1ee1623a1d0cc9d1173befb8d89c9962f38 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sun, 27 Jul 2025 00:18:47 +0000 Subject: [PATCH 03/10] feat(stats): add metric for basebackup age --- Cargo.lock | 1 + ci/rust-docker-target.sh | 7 ++- crates/scheduled_tasks/Cargo.toml | 1 + crates/scheduled_tasks/src/main.rs | 14 ++++- crates/scheduled_tasks/src/tasks.rs | 91 ++++++++++++++++++++++++++++- 5 files changed, 111 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a0ea3bde..2b168f1b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3311,6 +3311,7 @@ dependencies = [ "chrono", "croner", "fred", + "lazy_static", "libpk", "metrics", "num-format", diff --git a/ci/rust-docker-target.sh b/ci/rust-docker-target.sh index ba6df5e9..d7d5f6fc 100755 --- a/ci/rust-docker-target.sh +++ b/ci/rust-docker-target.sh @@ -42,5 +42,10 @@ build api build dispatch build gateway build avatars "COPY .docker-bin/avatar_cleanup /bin/avatar_cleanup" -build scheduled_tasks +build scheduled_tasks "$(cat < anyhow::Result<()> { update_db_message_meta ); doforever!("* * * * *", "discord stats updater", update_discord_stats); - // on :00 and :30 + // on hh:00 and hh:30 doforever!( "0,30 * * * *", "queue deleted image cleanup job", queue_deleted_image_cleanup ); + // non-standard cron: at hh:mm:00, hh:mm:30 doforever!("0,30 * * * * *", "stats api updater", update_stats_api); + // every hour (could probably even be less frequent, basebackups are taken rarely) + doforever!( + "* * * * *", + "data basebackup info updater", + update_data_basebackup_prometheus + ); + doforever!( + "* * * * *", + "messages basebackup info updater", + update_messages_basebackup_prometheus + ); set.join_next() .await diff --git a/crates/scheduled_tasks/src/tasks.rs b/crates/scheduled_tasks/src/tasks.rs index 64246fc9..84773149 100644 --- a/crates/scheduled_tasks/src/tasks.rs +++ b/crates/scheduled_tasks/src/tasks.rs @@ -1,4 +1,4 @@ -use std::time::Duration; +use std::{collections::HashMap, time::Duration}; use anyhow::anyhow; use fred::prelude::KeysInterface; @@ -10,10 +10,22 @@ use metrics::gauge; use num_format::{Locale, ToFormattedString}; use reqwest::ClientBuilder; use sqlx::Executor; +use tokio::{process::Command, sync::Mutex}; use crate::AppCtx; pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> { + let data_ts = *BASEBACKUP_TS.lock().await.get("data").unwrap_or(&0) as f64; + let messages_ts = *BASEBACKUP_TS.lock().await.get("messages").unwrap_or(&0) as f64; + + let now_ts = chrono::Utc::now().timestamp() as f64; + + gauge!("pluralkit_latest_backup_ts", "repo" => "data").set(data_ts); + gauge!("pluralkit_latest_backup_ts", "repo" => "messages").set(messages_ts); + + gauge!("pluralkit_latest_backup_age", "repo" => "data").set(now_ts - data_ts); + gauge!("pluralkit_latest_backup_age", "repo" => "messages").set(now_ts - messages_ts); + #[derive(sqlx::FromRow)] struct Count { count: i64, @@ -41,6 +53,83 @@ pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> { Ok(()) } +lazy_static::lazy_static! { + static ref BASEBACKUP_TS: Mutex> = Mutex::new(HashMap::new()); +} + +pub async fn update_data_basebackup_prometheus(_: AppCtx) -> anyhow::Result<()> { + update_basebackup_ts("data".to_string()).await +} + +pub async fn update_messages_basebackup_prometheus(_: AppCtx) -> anyhow::Result<()> { + update_basebackup_ts("messages".to_string()).await +} + +async fn update_basebackup_ts(repo: String) -> anyhow::Result<()> { + let mut env = HashMap::new(); + + for (key, value) in std::env::vars() { + if key.starts_with("AWS") { + env.insert(key, value); + } + } + + env.insert( + "WALG_S3_PREFIX".to_string(), + format!("s3://pluralkit-backups/{repo}/"), + ); + + let output = Command::new("wal-g") + .arg("backup-list") + .arg("--json") + .envs(env) + .output() + .await?; + + if !output.status.success() { + // todo: we should return error here + tracing::error!( + status = output.status.code(), + "failed to execute wal-g command" + ); + return Ok(()); + } + + #[derive(serde::Deserialize)] + struct WalgBackupInfo { + backup_name: String, + time: String, + ts_parsed: Option, + } + + let mut info = + serde_json::from_str::>(&String::from_utf8_lossy(&output.stdout))? + .into_iter() + .filter(|v| v.backup_name.contains("base")) + .filter_map(|mut v| { + chrono::DateTime::parse_from_rfc3339(&v.time) + .ok() + .map(|dt| { + v.ts_parsed = Some(dt.with_timezone(&chrono::Utc).timestamp()); + v + }) + }) + .collect::>(); + + info.sort_by(|a, b| b.ts_parsed.cmp(&a.ts_parsed)); + + let Some(info) = info.first() else { + anyhow::bail!("could not find any basebackups in repo {repo}"); + }; + + BASEBACKUP_TS + .lock() + .await + .insert(repo, info.ts_parsed.unwrap()); + + Ok(()) +} + pub async fn update_db_meta(ctx: AppCtx) -> anyhow::Result<()> { ctx.data .execute( From 9c1acd84e1bae6b8bdaa5a0c9b32b4faa28fcba7 Mon Sep 17 00:00:00 2001 From: alyssa Date: Fri, 8 Aug 2025 20:36:51 +0000 Subject: [PATCH 04/10] fix(api): use constant time comparison for tokens --- Cargo.lock | 1 + crates/api/Cargo.toml | 1 + crates/api/src/middleware/auth.rs | 9 ++++++--- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2b168f1b..d4b91f58 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -95,6 +95,7 @@ dependencies = [ "serde_json", "serde_urlencoded", "sqlx", + "subtle", "tokio", "tower 0.4.13", "tower-http", diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index d2f883d7..e1f99425 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -26,3 +26,4 @@ reverse-proxy-service = { version = "0.2.1", features = ["axum"] } serde_urlencoded = "0.7.1" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["catch-panic"] } +subtle = "2.6.1" diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 08981c3a..3d1d813b 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -5,6 +5,8 @@ use axum::{ response::Response, }; +use subtle::ConstantTimeEq; + use tracing::error; use crate::auth::AuthState; @@ -48,9 +50,10 @@ pub async fn auth(State(ctx): State, mut req: Request, next: Next) - .expect("missing api config") .temp_token2 .as_ref() - // this is NOT how you validate tokens - // but this is low abuse risk so we're keeping it for now - && app_auth_header == config_token2 + && app_auth_header + .as_bytes() + .ct_eq(config_token2.as_bytes()) + .into() { authed_app_id = Some(1); } From dd14e7daefb0757c41afc761fd3488ae466caa09 Mon Sep 17 00:00:00 2001 From: alyssa Date: Fri, 8 Aug 2025 20:57:38 +0000 Subject: [PATCH 05/10] feat(api): add internal auth --- crates/api/src/auth.rs | 13 +++++++++++-- crates/api/src/middleware/auth.rs | 17 ++++++++++++++++- crates/libpk/src/_config.rs | 3 +++ 3 files changed, 30 insertions(+), 3 deletions(-) diff --git a/crates/api/src/auth.rs b/crates/api/src/auth.rs index c084eafe..4e12a287 100644 --- a/crates/api/src/auth.rs +++ b/crates/api/src/auth.rs @@ -7,11 +7,16 @@ pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid"; pub struct AuthState { system_id: Option, app_id: Option, + internal: bool, } impl AuthState { - pub fn new(system_id: Option, app_id: Option) -> Self { - Self { system_id, app_id } + pub fn new(system_id: Option, app_id: Option, internal: bool) -> Self { + Self { + system_id, + app_id, + internal, + } } pub fn system_id(&self) -> Option { @@ -22,6 +27,10 @@ impl AuthState { self.app_id } + pub fn internal(&self) -> bool { + self.internal + } + pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel { if self .system_id diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 3d1d813b..0992757f 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -58,8 +58,23 @@ pub async fn auth(State(ctx): State, mut req: Request, next: Next) - authed_app_id = Some(1); } + // todo: fix syntax + let internal = if req.headers().get("x-pluralkit-client-ip").is_none() + && let Some(auth_header) = req + .headers() + .get("x-pluralkit-internalauth") + .map(|h| h.to_str().ok()) + .flatten() + && let Some(real_token) = libpk::config.internal_auth.clone() + && auth_header.as_bytes().ct_eq(real_token.as_bytes()).into() + { + true + } else { + false + }; + req.extensions_mut() - .insert(AuthState::new(authed_system_id, authed_app_id)); + .insert(AuthState::new(authed_system_id, authed_app_id, internal)); next.run(req).await } diff --git a/crates/libpk/src/_config.rs b/crates/libpk/src/_config.rs index 8358440b..7f992d95 100644 --- a/crates/libpk/src/_config.rs +++ b/crates/libpk/src/_config.rs @@ -128,6 +128,9 @@ pub struct PKConfig { #[serde(default)] pub sentry_url: Option, + + #[serde(default)] + pub internal_auth: Option, } impl PKConfig { From a49dbefe83ee2e756308d4e762d04252309b0046 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sat, 9 Aug 2025 14:50:57 +0000 Subject: [PATCH 06/10] fix(api): automatically reload ratelimit script on redis server restart --- crates/api/src/middleware/ratelimit.rs | 43 +++++++++++++++----------- 1 file changed, 25 insertions(+), 18 deletions(-) diff --git a/crates/api/src/middleware/ratelimit.rs b/crates/api/src/middleware/ratelimit.rs index f4a63f7e..1638ecc9 100644 --- a/crates/api/src/middleware/ratelimit.rs +++ b/crates/api/src/middleware/ratelimit.rs @@ -45,21 +45,6 @@ pub fn ratelimiter(f: F) -> FromFnLayer, T> { tokio::spawn(async move { handle }); - let rscript = r.clone(); - tokio::spawn(async move { - if let Ok(()) = rscript.wait_for_connect().await { - match rscript - .script_load::(LUA_SCRIPT.to_string()) - .await - { - Ok(_) => info!("connected to redis for request rate limiting"), - Err(error) => error!(?error, "could not load redis script"), - } - } else { - error!("could not wait for connection to load redis script!"); - } - }); - r }); @@ -152,12 +137,34 @@ pub async fn do_request_ratelimited( let period = 1; // seconds let cost = 1; // todo: update this for group member endpoints + let script_exists: Vec = + match redis.script_exists(vec![LUA_SCRIPT_SHA.to_string()]).await { + Ok(exists) => exists, + Err(error) => { + error!(?error, "failed to check ratelimit script"); + return json_err( + StatusCode::INTERNAL_SERVER_ERROR, + r#"{"message": "500: internal server error", "code": 0}"#.to_string(), + ); + } + }; + + if script_exists[0] != 1 { + match redis + .script_load::(LUA_SCRIPT.to_string()) + .await + { + Ok(_) => info!("successfully loaded ratelimit script to redis"), + Err(error) => { + error!(?error, "could not load redis script") + } + } + } + // local rate_limit_key = KEYS[1] // local rate = ARGV[1] // local period = ARGV[2] // return {remaining, tostring(retry_after), reset_after} - - // todo: check if error is script not found and reload script let resp = redis .evalsha::<(i32, String, u64), String, Vec, Vec>( LUA_SCRIPT_SHA.to_string(), @@ -219,7 +226,7 @@ pub async fn do_request_ratelimited( return response; } Err(error) => { - tracing::error!(?error, "error getting ratelimit info"); + error!(?error, "error getting ratelimit info"); return json_err( StatusCode::INTERNAL_SERVER_ERROR, r#"{"message": "500: internal server error", "code": 0}"#.to_string(), From 214f164fbcee1eb7d26ede0018b2cd5efac6024e Mon Sep 17 00:00:00 2001 From: alyssa Date: Sun, 10 Aug 2025 00:25:29 +0000 Subject: [PATCH 07/10] feat(api): implement PKError in rust-api --- Cargo.lock | 10 +++-- crates/api/Cargo.toml | 1 + crates/api/src/endpoints/private.rs | 18 ++++---- crates/api/src/endpoints/system.rs | 43 ++++++------------ crates/api/src/error.rs | 68 ++++++++++++++++++++++++++--- crates/api/src/main.rs | 21 ++++----- crates/macros/Cargo.toml | 1 + crates/macros/src/api.rs | 52 ++++++++++++++++++++++ crates/macros/src/lib.rs | 6 +++ 9 files changed, 157 insertions(+), 63 deletions(-) create mode 100644 crates/macros/src/api.rs diff --git a/Cargo.lock b/Cargo.lock index d4b91f58..d52d073b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -88,6 +88,7 @@ dependencies = [ "lazy_static", "libpk", "metrics", + "pk_macros", "pluralkit_models", "reqwest 0.12.15", "reverse-proxy-service", @@ -2530,6 +2531,7 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" name = "pk_macros" version = "0.1.0" dependencies = [ + "prettyplease", "proc-macro2", "quote", "syn", @@ -2611,9 +2613,9 @@ dependencies = [ [[package]] name = "prettyplease" -version = "0.2.31" +version = "0.2.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5316f57387668042f561aae71480de936257848f9c43ce528e311d89a07cadeb" +checksum = "ff24dfcda44452b9816fff4cd4227e1bb73ff5a2f1bc1105aa92fb8565ce44d2" dependencies = [ "proc-macro2", "syn", @@ -3965,9 +3967,9 @@ checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" -version = "2.0.100" +version = "2.0.104" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b09a44accad81e1ba1cd74a32461ba89dee89095ba17b32f5d03683b1b1fc2a0" +checksum = "17b6f705963418cdb9927482fa304bc562ece2fdd4f616084c50b7023b435a40" dependencies = [ "proc-macro2", "quote", diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index e1f99425..9413d62c 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -5,6 +5,7 @@ edition = "2021" [dependencies] pluralkit_models = { path = "../models" } +pk_macros = { path = "../macros" } libpk = { path = "../libpk" } anyhow = { workspace = true } diff --git a/crates/api/src/endpoints/private.rs b/crates/api/src/endpoints/private.rs index df67421c..067fef57 100644 --- a/crates/api/src/endpoints/private.rs +++ b/crates/api/src/endpoints/private.rs @@ -2,6 +2,7 @@ use crate::ApiContext; use axum::{extract::State, response::Json}; use fred::interfaces::*; use libpk::state::ShardState; +use pk_macros::api_endpoint; use serde::Deserialize; use serde_json::{json, Value}; use std::collections::HashMap; @@ -13,34 +14,33 @@ struct ClusterStats { pub channel_count: i32, } +#[api_endpoint] pub async fn discord_state(State(ctx): State) -> Json { let mut shard_status = ctx .redis .hgetall::, &str>("pluralkit:shardstatus") - .await - .unwrap() + .await? .values() .map(|v| serde_json::from_str(v).expect("could not deserialize shard")) .collect::>(); shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id)); - Json(json!({ + Ok(Json(json!({ "shards": shard_status, - })) + }))) } +#[api_endpoint] pub async fn meta(State(ctx): State) -> Json { let stats = serde_json::from_str::( ctx.redis .get::("statsapi") - .await - .unwrap() + .await? .as_str(), - ) - .unwrap(); + )?; - Json(stats) + Ok(Json(stats)) } use std::time::Duration; diff --git a/crates/api/src/endpoints/system.rs b/crates/api/src/endpoints/system.rs index e510f7c5..7b919df5 100644 --- a/crates/api/src/endpoints/system.rs +++ b/crates/api/src/endpoints/system.rs @@ -1,22 +1,18 @@ -use axum::{ - extract::State, - http::StatusCode, - response::{IntoResponse, Response}, - Extension, Json, -}; -use serde_json::json; +use axum::{extract::State, response::IntoResponse, Extension, Json}; +use pk_macros::api_endpoint; +use serde_json::{json, Value}; use sqlx::Postgres; -use tracing::error; use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel}; -use crate::{auth::AuthState, util::json_err, ApiContext}; +use crate::{auth::AuthState, error::fail, ApiContext}; +#[api_endpoint] pub async fn get_system_settings( Extension(auth): Extension, Extension(system): Extension, State(ctx): State, -) -> Response { +) -> Json { let access_level = auth.access_level_for(&system); let mut config = match sqlx::query_as::( @@ -27,23 +23,11 @@ pub async fn get_system_settings( .await { Ok(Some(config)) => config, - Ok(None) => { - error!( - system = system.id, - "failed to find system config for existing system" - ); - return json_err( - StatusCode::INTERNAL_SERVER_ERROR, - r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), - ); - } - Err(err) => { - error!(?err, "failed to query system config"); - return json_err( - StatusCode::INTERNAL_SERVER_ERROR, - r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), - ); - } + Ok(None) => fail!( + system = system.id, + "failed to find system config for existing system" + ), + Err(err) => fail!(?err, "failed to query system config"), }; // fix this @@ -51,7 +35,7 @@ pub async fn get_system_settings( config.name_format = Some("{name} {tag}".to_string()); } - Json(&match access_level { + Ok(Json(match access_level { PrivacyLevel::Private => config.to_json(), PrivacyLevel::Public => json!({ "pings_enabled": config.pings_enabled, @@ -64,6 +48,5 @@ pub async fn get_system_settings( "proxy_switch": config.proxy_switch, "name_format": config.name_format, }), - }) - .into_response() + })) } diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index 464534d8..fc481d0c 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -1,13 +1,17 @@ -use axum::http::StatusCode; +use axum::{ + http::StatusCode, + response::{IntoResponse, Response}, +}; use std::fmt; -// todo -#[allow(dead_code)] +// todo: model parse errors #[derive(Debug)] pub struct PKError { pub response_code: StatusCode, pub json_code: i32, pub message: &'static str, + + pub inner: Option, } impl fmt::Display for PKError { @@ -16,17 +20,67 @@ impl fmt::Display for PKError { } } -impl std::error::Error for PKError {} +impl Clone for PKError { + fn clone(&self) -> PKError { + if self.inner.is_some() { + panic!("cannot clone PKError with inner error"); + } + PKError { + response_code: self.response_code, + json_code: self.json_code, + message: self.message, + inner: None, + } + } +} + +impl From for PKError +where + E: std::fmt::Display + Into, +{ + fn from(err: E) -> Self { + let mut res = GENERIC_SERVER_ERROR.clone(); + res.inner = Some(err.into()); + res + } +} + +impl IntoResponse for PKError { + fn into_response(self) -> Response { + if let Some(inner) = self.inner { + tracing::error!(?inner, "error returned from handler"); + } + crate::util::json_err( + self.response_code, + serde_json::to_string(&serde_json::json!({ + "message": self.message, + "code": self.json_code, + })) + .unwrap(), + ) + } +} + +macro_rules! fail { + ($($stuff:tt)+) => {{ + tracing::error!($($stuff)+); + return Err(crate::error::GENERIC_SERVER_ERROR); + }}; +} + +pub(crate) use fail; -#[allow(unused_macros)] macro_rules! define_error { ( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => { - const $name: PKError = PKError { + #[allow(dead_code)] + pub const $name: PKError = PKError { response_code: $response_code, json_code: $json_code, message: $message, + inner: None, }; }; } -// define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } +define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } +define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index e3a201cb..07e47f89 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -4,8 +4,8 @@ use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}; use axum::{ body::Body, extract::{Request as ExtractRequest, State}, - http::{Response, StatusCode, Uri}, - response::IntoResponse, + http::Uri, + response::{IntoResponse, Response}, routing::{delete, get, patch, post}, Extension, Router, }; @@ -13,7 +13,9 @@ use hyper_util::{ client::legacy::{connect::HttpConnector, Client}, rt::TokioExecutor, }; -use tracing::{error, info}; +use tracing::info; + +use pk_macros::api_endpoint; mod auth; mod endpoints; @@ -30,11 +32,12 @@ pub struct ApiContext { rproxy_client: Client, } +#[api_endpoint] async fn rproxy( Extension(auth): Extension, State(ctx): State, mut req: ExtractRequest, -) -> Result, StatusCode> { +) -> Response { let path = req.uri().path(); let path_query = req .uri() @@ -59,15 +62,7 @@ async fn rproxy( headers.append(INTERNAL_APPID_HEADER, aid.into()); } - Ok(ctx - .rproxy_client - .request(req) - .await - .map_err(|error| { - error!(?error, "failed to serve reverse proxy to dotnet-api"); - StatusCode::BAD_GATEWAY - })? - .into_response()) + Ok(ctx.rproxy_client.request(req).await?.into_response()) } // this function is manually formatted for easier legibility of route_services diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index 8090798f..10feaf88 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -10,4 +10,5 @@ proc-macro = true quote = "1.0" proc-macro2 = "1.0" syn = "2.0" +prettyplease = "0.2.36" diff --git a/crates/macros/src/api.rs b/crates/macros/src/api.rs new file mode 100644 index 00000000..7f797b8c --- /dev/null +++ b/crates/macros/src/api.rs @@ -0,0 +1,52 @@ +use quote::quote; +use syn::{parse_macro_input, FnArg, ItemFn, Pat}; + +fn pretty_print(ts: &proc_macro2::TokenStream) -> String { + let file = syn::parse_file(&ts.to_string()).unwrap(); + prettyplease::unparse(&file) +} + +pub fn macro_impl( + _args: proc_macro::TokenStream, + input: proc_macro::TokenStream, +) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as ItemFn); + + let fn_name = &input.sig.ident; + let fn_params = &input.sig.inputs; + let fn_body = &input.block; + let syn::ReturnType::Type(_, fn_return_type) = &input.sig.output else { + panic!("handler return type must not be nothing"); + }; + let pms: Vec = fn_params + .iter() + .map(|v| { + let FnArg::Typed(pat) = v else { + panic!("must not have self param in handler"); + }; + let mut pat = pat.pat.clone(); + if let Pat::Ident(ident) = *pat { + let mut ident = ident.clone(); + ident.mutability = None; + pat = Box::new(Pat::Ident(ident)); + } + quote! { #pat } + }) + .collect(); + + let res = quote! { + #[allow(unused_mut)] + pub async fn #fn_name(#fn_params) -> axum::response::Response { + async fn inner(#fn_params) -> Result<#fn_return_type, crate::error::PKError> { + #fn_body + } + + match inner(#(#pms),*).await { + Ok(res) => res.into_response(), + Err(err) => err.into_response(), + } + } + }; + + res.into() +} diff --git a/crates/macros/src/lib.rs b/crates/macros/src/lib.rs index db5a55b7..ad3c1064 100644 --- a/crates/macros/src/lib.rs +++ b/crates/macros/src/lib.rs @@ -1,8 +1,14 @@ use proc_macro::TokenStream; +mod api; mod entrypoint; mod model; +#[proc_macro_attribute] +pub fn api_endpoint(args: TokenStream, input: TokenStream) -> TokenStream { + api::macro_impl(args, input) +} + #[proc_macro_attribute] pub fn main(args: TokenStream, input: TokenStream) -> TokenStream { entrypoint::macro_impl(args, input) From 4c940fa925a7df52c80e4c7c31a19db45c9bf1f3 Mon Sep 17 00:00:00 2001 From: alyssa Date: Mon, 1 Sep 2025 03:36:13 +0000 Subject: [PATCH 08/10] chore: bump rust edition to 2024 --- ci/Dockerfile.rust | 2 +- crates/api/Cargo.toml | 2 +- crates/api/src/endpoints/private.rs | 3 ++- crates/api/src/endpoints/system.rs | 6 +++--- crates/api/src/main.rs | 6 ++---- crates/api/src/middleware/auth.rs | 2 +- crates/api/src/middleware/logger.rs | 2 +- crates/api/src/middleware/params.rs | 6 +++--- crates/api/src/util.rs | 2 +- crates/avatars/Cargo.toml | 2 +- crates/avatars/src/main.rs | 4 ++-- crates/avatars/src/process.rs | 2 +- crates/avatars/src/pull.rs | 4 ++-- crates/dispatch/Cargo.toml | 2 +- crates/dispatch/src/logger.rs | 2 +- crates/dispatch/src/main.rs | 9 ++++----- crates/gateway/Cargo.toml | 2 +- crates/gateway/src/api.rs | 6 +++--- crates/gateway/src/discord/cache.rs | 4 ++-- crates/gateway/src/discord/gateway.rs | 6 +++--- crates/gateway/src/event_awaiter.rs | 12 +++++++++--- crates/gateway/src/logger.rs | 2 +- crates/gateway/src/main.rs | 3 +-- crates/gdpr_worker/Cargo.toml | 2 +- crates/gdpr_worker/src/main.rs | 4 +--- crates/libpk/Cargo.toml | 2 +- crates/libpk/src/_config.rs | 6 +++--- crates/libpk/src/db/repository/avatars.rs | 2 +- crates/libpk/src/db/types/avatars.rs | 2 +- crates/libpk/src/lib.rs | 3 +-- crates/macros/Cargo.toml | 2 +- crates/macros/src/api.rs | 4 ++-- crates/macros/src/model.rs | 2 +- crates/migrate/Cargo.toml | 2 +- crates/migrate/src/main.rs | 2 -- crates/models/Cargo.toml | 2 +- crates/models/src/lib.rs | 2 +- crates/scheduled_tasks/Cargo.toml | 2 +- 38 files changed, 64 insertions(+), 66 deletions(-) diff --git a/ci/Dockerfile.rust b/ci/Dockerfile.rust index e320fb00..063cb6ea 100644 --- a/ci/Dockerfile.rust +++ b/ci/Dockerfile.rust @@ -4,7 +4,7 @@ WORKDIR /build RUN apk add rustup build-base # todo: arm64 target -RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain nightly-2024-08-20 --profile default -y +RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain nightly-2025-08-22 --profile default -y ENV PATH=/root/.cargo/bin:$PATH ENV RUSTFLAGS='-C link-arg=-s' diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 9413d62c..b19bfe74 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "api" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] pluralkit_models = { path = "../models" } diff --git a/crates/api/src/endpoints/private.rs b/crates/api/src/endpoints/private.rs index 067fef57..2116e3c5 100644 --- a/crates/api/src/endpoints/private.rs +++ b/crates/api/src/endpoints/private.rs @@ -4,9 +4,10 @@ use fred::interfaces::*; use libpk::state::ShardState; use pk_macros::api_endpoint; use serde::Deserialize; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use std::collections::HashMap; +#[allow(dead_code)] #[derive(Deserialize)] #[serde(rename_all = "PascalCase")] struct ClusterStats { diff --git a/crates/api/src/endpoints/system.rs b/crates/api/src/endpoints/system.rs index 7b919df5..58c9a154 100644 --- a/crates/api/src/endpoints/system.rs +++ b/crates/api/src/endpoints/system.rs @@ -1,11 +1,11 @@ -use axum::{extract::State, response::IntoResponse, Extension, Json}; +use axum::{Extension, Json, extract::State, response::IntoResponse}; use pk_macros::api_endpoint; -use serde_json::{json, Value}; +use serde_json::{Value, json}; use sqlx::Postgres; use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel}; -use crate::{auth::AuthState, error::fail, ApiContext}; +use crate::{ApiContext, auth::AuthState, error::fail}; #[api_endpoint] pub async fn get_system_settings( diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index 07e47f89..f22450ce 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -1,16 +1,14 @@ -#![feature(let_chains)] - use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}; use axum::{ + Extension, Router, body::Body, extract::{Request as ExtractRequest, State}, http::Uri, response::{IntoResponse, Response}, routing::{delete, get, patch, post}, - Extension, Router, }; use hyper_util::{ - client::legacy::{connect::HttpConnector, Client}, + client::legacy::{Client, connect::HttpConnector}, rt::TokioExecutor, }; use tracing::info; diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 0992757f..1d536e97 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -10,7 +10,7 @@ use subtle::ConstantTimeEq; use tracing::error; use crate::auth::AuthState; -use crate::{util::json_err, ApiContext}; +use crate::{ApiContext, util::json_err}; pub async fn auth(State(ctx): State, mut req: Request, next: Next) -> Response { let mut authed_system_id: Option = None; diff --git a/crates/api/src/middleware/logger.rs b/crates/api/src/middleware/logger.rs index 38e45e2c..512234bb 100644 --- a/crates/api/src/middleware/logger.rs +++ b/crates/api/src/middleware/logger.rs @@ -2,7 +2,7 @@ use std::time::Instant; use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response}; use metrics::{counter, histogram}; -use tracing::{info, span, warn, Instrument, Level}; +use tracing::{Instrument, Level, info, span, warn}; use crate::{auth::AuthState, util::header_or_unknown}; diff --git a/crates/api/src/middleware/params.rs b/crates/api/src/middleware/params.rs index 06a76f64..1b52bfbf 100644 --- a/crates/api/src/middleware/params.rs +++ b/crates/api/src/middleware/params.rs @@ -6,11 +6,11 @@ use axum::{ routing::url_params::UrlParams, }; -use sqlx::{types::Uuid, Postgres}; +use sqlx::{Postgres, types::Uuid}; use tracing::error; use crate::auth::AuthState; -use crate::{util::json_err, ApiContext}; +use crate::{ApiContext, util::json_err}; use pluralkit_models::PKSystem; // move this somewhere else @@ -31,7 +31,7 @@ pub async fn params(State(ctx): State, mut req: Request, next: Next) StatusCode::BAD_REQUEST, r#"{"message":"400: Bad Request","code": 0}"#.to_string(), ) - .into() + .into(); } }; diff --git a/crates/api/src/util.rs b/crates/api/src/util.rs index 35a5bf0d..e9723976 100644 --- a/crates/api/src/util.rs +++ b/crates/api/src/util.rs @@ -3,7 +3,7 @@ use axum::{ http::{HeaderValue, StatusCode}, response::IntoResponse, }; -use serde_json::{json, to_string, Value}; +use serde_json::{Value, json, to_string}; use tracing::error; pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str { diff --git a/crates/avatars/Cargo.toml b/crates/avatars/Cargo.toml index 725e5396..ee1aa91e 100644 --- a/crates/avatars/Cargo.toml +++ b/crates/avatars/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "avatars" version = "0.1.0" -edition = "2021" +edition = "2024" [[bin]] name = "avatar_cleanup" diff --git a/crates/avatars/src/main.rs b/crates/avatars/src/main.rs index c8399086..df80ac82 100644 --- a/crates/avatars/src/main.rs +++ b/crates/avatars/src/main.rs @@ -8,10 +8,10 @@ use anyhow::Context; use axum::extract::State; use axum::routing::get; use axum::{ + Json, Router, http::StatusCode, response::{IntoResponse, Response}, routing::post, - Json, Router, }; use libpk::_config::AvatarsConfig; use libpk::db::repository::avatars as db; @@ -153,7 +153,7 @@ async fn verify( ) .await?; - let encoded = process::process_async(result.data, req.kind).await?; + process::process_async(result.data, req.kind).await?; Ok(()) } diff --git a/crates/avatars/src/process.rs b/crates/avatars/src/process.rs index 024f40de..0c9ba8c1 100644 --- a/crates/avatars/src/process.rs +++ b/crates/avatars/src/process.rs @@ -4,7 +4,7 @@ use std::io::Cursor; use std::time::Instant; use tracing::{debug, error, info, instrument}; -use crate::{hash::Hash, ImageKind, PKAvatarError}; +use crate::{ImageKind, PKAvatarError, hash::Hash}; const MAX_DIMENSION: u32 = 4000; diff --git a/crates/avatars/src/pull.rs b/crates/avatars/src/pull.rs index fdf5f073..44c4a952 100644 --- a/crates/avatars/src/pull.rs +++ b/crates/avatars/src/pull.rs @@ -62,7 +62,7 @@ pub async fn pull( let size = match response.content_length() { None => return Err(PKAvatarError::MissingHeader("Content-Length")), Some(size) if size > MAX_SIZE => { - return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE)) + return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE)); } Some(size) => size, }; @@ -162,7 +162,7 @@ pub fn parse_url(url: &str) -> anyhow::Result { attachment_id: 0, filename: "".to_string(), full_url: url.to_string(), - }) + }); } _ => anyhow::bail!("not a discord cdn url"), } diff --git a/crates/dispatch/Cargo.toml b/crates/dispatch/Cargo.toml index c76856d1..f48acf8c 100644 --- a/crates/dispatch/Cargo.toml +++ b/crates/dispatch/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "dispatch" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] anyhow = { workspace = true } diff --git a/crates/dispatch/src/logger.rs b/crates/dispatch/src/logger.rs index aa65bc67..ec2576b6 100644 --- a/crates/dispatch/src/logger.rs +++ b/crates/dispatch/src/logger.rs @@ -1,7 +1,7 @@ use std::time::Instant; use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response}; -use tracing::{info, span, warn, Instrument, Level}; +use tracing::{Instrument, Level, info, span, warn}; // log any requests that take longer than 2 seconds // todo: change as necessary diff --git a/crates/dispatch/src/main.rs b/crates/dispatch/src/main.rs index 6570cf19..3a3403bd 100644 --- a/crates/dispatch/src/main.rs +++ b/crates/dispatch/src/main.rs @@ -5,17 +5,16 @@ use hickory_client::{ rr::{DNSClass, Name, RData, RecordType}, udp::UdpClientStream, }; -use reqwest::{redirect::Policy, StatusCode}; +use reqwest::{StatusCode, redirect::Policy}; use std::{ net::{Ipv4Addr, SocketAddr, SocketAddrV4}, sync::Arc, time::Duration, }; use tokio::{net::UdpSocket, sync::RwLock}; -use tracing::{debug, error, info}; -use tracing_subscriber::EnvFilter; +use tracing::{debug, error}; -use axum::{extract::State, http::Uri, routing::post, Json, Router}; +use axum::{Json, Router, extract::State, http::Uri, routing::post}; mod logger; @@ -128,7 +127,7 @@ async fn dispatch( match res { Ok(res) if res.status() != 200 => { - return DispatchResponse::InvalidResponseCode(res.status()).to_string() + return DispatchResponse::InvalidResponseCode(res.status()).to_string(); } Err(error) => { error!(?error, url = req.url.clone(), "failed to fetch"); diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index c707b29b..0222ab18 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "gateway" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] anyhow = { workspace = true } diff --git a/crates/gateway/src/api.rs b/crates/gateway/src/api.rs index f8c3f556..aa2d069e 100644 --- a/crates/gateway/src/api.rs +++ b/crates/gateway/src/api.rs @@ -1,18 +1,18 @@ use axum::{ + Router, extract::{ConnectInfo, Path, State}, http::StatusCode, response::{IntoResponse, Response}, routing::{delete, get, post}, - Router, }; use libpk::runtime_config::RuntimeConfig; use serde_json::{json, to_string}; use tracing::{error, info}; -use twilight_model::id::{marker::ChannelMarker, Id}; +use twilight_model::id::{Id, marker::ChannelMarker}; use crate::{ discord::{ - cache::{dm_channel, DiscordCache, DM_PERMISSIONS}, + cache::{DM_PERMISSIONS, DiscordCache, dm_channel}, gateway::cluster_config, shard_state::ShardStateManager, }, diff --git a/crates/gateway/src/discord/cache.rs b/crates/gateway/src/discord/cache.rs index e0a4aacf..cc538d08 100644 --- a/crates/gateway/src/discord/cache.rs +++ b/crates/gateway/src/discord/cache.rs @@ -4,18 +4,18 @@ use serde::Serialize; use std::{collections::HashMap, sync::Arc}; use tokio::sync::RwLock; use twilight_cache_inmemory::{ + InMemoryCache, ResourceType, model::CachedMember, permission::{MemberRoles, RootError}, traits::CacheableChannel, - InMemoryCache, ResourceType, }; use twilight_gateway::Event; use twilight_model::{ channel::{Channel, ChannelType}, guild::{Guild, Member, Permissions}, id::{ - marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker}, Id, + marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker}, }, }; use twilight_util::permission_calculator::PermissionCalculator; diff --git a/crates/gateway/src/discord/gateway.rs b/crates/gateway/src/discord/gateway.rs index 8210e06e..215fb4cf 100644 --- a/crates/gateway/src/discord/gateway.rs +++ b/crates/gateway/src/discord/gateway.rs @@ -6,17 +6,17 @@ use std::sync::Arc; use tokio::sync::mpsc::Sender; use tracing::{error, info, warn}; use twilight_gateway::{ - create_iterator, CloseFrame, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId, + CloseFrame, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId, create_iterator, }; use twilight_model::gateway::{ + Intents, payload::outgoing::update_presence::UpdatePresencePayload, presence::{Activity, ActivityType, Status}, - Intents, }; use crate::{ - discord::identify_queue::{self, RedisQueue}, RUNTIME_CONFIG_KEY_EVENT_TARGET, + discord::identify_queue::{self, RedisQueue}, }; use super::cache::DiscordCache; diff --git a/crates/gateway/src/event_awaiter.rs b/crates/gateway/src/event_awaiter.rs index 765ad8e5..97a6955e 100644 --- a/crates/gateway/src/event_awaiter.rs +++ b/crates/gateway/src/event_awaiter.rs @@ -3,7 +3,7 @@ // - interaction: (custom_id where not_includes "help-menu") use std::{ - collections::{hash_map::Entry, HashMap}, + collections::{HashMap, hash_map::Entry}, net::{IpAddr, SocketAddr}, time::Duration, }; @@ -15,8 +15,8 @@ use twilight_gateway::Event; use twilight_model::{ application::interaction::InteractionData, id::{ - marker::{ChannelMarker, MessageMarker, UserMarker}, Id, + marker::{ChannelMarker, MessageMarker, UserMarker}, }, }; @@ -103,7 +103,13 @@ impl EventAwaiter { } } } - info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2); + info!( + "ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", + Instant::now().duration_since(now).as_micros(), + counts.0, + counts.1, + counts.2 + ); } } diff --git a/crates/gateway/src/logger.rs b/crates/gateway/src/logger.rs index 459aef31..0a081432 100644 --- a/crates/gateway/src/logger.rs +++ b/crates/gateway/src/logger.rs @@ -4,7 +4,7 @@ use axum::{ extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response, }; use metrics::{counter, histogram}; -use tracing::{info, span, warn, Instrument, Level}; +use tracing::{Instrument, Level, info, span, warn}; // log any requests that take longer than 2 seconds // todo: change as necessary diff --git a/crates/gateway/src/main.rs b/crates/gateway/src/main.rs index 12db76b5..3ac7be21 100644 --- a/crates/gateway/src/main.rs +++ b/crates/gateway/src/main.rs @@ -1,4 +1,3 @@ -#![feature(let_chains)] #![feature(if_let_guard)] #![feature(duration_constructors)] @@ -10,7 +9,7 @@ use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent}; use reqwest::{ClientBuilder, StatusCode}; use std::{sync::Arc, time::Duration, vec::Vec}; use tokio::{ - signal::unix::{signal, SignalKind}, + signal::unix::{SignalKind, signal}, sync::mpsc::channel, task::JoinSet, }; diff --git a/crates/gdpr_worker/Cargo.toml b/crates/gdpr_worker/Cargo.toml index a30751f9..b57ccddf 100644 --- a/crates/gdpr_worker/Cargo.toml +++ b/crates/gdpr_worker/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "gdpr_worker" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] libpk = { path = "../libpk" } diff --git a/crates/gdpr_worker/src/main.rs b/crates/gdpr_worker/src/main.rs index b40557c0..bcedbedd 100644 --- a/crates/gdpr_worker/src/main.rs +++ b/crates/gdpr_worker/src/main.rs @@ -1,12 +1,10 @@ -#![feature(let_chains)] - use sqlx::prelude::FromRow; use std::{sync::Arc, time::Duration}; use tracing::{error, info, warn}; use twilight_http::api_error::{ApiError, GeneralApiError}; use twilight_model::id::{ - marker::{ChannelMarker, MessageMarker}, Id, + marker::{ChannelMarker, MessageMarker}, }; // create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null); diff --git a/crates/libpk/Cargo.toml b/crates/libpk/Cargo.toml index 30d77ae0..1f0c3c42 100644 --- a/crates/libpk/Cargo.toml +++ b/crates/libpk/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "libpk" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] anyhow = { workspace = true } diff --git a/crates/libpk/src/_config.rs b/crates/libpk/src/_config.rs index 7f992d95..f21d9adf 100644 --- a/crates/libpk/src/_config.rs +++ b/crates/libpk/src/_config.rs @@ -3,7 +3,7 @@ use lazy_static::lazy_static; use serde::Deserialize; use std::sync::Arc; -use twilight_model::id::{marker::UserMarker, Id}; +use twilight_model::id::{Id, marker::UserMarker}; #[derive(Clone, Deserialize, Debug)] pub struct ClusterSettings { @@ -151,11 +151,11 @@ lazy_static! { // hacks if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX") && std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() { - std::env::set_var("pluralkit__discord__cluster__node_id", var); + unsafe { std::env::set_var("pluralkit__discord__cluster__node_id", var); } } if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX") && std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() { - std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap()); + unsafe { std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap()); } } Arc::new(Config::builder() diff --git a/crates/libpk/src/db/repository/avatars.rs b/crates/libpk/src/db/repository/avatars.rs index 1ff10cc7..9a667c11 100644 --- a/crates/libpk/src/db/repository/avatars.rs +++ b/crates/libpk/src/db/repository/avatars.rs @@ -52,7 +52,7 @@ pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow: pub async fn pop_queue( pool: &PgPool, -) -> anyhow::Result, ImageQueueEntry)>> { +) -> anyhow::Result, ImageQueueEntry)>> { let mut tx = pool.begin().await?; let res: Option = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *") .fetch_optional(&mut *tx).await?; diff --git a/crates/libpk/src/db/types/avatars.rs b/crates/libpk/src/db/types/avatars.rs index aea6aafd..0b07fbb2 100644 --- a/crates/libpk/src/db/types/avatars.rs +++ b/crates/libpk/src/db/types/avatars.rs @@ -1,7 +1,7 @@ use serde::{Deserialize, Serialize}; use sqlx::{ - types::chrono::{DateTime, Utc}, FromRow, + types::chrono::{DateTime, Utc}, }; use uuid::Uuid; diff --git a/crates/libpk/src/lib.rs b/crates/libpk/src/lib.rs index 55031bf3..137eb94d 100644 --- a/crates/libpk/src/lib.rs +++ b/crates/libpk/src/lib.rs @@ -1,9 +1,8 @@ -#![feature(let_chains)] use std::net::SocketAddr; use metrics_exporter_prometheus::PrometheusBuilder; use sentry::IntoDsn; -use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter}; +use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt}; use sentry_tracing::event_from_event; diff --git a/crates/macros/Cargo.toml b/crates/macros/Cargo.toml index 10feaf88..7f320881 100644 --- a/crates/macros/Cargo.toml +++ b/crates/macros/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "pk_macros" version = "0.1.0" -edition = "2021" +edition = "2024" [lib] proc-macro = true diff --git a/crates/macros/src/api.rs b/crates/macros/src/api.rs index 7f797b8c..8094b9ea 100644 --- a/crates/macros/src/api.rs +++ b/crates/macros/src/api.rs @@ -1,7 +1,7 @@ use quote::quote; -use syn::{parse_macro_input, FnArg, ItemFn, Pat}; +use syn::{FnArg, ItemFn, Pat, parse_macro_input}; -fn pretty_print(ts: &proc_macro2::TokenStream) -> String { +fn _pretty_print(ts: &proc_macro2::TokenStream) -> String { let file = syn::parse_file(&ts.to_string()).unwrap(); prettyplease::unparse(&file) } diff --git a/crates/macros/src/model.rs b/crates/macros/src/model.rs index 924b5bcd..e37d0dde 100644 --- a/crates/macros/src/model.rs +++ b/crates/macros/src/model.rs @@ -1,6 +1,6 @@ use proc_macro2::{Span, TokenStream}; use quote::quote; -use syn::{parse_macro_input, DeriveInput, Expr, Ident, Meta, Type}; +use syn::{DeriveInput, Expr, Ident, Meta, Type, parse_macro_input}; #[derive(Clone, Debug)] enum ElemPatchability { diff --git a/crates/migrate/Cargo.toml b/crates/migrate/Cargo.toml index cf4eff2d..0843cb3f 100644 --- a/crates/migrate/Cargo.toml +++ b/crates/migrate/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "migrate" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] libpk = { path = "../libpk" } diff --git a/crates/migrate/src/main.rs b/crates/migrate/src/main.rs index 85b15e33..0ee621e2 100644 --- a/crates/migrate/src/main.rs +++ b/crates/migrate/src/main.rs @@ -1,5 +1,3 @@ -#![feature(let_chains)] - use tracing::info; include!(concat!(env!("OUT_DIR"), "/data.rs")); diff --git a/crates/models/Cargo.toml b/crates/models/Cargo.toml index 0fbc358c..752fbaa5 100644 --- a/crates/models/Cargo.toml +++ b/crates/models/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "pluralkit_models" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] chrono = { workspace = true, features = ["serde"] } diff --git a/crates/models/src/lib.rs b/crates/models/src/lib.rs index 0bd1f92b..08350488 100644 --- a/crates/models/src/lib.rs +++ b/crates/models/src/lib.rs @@ -18,7 +18,7 @@ pub enum PrivacyLevel { } // this sucks, put it somewhere else -use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type}; +use sqlx::{Database, Decode, Postgres, Type, postgres::PgTypeInfo}; use std::error::Error; _util::fake_enum_impls!(PrivacyLevel); diff --git a/crates/scheduled_tasks/Cargo.toml b/crates/scheduled_tasks/Cargo.toml index 624db0e8..1e86f2c3 100644 --- a/crates/scheduled_tasks/Cargo.toml +++ b/crates/scheduled_tasks/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "scheduled_tasks" version = "0.1.0" -edition = "2021" +edition = "2024" [dependencies] libpk = { path = "../libpk" } From 4ba5b785e5eebcfb77c9b4cdf6a2e94448f7b592 Mon Sep 17 00:00:00 2001 From: alyssa Date: Wed, 3 Sep 2025 00:35:25 +0000 Subject: [PATCH 09/10] add /api/v2/bulk endpoint also, initial support for patch models in rust! --- Cargo.lock | 22 +++- Cargo.toml | 1 + crates/api/Cargo.toml | 2 + crates/api/src/endpoints/bulk.rs | 211 +++++++++++++++++++++++++++++++ crates/api/src/endpoints/mod.rs | 1 + crates/api/src/error.rs | 43 ++++++- crates/api/src/main.rs | 2 + crates/macros/src/model.rs | 94 +++++++++++--- crates/models/Cargo.toml | 2 +- crates/models/src/group.rs | 132 +++++++++++++++++++ crates/models/src/lib.rs | 29 +++++ crates/models/src/member.rs | 208 ++++++++++++++++++++++++++++++ flake.nix | 1 + 13 files changed, 716 insertions(+), 32 deletions(-) create mode 100644 crates/api/src/endpoints/bulk.rs create mode 100644 crates/models/src/group.rs create mode 100644 crates/models/src/member.rs diff --git a/Cargo.lock b/Cargo.lock index d52d073b..7dafae51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,8 @@ dependencies = [ "pluralkit_models", "reqwest 0.12.15", "reverse-proxy-service", + "sea-query", + "sea-query-sqlx", "serde", "serde_json", "serde_urlencoded", @@ -3345,19 +3347,20 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.32.3" +version = "1.0.0-rc.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5a24d8b9fcd2674a6c878a3d871f4f1380c6c43cc3718728ac96864d888458e" +checksum = "ab621a8d8b03a3e513ea075f71aa26830a55c977d7b40f09e825bb91910db823" dependencies = [ + "chrono", "inherent", "sea-query-derive", ] [[package]] name = "sea-query-derive" -version = "0.4.3" +version = "1.0.0-rc.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae0cbad6ab996955664982739354128c58d16e126114fe88c2a493642502aab" +checksum = "217e9422de35f26c16c5f671fce3c075a65e10322068dbc66078428634af6195" dependencies = [ "darling", "heck 0.4.1", @@ -3367,6 +3370,17 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "sea-query-sqlx" +version = "0.8.0-rc.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5eb19495858d8ae3663387a4f5298516c6f0171a7ca5681055450f190236b8" +dependencies = [ + "chrono", + "sea-query", + "sqlx", +] + [[package]] name = "security-framework" version = "3.2.0" diff --git a/Cargo.toml b/Cargo.toml index 270d00a6..63404b8e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ futures = "0.3.30" lazy_static = "1.4.0" metrics = "0.23.0" reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]} +sea-query = { version = "1.0.0-rc.10", features = ["with-chrono"] } sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.117" diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index b19bfe74..16c54a7f 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -14,6 +14,7 @@ fred = { workspace = true } lazy_static = { workspace = true } metrics = { workspace = true } reqwest = { workspace = true } +sea-query = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlx = { workspace = true } @@ -28,3 +29,4 @@ serde_urlencoded = "0.7.1" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["catch-panic"] } subtle = "2.6.1" +sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] } diff --git a/crates/api/src/endpoints/bulk.rs b/crates/api/src/endpoints/bulk.rs new file mode 100644 index 00000000..d859da88 --- /dev/null +++ b/crates/api/src/endpoints/bulk.rs @@ -0,0 +1,211 @@ +use axum::{ + Extension, Json, + extract::{Json as ExtractJson, State}, + response::IntoResponse, +}; +use pk_macros::api_endpoint; +use sea_query::{Expr, ExprTrait, PostgresQueryBuilder}; +use sea_query_sqlx::SqlxBinder; +use serde_json::{Value, json}; + +use pluralkit_models::{PKGroup, PKGroupPatch, PKMember, PKMemberPatch, PKSystem}; + +use crate::{ + ApiContext, + auth::AuthState, + error::{ + GENERIC_AUTH_ERROR, NOT_OWN_GROUP, NOT_OWN_MEMBER, PKError, TARGET_GROUP_NOT_FOUND, + TARGET_MEMBER_NOT_FOUND, + }, +}; + +#[derive(serde::Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BulkActionRequestFilter { + All, + Ids { ids: Vec }, + Connection { id: String }, +} + +#[derive(serde::Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BulkActionRequest { + Member { + filter: BulkActionRequestFilter, + patch: PKMemberPatch, + }, + Group { + filter: BulkActionRequestFilter, + patch: PKGroupPatch, + }, +} + +#[api_endpoint] +pub async fn bulk( + Extension(auth): Extension, + State(ctx): State, + ExtractJson(req): ExtractJson, +) -> Json { + let Some(system_id) = auth.system_id() else { + return Err(GENERIC_AUTH_ERROR); + }; + + #[derive(sqlx::FromRow)] + struct Ider { + id: i32, + hid: String, + uuid: String, + } + + #[derive(sqlx::FromRow)] + struct GroupMemberEntry { + member_id: i32, + group_id: i32, + } + + #[allow(dead_code)] + #[derive(sqlx::FromRow)] + struct OnlyIder { + id: i32, + } + + println!("BulkActionRequest::{req:#?}"); + match req { + BulkActionRequest::Member { filter, mut patch } => { + patch.validate_bulk(); + if patch.errors().len() > 0 { + return Err(PKError::from_validation_errors(patch.errors())); + } + + let ids: Vec = match filter { + BulkActionRequestFilter::All => { + let ids: Vec = sqlx::query_as("select id from members where system = $1") + .bind(system_id as i64) + .fetch_all(&ctx.db) + .await?; + ids.iter().map(|v| v.id).collect() + } + BulkActionRequestFilter::Ids { ids } => { + let members: Vec = sqlx::query_as( + "select * from members where hid = any($1::array) or uuid::text = any($1::array)", + ) + .bind(&ids) + .fetch_all(&ctx.db) + .await?; + + // todo: better errors + if members.len() != ids.len() { + return Err(TARGET_MEMBER_NOT_FOUND); + } + + if members.iter().any(|m| m.system != system_id) { + return Err(NOT_OWN_MEMBER); + } + + members.iter().map(|m| m.id).collect() + } + BulkActionRequestFilter::Connection { id } => { + let Some(group): Option = + sqlx::query_as("select * from groups where hid = $1 or uuid::text = $1") + .bind(id) + .fetch_optional(&ctx.db) + .await? + else { + return Err(TARGET_GROUP_NOT_FOUND); + }; + + if group.system != system_id { + return Err(NOT_OWN_GROUP); + } + + let entries: Vec = + sqlx::query_as("select * from group_members where group_id = $1") + .bind(group.id) + .fetch_all(&ctx.db) + .await?; + + entries.iter().map(|v| v.member_id).collect() + } + }; + + let (q, pms) = patch + .to_sql() + .table("members") // todo: this should be in the model definition + .and_where(Expr::col("id").is_in(ids)) + .returning_col("id") + .build_sqlx(PostgresQueryBuilder); + + let res: Vec = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?; + Ok(Json(json! {{ "updated": res.len() }})) + } + BulkActionRequest::Group { filter, mut patch } => { + patch.validate_bulk(); + if patch.errors().len() > 0 { + return Err(PKError::from_validation_errors(patch.errors())); + } + + let ids: Vec = match filter { + BulkActionRequestFilter::All => { + let ids: Vec = sqlx::query_as("select id from groups where system = $1") + .bind(system_id as i64) + .fetch_all(&ctx.db) + .await?; + ids.iter().map(|v| v.id).collect() + } + BulkActionRequestFilter::Ids { ids } => { + let groups: Vec = sqlx::query_as( + "select * from groups where hid = any($1) or uuid::text = any($1)", + ) + .bind(&ids) + .fetch_all(&ctx.db) + .await?; + + // todo: better errors + if groups.len() != ids.len() { + return Err(TARGET_GROUP_NOT_FOUND); + } + + if groups.iter().any(|m| m.system != system_id) { + return Err(NOT_OWN_GROUP); + } + + groups.iter().map(|m| m.id).collect() + } + BulkActionRequestFilter::Connection { id } => { + let Some(member): Option = + sqlx::query_as("select * from members where hid = $1 or uuid::text = $1") + .bind(id) + .fetch_optional(&ctx.db) + .await? + else { + return Err(TARGET_MEMBER_NOT_FOUND); + }; + + if member.system != system_id { + return Err(NOT_OWN_MEMBER); + } + + let entries: Vec = + sqlx::query_as("select * from group_members where member_id = $1") + .bind(member.id) + .fetch_all(&ctx.db) + .await?; + + entries.iter().map(|v| v.group_id).collect() + } + }; + + let (q, pms) = patch + .to_sql() + .table("groups") // todo: this should be in the model definition + .and_where(Expr::col("id").is_in(ids)) + .returning_col("id") + .build_sqlx(PostgresQueryBuilder); + + println!("{q:#?} {pms:#?}"); + + let res: Vec = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?; + Ok(Json(json! {{ "updated": res.len() }})) + } + } +} diff --git a/crates/api/src/endpoints/mod.rs b/crates/api/src/endpoints/mod.rs index c311367c..167acee8 100644 --- a/crates/api/src/endpoints/mod.rs +++ b/crates/api/src/endpoints/mod.rs @@ -1,2 +1,3 @@ +pub mod bulk; pub mod private; pub mod system; diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index fc481d0c..ae7e5a99 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -2,6 +2,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; +use pluralkit_models::ValidationError; use std::fmt; // todo: model parse errors @@ -11,6 +12,8 @@ pub struct PKError { pub json_code: i32, pub message: &'static str, + pub errors: Vec, + pub inner: Option, } @@ -30,6 +33,21 @@ impl Clone for PKError { json_code: self.json_code, message: self.message, inner: None, + errors: self.errors.clone(), + } + } +} + +// can't `impl From>` +// because "upstream crate may add a new impl" >:( +impl PKError { + pub fn from_validation_errors(errs: Vec) -> Self { + Self { + message: "Error parsing JSON model", + json_code: 40001, + errors: errs, + response_code: StatusCode::BAD_REQUEST, + inner: None, } } } @@ -50,14 +68,19 @@ impl IntoResponse for PKError { if let Some(inner) = self.inner { tracing::error!(?inner, "error returned from handler"); } - crate::util::json_err( - self.response_code, - serde_json::to_string(&serde_json::json!({ + let json = if self.errors.len() > 0 { + serde_json::json!({ "message": self.message, "code": self.json_code, - })) - .unwrap(), - ) + "errors": self.errors, + }) + } else { + serde_json::json!({ + "message": self.message, + "code": self.json_code, + }) + }; + crate::util::json_err(self.response_code, serde_json::to_string(&json).unwrap()) } } @@ -78,9 +101,17 @@ macro_rules! define_error { json_code: $json_code, message: $message, inner: None, + errors: vec![], }; }; } +define_error! { GENERIC_AUTH_ERROR, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" } define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" } + +define_error! { NOT_OWN_MEMBER, StatusCode::FORBIDDEN, 30006, "Target member is not part of your system." } +define_error! { NOT_OWN_GROUP, StatusCode::FORBIDDEN, 30007, "Target group is not part of your system." } + +define_error! { TARGET_MEMBER_NOT_FOUND, StatusCode::BAD_REQUEST, 40010, "Target member not found." } +define_error! { TARGET_GROUP_NOT_FOUND, StatusCode::BAD_REQUEST, 40011, "Target group not found." } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index f22450ce..1850861b 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -115,6 +115,8 @@ fn router(ctx: ApiContext) -> Router { .route("/v2/messages/{message_id}", get(rproxy)) + .route("/v2/bulk", post(endpoints::bulk::bulk)) + .route("/private/bulk_privacy/member", post(rproxy)) .route("/private/bulk_privacy/group", post(rproxy)) .route("/private/discord/callback", post(rproxy)) diff --git a/crates/macros/src/model.rs b/crates/macros/src/model.rs index e37d0dde..5505e76a 100644 --- a/crates/macros/src/model.rs +++ b/crates/macros/src/model.rs @@ -85,8 +85,14 @@ fn parse_field(field: syn::Field) -> ModelField { panic!("must have json name to be publicly patchable"); } - if f.json.is_some() && f.is_privacy { - panic!("cannot set custom json name for privacy field"); + if f.is_privacy && f.json.is_none() { + f.json = Some(syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new( + f.name.clone().to_string().as_str(), + proc_macro2::Span::call_site(), + )), + })) } f @@ -122,17 +128,17 @@ pub fn macro_impl( let fields: Vec = fields .iter() - .filter(|f| !matches!(f.patch, ElemPatchability::None)) + .filter(|f| f.is_privacy || !matches!(f.patch, ElemPatchability::None)) .cloned() .collect(); let patch_fields = mk_patch_fields(fields.clone()); - let patch_from_json = mk_patch_from_json(fields.clone()); let patch_validate = mk_patch_validate(fields.clone()); + let patch_validate_bulk = mk_patch_validate_bulk(fields.clone()); let patch_to_json = mk_patch_to_json(fields.clone()); let patch_to_sql = mk_patch_to_sql(fields.clone()); - return quote! { + let code = quote! { #[derive(sqlx::FromRow, Debug, Clone)] pub struct #tname { #tfields @@ -146,31 +152,42 @@ pub fn macro_impl( #to_json } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, Default)] pub struct #patchable_name { #patch_fields + + errors: Vec, } impl #patchable_name { - pub fn from_json(input: String) -> Self { - #patch_from_json - } - - pub fn validate(self) -> bool { + pub fn validate(&mut self) { #patch_validate } + pub fn errors(&self) -> Vec { + self.errors.clone() + } + + pub fn validate_bulk(&mut self) { + #patch_validate_bulk + } + pub fn to_sql(self) -> sea_query::UpdateStatement { - // sea_query::Query::update() - #patch_to_sql + use sea_query::types::*; + let mut patch = &mut sea_query::Query::update(); + #patch_to_sql + patch.clone() } pub fn to_json(self) -> serde_json::Value { #patch_to_json } } - } - .into(); + }; + + // panic!("{:#?}", code.to_string()); + + return code.into(); } fn mk_tfields(fields: Vec) -> TokenStream { @@ -225,7 +242,7 @@ fn mk_tto_json(fields: Vec) -> TokenStream { .filter_map(|f| { if f.is_privacy { let tname = f.name.clone(); - let tnamestr = f.name.clone().to_string(); + let tnamestr = f.json.clone(); Some(quote! { #tnamestr: self.#tname, }) @@ -280,13 +297,48 @@ fn mk_patch_fields(fields: Vec) -> TokenStream { .collect() } fn mk_patch_validate(_fields: Vec) -> TokenStream { - quote! { true } -} -fn mk_patch_from_json(_fields: Vec) -> TokenStream { quote! { unimplemented!(); } } -fn mk_patch_to_sql(_fields: Vec) -> TokenStream { - quote! { unimplemented!(); } +fn mk_patch_validate_bulk(fields: Vec) -> TokenStream { + // iterate over all nullable patchable fields other than privacy + // add an error if any field is set to a value other than null + fields + .iter() + .map(|f| { + if let syn::Type::Path(path) = &f.ty && let Some(inner) = path.path.segments.last() && inner.ident != "Option" { + return quote! {}; + } + let name = f.name.clone(); + if matches!(f.patch, ElemPatchability::Public) { + let json = f.json.clone().unwrap(); + quote! { + if let Some(val) = self.#name.clone() && val.is_some() { + self.errors.push(ValidationError::simple(#json, "Only null values are supported in bulk endpoint")); + } + } + } else { + quote! {} + } + }) + .collect() +} +fn mk_patch_to_sql(fields: Vec) -> TokenStream { + fields + .iter() + .filter_map(|f| { + if !matches!(f.patch, ElemPatchability::None) || f.is_privacy { + let name = f.name.clone(); + let column = f.name.to_string(); + Some(quote! { + if let Some(value) = self.#name { + patch = patch.value(#column, value); + } + }) + } else { + None + } + }) + .collect() } fn mk_patch_to_json(_fields: Vec) -> TokenStream { quote! { unimplemented!(); } diff --git a/crates/models/Cargo.toml b/crates/models/Cargo.toml index 752fbaa5..93366a82 100644 --- a/crates/models/Cargo.toml +++ b/crates/models/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] chrono = { workspace = true, features = ["serde"] } pk_macros = { path = "../macros" } -sea-query = "0.32.1" +sea-query = { workspace = true } serde = { workspace = true } serde_json = { workspace = true, features = ["preserve_order"] } # in theory we want to default-features = false for sqlx diff --git a/crates/models/src/group.rs b/crates/models/src/group.rs new file mode 100644 index 00000000..ab94d27b --- /dev/null +++ b/crates/models/src/group.rs @@ -0,0 +1,132 @@ +use pk_macros::pk_model; + +use chrono::{DateTime, Utc}; +use serde::Deserialize; +use serde_json::Value; +use uuid::Uuid; + +use crate::{PrivacyLevel, SystemId, ValidationError}; + +// todo: fix +pub type GroupId = i32; + +#[pk_model] +struct Group { + id: GroupId, + #[json = "hid"] + #[private_patchable] + hid: String, + #[json = "uuid"] + uuid: Uuid, + // TODO fix + #[json = "system"] + system: SystemId, + + #[json = "name"] + #[privacy = name_privacy] + #[patchable] + name: String, + #[json = "display_name"] + #[patchable] + display_name: Option, + #[json = "color"] + #[patchable] + color: Option, + #[json = "icon"] + #[patchable] + icon: Option, + #[json = "banner_image"] + #[patchable] + banner_image: Option, + #[json = "description"] + #[privacy = description_privacy] + #[patchable] + description: Option, + #[json = "created"] + created: DateTime, + + #[privacy] + name_privacy: PrivacyLevel, + #[privacy] + description_privacy: PrivacyLevel, + #[privacy] + banner_privacy: PrivacyLevel, + #[privacy] + icon_privacy: PrivacyLevel, + #[privacy] + list_privacy: PrivacyLevel, + #[privacy] + metadata_privacy: PrivacyLevel, + #[privacy] + visibility: PrivacyLevel, +} + +impl<'de> Deserialize<'de> for PKGroupPatch { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut patch: PKGroupPatch = Default::default(); + let value: Value = Value::deserialize(deserializer)?; + + if let Some(v) = value.get("name") { + if let Some(name) = v.as_str() { + patch.name = Some(name.to_string()); + } else if v.is_null() { + patch.errors.push(ValidationError::simple( + "name", + "Group name cannot be set to null.", + )); + } + } + + macro_rules! parse_string_simple { + ($k:expr) => { + match value.get($k) { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s.clone())), + _ => { + patch.errors.push(ValidationError::new($k)); + None + } + } + }; + } + + patch.display_name = parse_string_simple!("display_name"); + patch.description = parse_string_simple!("description"); + patch.icon = parse_string_simple!("icon"); + patch.banner_image = parse_string_simple!("banner"); + patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase())); + + if let Some(privacy) = value.get("privacy").and_then(Value::as_object) { + macro_rules! parse_privacy { + ($v:expr) => { + match privacy.get($v) { + None => None, + Some(Value::Null) => Some(PrivacyLevel::Private), + Some(Value::String(s)) if s == "" || s == "private" => { + Some(PrivacyLevel::Private) + } + Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public), + _ => { + patch.errors.push(ValidationError::new($v)); + None + } + } + }; + } + + patch.name_privacy = parse_privacy!("name_privacy"); + patch.description_privacy = parse_privacy!("description_privacy"); + patch.banner_privacy = parse_privacy!("banner_privacy"); + patch.icon_privacy = parse_privacy!("icon_privacy"); + patch.list_privacy = parse_privacy!("list_privacy"); + patch.metadata_privacy = parse_privacy!("metadata_privacy"); + patch.visibility = parse_privacy!("visibility"); + } + + Ok(patch) + } +} diff --git a/crates/models/src/lib.rs b/crates/models/src/lib.rs index 08350488..6bb4adf4 100644 --- a/crates/models/src/lib.rs +++ b/crates/models/src/lib.rs @@ -9,6 +9,8 @@ macro_rules! model { model!(system); model!(system_config); +model!(member); +model!(group); #[derive(serde::Serialize, Debug, Clone)] #[serde(rename_all = "snake_case")] @@ -31,3 +33,30 @@ impl From for PrivacyLevel { } } } + +impl From for sea_query::Value { + fn from(level: PrivacyLevel) -> sea_query::Value { + match level { + PrivacyLevel::Public => sea_query::Value::Int(Some(1)), + PrivacyLevel::Private => sea_query::Value::Int(Some(2)), + } + } +} + +#[derive(serde::Serialize, Debug, Clone)] +pub enum ValidationError { + Simple { key: String, value: String }, +} + +impl ValidationError { + fn new(key: &str) -> Self { + Self::simple(key, "is invalid") + } + + fn simple(key: &str, value: &str) -> Self { + Self::Simple { + key: key.to_string(), + value: value.to_string(), + } + } +} diff --git a/crates/models/src/member.rs b/crates/models/src/member.rs new file mode 100644 index 00000000..84109cbe --- /dev/null +++ b/crates/models/src/member.rs @@ -0,0 +1,208 @@ +use pk_macros::pk_model; + +use chrono::NaiveDateTime; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use uuid::Uuid; + +use crate::{PrivacyLevel, SystemId, ValidationError}; + +// todo: fix +pub type MemberId = i32; + +#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)] +#[sqlx(type_name = "proxy_tag")] +pub struct ProxyTag { + pub prefix: Option, + pub suffix: Option, +} + +#[pk_model] +struct Member { + id: MemberId, + #[json = "hid"] + #[private_patchable] + hid: String, + #[json = "uuid"] + uuid: Uuid, + // TODO fix + #[json = "system"] + system: SystemId, + + #[json = "color"] + #[patchable] + color: Option, + #[json = "webhook_avatar_url"] + #[patchable] + webhook_avatar_url: Option, + #[json = "avatar_url"] + #[patchable] + avatar_url: Option, + #[json = "banner_image"] + #[patchable] + banner_image: Option, + #[json = "name"] + #[privacy = name_privacy] + #[patchable] + name: String, + #[json = "display_name"] + #[patchable] + display_name: Option, + #[json = "birthday"] + #[patchable] + birthday: Option, + #[json = "pronouns"] + #[privacy = pronoun_privacy] + #[patchable] + pronouns: Option, + #[json = "description"] + #[privacy = description_privacy] + #[patchable] + description: Option, + #[json = "proxy_tags"] + // #[patchable] + proxy_tags: Vec, + #[json = "keep_proxy"] + #[patchable] + keep_proxy: bool, + #[json = "tts"] + #[patchable] + tts: bool, + #[json = "created"] + created: NaiveDateTime, + #[json = "message_count"] + #[private_patchable] + message_count: i32, + #[json = "last_message_timestamp"] + #[private_patchable] + last_message_timestamp: Option, + #[json = "allow_autoproxy"] + #[patchable] + allow_autoproxy: bool, + + #[privacy] + #[json = "visibility"] + member_visibility: PrivacyLevel, + #[privacy] + description_privacy: PrivacyLevel, + #[privacy] + banner_privacy: PrivacyLevel, + #[privacy] + avatar_privacy: PrivacyLevel, + #[privacy] + name_privacy: PrivacyLevel, + #[privacy] + birthday_privacy: PrivacyLevel, + #[privacy] + pronoun_privacy: PrivacyLevel, + #[privacy] + metadata_privacy: PrivacyLevel, + #[privacy] + proxy_privacy: PrivacyLevel, +} + +impl<'de> Deserialize<'de> for PKMemberPatch { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut patch: PKMemberPatch = Default::default(); + let value: Value = Value::deserialize(deserializer)?; + + if let Some(v) = value.get("name") { + if let Some(name) = v.as_str() { + patch.name = Some(name.to_string()); + } else if v.is_null() { + patch.errors.push(ValidationError::simple( + "name", + "Member name cannot be set to null.", + )); + } + } + + macro_rules! parse_string_simple { + ($k:expr) => { + match value.get($k) { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s.clone())), + _ => { + patch.errors.push(ValidationError::new($k)); + None + } + } + }; + } + + patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase())); + patch.display_name = parse_string_simple!("display_name"); + patch.avatar_url = parse_string_simple!("avatar_url"); + patch.banner_image = parse_string_simple!("banner"); + patch.birthday = parse_string_simple!("birthday"); // fix + patch.pronouns = parse_string_simple!("pronouns"); + patch.description = parse_string_simple!("description"); + + if let Some(keep_proxy) = value.get("keep_proxy").and_then(Value::as_bool) { + patch.keep_proxy = Some(keep_proxy); + } + if let Some(tts) = value.get("tts").and_then(Value::as_bool) { + patch.tts = Some(tts); + } + + // todo: legacy import handling + + // todo: fix proxy_tag type in sea_query + + // if let Some(proxy_tags) = value.get("proxy_tags").and_then(Value::as_array) { + // patch.proxy_tags = Some( + // proxy_tags + // .iter() + // .filter_map(|tag| { + // tag.as_object().map(|tag_obj| { + // let prefix = tag_obj + // .get("prefix") + // .and_then(Value::as_str) + // .map(|s| s.to_string()); + // let suffix = tag_obj + // .get("suffix") + // .and_then(Value::as_str) + // .map(|s| s.to_string()); + // ProxyTag { prefix, suffix } + // }) + // }) + // .collect(), + // ) + // } + + if let Some(privacy) = value.get("privacy").and_then(Value::as_object) { + macro_rules! parse_privacy { + ($v:expr) => { + match privacy.get($v) { + None => None, + Some(Value::Null) => Some(PrivacyLevel::Private), + Some(Value::String(s)) if s == "" || s == "private" => { + Some(PrivacyLevel::Private) + } + Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public), + _ => { + patch.errors.push(ValidationError::new($v)); + None + } + } + }; + } + + patch.member_visibility = parse_privacy!("visibility"); + patch.name_privacy = parse_privacy!("name_privacy"); + patch.description_privacy = parse_privacy!("description_privacy"); + patch.banner_privacy = parse_privacy!("banner_privacy"); + patch.avatar_privacy = parse_privacy!("avatar_privacy"); + patch.birthday_privacy = parse_privacy!("birthday_privacy"); + patch.pronoun_privacy = parse_privacy!("pronoun_privacy"); + patch.proxy_privacy = parse_privacy!("proxy_privacy"); + patch.metadata_privacy = parse_privacy!("metadata_privacy"); + } + + Ok(patch) + } +} diff --git a/flake.nix b/flake.nix index 85793415..f693c704 100644 --- a/flake.nix +++ b/flake.nix @@ -54,6 +54,7 @@ gcc omnisharp-roslyn bashInteractive + rust-analyzer ]; runScript = cmd; }; From b8b98e7fd0a563ded3acafe432da1eea62c8d8e3 Mon Sep 17 00:00:00 2001 From: alyssa Date: Wed, 3 Sep 2025 00:43:56 +0000 Subject: [PATCH 10/10] chore: update recovery message on dashboard --- dashboard/src/routes/Settings/Settings.svelte | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dashboard/src/routes/Settings/Settings.svelte b/dashboard/src/routes/Settings/Settings.svelte index 62ce89b6..c2e8d333 100644 --- a/dashboard/src/routes/Settings/Settings.svelte +++ b/dashboard/src/routes/Settings/Settings.svelte @@ -154,7 +154,7 @@

If you've lost access to your discord account, you can retrieve your token here.

-

Send a direct message to a staff member (a helper, moderator or developer in the support server), they can recover your system with this token.

+

Ask in the #bot-support channel of the support server for a staff member to DM you, they can recover your system with this token. Do not post the token in the channel.

{#if showToken}