mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
add /api/v2/bulk endpoint
also, initial support for patch models in rust!
This commit is contained in:
parent
1776902000
commit
034865cc13
12 changed files with 715 additions and 32 deletions
22
Cargo.lock
generated
22
Cargo.lock
generated
|
|
@ -92,6 +92,8 @@ dependencies = [
|
||||||
"pluralkit_models",
|
"pluralkit_models",
|
||||||
"reqwest 0.12.15",
|
"reqwest 0.12.15",
|
||||||
"reverse-proxy-service",
|
"reverse-proxy-service",
|
||||||
|
"sea-query",
|
||||||
|
"sea-query-sqlx",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
|
|
@ -3372,19 +3374,20 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sea-query"
|
name = "sea-query"
|
||||||
version = "0.32.3"
|
version = "1.0.0-rc.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "f5a24d8b9fcd2674a6c878a3d871f4f1380c6c43cc3718728ac96864d888458e"
|
checksum = "ab621a8d8b03a3e513ea075f71aa26830a55c977d7b40f09e825bb91910db823"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"chrono",
|
||||||
"inherent",
|
"inherent",
|
||||||
"sea-query-derive",
|
"sea-query-derive",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "sea-query-derive"
|
name = "sea-query-derive"
|
||||||
version = "0.4.3"
|
version = "1.0.0-rc.9"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bae0cbad6ab996955664982739354128c58d16e126114fe88c2a493642502aab"
|
checksum = "217e9422de35f26c16c5f671fce3c075a65e10322068dbc66078428634af6195"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"darling",
|
"darling",
|
||||||
"heck 0.4.1",
|
"heck 0.4.1",
|
||||||
|
|
@ -3394,6 +3397,17 @@ dependencies = [
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "sea-query-sqlx"
|
||||||
|
version = "0.8.0-rc.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ed5eb19495858d8ae3663387a4f5298516c6f0171a7ca5681055450f190236b8"
|
||||||
|
dependencies = [
|
||||||
|
"chrono",
|
||||||
|
"sea-query",
|
||||||
|
"sqlx",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "security-framework"
|
name = "security-framework"
|
||||||
version = "3.2.0"
|
version = "3.2.0"
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ futures = "0.3.30"
|
||||||
lazy_static = "1.4.0"
|
lazy_static = "1.4.0"
|
||||||
metrics = "0.23.0"
|
metrics = "0.23.0"
|
||||||
reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]}
|
reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]}
|
||||||
|
sea-query = { version = "1.0.0-rc.10", features = ["with-chrono"] }
|
||||||
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
|
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
|
||||||
serde = { version = "1.0.196", features = ["derive"] }
|
serde = { version = "1.0.196", features = ["derive"] }
|
||||||
serde_json = "1.0.117"
|
serde_json = "1.0.117"
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ fred = { workspace = true }
|
||||||
lazy_static = { workspace = true }
|
lazy_static = { workspace = true }
|
||||||
metrics = { workspace = true }
|
metrics = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
|
sea-query = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
sqlx = { workspace = true }
|
sqlx = { workspace = true }
|
||||||
|
|
@ -28,3 +29,4 @@ serde_urlencoded = "0.7.1"
|
||||||
tower = "0.4.13"
|
tower = "0.4.13"
|
||||||
tower-http = { version = "0.5.2", features = ["catch-panic"] }
|
tower-http = { version = "0.5.2", features = ["catch-panic"] }
|
||||||
subtle = "2.6.1"
|
subtle = "2.6.1"
|
||||||
|
sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] }
|
||||||
|
|
|
||||||
211
crates/api/src/endpoints/bulk.rs
Normal file
211
crates/api/src/endpoints/bulk.rs
Normal file
|
|
@ -0,0 +1,211 @@
|
||||||
|
use axum::{
|
||||||
|
Extension, Json,
|
||||||
|
extract::{Json as ExtractJson, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use pk_macros::api_endpoint;
|
||||||
|
use sea_query::{Expr, ExprTrait, PostgresQueryBuilder};
|
||||||
|
use sea_query_sqlx::SqlxBinder;
|
||||||
|
use serde_json::{Value, json};
|
||||||
|
|
||||||
|
use pluralkit_models::{PKGroup, PKGroupPatch, PKMember, PKMemberPatch, PKSystem};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
ApiContext,
|
||||||
|
auth::AuthState,
|
||||||
|
error::{
|
||||||
|
GENERIC_AUTH_ERROR, NOT_OWN_GROUP, NOT_OWN_MEMBER, PKError, TARGET_GROUP_NOT_FOUND,
|
||||||
|
TARGET_MEMBER_NOT_FOUND,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum BulkActionRequestFilter {
|
||||||
|
All,
|
||||||
|
Ids { ids: Vec<String> },
|
||||||
|
Connection { id: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize, Debug)]
|
||||||
|
#[serde(tag = "type", rename_all = "snake_case")]
|
||||||
|
pub enum BulkActionRequest {
|
||||||
|
Member {
|
||||||
|
filter: BulkActionRequestFilter,
|
||||||
|
patch: PKMemberPatch,
|
||||||
|
},
|
||||||
|
Group {
|
||||||
|
filter: BulkActionRequestFilter,
|
||||||
|
patch: PKGroupPatch,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[api_endpoint]
|
||||||
|
pub async fn bulk(
|
||||||
|
Extension(auth): Extension<AuthState>,
|
||||||
|
State(ctx): State<ApiContext>,
|
||||||
|
ExtractJson(req): ExtractJson<BulkActionRequest>,
|
||||||
|
) -> Json<Value> {
|
||||||
|
let Some(system_id) = auth.system_id() else {
|
||||||
|
return Err(GENERIC_AUTH_ERROR);
|
||||||
|
};
|
||||||
|
|
||||||
|
#[derive(sqlx::FromRow)]
|
||||||
|
struct Ider {
|
||||||
|
id: i32,
|
||||||
|
hid: String,
|
||||||
|
uuid: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(sqlx::FromRow)]
|
||||||
|
struct GroupMemberEntry {
|
||||||
|
member_id: i32,
|
||||||
|
group_id: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
#[derive(sqlx::FromRow)]
|
||||||
|
struct OnlyIder {
|
||||||
|
id: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("BulkActionRequest::{req:#?}");
|
||||||
|
match req {
|
||||||
|
BulkActionRequest::Member { filter, mut patch } => {
|
||||||
|
patch.validate_bulk();
|
||||||
|
if patch.errors().len() > 0 {
|
||||||
|
return Err(PKError::from_validation_errors(patch.errors()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ids: Vec<i32> = match filter {
|
||||||
|
BulkActionRequestFilter::All => {
|
||||||
|
let ids: Vec<Ider> = sqlx::query_as("select id from members where system = $1")
|
||||||
|
.bind(system_id as i64)
|
||||||
|
.fetch_all(&ctx.db)
|
||||||
|
.await?;
|
||||||
|
ids.iter().map(|v| v.id).collect()
|
||||||
|
}
|
||||||
|
BulkActionRequestFilter::Ids { ids } => {
|
||||||
|
let members: Vec<PKMember> = sqlx::query_as(
|
||||||
|
"select * from members where hid = any($1::array) or uuid::text = any($1::array)",
|
||||||
|
)
|
||||||
|
.bind(&ids)
|
||||||
|
.fetch_all(&ctx.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// todo: better errors
|
||||||
|
if members.len() != ids.len() {
|
||||||
|
return Err(TARGET_MEMBER_NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
if members.iter().any(|m| m.system != system_id) {
|
||||||
|
return Err(NOT_OWN_MEMBER);
|
||||||
|
}
|
||||||
|
|
||||||
|
members.iter().map(|m| m.id).collect()
|
||||||
|
}
|
||||||
|
BulkActionRequestFilter::Connection { id } => {
|
||||||
|
let Some(group): Option<PKGroup> =
|
||||||
|
sqlx::query_as("select * from groups where hid = $1 or uuid::text = $1")
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&ctx.db)
|
||||||
|
.await?
|
||||||
|
else {
|
||||||
|
return Err(TARGET_GROUP_NOT_FOUND);
|
||||||
|
};
|
||||||
|
|
||||||
|
if group.system != system_id {
|
||||||
|
return Err(NOT_OWN_GROUP);
|
||||||
|
}
|
||||||
|
|
||||||
|
let entries: Vec<GroupMemberEntry> =
|
||||||
|
sqlx::query_as("select * from group_members where group_id = $1")
|
||||||
|
.bind(group.id)
|
||||||
|
.fetch_all(&ctx.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
entries.iter().map(|v| v.member_id).collect()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (q, pms) = patch
|
||||||
|
.to_sql()
|
||||||
|
.table("members") // todo: this should be in the model definition
|
||||||
|
.and_where(Expr::col("id").is_in(ids))
|
||||||
|
.returning_col("id")
|
||||||
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
let res: Vec<OnlyIder> = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?;
|
||||||
|
Ok(Json(json! {{ "updated": res.len() }}))
|
||||||
|
}
|
||||||
|
BulkActionRequest::Group { filter, mut patch } => {
|
||||||
|
patch.validate_bulk();
|
||||||
|
if patch.errors().len() > 0 {
|
||||||
|
return Err(PKError::from_validation_errors(patch.errors()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let ids: Vec<i32> = match filter {
|
||||||
|
BulkActionRequestFilter::All => {
|
||||||
|
let ids: Vec<Ider> = sqlx::query_as("select id from groups where system = $1")
|
||||||
|
.bind(system_id as i64)
|
||||||
|
.fetch_all(&ctx.db)
|
||||||
|
.await?;
|
||||||
|
ids.iter().map(|v| v.id).collect()
|
||||||
|
}
|
||||||
|
BulkActionRequestFilter::Ids { ids } => {
|
||||||
|
let groups: Vec<PKGroup> = sqlx::query_as(
|
||||||
|
"select * from groups where hid = any($1) or uuid::text = any($1)",
|
||||||
|
)
|
||||||
|
.bind(&ids)
|
||||||
|
.fetch_all(&ctx.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// todo: better errors
|
||||||
|
if groups.len() != ids.len() {
|
||||||
|
return Err(TARGET_GROUP_NOT_FOUND);
|
||||||
|
}
|
||||||
|
|
||||||
|
if groups.iter().any(|m| m.system != system_id) {
|
||||||
|
return Err(NOT_OWN_GROUP);
|
||||||
|
}
|
||||||
|
|
||||||
|
groups.iter().map(|m| m.id).collect()
|
||||||
|
}
|
||||||
|
BulkActionRequestFilter::Connection { id } => {
|
||||||
|
let Some(member): Option<PKMember> =
|
||||||
|
sqlx::query_as("select * from members where hid = $1 or uuid::text = $1")
|
||||||
|
.bind(id)
|
||||||
|
.fetch_optional(&ctx.db)
|
||||||
|
.await?
|
||||||
|
else {
|
||||||
|
return Err(TARGET_MEMBER_NOT_FOUND);
|
||||||
|
};
|
||||||
|
|
||||||
|
if member.system != system_id {
|
||||||
|
return Err(NOT_OWN_MEMBER);
|
||||||
|
}
|
||||||
|
|
||||||
|
let entries: Vec<GroupMemberEntry> =
|
||||||
|
sqlx::query_as("select * from group_members where member_id = $1")
|
||||||
|
.bind(member.id)
|
||||||
|
.fetch_all(&ctx.db)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
entries.iter().map(|v| v.group_id).collect()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let (q, pms) = patch
|
||||||
|
.to_sql()
|
||||||
|
.table("groups") // todo: this should be in the model definition
|
||||||
|
.and_where(Expr::col("id").is_in(ids))
|
||||||
|
.returning_col("id")
|
||||||
|
.build_sqlx(PostgresQueryBuilder);
|
||||||
|
|
||||||
|
println!("{q:#?} {pms:#?}");
|
||||||
|
|
||||||
|
let res: Vec<OnlyIder> = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?;
|
||||||
|
Ok(Json(json! {{ "updated": res.len() }}))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,2 +1,3 @@
|
||||||
|
pub mod bulk;
|
||||||
pub mod private;
|
pub mod private;
|
||||||
pub mod system;
|
pub mod system;
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,7 @@ use axum::{
|
||||||
http::StatusCode,
|
http::StatusCode,
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
|
use pluralkit_models::ValidationError;
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
// todo: model parse errors
|
// todo: model parse errors
|
||||||
|
|
@ -11,6 +12,8 @@ pub struct PKError {
|
||||||
pub json_code: i32,
|
pub json_code: i32,
|
||||||
pub message: &'static str,
|
pub message: &'static str,
|
||||||
|
|
||||||
|
pub errors: Vec<ValidationError>,
|
||||||
|
|
||||||
pub inner: Option<anyhow::Error>,
|
pub inner: Option<anyhow::Error>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -30,6 +33,21 @@ impl Clone for PKError {
|
||||||
json_code: self.json_code,
|
json_code: self.json_code,
|
||||||
message: self.message,
|
message: self.message,
|
||||||
inner: None,
|
inner: None,
|
||||||
|
errors: self.errors.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// can't `impl From<Vec<ValidationError>>`
|
||||||
|
// because "upstream crate may add a new impl" >:(
|
||||||
|
impl PKError {
|
||||||
|
pub fn from_validation_errors(errs: Vec<ValidationError>) -> Self {
|
||||||
|
Self {
|
||||||
|
message: "Error parsing JSON model",
|
||||||
|
json_code: 40001,
|
||||||
|
errors: errs,
|
||||||
|
response_code: StatusCode::BAD_REQUEST,
|
||||||
|
inner: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -50,14 +68,19 @@ impl IntoResponse for PKError {
|
||||||
if let Some(inner) = self.inner {
|
if let Some(inner) = self.inner {
|
||||||
tracing::error!(?inner, "error returned from handler");
|
tracing::error!(?inner, "error returned from handler");
|
||||||
}
|
}
|
||||||
crate::util::json_err(
|
let json = if self.errors.len() > 0 {
|
||||||
self.response_code,
|
serde_json::json!({
|
||||||
serde_json::to_string(&serde_json::json!({
|
|
||||||
"message": self.message,
|
"message": self.message,
|
||||||
"code": self.json_code,
|
"code": self.json_code,
|
||||||
}))
|
"errors": self.errors,
|
||||||
.unwrap(),
|
})
|
||||||
)
|
} else {
|
||||||
|
serde_json::json!({
|
||||||
|
"message": self.message,
|
||||||
|
"code": self.json_code,
|
||||||
|
})
|
||||||
|
};
|
||||||
|
crate::util::json_err(self.response_code, serde_json::to_string(&json).unwrap())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -78,9 +101,17 @@ macro_rules! define_error {
|
||||||
json_code: $json_code,
|
json_code: $json_code,
|
||||||
message: $message,
|
message: $message,
|
||||||
inner: None,
|
inner: None,
|
||||||
|
errors: vec![],
|
||||||
};
|
};
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
define_error! { GENERIC_AUTH_ERROR, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" }
|
||||||
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
|
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
|
||||||
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }
|
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }
|
||||||
|
|
||||||
|
define_error! { NOT_OWN_MEMBER, StatusCode::FORBIDDEN, 30006, "Target member is not part of your system." }
|
||||||
|
define_error! { NOT_OWN_GROUP, StatusCode::FORBIDDEN, 30007, "Target group is not part of your system." }
|
||||||
|
|
||||||
|
define_error! { TARGET_MEMBER_NOT_FOUND, StatusCode::BAD_REQUEST, 40010, "Target member not found." }
|
||||||
|
define_error! { TARGET_GROUP_NOT_FOUND, StatusCode::BAD_REQUEST, 40011, "Target group not found." }
|
||||||
|
|
|
||||||
|
|
@ -115,6 +115,8 @@ fn router(ctx: ApiContext) -> Router {
|
||||||
|
|
||||||
.route("/v2/messages/{message_id}", get(rproxy))
|
.route("/v2/messages/{message_id}", get(rproxy))
|
||||||
|
|
||||||
|
.route("/v2/bulk", post(endpoints::bulk::bulk))
|
||||||
|
|
||||||
.route("/private/bulk_privacy/member", post(rproxy))
|
.route("/private/bulk_privacy/member", post(rproxy))
|
||||||
.route("/private/bulk_privacy/group", post(rproxy))
|
.route("/private/bulk_privacy/group", post(rproxy))
|
||||||
.route("/private/discord/callback", post(rproxy))
|
.route("/private/discord/callback", post(rproxy))
|
||||||
|
|
|
||||||
|
|
@ -85,8 +85,14 @@ fn parse_field(field: syn::Field) -> ModelField {
|
||||||
panic!("must have json name to be publicly patchable");
|
panic!("must have json name to be publicly patchable");
|
||||||
}
|
}
|
||||||
|
|
||||||
if f.json.is_some() && f.is_privacy {
|
if f.is_privacy && f.json.is_none() {
|
||||||
panic!("cannot set custom json name for privacy field");
|
f.json = Some(syn::Expr::Lit(syn::ExprLit {
|
||||||
|
attrs: vec![],
|
||||||
|
lit: syn::Lit::Str(syn::LitStr::new(
|
||||||
|
f.name.clone().to_string().as_str(),
|
||||||
|
proc_macro2::Span::call_site(),
|
||||||
|
)),
|
||||||
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
f
|
f
|
||||||
|
|
@ -122,17 +128,17 @@ pub fn macro_impl(
|
||||||
|
|
||||||
let fields: Vec<ModelField> = fields
|
let fields: Vec<ModelField> = fields
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|f| !matches!(f.patch, ElemPatchability::None))
|
.filter(|f| f.is_privacy || !matches!(f.patch, ElemPatchability::None))
|
||||||
.cloned()
|
.cloned()
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let patch_fields = mk_patch_fields(fields.clone());
|
let patch_fields = mk_patch_fields(fields.clone());
|
||||||
let patch_from_json = mk_patch_from_json(fields.clone());
|
|
||||||
let patch_validate = mk_patch_validate(fields.clone());
|
let patch_validate = mk_patch_validate(fields.clone());
|
||||||
|
let patch_validate_bulk = mk_patch_validate_bulk(fields.clone());
|
||||||
let patch_to_json = mk_patch_to_json(fields.clone());
|
let patch_to_json = mk_patch_to_json(fields.clone());
|
||||||
let patch_to_sql = mk_patch_to_sql(fields.clone());
|
let patch_to_sql = mk_patch_to_sql(fields.clone());
|
||||||
|
|
||||||
return quote! {
|
let code = quote! {
|
||||||
#[derive(sqlx::FromRow, Debug, Clone)]
|
#[derive(sqlx::FromRow, Debug, Clone)]
|
||||||
pub struct #tname {
|
pub struct #tname {
|
||||||
#tfields
|
#tfields
|
||||||
|
|
@ -146,31 +152,42 @@ pub fn macro_impl(
|
||||||
#to_json
|
#to_json
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone, Default)]
|
||||||
pub struct #patchable_name {
|
pub struct #patchable_name {
|
||||||
#patch_fields
|
#patch_fields
|
||||||
|
|
||||||
|
errors: Vec<crate::ValidationError>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl #patchable_name {
|
impl #patchable_name {
|
||||||
pub fn from_json(input: String) -> Self {
|
pub fn validate(&mut self) {
|
||||||
#patch_from_json
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn validate(self) -> bool {
|
|
||||||
#patch_validate
|
#patch_validate
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn errors(&self) -> Vec<crate::ValidationError> {
|
||||||
|
self.errors.clone()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn validate_bulk(&mut self) {
|
||||||
|
#patch_validate_bulk
|
||||||
|
}
|
||||||
|
|
||||||
pub fn to_sql(self) -> sea_query::UpdateStatement {
|
pub fn to_sql(self) -> sea_query::UpdateStatement {
|
||||||
// sea_query::Query::update()
|
use sea_query::types::*;
|
||||||
#patch_to_sql
|
let mut patch = &mut sea_query::Query::update();
|
||||||
|
#patch_to_sql
|
||||||
|
patch.clone()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_json(self) -> serde_json::Value {
|
pub fn to_json(self) -> serde_json::Value {
|
||||||
#patch_to_json
|
#patch_to_json
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
};
|
||||||
.into();
|
|
||||||
|
// panic!("{:#?}", code.to_string());
|
||||||
|
|
||||||
|
return code.into();
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mk_tfields(fields: Vec<ModelField>) -> TokenStream {
|
fn mk_tfields(fields: Vec<ModelField>) -> TokenStream {
|
||||||
|
|
@ -225,7 +242,7 @@ fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
||||||
.filter_map(|f| {
|
.filter_map(|f| {
|
||||||
if f.is_privacy {
|
if f.is_privacy {
|
||||||
let tname = f.name.clone();
|
let tname = f.name.clone();
|
||||||
let tnamestr = f.name.clone().to_string();
|
let tnamestr = f.json.clone();
|
||||||
Some(quote! {
|
Some(quote! {
|
||||||
#tnamestr: self.#tname,
|
#tnamestr: self.#tname,
|
||||||
})
|
})
|
||||||
|
|
@ -280,13 +297,48 @@ fn mk_patch_fields(fields: Vec<ModelField>) -> TokenStream {
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
fn mk_patch_validate(_fields: Vec<ModelField>) -> TokenStream {
|
fn mk_patch_validate(_fields: Vec<ModelField>) -> TokenStream {
|
||||||
quote! { true }
|
|
||||||
}
|
|
||||||
fn mk_patch_from_json(_fields: Vec<ModelField>) -> TokenStream {
|
|
||||||
quote! { unimplemented!(); }
|
quote! { unimplemented!(); }
|
||||||
}
|
}
|
||||||
fn mk_patch_to_sql(_fields: Vec<ModelField>) -> TokenStream {
|
fn mk_patch_validate_bulk(fields: Vec<ModelField>) -> TokenStream {
|
||||||
quote! { unimplemented!(); }
|
// iterate over all nullable patchable fields other than privacy
|
||||||
|
// add an error if any field is set to a value other than null
|
||||||
|
fields
|
||||||
|
.iter()
|
||||||
|
.map(|f| {
|
||||||
|
if let syn::Type::Path(path) = &f.ty && let Some(inner) = path.path.segments.last() && inner.ident != "Option" {
|
||||||
|
return quote! {};
|
||||||
|
}
|
||||||
|
let name = f.name.clone();
|
||||||
|
if matches!(f.patch, ElemPatchability::Public) {
|
||||||
|
let json = f.json.clone().unwrap();
|
||||||
|
quote! {
|
||||||
|
if let Some(val) = self.#name.clone() && val.is_some() {
|
||||||
|
self.errors.push(ValidationError::simple(#json, "Only null values are supported in bulk endpoint"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
quote! {}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
fn mk_patch_to_sql(fields: Vec<ModelField>) -> TokenStream {
|
||||||
|
fields
|
||||||
|
.iter()
|
||||||
|
.filter_map(|f| {
|
||||||
|
if !matches!(f.patch, ElemPatchability::None) || f.is_privacy {
|
||||||
|
let name = f.name.clone();
|
||||||
|
let column = f.name.to_string();
|
||||||
|
Some(quote! {
|
||||||
|
if let Some(value) = self.#name {
|
||||||
|
patch = patch.value(#column, value);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
}
|
}
|
||||||
fn mk_patch_to_json(_fields: Vec<ModelField>) -> TokenStream {
|
fn mk_patch_to_json(_fields: Vec<ModelField>) -> TokenStream {
|
||||||
quote! { unimplemented!(); }
|
quote! { unimplemented!(); }
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ edition = "2024"
|
||||||
[dependencies]
|
[dependencies]
|
||||||
chrono = { workspace = true, features = ["serde"] }
|
chrono = { workspace = true, features = ["serde"] }
|
||||||
pk_macros = { path = "../macros" }
|
pk_macros = { path = "../macros" }
|
||||||
sea-query = "0.32.1"
|
sea-query = { workspace = true }
|
||||||
serde = { workspace = true }
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true, features = ["preserve_order"] }
|
serde_json = { workspace = true, features = ["preserve_order"] }
|
||||||
# in theory we want to default-features = false for sqlx
|
# in theory we want to default-features = false for sqlx
|
||||||
|
|
|
||||||
132
crates/models/src/group.rs
Normal file
132
crates/models/src/group.rs
Normal file
|
|
@ -0,0 +1,132 @@
|
||||||
|
use pk_macros::pk_model;
|
||||||
|
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::Deserialize;
|
||||||
|
use serde_json::Value;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{PrivacyLevel, SystemId, ValidationError};
|
||||||
|
|
||||||
|
// todo: fix
|
||||||
|
pub type GroupId = i32;
|
||||||
|
|
||||||
|
#[pk_model]
|
||||||
|
struct Group {
|
||||||
|
id: GroupId,
|
||||||
|
#[json = "hid"]
|
||||||
|
#[private_patchable]
|
||||||
|
hid: String,
|
||||||
|
#[json = "uuid"]
|
||||||
|
uuid: Uuid,
|
||||||
|
// TODO fix
|
||||||
|
#[json = "system"]
|
||||||
|
system: SystemId,
|
||||||
|
|
||||||
|
#[json = "name"]
|
||||||
|
#[privacy = name_privacy]
|
||||||
|
#[patchable]
|
||||||
|
name: String,
|
||||||
|
#[json = "display_name"]
|
||||||
|
#[patchable]
|
||||||
|
display_name: Option<String>,
|
||||||
|
#[json = "color"]
|
||||||
|
#[patchable]
|
||||||
|
color: Option<String>,
|
||||||
|
#[json = "icon"]
|
||||||
|
#[patchable]
|
||||||
|
icon: Option<String>,
|
||||||
|
#[json = "banner_image"]
|
||||||
|
#[patchable]
|
||||||
|
banner_image: Option<String>,
|
||||||
|
#[json = "description"]
|
||||||
|
#[privacy = description_privacy]
|
||||||
|
#[patchable]
|
||||||
|
description: Option<String>,
|
||||||
|
#[json = "created"]
|
||||||
|
created: DateTime<Utc>,
|
||||||
|
|
||||||
|
#[privacy]
|
||||||
|
name_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
description_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
banner_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
icon_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
list_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
metadata_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
visibility: PrivacyLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for PKGroupPatch {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let mut patch: PKGroupPatch = Default::default();
|
||||||
|
let value: Value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
if let Some(v) = value.get("name") {
|
||||||
|
if let Some(name) = v.as_str() {
|
||||||
|
patch.name = Some(name.to_string());
|
||||||
|
} else if v.is_null() {
|
||||||
|
patch.errors.push(ValidationError::simple(
|
||||||
|
"name",
|
||||||
|
"Group name cannot be set to null.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! parse_string_simple {
|
||||||
|
($k:expr) => {
|
||||||
|
match value.get($k) {
|
||||||
|
None => None,
|
||||||
|
Some(Value::Null) => Some(None),
|
||||||
|
Some(Value::String(s)) => Some(Some(s.clone())),
|
||||||
|
_ => {
|
||||||
|
patch.errors.push(ValidationError::new($k));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
patch.display_name = parse_string_simple!("display_name");
|
||||||
|
patch.description = parse_string_simple!("description");
|
||||||
|
patch.icon = parse_string_simple!("icon");
|
||||||
|
patch.banner_image = parse_string_simple!("banner");
|
||||||
|
patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase()));
|
||||||
|
|
||||||
|
if let Some(privacy) = value.get("privacy").and_then(Value::as_object) {
|
||||||
|
macro_rules! parse_privacy {
|
||||||
|
($v:expr) => {
|
||||||
|
match privacy.get($v) {
|
||||||
|
None => None,
|
||||||
|
Some(Value::Null) => Some(PrivacyLevel::Private),
|
||||||
|
Some(Value::String(s)) if s == "" || s == "private" => {
|
||||||
|
Some(PrivacyLevel::Private)
|
||||||
|
}
|
||||||
|
Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public),
|
||||||
|
_ => {
|
||||||
|
patch.errors.push(ValidationError::new($v));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
patch.name_privacy = parse_privacy!("name_privacy");
|
||||||
|
patch.description_privacy = parse_privacy!("description_privacy");
|
||||||
|
patch.banner_privacy = parse_privacy!("banner_privacy");
|
||||||
|
patch.icon_privacy = parse_privacy!("icon_privacy");
|
||||||
|
patch.list_privacy = parse_privacy!("list_privacy");
|
||||||
|
patch.metadata_privacy = parse_privacy!("metadata_privacy");
|
||||||
|
patch.visibility = parse_privacy!("visibility");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(patch)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -9,6 +9,8 @@ macro_rules! model {
|
||||||
|
|
||||||
model!(system);
|
model!(system);
|
||||||
model!(system_config);
|
model!(system_config);
|
||||||
|
model!(member);
|
||||||
|
model!(group);
|
||||||
|
|
||||||
#[derive(serde::Serialize, Debug, Clone)]
|
#[derive(serde::Serialize, Debug, Clone)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
|
|
@ -31,3 +33,30 @@ impl From<i32> for PrivacyLevel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl From<PrivacyLevel> for sea_query::Value {
|
||||||
|
fn from(level: PrivacyLevel) -> sea_query::Value {
|
||||||
|
match level {
|
||||||
|
PrivacyLevel::Public => sea_query::Value::Int(Some(1)),
|
||||||
|
PrivacyLevel::Private => sea_query::Value::Int(Some(2)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, Debug, Clone)]
|
||||||
|
pub enum ValidationError {
|
||||||
|
Simple { key: String, value: String },
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ValidationError {
|
||||||
|
fn new(key: &str) -> Self {
|
||||||
|
Self::simple(key, "is invalid")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn simple(key: &str, value: &str) -> Self {
|
||||||
|
Self::Simple {
|
||||||
|
key: key.to_string(),
|
||||||
|
value: value.to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
208
crates/models/src/member.rs
Normal file
208
crates/models/src/member.rs
Normal file
|
|
@ -0,0 +1,208 @@
|
||||||
|
use pk_macros::pk_model;
|
||||||
|
|
||||||
|
use chrono::NaiveDateTime;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use serde_json::Value;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::{PrivacyLevel, SystemId, ValidationError};
|
||||||
|
|
||||||
|
// todo: fix
|
||||||
|
pub type MemberId = i32;
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)]
|
||||||
|
#[sqlx(type_name = "proxy_tag")]
|
||||||
|
pub struct ProxyTag {
|
||||||
|
pub prefix: Option<String>,
|
||||||
|
pub suffix: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[pk_model]
|
||||||
|
struct Member {
|
||||||
|
id: MemberId,
|
||||||
|
#[json = "hid"]
|
||||||
|
#[private_patchable]
|
||||||
|
hid: String,
|
||||||
|
#[json = "uuid"]
|
||||||
|
uuid: Uuid,
|
||||||
|
// TODO fix
|
||||||
|
#[json = "system"]
|
||||||
|
system: SystemId,
|
||||||
|
|
||||||
|
#[json = "color"]
|
||||||
|
#[patchable]
|
||||||
|
color: Option<String>,
|
||||||
|
#[json = "webhook_avatar_url"]
|
||||||
|
#[patchable]
|
||||||
|
webhook_avatar_url: Option<String>,
|
||||||
|
#[json = "avatar_url"]
|
||||||
|
#[patchable]
|
||||||
|
avatar_url: Option<String>,
|
||||||
|
#[json = "banner_image"]
|
||||||
|
#[patchable]
|
||||||
|
banner_image: Option<String>,
|
||||||
|
#[json = "name"]
|
||||||
|
#[privacy = name_privacy]
|
||||||
|
#[patchable]
|
||||||
|
name: String,
|
||||||
|
#[json = "display_name"]
|
||||||
|
#[patchable]
|
||||||
|
display_name: Option<String>,
|
||||||
|
#[json = "birthday"]
|
||||||
|
#[patchable]
|
||||||
|
birthday: Option<String>,
|
||||||
|
#[json = "pronouns"]
|
||||||
|
#[privacy = pronoun_privacy]
|
||||||
|
#[patchable]
|
||||||
|
pronouns: Option<String>,
|
||||||
|
#[json = "description"]
|
||||||
|
#[privacy = description_privacy]
|
||||||
|
#[patchable]
|
||||||
|
description: Option<String>,
|
||||||
|
#[json = "proxy_tags"]
|
||||||
|
// #[patchable]
|
||||||
|
proxy_tags: Vec<ProxyTag>,
|
||||||
|
#[json = "keep_proxy"]
|
||||||
|
#[patchable]
|
||||||
|
keep_proxy: bool,
|
||||||
|
#[json = "tts"]
|
||||||
|
#[patchable]
|
||||||
|
tts: bool,
|
||||||
|
#[json = "created"]
|
||||||
|
created: NaiveDateTime,
|
||||||
|
#[json = "message_count"]
|
||||||
|
#[private_patchable]
|
||||||
|
message_count: i32,
|
||||||
|
#[json = "last_message_timestamp"]
|
||||||
|
#[private_patchable]
|
||||||
|
last_message_timestamp: Option<NaiveDateTime>,
|
||||||
|
#[json = "allow_autoproxy"]
|
||||||
|
#[patchable]
|
||||||
|
allow_autoproxy: bool,
|
||||||
|
|
||||||
|
#[privacy]
|
||||||
|
#[json = "visibility"]
|
||||||
|
member_visibility: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
description_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
banner_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
avatar_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
name_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
birthday_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
pronoun_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
metadata_privacy: PrivacyLevel,
|
||||||
|
#[privacy]
|
||||||
|
proxy_privacy: PrivacyLevel,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de> Deserialize<'de> for PKMemberPatch {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||||
|
where
|
||||||
|
D: serde::Deserializer<'de>,
|
||||||
|
{
|
||||||
|
let mut patch: PKMemberPatch = Default::default();
|
||||||
|
let value: Value = Value::deserialize(deserializer)?;
|
||||||
|
|
||||||
|
if let Some(v) = value.get("name") {
|
||||||
|
if let Some(name) = v.as_str() {
|
||||||
|
patch.name = Some(name.to_string());
|
||||||
|
} else if v.is_null() {
|
||||||
|
patch.errors.push(ValidationError::simple(
|
||||||
|
"name",
|
||||||
|
"Member name cannot be set to null.",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
macro_rules! parse_string_simple {
|
||||||
|
($k:expr) => {
|
||||||
|
match value.get($k) {
|
||||||
|
None => None,
|
||||||
|
Some(Value::Null) => Some(None),
|
||||||
|
Some(Value::String(s)) => Some(Some(s.clone())),
|
||||||
|
_ => {
|
||||||
|
patch.errors.push(ValidationError::new($k));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase()));
|
||||||
|
patch.display_name = parse_string_simple!("display_name");
|
||||||
|
patch.avatar_url = parse_string_simple!("avatar_url");
|
||||||
|
patch.banner_image = parse_string_simple!("banner");
|
||||||
|
patch.birthday = parse_string_simple!("birthday"); // fix
|
||||||
|
patch.pronouns = parse_string_simple!("pronouns");
|
||||||
|
patch.description = parse_string_simple!("description");
|
||||||
|
|
||||||
|
if let Some(keep_proxy) = value.get("keep_proxy").and_then(Value::as_bool) {
|
||||||
|
patch.keep_proxy = Some(keep_proxy);
|
||||||
|
}
|
||||||
|
if let Some(tts) = value.get("tts").and_then(Value::as_bool) {
|
||||||
|
patch.tts = Some(tts);
|
||||||
|
}
|
||||||
|
|
||||||
|
// todo: legacy import handling
|
||||||
|
|
||||||
|
// todo: fix proxy_tag type in sea_query
|
||||||
|
|
||||||
|
// if let Some(proxy_tags) = value.get("proxy_tags").and_then(Value::as_array) {
|
||||||
|
// patch.proxy_tags = Some(
|
||||||
|
// proxy_tags
|
||||||
|
// .iter()
|
||||||
|
// .filter_map(|tag| {
|
||||||
|
// tag.as_object().map(|tag_obj| {
|
||||||
|
// let prefix = tag_obj
|
||||||
|
// .get("prefix")
|
||||||
|
// .and_then(Value::as_str)
|
||||||
|
// .map(|s| s.to_string());
|
||||||
|
// let suffix = tag_obj
|
||||||
|
// .get("suffix")
|
||||||
|
// .and_then(Value::as_str)
|
||||||
|
// .map(|s| s.to_string());
|
||||||
|
// ProxyTag { prefix, suffix }
|
||||||
|
// })
|
||||||
|
// })
|
||||||
|
// .collect(),
|
||||||
|
// )
|
||||||
|
// }
|
||||||
|
|
||||||
|
if let Some(privacy) = value.get("privacy").and_then(Value::as_object) {
|
||||||
|
macro_rules! parse_privacy {
|
||||||
|
($v:expr) => {
|
||||||
|
match privacy.get($v) {
|
||||||
|
None => None,
|
||||||
|
Some(Value::Null) => Some(PrivacyLevel::Private),
|
||||||
|
Some(Value::String(s)) if s == "" || s == "private" => {
|
||||||
|
Some(PrivacyLevel::Private)
|
||||||
|
}
|
||||||
|
Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public),
|
||||||
|
_ => {
|
||||||
|
patch.errors.push(ValidationError::new($v));
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
patch.member_visibility = parse_privacy!("visibility");
|
||||||
|
patch.name_privacy = parse_privacy!("name_privacy");
|
||||||
|
patch.description_privacy = parse_privacy!("description_privacy");
|
||||||
|
patch.banner_privacy = parse_privacy!("banner_privacy");
|
||||||
|
patch.avatar_privacy = parse_privacy!("avatar_privacy");
|
||||||
|
patch.birthday_privacy = parse_privacy!("birthday_privacy");
|
||||||
|
patch.pronoun_privacy = parse_privacy!("pronoun_privacy");
|
||||||
|
patch.proxy_privacy = parse_privacy!("proxy_privacy");
|
||||||
|
patch.metadata_privacy = parse_privacy!("metadata_privacy");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(patch)
|
||||||
|
}
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue