diff --git a/Cargo.lock b/Cargo.lock index d52d073b..7dafae51 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,8 @@ dependencies = [ "pluralkit_models", "reqwest 0.12.15", "reverse-proxy-service", + "sea-query", + "sea-query-sqlx", "serde", "serde_json", "serde_urlencoded", @@ -3345,19 +3347,20 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.32.3" +version = "1.0.0-rc.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5a24d8b9fcd2674a6c878a3d871f4f1380c6c43cc3718728ac96864d888458e" +checksum = "ab621a8d8b03a3e513ea075f71aa26830a55c977d7b40f09e825bb91910db823" dependencies = [ + "chrono", "inherent", "sea-query-derive", ] [[package]] name = "sea-query-derive" -version = "0.4.3" +version = "1.0.0-rc.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae0cbad6ab996955664982739354128c58d16e126114fe88c2a493642502aab" +checksum = "217e9422de35f26c16c5f671fce3c075a65e10322068dbc66078428634af6195" dependencies = [ "darling", "heck 0.4.1", @@ -3367,6 +3370,17 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "sea-query-sqlx" +version = "0.8.0-rc.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5eb19495858d8ae3663387a4f5298516c6f0171a7ca5681055450f190236b8" +dependencies = [ + "chrono", + "sea-query", + "sqlx", +] + [[package]] name = "security-framework" version = "3.2.0" diff --git a/Cargo.toml b/Cargo.toml index 444bcee6..9bae7053 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ futures = "0.3.30" lazy_static = "1.4.0" metrics = "0.23.0" reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]} +sea-query = { version = "1.0.0-rc.10", features = ["with-chrono"] } sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.117" diff --git a/PluralKit.Bot/Commands/Help.cs b/PluralKit.Bot/Commands/Help.cs index e148714b..333f4997 100644 --- a/PluralKit.Bot/Commands/Help.cs +++ b/PluralKit.Bot/Commands/Help.cs @@ -7,12 +7,94 @@ namespace PluralKit.Bot; public class Help { + public Task HelpRoot(Context ctx) + { + if (ctx.MatchFlag("show-embed", "se")) + return HelpRootOld(ctx); + + return ctx.Reply(BuildComponents(ctx.Author.Id, Help.Description.Replace("{prefix}", ctx.DefaultPrefix), -1)); + } + + public static Task ButtonClick(InteractionContext ctx, string prefix) + { + if (!ctx.CustomId.Contains(ctx.User.Id.ToString())) + return ctx.Ignore(); + + if (ctx.CustomId.StartsWith("new-")) + { + Console.WriteLine($"{ctx.Event.Message.Components.First().Components.Length}"); + if (ctx.Event.Message.Components.First().Components[1].Components.Where(x => x.CustomId == ctx.CustomId).First().Style == ButtonStyle.Primary) + return ctx.Respond(InteractionResponse.ResponseType.UpdateMessage, new() + { + Components = BuildComponents(ctx.User.Id, Help.Description.Replace("{prefix}", prefix), -1), + Flags = Message.MessageFlags.IsComponentsV2, + }); + + var text = helpEmbedPages.GetValueOrDefault(ctx.CustomId.Split("-")[3]).Select( + (item, index) => $"### {item.Name.Replace("{prefix}", prefix)}\n{item.Value.Replace("{prefix}", prefix)}" + ).ToArray(); + + var index = Array.FindIndex(ctx.Event.Message.Components.First().Components[1].Components, x => x.CustomId == ctx.CustomId); + var components = BuildComponents(ctx.User.Id, Help.Description.Replace("{prefix}", prefix), index); + + components.First().Components[ctx.Event.Message.Components.First().Components.Length - 1] = new MessageComponent() + { + Type = ComponentType.Text, + Content = String.Join("\n", text), + }; + + return ctx.Respond(InteractionResponse.ResponseType.UpdateMessage, new() + { + Components = components, + Flags = Message.MessageFlags.IsComponentsV2, + }); + } + + return ButtonClickOld(ctx, prefix); + } + + private static MessageComponent[] BuildComponents(ulong userId, string textContent, int menuIndex) + { + return [ + new MessageComponent() + { + Type = ComponentType.Container, + AccentColor = DiscordUtils.Blue, + Components = [ + new MessageComponent() + { + Type = ComponentType.Text, + Content = "# PluralKit\n-# Use the buttons below to see more info!" + }, + helpPageButtons(userId, "new-", menuIndex), + new MessageComponent() + { + Type = ComponentType.Separator, + }, + new MessageComponent() + { + Type = ComponentType.Text, + Content = textContent, + }, + ], + }, + new MessageComponent() + { + Type = ComponentType.Text, + Content = EmbedFooter("\n-# "), + }, + ]; + } + + /// + private static string Description = "PluralKit is a bot designed for plural communities on Discord, and is open for anyone to use. It allows you to register systems, maintain system information, set up message proxying, log switches, and more.\n\n" + "**System recovery:** in the case of your Discord account getting lost or deleted, the PluralKit staff can help you recover your system, **only if you save the system token from `{prefix}token`**. See [this FAQ entry](https://pluralkit.me/faq/#can-i-recover-my-system-if-i-lose-access-to-my-discord-account) for more details.\n\n" + - "If PluralKit is useful to you, please consider donating on [Patreon](https://patreon.com/pluralkit) or [Buy Me A Coffee](https://buymeacoffee.com/pluralkit).\n" + - "## Use the buttons below to see more info!"; + "If PluralKit is useful to you, please consider donating on [Patreon](https://patreon.com/pluralkit) or [Buy Me A Coffee](https://buymeacoffee.com/pluralkit)."; - public static string EmbedFooter = "-# PluralKit by @ske and contributors | Myriad design by @layl, icon by @tedkalashnikov, banner by @fulmine | GitHub: https://github.com/PluralKit/PluralKit/ | Website: https://pluralkit.me/"; + private static string DescriptionOld = $"{Description}\n## Use the buttons below to see more info!"; + + public static string EmbedFooter(string linkSeparator) => $"-# PluralKit by @ske and contributors | Myriad design by @layl, icon by @tedkalashnikov, banner by @fulmine{linkSeparator}GitHub: https://github.com/PluralKit/PluralKit/ | Website: https://pluralkit.me/"; public static Embed helpEmbed = new() { @@ -98,7 +180,7 @@ public class Help } }; - private static MessageComponent helpPageButtons(ulong userId) => new MessageComponent + private static MessageComponent helpPageButtons(ulong userId, string pfx = "", int menuIndex = -1) => new MessageComponent { Type = ComponentType.ActionRow, Components = new[] @@ -106,58 +188,54 @@ public class Help new MessageComponent { Type = ComponentType.Button, - Style = ButtonStyle.Secondary, + Style = menuIndex == 0 ? ButtonStyle.Primary : ButtonStyle.Secondary, Label = "Basic Info", - CustomId = $"help-menu-basicinfo-{userId}", + CustomId = $"{pfx}help-menu-basicinfo-{userId}", Emoji = new() { Name = "\u2139" }, }, new() { Type = ComponentType.Button, - Style = ButtonStyle.Secondary, + Style = menuIndex == 1 ? ButtonStyle.Primary : ButtonStyle.Secondary, Label = "Getting Started", - CustomId = $"help-menu-gettingstarted-{userId}", + CustomId = $"{pfx}help-menu-gettingstarted-{userId}", Emoji = new() { Name = "\u2753", }, }, new() { Type = ComponentType.Button, - Style = ButtonStyle.Secondary, + Style = menuIndex == 2 ? ButtonStyle.Primary : ButtonStyle.Secondary, Label = "Useful Tips", - CustomId = $"help-menu-usefultips-{userId}", + CustomId = $"{pfx}help-menu-usefultips-{userId}", Emoji = new() { Name = "\U0001f4a1", }, - }, new() { Type = ComponentType.Button, - Style = ButtonStyle.Secondary, + Style = menuIndex == 3 ? ButtonStyle.Primary : ButtonStyle.Secondary, Label = "More Info", - CustomId = $"help-menu-moreinfo-{userId}", + CustomId = $"{pfx}help-menu-moreinfo-{userId}", Emoji = new() { Id = 986379675066593330, }, } } }; - public Task HelpRoot(Context ctx) + public Task HelpRootOld(Context ctx) => ctx.Rest.CreateMessage(ctx.Channel.Id, new MessageRequest { Content = $"{Emojis.Warn} If you cannot see the rest of this message see [the FAQ]()", - Embeds = new[] { helpEmbed with { Description = Help.Description.Replace("{prefix}", ctx.DefaultPrefix), Fields = new Embed.Field[] { new("", EmbedFooter) } } }, + Embeds = new[] { helpEmbed with { Description = Help.DescriptionOld.Replace("{prefix}", ctx.DefaultPrefix), Fields = new Embed.Field[] { new("", EmbedFooter(" | ")) } } }, Components = new[] { helpPageButtons(ctx.Author.Id) }, }); - public static Task ButtonClick(InteractionContext ctx, string prefix) + public static Task ButtonClickOld(InteractionContext ctx, string prefix) { - if (!ctx.CustomId.Contains(ctx.User.Id.ToString())) - return ctx.Ignore(); - var buttons = helpPageButtons(ctx.User.Id); if (ctx.Event.Message.Components.First().Components.Where(x => x.CustomId == ctx.CustomId).First().Style == ButtonStyle.Primary) return ctx.Respond(InteractionResponse.ResponseType.UpdateMessage, new() { - Embeds = new[] { helpEmbed with { Description = Help.Description.Replace("{prefix}", prefix), Fields = new Embed.Field[] { new("", EmbedFooter) } } }, + Embeds = new[] { helpEmbed with { Description = Help.DescriptionOld.Replace("{prefix}", prefix), Fields = new Embed.Field[] { new("", EmbedFooter(" | ")) } } }, Components = new[] { buttons } }); @@ -167,7 +245,7 @@ public class Help { Embeds = new[] { helpEmbed with { Fields = helpEmbedPages.GetValueOrDefault(ctx.CustomId.Split("-")[2]).Select( (item, index) => new Embed.Field(item.Name.Replace("{prefix}", prefix), item.Value.Replace("{prefix}", prefix)) - ).Append(new("", EmbedFooter)).ToArray() } }, + ).Append(new("", EmbedFooter(" | "))).ToArray() } }, Components = new[] { buttons } }); } diff --git a/PluralKit.Bot/Commands/Misc.cs b/PluralKit.Bot/Commands/Misc.cs index 8688f9ef..514c8999 100644 --- a/PluralKit.Bot/Commands/Misc.cs +++ b/PluralKit.Bot/Commands/Misc.cs @@ -92,7 +92,7 @@ public class Misc + $"**{stats.db.switches:N0}** switches, **{stats.db.messages:N0}** messages\n" + $"**{stats.db.guilds:N0}** servers with **{stats.db.channels:N0}** channels")); - embed.Field(new("", Help.EmbedFooter)); + embed.Field(new("", Help.EmbedFooter(" | "))); var uptime = ((DateTimeOffset)process.StartTime).ToUnixTimeSeconds(); embed.Description($"### PluralKit [{BuildInfoService.Version}](https://github.com/pluralkit/pluralkit/commit/{BuildInfoService.FullVersion})\n" + diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index b19bfe74..16c54a7f 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -14,6 +14,7 @@ fred = { workspace = true } lazy_static = { workspace = true } metrics = { workspace = true } reqwest = { workspace = true } +sea-query = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlx = { workspace = true } @@ -28,3 +29,4 @@ serde_urlencoded = "0.7.1" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["catch-panic"] } subtle = "2.6.1" +sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] } diff --git a/crates/api/src/endpoints/bulk.rs b/crates/api/src/endpoints/bulk.rs new file mode 100644 index 00000000..d859da88 --- /dev/null +++ b/crates/api/src/endpoints/bulk.rs @@ -0,0 +1,211 @@ +use axum::{ + Extension, Json, + extract::{Json as ExtractJson, State}, + response::IntoResponse, +}; +use pk_macros::api_endpoint; +use sea_query::{Expr, ExprTrait, PostgresQueryBuilder}; +use sea_query_sqlx::SqlxBinder; +use serde_json::{Value, json}; + +use pluralkit_models::{PKGroup, PKGroupPatch, PKMember, PKMemberPatch, PKSystem}; + +use crate::{ + ApiContext, + auth::AuthState, + error::{ + GENERIC_AUTH_ERROR, NOT_OWN_GROUP, NOT_OWN_MEMBER, PKError, TARGET_GROUP_NOT_FOUND, + TARGET_MEMBER_NOT_FOUND, + }, +}; + +#[derive(serde::Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BulkActionRequestFilter { + All, + Ids { ids: Vec }, + Connection { id: String }, +} + +#[derive(serde::Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BulkActionRequest { + Member { + filter: BulkActionRequestFilter, + patch: PKMemberPatch, + }, + Group { + filter: BulkActionRequestFilter, + patch: PKGroupPatch, + }, +} + +#[api_endpoint] +pub async fn bulk( + Extension(auth): Extension, + State(ctx): State, + ExtractJson(req): ExtractJson, +) -> Json { + let Some(system_id) = auth.system_id() else { + return Err(GENERIC_AUTH_ERROR); + }; + + #[derive(sqlx::FromRow)] + struct Ider { + id: i32, + hid: String, + uuid: String, + } + + #[derive(sqlx::FromRow)] + struct GroupMemberEntry { + member_id: i32, + group_id: i32, + } + + #[allow(dead_code)] + #[derive(sqlx::FromRow)] + struct OnlyIder { + id: i32, + } + + println!("BulkActionRequest::{req:#?}"); + match req { + BulkActionRequest::Member { filter, mut patch } => { + patch.validate_bulk(); + if patch.errors().len() > 0 { + return Err(PKError::from_validation_errors(patch.errors())); + } + + let ids: Vec = match filter { + BulkActionRequestFilter::All => { + let ids: Vec = sqlx::query_as("select id from members where system = $1") + .bind(system_id as i64) + .fetch_all(&ctx.db) + .await?; + ids.iter().map(|v| v.id).collect() + } + BulkActionRequestFilter::Ids { ids } => { + let members: Vec = sqlx::query_as( + "select * from members where hid = any($1::array) or uuid::text = any($1::array)", + ) + .bind(&ids) + .fetch_all(&ctx.db) + .await?; + + // todo: better errors + if members.len() != ids.len() { + return Err(TARGET_MEMBER_NOT_FOUND); + } + + if members.iter().any(|m| m.system != system_id) { + return Err(NOT_OWN_MEMBER); + } + + members.iter().map(|m| m.id).collect() + } + BulkActionRequestFilter::Connection { id } => { + let Some(group): Option = + sqlx::query_as("select * from groups where hid = $1 or uuid::text = $1") + .bind(id) + .fetch_optional(&ctx.db) + .await? + else { + return Err(TARGET_GROUP_NOT_FOUND); + }; + + if group.system != system_id { + return Err(NOT_OWN_GROUP); + } + + let entries: Vec = + sqlx::query_as("select * from group_members where group_id = $1") + .bind(group.id) + .fetch_all(&ctx.db) + .await?; + + entries.iter().map(|v| v.member_id).collect() + } + }; + + let (q, pms) = patch + .to_sql() + .table("members") // todo: this should be in the model definition + .and_where(Expr::col("id").is_in(ids)) + .returning_col("id") + .build_sqlx(PostgresQueryBuilder); + + let res: Vec = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?; + Ok(Json(json! {{ "updated": res.len() }})) + } + BulkActionRequest::Group { filter, mut patch } => { + patch.validate_bulk(); + if patch.errors().len() > 0 { + return Err(PKError::from_validation_errors(patch.errors())); + } + + let ids: Vec = match filter { + BulkActionRequestFilter::All => { + let ids: Vec = sqlx::query_as("select id from groups where system = $1") + .bind(system_id as i64) + .fetch_all(&ctx.db) + .await?; + ids.iter().map(|v| v.id).collect() + } + BulkActionRequestFilter::Ids { ids } => { + let groups: Vec = sqlx::query_as( + "select * from groups where hid = any($1) or uuid::text = any($1)", + ) + .bind(&ids) + .fetch_all(&ctx.db) + .await?; + + // todo: better errors + if groups.len() != ids.len() { + return Err(TARGET_GROUP_NOT_FOUND); + } + + if groups.iter().any(|m| m.system != system_id) { + return Err(NOT_OWN_GROUP); + } + + groups.iter().map(|m| m.id).collect() + } + BulkActionRequestFilter::Connection { id } => { + let Some(member): Option = + sqlx::query_as("select * from members where hid = $1 or uuid::text = $1") + .bind(id) + .fetch_optional(&ctx.db) + .await? + else { + return Err(TARGET_MEMBER_NOT_FOUND); + }; + + if member.system != system_id { + return Err(NOT_OWN_MEMBER); + } + + let entries: Vec = + sqlx::query_as("select * from group_members where member_id = $1") + .bind(member.id) + .fetch_all(&ctx.db) + .await?; + + entries.iter().map(|v| v.group_id).collect() + } + }; + + let (q, pms) = patch + .to_sql() + .table("groups") // todo: this should be in the model definition + .and_where(Expr::col("id").is_in(ids)) + .returning_col("id") + .build_sqlx(PostgresQueryBuilder); + + println!("{q:#?} {pms:#?}"); + + let res: Vec = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?; + Ok(Json(json! {{ "updated": res.len() }})) + } + } +} diff --git a/crates/api/src/endpoints/mod.rs b/crates/api/src/endpoints/mod.rs index c311367c..167acee8 100644 --- a/crates/api/src/endpoints/mod.rs +++ b/crates/api/src/endpoints/mod.rs @@ -1,2 +1,3 @@ +pub mod bulk; pub mod private; pub mod system; diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index fc481d0c..ae7e5a99 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -2,6 +2,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; +use pluralkit_models::ValidationError; use std::fmt; // todo: model parse errors @@ -11,6 +12,8 @@ pub struct PKError { pub json_code: i32, pub message: &'static str, + pub errors: Vec, + pub inner: Option, } @@ -30,6 +33,21 @@ impl Clone for PKError { json_code: self.json_code, message: self.message, inner: None, + errors: self.errors.clone(), + } + } +} + +// can't `impl From>` +// because "upstream crate may add a new impl" >:( +impl PKError { + pub fn from_validation_errors(errs: Vec) -> Self { + Self { + message: "Error parsing JSON model", + json_code: 40001, + errors: errs, + response_code: StatusCode::BAD_REQUEST, + inner: None, } } } @@ -50,14 +68,19 @@ impl IntoResponse for PKError { if let Some(inner) = self.inner { tracing::error!(?inner, "error returned from handler"); } - crate::util::json_err( - self.response_code, - serde_json::to_string(&serde_json::json!({ + let json = if self.errors.len() > 0 { + serde_json::json!({ "message": self.message, "code": self.json_code, - })) - .unwrap(), - ) + "errors": self.errors, + }) + } else { + serde_json::json!({ + "message": self.message, + "code": self.json_code, + }) + }; + crate::util::json_err(self.response_code, serde_json::to_string(&json).unwrap()) } } @@ -78,9 +101,17 @@ macro_rules! define_error { json_code: $json_code, message: $message, inner: None, + errors: vec![], }; }; } +define_error! { GENERIC_AUTH_ERROR, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" } define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" } + +define_error! { NOT_OWN_MEMBER, StatusCode::FORBIDDEN, 30006, "Target member is not part of your system." } +define_error! { NOT_OWN_GROUP, StatusCode::FORBIDDEN, 30007, "Target group is not part of your system." } + +define_error! { TARGET_MEMBER_NOT_FOUND, StatusCode::BAD_REQUEST, 40010, "Target member not found." } +define_error! { TARGET_GROUP_NOT_FOUND, StatusCode::BAD_REQUEST, 40011, "Target group not found." } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index f22450ce..1850861b 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -115,6 +115,8 @@ fn router(ctx: ApiContext) -> Router { .route("/v2/messages/{message_id}", get(rproxy)) + .route("/v2/bulk", post(endpoints::bulk::bulk)) + .route("/private/bulk_privacy/member", post(rproxy)) .route("/private/bulk_privacy/group", post(rproxy)) .route("/private/discord/callback", post(rproxy)) diff --git a/crates/macros/src/model.rs b/crates/macros/src/model.rs index e37d0dde..5505e76a 100644 --- a/crates/macros/src/model.rs +++ b/crates/macros/src/model.rs @@ -85,8 +85,14 @@ fn parse_field(field: syn::Field) -> ModelField { panic!("must have json name to be publicly patchable"); } - if f.json.is_some() && f.is_privacy { - panic!("cannot set custom json name for privacy field"); + if f.is_privacy && f.json.is_none() { + f.json = Some(syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new( + f.name.clone().to_string().as_str(), + proc_macro2::Span::call_site(), + )), + })) } f @@ -122,17 +128,17 @@ pub fn macro_impl( let fields: Vec = fields .iter() - .filter(|f| !matches!(f.patch, ElemPatchability::None)) + .filter(|f| f.is_privacy || !matches!(f.patch, ElemPatchability::None)) .cloned() .collect(); let patch_fields = mk_patch_fields(fields.clone()); - let patch_from_json = mk_patch_from_json(fields.clone()); let patch_validate = mk_patch_validate(fields.clone()); + let patch_validate_bulk = mk_patch_validate_bulk(fields.clone()); let patch_to_json = mk_patch_to_json(fields.clone()); let patch_to_sql = mk_patch_to_sql(fields.clone()); - return quote! { + let code = quote! { #[derive(sqlx::FromRow, Debug, Clone)] pub struct #tname { #tfields @@ -146,31 +152,42 @@ pub fn macro_impl( #to_json } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, Default)] pub struct #patchable_name { #patch_fields + + errors: Vec, } impl #patchable_name { - pub fn from_json(input: String) -> Self { - #patch_from_json - } - - pub fn validate(self) -> bool { + pub fn validate(&mut self) { #patch_validate } + pub fn errors(&self) -> Vec { + self.errors.clone() + } + + pub fn validate_bulk(&mut self) { + #patch_validate_bulk + } + pub fn to_sql(self) -> sea_query::UpdateStatement { - // sea_query::Query::update() - #patch_to_sql + use sea_query::types::*; + let mut patch = &mut sea_query::Query::update(); + #patch_to_sql + patch.clone() } pub fn to_json(self) -> serde_json::Value { #patch_to_json } } - } - .into(); + }; + + // panic!("{:#?}", code.to_string()); + + return code.into(); } fn mk_tfields(fields: Vec) -> TokenStream { @@ -225,7 +242,7 @@ fn mk_tto_json(fields: Vec) -> TokenStream { .filter_map(|f| { if f.is_privacy { let tname = f.name.clone(); - let tnamestr = f.name.clone().to_string(); + let tnamestr = f.json.clone(); Some(quote! { #tnamestr: self.#tname, }) @@ -280,13 +297,48 @@ fn mk_patch_fields(fields: Vec) -> TokenStream { .collect() } fn mk_patch_validate(_fields: Vec) -> TokenStream { - quote! { true } -} -fn mk_patch_from_json(_fields: Vec) -> TokenStream { quote! { unimplemented!(); } } -fn mk_patch_to_sql(_fields: Vec) -> TokenStream { - quote! { unimplemented!(); } +fn mk_patch_validate_bulk(fields: Vec) -> TokenStream { + // iterate over all nullable patchable fields other than privacy + // add an error if any field is set to a value other than null + fields + .iter() + .map(|f| { + if let syn::Type::Path(path) = &f.ty && let Some(inner) = path.path.segments.last() && inner.ident != "Option" { + return quote! {}; + } + let name = f.name.clone(); + if matches!(f.patch, ElemPatchability::Public) { + let json = f.json.clone().unwrap(); + quote! { + if let Some(val) = self.#name.clone() && val.is_some() { + self.errors.push(ValidationError::simple(#json, "Only null values are supported in bulk endpoint")); + } + } + } else { + quote! {} + } + }) + .collect() +} +fn mk_patch_to_sql(fields: Vec) -> TokenStream { + fields + .iter() + .filter_map(|f| { + if !matches!(f.patch, ElemPatchability::None) || f.is_privacy { + let name = f.name.clone(); + let column = f.name.to_string(); + Some(quote! { + if let Some(value) = self.#name { + patch = patch.value(#column, value); + } + }) + } else { + None + } + }) + .collect() } fn mk_patch_to_json(_fields: Vec) -> TokenStream { quote! { unimplemented!(); } diff --git a/crates/models/Cargo.toml b/crates/models/Cargo.toml index 752fbaa5..93366a82 100644 --- a/crates/models/Cargo.toml +++ b/crates/models/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] chrono = { workspace = true, features = ["serde"] } pk_macros = { path = "../macros" } -sea-query = "0.32.1" +sea-query = { workspace = true } serde = { workspace = true } serde_json = { workspace = true, features = ["preserve_order"] } # in theory we want to default-features = false for sqlx diff --git a/crates/models/src/group.rs b/crates/models/src/group.rs new file mode 100644 index 00000000..ab94d27b --- /dev/null +++ b/crates/models/src/group.rs @@ -0,0 +1,132 @@ +use pk_macros::pk_model; + +use chrono::{DateTime, Utc}; +use serde::Deserialize; +use serde_json::Value; +use uuid::Uuid; + +use crate::{PrivacyLevel, SystemId, ValidationError}; + +// todo: fix +pub type GroupId = i32; + +#[pk_model] +struct Group { + id: GroupId, + #[json = "hid"] + #[private_patchable] + hid: String, + #[json = "uuid"] + uuid: Uuid, + // TODO fix + #[json = "system"] + system: SystemId, + + #[json = "name"] + #[privacy = name_privacy] + #[patchable] + name: String, + #[json = "display_name"] + #[patchable] + display_name: Option, + #[json = "color"] + #[patchable] + color: Option, + #[json = "icon"] + #[patchable] + icon: Option, + #[json = "banner_image"] + #[patchable] + banner_image: Option, + #[json = "description"] + #[privacy = description_privacy] + #[patchable] + description: Option, + #[json = "created"] + created: DateTime, + + #[privacy] + name_privacy: PrivacyLevel, + #[privacy] + description_privacy: PrivacyLevel, + #[privacy] + banner_privacy: PrivacyLevel, + #[privacy] + icon_privacy: PrivacyLevel, + #[privacy] + list_privacy: PrivacyLevel, + #[privacy] + metadata_privacy: PrivacyLevel, + #[privacy] + visibility: PrivacyLevel, +} + +impl<'de> Deserialize<'de> for PKGroupPatch { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut patch: PKGroupPatch = Default::default(); + let value: Value = Value::deserialize(deserializer)?; + + if let Some(v) = value.get("name") { + if let Some(name) = v.as_str() { + patch.name = Some(name.to_string()); + } else if v.is_null() { + patch.errors.push(ValidationError::simple( + "name", + "Group name cannot be set to null.", + )); + } + } + + macro_rules! parse_string_simple { + ($k:expr) => { + match value.get($k) { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s.clone())), + _ => { + patch.errors.push(ValidationError::new($k)); + None + } + } + }; + } + + patch.display_name = parse_string_simple!("display_name"); + patch.description = parse_string_simple!("description"); + patch.icon = parse_string_simple!("icon"); + patch.banner_image = parse_string_simple!("banner"); + patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase())); + + if let Some(privacy) = value.get("privacy").and_then(Value::as_object) { + macro_rules! parse_privacy { + ($v:expr) => { + match privacy.get($v) { + None => None, + Some(Value::Null) => Some(PrivacyLevel::Private), + Some(Value::String(s)) if s == "" || s == "private" => { + Some(PrivacyLevel::Private) + } + Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public), + _ => { + patch.errors.push(ValidationError::new($v)); + None + } + } + }; + } + + patch.name_privacy = parse_privacy!("name_privacy"); + patch.description_privacy = parse_privacy!("description_privacy"); + patch.banner_privacy = parse_privacy!("banner_privacy"); + patch.icon_privacy = parse_privacy!("icon_privacy"); + patch.list_privacy = parse_privacy!("list_privacy"); + patch.metadata_privacy = parse_privacy!("metadata_privacy"); + patch.visibility = parse_privacy!("visibility"); + } + + Ok(patch) + } +} diff --git a/crates/models/src/lib.rs b/crates/models/src/lib.rs index 08350488..6bb4adf4 100644 --- a/crates/models/src/lib.rs +++ b/crates/models/src/lib.rs @@ -9,6 +9,8 @@ macro_rules! model { model!(system); model!(system_config); +model!(member); +model!(group); #[derive(serde::Serialize, Debug, Clone)] #[serde(rename_all = "snake_case")] @@ -31,3 +33,30 @@ impl From for PrivacyLevel { } } } + +impl From for sea_query::Value { + fn from(level: PrivacyLevel) -> sea_query::Value { + match level { + PrivacyLevel::Public => sea_query::Value::Int(Some(1)), + PrivacyLevel::Private => sea_query::Value::Int(Some(2)), + } + } +} + +#[derive(serde::Serialize, Debug, Clone)] +pub enum ValidationError { + Simple { key: String, value: String }, +} + +impl ValidationError { + fn new(key: &str) -> Self { + Self::simple(key, "is invalid") + } + + fn simple(key: &str, value: &str) -> Self { + Self::Simple { + key: key.to_string(), + value: value.to_string(), + } + } +} diff --git a/crates/models/src/member.rs b/crates/models/src/member.rs new file mode 100644 index 00000000..84109cbe --- /dev/null +++ b/crates/models/src/member.rs @@ -0,0 +1,208 @@ +use pk_macros::pk_model; + +use chrono::NaiveDateTime; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use uuid::Uuid; + +use crate::{PrivacyLevel, SystemId, ValidationError}; + +// todo: fix +pub type MemberId = i32; + +#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)] +#[sqlx(type_name = "proxy_tag")] +pub struct ProxyTag { + pub prefix: Option, + pub suffix: Option, +} + +#[pk_model] +struct Member { + id: MemberId, + #[json = "hid"] + #[private_patchable] + hid: String, + #[json = "uuid"] + uuid: Uuid, + // TODO fix + #[json = "system"] + system: SystemId, + + #[json = "color"] + #[patchable] + color: Option, + #[json = "webhook_avatar_url"] + #[patchable] + webhook_avatar_url: Option, + #[json = "avatar_url"] + #[patchable] + avatar_url: Option, + #[json = "banner_image"] + #[patchable] + banner_image: Option, + #[json = "name"] + #[privacy = name_privacy] + #[patchable] + name: String, + #[json = "display_name"] + #[patchable] + display_name: Option, + #[json = "birthday"] + #[patchable] + birthday: Option, + #[json = "pronouns"] + #[privacy = pronoun_privacy] + #[patchable] + pronouns: Option, + #[json = "description"] + #[privacy = description_privacy] + #[patchable] + description: Option, + #[json = "proxy_tags"] + // #[patchable] + proxy_tags: Vec, + #[json = "keep_proxy"] + #[patchable] + keep_proxy: bool, + #[json = "tts"] + #[patchable] + tts: bool, + #[json = "created"] + created: NaiveDateTime, + #[json = "message_count"] + #[private_patchable] + message_count: i32, + #[json = "last_message_timestamp"] + #[private_patchable] + last_message_timestamp: Option, + #[json = "allow_autoproxy"] + #[patchable] + allow_autoproxy: bool, + + #[privacy] + #[json = "visibility"] + member_visibility: PrivacyLevel, + #[privacy] + description_privacy: PrivacyLevel, + #[privacy] + banner_privacy: PrivacyLevel, + #[privacy] + avatar_privacy: PrivacyLevel, + #[privacy] + name_privacy: PrivacyLevel, + #[privacy] + birthday_privacy: PrivacyLevel, + #[privacy] + pronoun_privacy: PrivacyLevel, + #[privacy] + metadata_privacy: PrivacyLevel, + #[privacy] + proxy_privacy: PrivacyLevel, +} + +impl<'de> Deserialize<'de> for PKMemberPatch { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut patch: PKMemberPatch = Default::default(); + let value: Value = Value::deserialize(deserializer)?; + + if let Some(v) = value.get("name") { + if let Some(name) = v.as_str() { + patch.name = Some(name.to_string()); + } else if v.is_null() { + patch.errors.push(ValidationError::simple( + "name", + "Member name cannot be set to null.", + )); + } + } + + macro_rules! parse_string_simple { + ($k:expr) => { + match value.get($k) { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s.clone())), + _ => { + patch.errors.push(ValidationError::new($k)); + None + } + } + }; + } + + patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase())); + patch.display_name = parse_string_simple!("display_name"); + patch.avatar_url = parse_string_simple!("avatar_url"); + patch.banner_image = parse_string_simple!("banner"); + patch.birthday = parse_string_simple!("birthday"); // fix + patch.pronouns = parse_string_simple!("pronouns"); + patch.description = parse_string_simple!("description"); + + if let Some(keep_proxy) = value.get("keep_proxy").and_then(Value::as_bool) { + patch.keep_proxy = Some(keep_proxy); + } + if let Some(tts) = value.get("tts").and_then(Value::as_bool) { + patch.tts = Some(tts); + } + + // todo: legacy import handling + + // todo: fix proxy_tag type in sea_query + + // if let Some(proxy_tags) = value.get("proxy_tags").and_then(Value::as_array) { + // patch.proxy_tags = Some( + // proxy_tags + // .iter() + // .filter_map(|tag| { + // tag.as_object().map(|tag_obj| { + // let prefix = tag_obj + // .get("prefix") + // .and_then(Value::as_str) + // .map(|s| s.to_string()); + // let suffix = tag_obj + // .get("suffix") + // .and_then(Value::as_str) + // .map(|s| s.to_string()); + // ProxyTag { prefix, suffix } + // }) + // }) + // .collect(), + // ) + // } + + if let Some(privacy) = value.get("privacy").and_then(Value::as_object) { + macro_rules! parse_privacy { + ($v:expr) => { + match privacy.get($v) { + None => None, + Some(Value::Null) => Some(PrivacyLevel::Private), + Some(Value::String(s)) if s == "" || s == "private" => { + Some(PrivacyLevel::Private) + } + Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public), + _ => { + patch.errors.push(ValidationError::new($v)); + None + } + } + }; + } + + patch.member_visibility = parse_privacy!("visibility"); + patch.name_privacy = parse_privacy!("name_privacy"); + patch.description_privacy = parse_privacy!("description_privacy"); + patch.banner_privacy = parse_privacy!("banner_privacy"); + patch.avatar_privacy = parse_privacy!("avatar_privacy"); + patch.birthday_privacy = parse_privacy!("birthday_privacy"); + patch.pronoun_privacy = parse_privacy!("pronoun_privacy"); + patch.proxy_privacy = parse_privacy!("proxy_privacy"); + patch.metadata_privacy = parse_privacy!("metadata_privacy"); + } + + Ok(patch) + } +}