diff --git a/Cargo.lock b/Cargo.lock index 0bb714cb..79f65d9f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -92,6 +92,8 @@ dependencies = [ "pluralkit_models", "reqwest 0.12.15", "reverse-proxy-service", + "sea-query", + "sea-query-sqlx", "serde", "serde_json", "serde_urlencoded", @@ -3372,19 +3374,20 @@ dependencies = [ [[package]] name = "sea-query" -version = "0.32.3" +version = "1.0.0-rc.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5a24d8b9fcd2674a6c878a3d871f4f1380c6c43cc3718728ac96864d888458e" +checksum = "ab621a8d8b03a3e513ea075f71aa26830a55c977d7b40f09e825bb91910db823" dependencies = [ + "chrono", "inherent", "sea-query-derive", ] [[package]] name = "sea-query-derive" -version = "0.4.3" +version = "1.0.0-rc.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bae0cbad6ab996955664982739354128c58d16e126114fe88c2a493642502aab" +checksum = "217e9422de35f26c16c5f671fce3c075a65e10322068dbc66078428634af6195" dependencies = [ "darling", "heck 0.4.1", @@ -3394,6 +3397,17 @@ dependencies = [ "thiserror 2.0.12", ] +[[package]] +name = "sea-query-sqlx" +version = "0.8.0-rc.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed5eb19495858d8ae3663387a4f5298516c6f0171a7ca5681055450f190236b8" +dependencies = [ + "chrono", + "sea-query", + "sqlx", +] + [[package]] name = "security-framework" version = "3.2.0" diff --git a/Cargo.toml b/Cargo.toml index 69a9048a..3707ce77 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,6 +14,7 @@ futures = "0.3.30" lazy_static = "1.4.0" metrics = "0.23.0" reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]} +sea-query = { version = "1.0.0-rc.10", features = ["with-chrono"] } sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls serde = { version = "1.0.196", features = ["derive"] } serde_json = "1.0.117" diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index b19bfe74..16c54a7f 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -14,6 +14,7 @@ fred = { workspace = true } lazy_static = { workspace = true } metrics = { workspace = true } reqwest = { workspace = true } +sea-query = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlx = { workspace = true } @@ -28,3 +29,4 @@ serde_urlencoded = "0.7.1" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["catch-panic"] } subtle = "2.6.1" +sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] } diff --git a/crates/api/src/endpoints/bulk.rs b/crates/api/src/endpoints/bulk.rs new file mode 100644 index 00000000..d859da88 --- /dev/null +++ b/crates/api/src/endpoints/bulk.rs @@ -0,0 +1,211 @@ +use axum::{ + Extension, Json, + extract::{Json as ExtractJson, State}, + response::IntoResponse, +}; +use pk_macros::api_endpoint; +use sea_query::{Expr, ExprTrait, PostgresQueryBuilder}; +use sea_query_sqlx::SqlxBinder; +use serde_json::{Value, json}; + +use pluralkit_models::{PKGroup, PKGroupPatch, PKMember, PKMemberPatch, PKSystem}; + +use crate::{ + ApiContext, + auth::AuthState, + error::{ + GENERIC_AUTH_ERROR, NOT_OWN_GROUP, NOT_OWN_MEMBER, PKError, TARGET_GROUP_NOT_FOUND, + TARGET_MEMBER_NOT_FOUND, + }, +}; + +#[derive(serde::Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BulkActionRequestFilter { + All, + Ids { ids: Vec }, + Connection { id: String }, +} + +#[derive(serde::Deserialize, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum BulkActionRequest { + Member { + filter: BulkActionRequestFilter, + patch: PKMemberPatch, + }, + Group { + filter: BulkActionRequestFilter, + patch: PKGroupPatch, + }, +} + +#[api_endpoint] +pub async fn bulk( + Extension(auth): Extension, + State(ctx): State, + ExtractJson(req): ExtractJson, +) -> Json { + let Some(system_id) = auth.system_id() else { + return Err(GENERIC_AUTH_ERROR); + }; + + #[derive(sqlx::FromRow)] + struct Ider { + id: i32, + hid: String, + uuid: String, + } + + #[derive(sqlx::FromRow)] + struct GroupMemberEntry { + member_id: i32, + group_id: i32, + } + + #[allow(dead_code)] + #[derive(sqlx::FromRow)] + struct OnlyIder { + id: i32, + } + + println!("BulkActionRequest::{req:#?}"); + match req { + BulkActionRequest::Member { filter, mut patch } => { + patch.validate_bulk(); + if patch.errors().len() > 0 { + return Err(PKError::from_validation_errors(patch.errors())); + } + + let ids: Vec = match filter { + BulkActionRequestFilter::All => { + let ids: Vec = sqlx::query_as("select id from members where system = $1") + .bind(system_id as i64) + .fetch_all(&ctx.db) + .await?; + ids.iter().map(|v| v.id).collect() + } + BulkActionRequestFilter::Ids { ids } => { + let members: Vec = sqlx::query_as( + "select * from members where hid = any($1::array) or uuid::text = any($1::array)", + ) + .bind(&ids) + .fetch_all(&ctx.db) + .await?; + + // todo: better errors + if members.len() != ids.len() { + return Err(TARGET_MEMBER_NOT_FOUND); + } + + if members.iter().any(|m| m.system != system_id) { + return Err(NOT_OWN_MEMBER); + } + + members.iter().map(|m| m.id).collect() + } + BulkActionRequestFilter::Connection { id } => { + let Some(group): Option = + sqlx::query_as("select * from groups where hid = $1 or uuid::text = $1") + .bind(id) + .fetch_optional(&ctx.db) + .await? + else { + return Err(TARGET_GROUP_NOT_FOUND); + }; + + if group.system != system_id { + return Err(NOT_OWN_GROUP); + } + + let entries: Vec = + sqlx::query_as("select * from group_members where group_id = $1") + .bind(group.id) + .fetch_all(&ctx.db) + .await?; + + entries.iter().map(|v| v.member_id).collect() + } + }; + + let (q, pms) = patch + .to_sql() + .table("members") // todo: this should be in the model definition + .and_where(Expr::col("id").is_in(ids)) + .returning_col("id") + .build_sqlx(PostgresQueryBuilder); + + let res: Vec = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?; + Ok(Json(json! {{ "updated": res.len() }})) + } + BulkActionRequest::Group { filter, mut patch } => { + patch.validate_bulk(); + if patch.errors().len() > 0 { + return Err(PKError::from_validation_errors(patch.errors())); + } + + let ids: Vec = match filter { + BulkActionRequestFilter::All => { + let ids: Vec = sqlx::query_as("select id from groups where system = $1") + .bind(system_id as i64) + .fetch_all(&ctx.db) + .await?; + ids.iter().map(|v| v.id).collect() + } + BulkActionRequestFilter::Ids { ids } => { + let groups: Vec = sqlx::query_as( + "select * from groups where hid = any($1) or uuid::text = any($1)", + ) + .bind(&ids) + .fetch_all(&ctx.db) + .await?; + + // todo: better errors + if groups.len() != ids.len() { + return Err(TARGET_GROUP_NOT_FOUND); + } + + if groups.iter().any(|m| m.system != system_id) { + return Err(NOT_OWN_GROUP); + } + + groups.iter().map(|m| m.id).collect() + } + BulkActionRequestFilter::Connection { id } => { + let Some(member): Option = + sqlx::query_as("select * from members where hid = $1 or uuid::text = $1") + .bind(id) + .fetch_optional(&ctx.db) + .await? + else { + return Err(TARGET_MEMBER_NOT_FOUND); + }; + + if member.system != system_id { + return Err(NOT_OWN_MEMBER); + } + + let entries: Vec = + sqlx::query_as("select * from group_members where member_id = $1") + .bind(member.id) + .fetch_all(&ctx.db) + .await?; + + entries.iter().map(|v| v.group_id).collect() + } + }; + + let (q, pms) = patch + .to_sql() + .table("groups") // todo: this should be in the model definition + .and_where(Expr::col("id").is_in(ids)) + .returning_col("id") + .build_sqlx(PostgresQueryBuilder); + + println!("{q:#?} {pms:#?}"); + + let res: Vec = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?; + Ok(Json(json! {{ "updated": res.len() }})) + } + } +} diff --git a/crates/api/src/endpoints/mod.rs b/crates/api/src/endpoints/mod.rs index c311367c..167acee8 100644 --- a/crates/api/src/endpoints/mod.rs +++ b/crates/api/src/endpoints/mod.rs @@ -1,2 +1,3 @@ +pub mod bulk; pub mod private; pub mod system; diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index fc481d0c..ae7e5a99 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -2,6 +2,7 @@ use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; +use pluralkit_models::ValidationError; use std::fmt; // todo: model parse errors @@ -11,6 +12,8 @@ pub struct PKError { pub json_code: i32, pub message: &'static str, + pub errors: Vec, + pub inner: Option, } @@ -30,6 +33,21 @@ impl Clone for PKError { json_code: self.json_code, message: self.message, inner: None, + errors: self.errors.clone(), + } + } +} + +// can't `impl From>` +// because "upstream crate may add a new impl" >:( +impl PKError { + pub fn from_validation_errors(errs: Vec) -> Self { + Self { + message: "Error parsing JSON model", + json_code: 40001, + errors: errs, + response_code: StatusCode::BAD_REQUEST, + inner: None, } } } @@ -50,14 +68,19 @@ impl IntoResponse for PKError { if let Some(inner) = self.inner { tracing::error!(?inner, "error returned from handler"); } - crate::util::json_err( - self.response_code, - serde_json::to_string(&serde_json::json!({ + let json = if self.errors.len() > 0 { + serde_json::json!({ "message": self.message, "code": self.json_code, - })) - .unwrap(), - ) + "errors": self.errors, + }) + } else { + serde_json::json!({ + "message": self.message, + "code": self.json_code, + }) + }; + crate::util::json_err(self.response_code, serde_json::to_string(&json).unwrap()) } } @@ -78,9 +101,17 @@ macro_rules! define_error { json_code: $json_code, message: $message, inner: None, + errors: vec![], }; }; } +define_error! { GENERIC_AUTH_ERROR, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" } define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" } define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" } + +define_error! { NOT_OWN_MEMBER, StatusCode::FORBIDDEN, 30006, "Target member is not part of your system." } +define_error! { NOT_OWN_GROUP, StatusCode::FORBIDDEN, 30007, "Target group is not part of your system." } + +define_error! { TARGET_MEMBER_NOT_FOUND, StatusCode::BAD_REQUEST, 40010, "Target member not found." } +define_error! { TARGET_GROUP_NOT_FOUND, StatusCode::BAD_REQUEST, 40011, "Target group not found." } diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index 5c2bcfd4..a6e2680a 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -115,6 +115,8 @@ fn router(ctx: ApiContext) -> Router { .route("/v2/messages/{message_id}", get(rproxy)) + .route("/v2/bulk", post(endpoints::bulk::bulk)) + .route("/private/bulk_privacy/member", post(rproxy)) .route("/private/bulk_privacy/group", post(rproxy)) .route("/private/discord/callback", post(rproxy)) diff --git a/crates/macros/src/model.rs b/crates/macros/src/model.rs index e37d0dde..5505e76a 100644 --- a/crates/macros/src/model.rs +++ b/crates/macros/src/model.rs @@ -85,8 +85,14 @@ fn parse_field(field: syn::Field) -> ModelField { panic!("must have json name to be publicly patchable"); } - if f.json.is_some() && f.is_privacy { - panic!("cannot set custom json name for privacy field"); + if f.is_privacy && f.json.is_none() { + f.json = Some(syn::Expr::Lit(syn::ExprLit { + attrs: vec![], + lit: syn::Lit::Str(syn::LitStr::new( + f.name.clone().to_string().as_str(), + proc_macro2::Span::call_site(), + )), + })) } f @@ -122,17 +128,17 @@ pub fn macro_impl( let fields: Vec = fields .iter() - .filter(|f| !matches!(f.patch, ElemPatchability::None)) + .filter(|f| f.is_privacy || !matches!(f.patch, ElemPatchability::None)) .cloned() .collect(); let patch_fields = mk_patch_fields(fields.clone()); - let patch_from_json = mk_patch_from_json(fields.clone()); let patch_validate = mk_patch_validate(fields.clone()); + let patch_validate_bulk = mk_patch_validate_bulk(fields.clone()); let patch_to_json = mk_patch_to_json(fields.clone()); let patch_to_sql = mk_patch_to_sql(fields.clone()); - return quote! { + let code = quote! { #[derive(sqlx::FromRow, Debug, Clone)] pub struct #tname { #tfields @@ -146,31 +152,42 @@ pub fn macro_impl( #to_json } - #[derive(Debug, Clone)] + #[derive(Debug, Clone, Default)] pub struct #patchable_name { #patch_fields + + errors: Vec, } impl #patchable_name { - pub fn from_json(input: String) -> Self { - #patch_from_json - } - - pub fn validate(self) -> bool { + pub fn validate(&mut self) { #patch_validate } + pub fn errors(&self) -> Vec { + self.errors.clone() + } + + pub fn validate_bulk(&mut self) { + #patch_validate_bulk + } + pub fn to_sql(self) -> sea_query::UpdateStatement { - // sea_query::Query::update() - #patch_to_sql + use sea_query::types::*; + let mut patch = &mut sea_query::Query::update(); + #patch_to_sql + patch.clone() } pub fn to_json(self) -> serde_json::Value { #patch_to_json } } - } - .into(); + }; + + // panic!("{:#?}", code.to_string()); + + return code.into(); } fn mk_tfields(fields: Vec) -> TokenStream { @@ -225,7 +242,7 @@ fn mk_tto_json(fields: Vec) -> TokenStream { .filter_map(|f| { if f.is_privacy { let tname = f.name.clone(); - let tnamestr = f.name.clone().to_string(); + let tnamestr = f.json.clone(); Some(quote! { #tnamestr: self.#tname, }) @@ -280,13 +297,48 @@ fn mk_patch_fields(fields: Vec) -> TokenStream { .collect() } fn mk_patch_validate(_fields: Vec) -> TokenStream { - quote! { true } -} -fn mk_patch_from_json(_fields: Vec) -> TokenStream { quote! { unimplemented!(); } } -fn mk_patch_to_sql(_fields: Vec) -> TokenStream { - quote! { unimplemented!(); } +fn mk_patch_validate_bulk(fields: Vec) -> TokenStream { + // iterate over all nullable patchable fields other than privacy + // add an error if any field is set to a value other than null + fields + .iter() + .map(|f| { + if let syn::Type::Path(path) = &f.ty && let Some(inner) = path.path.segments.last() && inner.ident != "Option" { + return quote! {}; + } + let name = f.name.clone(); + if matches!(f.patch, ElemPatchability::Public) { + let json = f.json.clone().unwrap(); + quote! { + if let Some(val) = self.#name.clone() && val.is_some() { + self.errors.push(ValidationError::simple(#json, "Only null values are supported in bulk endpoint")); + } + } + } else { + quote! {} + } + }) + .collect() +} +fn mk_patch_to_sql(fields: Vec) -> TokenStream { + fields + .iter() + .filter_map(|f| { + if !matches!(f.patch, ElemPatchability::None) || f.is_privacy { + let name = f.name.clone(); + let column = f.name.to_string(); + Some(quote! { + if let Some(value) = self.#name { + patch = patch.value(#column, value); + } + }) + } else { + None + } + }) + .collect() } fn mk_patch_to_json(_fields: Vec) -> TokenStream { quote! { unimplemented!(); } diff --git a/crates/models/Cargo.toml b/crates/models/Cargo.toml index 752fbaa5..93366a82 100644 --- a/crates/models/Cargo.toml +++ b/crates/models/Cargo.toml @@ -6,7 +6,7 @@ edition = "2024" [dependencies] chrono = { workspace = true, features = ["serde"] } pk_macros = { path = "../macros" } -sea-query = "0.32.1" +sea-query = { workspace = true } serde = { workspace = true } serde_json = { workspace = true, features = ["preserve_order"] } # in theory we want to default-features = false for sqlx diff --git a/crates/models/src/group.rs b/crates/models/src/group.rs new file mode 100644 index 00000000..ab94d27b --- /dev/null +++ b/crates/models/src/group.rs @@ -0,0 +1,132 @@ +use pk_macros::pk_model; + +use chrono::{DateTime, Utc}; +use serde::Deserialize; +use serde_json::Value; +use uuid::Uuid; + +use crate::{PrivacyLevel, SystemId, ValidationError}; + +// todo: fix +pub type GroupId = i32; + +#[pk_model] +struct Group { + id: GroupId, + #[json = "hid"] + #[private_patchable] + hid: String, + #[json = "uuid"] + uuid: Uuid, + // TODO fix + #[json = "system"] + system: SystemId, + + #[json = "name"] + #[privacy = name_privacy] + #[patchable] + name: String, + #[json = "display_name"] + #[patchable] + display_name: Option, + #[json = "color"] + #[patchable] + color: Option, + #[json = "icon"] + #[patchable] + icon: Option, + #[json = "banner_image"] + #[patchable] + banner_image: Option, + #[json = "description"] + #[privacy = description_privacy] + #[patchable] + description: Option, + #[json = "created"] + created: DateTime, + + #[privacy] + name_privacy: PrivacyLevel, + #[privacy] + description_privacy: PrivacyLevel, + #[privacy] + banner_privacy: PrivacyLevel, + #[privacy] + icon_privacy: PrivacyLevel, + #[privacy] + list_privacy: PrivacyLevel, + #[privacy] + metadata_privacy: PrivacyLevel, + #[privacy] + visibility: PrivacyLevel, +} + +impl<'de> Deserialize<'de> for PKGroupPatch { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut patch: PKGroupPatch = Default::default(); + let value: Value = Value::deserialize(deserializer)?; + + if let Some(v) = value.get("name") { + if let Some(name) = v.as_str() { + patch.name = Some(name.to_string()); + } else if v.is_null() { + patch.errors.push(ValidationError::simple( + "name", + "Group name cannot be set to null.", + )); + } + } + + macro_rules! parse_string_simple { + ($k:expr) => { + match value.get($k) { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s.clone())), + _ => { + patch.errors.push(ValidationError::new($k)); + None + } + } + }; + } + + patch.display_name = parse_string_simple!("display_name"); + patch.description = parse_string_simple!("description"); + patch.icon = parse_string_simple!("icon"); + patch.banner_image = parse_string_simple!("banner"); + patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase())); + + if let Some(privacy) = value.get("privacy").and_then(Value::as_object) { + macro_rules! parse_privacy { + ($v:expr) => { + match privacy.get($v) { + None => None, + Some(Value::Null) => Some(PrivacyLevel::Private), + Some(Value::String(s)) if s == "" || s == "private" => { + Some(PrivacyLevel::Private) + } + Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public), + _ => { + patch.errors.push(ValidationError::new($v)); + None + } + } + }; + } + + patch.name_privacy = parse_privacy!("name_privacy"); + patch.description_privacy = parse_privacy!("description_privacy"); + patch.banner_privacy = parse_privacy!("banner_privacy"); + patch.icon_privacy = parse_privacy!("icon_privacy"); + patch.list_privacy = parse_privacy!("list_privacy"); + patch.metadata_privacy = parse_privacy!("metadata_privacy"); + patch.visibility = parse_privacy!("visibility"); + } + + Ok(patch) + } +} diff --git a/crates/models/src/lib.rs b/crates/models/src/lib.rs index 08350488..6bb4adf4 100644 --- a/crates/models/src/lib.rs +++ b/crates/models/src/lib.rs @@ -9,6 +9,8 @@ macro_rules! model { model!(system); model!(system_config); +model!(member); +model!(group); #[derive(serde::Serialize, Debug, Clone)] #[serde(rename_all = "snake_case")] @@ -31,3 +33,30 @@ impl From for PrivacyLevel { } } } + +impl From for sea_query::Value { + fn from(level: PrivacyLevel) -> sea_query::Value { + match level { + PrivacyLevel::Public => sea_query::Value::Int(Some(1)), + PrivacyLevel::Private => sea_query::Value::Int(Some(2)), + } + } +} + +#[derive(serde::Serialize, Debug, Clone)] +pub enum ValidationError { + Simple { key: String, value: String }, +} + +impl ValidationError { + fn new(key: &str) -> Self { + Self::simple(key, "is invalid") + } + + fn simple(key: &str, value: &str) -> Self { + Self::Simple { + key: key.to_string(), + value: value.to_string(), + } + } +} diff --git a/crates/models/src/member.rs b/crates/models/src/member.rs new file mode 100644 index 00000000..84109cbe --- /dev/null +++ b/crates/models/src/member.rs @@ -0,0 +1,208 @@ +use pk_macros::pk_model; + +use chrono::NaiveDateTime; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use uuid::Uuid; + +use crate::{PrivacyLevel, SystemId, ValidationError}; + +// todo: fix +pub type MemberId = i32; + +#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)] +#[sqlx(type_name = "proxy_tag")] +pub struct ProxyTag { + pub prefix: Option, + pub suffix: Option, +} + +#[pk_model] +struct Member { + id: MemberId, + #[json = "hid"] + #[private_patchable] + hid: String, + #[json = "uuid"] + uuid: Uuid, + // TODO fix + #[json = "system"] + system: SystemId, + + #[json = "color"] + #[patchable] + color: Option, + #[json = "webhook_avatar_url"] + #[patchable] + webhook_avatar_url: Option, + #[json = "avatar_url"] + #[patchable] + avatar_url: Option, + #[json = "banner_image"] + #[patchable] + banner_image: Option, + #[json = "name"] + #[privacy = name_privacy] + #[patchable] + name: String, + #[json = "display_name"] + #[patchable] + display_name: Option, + #[json = "birthday"] + #[patchable] + birthday: Option, + #[json = "pronouns"] + #[privacy = pronoun_privacy] + #[patchable] + pronouns: Option, + #[json = "description"] + #[privacy = description_privacy] + #[patchable] + description: Option, + #[json = "proxy_tags"] + // #[patchable] + proxy_tags: Vec, + #[json = "keep_proxy"] + #[patchable] + keep_proxy: bool, + #[json = "tts"] + #[patchable] + tts: bool, + #[json = "created"] + created: NaiveDateTime, + #[json = "message_count"] + #[private_patchable] + message_count: i32, + #[json = "last_message_timestamp"] + #[private_patchable] + last_message_timestamp: Option, + #[json = "allow_autoproxy"] + #[patchable] + allow_autoproxy: bool, + + #[privacy] + #[json = "visibility"] + member_visibility: PrivacyLevel, + #[privacy] + description_privacy: PrivacyLevel, + #[privacy] + banner_privacy: PrivacyLevel, + #[privacy] + avatar_privacy: PrivacyLevel, + #[privacy] + name_privacy: PrivacyLevel, + #[privacy] + birthday_privacy: PrivacyLevel, + #[privacy] + pronoun_privacy: PrivacyLevel, + #[privacy] + metadata_privacy: PrivacyLevel, + #[privacy] + proxy_privacy: PrivacyLevel, +} + +impl<'de> Deserialize<'de> for PKMemberPatch { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let mut patch: PKMemberPatch = Default::default(); + let value: Value = Value::deserialize(deserializer)?; + + if let Some(v) = value.get("name") { + if let Some(name) = v.as_str() { + patch.name = Some(name.to_string()); + } else if v.is_null() { + patch.errors.push(ValidationError::simple( + "name", + "Member name cannot be set to null.", + )); + } + } + + macro_rules! parse_string_simple { + ($k:expr) => { + match value.get($k) { + None => None, + Some(Value::Null) => Some(None), + Some(Value::String(s)) => Some(Some(s.clone())), + _ => { + patch.errors.push(ValidationError::new($k)); + None + } + } + }; + } + + patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase())); + patch.display_name = parse_string_simple!("display_name"); + patch.avatar_url = parse_string_simple!("avatar_url"); + patch.banner_image = parse_string_simple!("banner"); + patch.birthday = parse_string_simple!("birthday"); // fix + patch.pronouns = parse_string_simple!("pronouns"); + patch.description = parse_string_simple!("description"); + + if let Some(keep_proxy) = value.get("keep_proxy").and_then(Value::as_bool) { + patch.keep_proxy = Some(keep_proxy); + } + if let Some(tts) = value.get("tts").and_then(Value::as_bool) { + patch.tts = Some(tts); + } + + // todo: legacy import handling + + // todo: fix proxy_tag type in sea_query + + // if let Some(proxy_tags) = value.get("proxy_tags").and_then(Value::as_array) { + // patch.proxy_tags = Some( + // proxy_tags + // .iter() + // .filter_map(|tag| { + // tag.as_object().map(|tag_obj| { + // let prefix = tag_obj + // .get("prefix") + // .and_then(Value::as_str) + // .map(|s| s.to_string()); + // let suffix = tag_obj + // .get("suffix") + // .and_then(Value::as_str) + // .map(|s| s.to_string()); + // ProxyTag { prefix, suffix } + // }) + // }) + // .collect(), + // ) + // } + + if let Some(privacy) = value.get("privacy").and_then(Value::as_object) { + macro_rules! parse_privacy { + ($v:expr) => { + match privacy.get($v) { + None => None, + Some(Value::Null) => Some(PrivacyLevel::Private), + Some(Value::String(s)) if s == "" || s == "private" => { + Some(PrivacyLevel::Private) + } + Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public), + _ => { + patch.errors.push(ValidationError::new($v)); + None + } + } + }; + } + + patch.member_visibility = parse_privacy!("visibility"); + patch.name_privacy = parse_privacy!("name_privacy"); + patch.description_privacy = parse_privacy!("description_privacy"); + patch.banner_privacy = parse_privacy!("banner_privacy"); + patch.avatar_privacy = parse_privacy!("avatar_privacy"); + patch.birthday_privacy = parse_privacy!("birthday_privacy"); + patch.pronoun_privacy = parse_privacy!("pronoun_privacy"); + patch.proxy_privacy = parse_privacy!("proxy_privacy"); + patch.metadata_privacy = parse_privacy!("metadata_privacy"); + } + + Ok(patch) + } +}