Compare commits

...

3 commits

Author SHA1 Message Date
alyssa
b83109b65a add /api/v2/bulk endpoint
Some checks are pending
Build and push Docker image / .net docker build (push) Waiting to run
Build and push Rust service Docker images / rust docker build (push) Waiting to run
also, initial support for patch models in rust!
2025-09-06 18:33:40 +00:00
alyssa
695d1debf2 chore: update recovery message on dashboard
Some checks failed
Build and push Docker image / .net docker build (push) Waiting to run
Build and push Rust service Docker images / rust docker build (push) Waiting to run
rust checks / cargo fmt (push) Waiting to run
Build dashboard Docker image / dashboard docker build (push) Has been cancelled
2025-09-06 18:33:40 +00:00
alyssa
ebb23286d8 chore: bump rust edition to 2024 2025-09-06 18:28:24 +00:00
49 changed files with 786 additions and 102 deletions

View file

@ -11,6 +11,7 @@
!Cargo.toml
!Cargo.lock
!rust-toolchain.toml
!PluralKit.sln
!nuget.config
!ci/dotnet-version.sh

22
Cargo.lock generated
View file

@ -92,6 +92,8 @@ dependencies = [
"pluralkit_models",
"reqwest 0.12.15",
"reverse-proxy-service",
"sea-query",
"sea-query-sqlx",
"serde",
"serde_json",
"serde_urlencoded",
@ -3345,19 +3347,20 @@ dependencies = [
[[package]]
name = "sea-query"
version = "0.32.3"
version = "1.0.0-rc.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5a24d8b9fcd2674a6c878a3d871f4f1380c6c43cc3718728ac96864d888458e"
checksum = "ab621a8d8b03a3e513ea075f71aa26830a55c977d7b40f09e825bb91910db823"
dependencies = [
"chrono",
"inherent",
"sea-query-derive",
]
[[package]]
name = "sea-query-derive"
version = "0.4.3"
version = "1.0.0-rc.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bae0cbad6ab996955664982739354128c58d16e126114fe88c2a493642502aab"
checksum = "217e9422de35f26c16c5f671fce3c075a65e10322068dbc66078428634af6195"
dependencies = [
"darling",
"heck 0.4.1",
@ -3367,6 +3370,17 @@ dependencies = [
"thiserror 2.0.12",
]
[[package]]
name = "sea-query-sqlx"
version = "0.8.0-rc.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed5eb19495858d8ae3663387a4f5298516c6f0171a7ca5681055450f190236b8"
dependencies = [
"chrono",
"sea-query",
"sqlx",
]
[[package]]
name = "security-framework"
version = "3.2.0"

View file

@ -14,6 +14,7 @@ futures = "0.3.30"
lazy_static = "1.4.0"
metrics = "0.23.0"
reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]}
sea-query = { version = "1.0.0-rc.10", features = ["with-chrono"] }
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.117"

View file

@ -4,7 +4,8 @@ WORKDIR /build
RUN apk add rustup build-base
# todo: arm64 target
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain nightly-2024-08-20 --profile default -y
COPY rust-toolchain.toml .
RUN rustup-init --no-update-default-toolchain -y
ENV PATH=/root/.cargo/bin:$PATH
ENV RUSTFLAGS='-C link-arg=-s'

View file

@ -1,7 +1,7 @@
[package]
name = "api"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
pluralkit_models = { path = "../models" }
@ -14,6 +14,7 @@ fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
sea-query = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
@ -28,3 +29,4 @@ serde_urlencoded = "0.7.1"
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["catch-panic"] }
subtle = "2.6.1"
sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] }

View file

@ -0,0 +1,211 @@
use axum::{
Extension, Json,
extract::{Json as ExtractJson, State},
response::IntoResponse,
};
use pk_macros::api_endpoint;
use sea_query::{Expr, ExprTrait, PostgresQueryBuilder};
use sea_query_sqlx::SqlxBinder;
use serde_json::{Value, json};
use pluralkit_models::{PKGroup, PKGroupPatch, PKMember, PKMemberPatch, PKSystem};
use crate::{
ApiContext,
auth::AuthState,
error::{
GENERIC_AUTH_ERROR, NOT_OWN_GROUP, NOT_OWN_MEMBER, PKError, TARGET_GROUP_NOT_FOUND,
TARGET_MEMBER_NOT_FOUND,
},
};
#[derive(serde::Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum BulkActionRequestFilter {
All,
Ids { ids: Vec<String> },
Connection { id: String },
}
#[derive(serde::Deserialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum BulkActionRequest {
Member {
filter: BulkActionRequestFilter,
patch: PKMemberPatch,
},
Group {
filter: BulkActionRequestFilter,
patch: PKGroupPatch,
},
}
#[api_endpoint]
pub async fn bulk(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
ExtractJson(req): ExtractJson<BulkActionRequest>,
) -> Json<Value> {
let Some(system_id) = auth.system_id() else {
return Err(GENERIC_AUTH_ERROR);
};
#[derive(sqlx::FromRow)]
struct Ider {
id: i32,
hid: String,
uuid: String,
}
#[derive(sqlx::FromRow)]
struct GroupMemberEntry {
member_id: i32,
group_id: i32,
}
#[allow(dead_code)]
#[derive(sqlx::FromRow)]
struct OnlyIder {
id: i32,
}
println!("BulkActionRequest::{req:#?}");
match req {
BulkActionRequest::Member { filter, mut patch } => {
patch.validate_bulk();
if patch.errors().len() > 0 {
return Err(PKError::from_validation_errors(patch.errors()));
}
let ids: Vec<i32> = match filter {
BulkActionRequestFilter::All => {
let ids: Vec<Ider> = sqlx::query_as("select id from members where system = $1")
.bind(system_id as i64)
.fetch_all(&ctx.db)
.await?;
ids.iter().map(|v| v.id).collect()
}
BulkActionRequestFilter::Ids { ids } => {
let members: Vec<PKMember> = sqlx::query_as(
"select * from members where hid = any($1::array) or uuid::text = any($1::array)",
)
.bind(&ids)
.fetch_all(&ctx.db)
.await?;
// todo: better errors
if members.len() != ids.len() {
return Err(TARGET_MEMBER_NOT_FOUND);
}
if members.iter().any(|m| m.system != system_id) {
return Err(NOT_OWN_MEMBER);
}
members.iter().map(|m| m.id).collect()
}
BulkActionRequestFilter::Connection { id } => {
let Some(group): Option<PKGroup> =
sqlx::query_as("select * from groups where hid = $1 or uuid::text = $1")
.bind(id)
.fetch_optional(&ctx.db)
.await?
else {
return Err(TARGET_GROUP_NOT_FOUND);
};
if group.system != system_id {
return Err(NOT_OWN_GROUP);
}
let entries: Vec<GroupMemberEntry> =
sqlx::query_as("select * from group_members where group_id = $1")
.bind(group.id)
.fetch_all(&ctx.db)
.await?;
entries.iter().map(|v| v.member_id).collect()
}
};
let (q, pms) = patch
.to_sql()
.table("members") // todo: this should be in the model definition
.and_where(Expr::col("id").is_in(ids))
.returning_col("id")
.build_sqlx(PostgresQueryBuilder);
let res: Vec<OnlyIder> = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?;
Ok(Json(json! {{ "updated": res.len() }}))
}
BulkActionRequest::Group { filter, mut patch } => {
patch.validate_bulk();
if patch.errors().len() > 0 {
return Err(PKError::from_validation_errors(patch.errors()));
}
let ids: Vec<i32> = match filter {
BulkActionRequestFilter::All => {
let ids: Vec<Ider> = sqlx::query_as("select id from groups where system = $1")
.bind(system_id as i64)
.fetch_all(&ctx.db)
.await?;
ids.iter().map(|v| v.id).collect()
}
BulkActionRequestFilter::Ids { ids } => {
let groups: Vec<PKGroup> = sqlx::query_as(
"select * from groups where hid = any($1) or uuid::text = any($1)",
)
.bind(&ids)
.fetch_all(&ctx.db)
.await?;
// todo: better errors
if groups.len() != ids.len() {
return Err(TARGET_GROUP_NOT_FOUND);
}
if groups.iter().any(|m| m.system != system_id) {
return Err(NOT_OWN_GROUP);
}
groups.iter().map(|m| m.id).collect()
}
BulkActionRequestFilter::Connection { id } => {
let Some(member): Option<PKMember> =
sqlx::query_as("select * from members where hid = $1 or uuid::text = $1")
.bind(id)
.fetch_optional(&ctx.db)
.await?
else {
return Err(TARGET_MEMBER_NOT_FOUND);
};
if member.system != system_id {
return Err(NOT_OWN_MEMBER);
}
let entries: Vec<GroupMemberEntry> =
sqlx::query_as("select * from group_members where member_id = $1")
.bind(member.id)
.fetch_all(&ctx.db)
.await?;
entries.iter().map(|v| v.group_id).collect()
}
};
let (q, pms) = patch
.to_sql()
.table("groups") // todo: this should be in the model definition
.and_where(Expr::col("id").is_in(ids))
.returning_col("id")
.build_sqlx(PostgresQueryBuilder);
println!("{q:#?} {pms:#?}");
let res: Vec<OnlyIder> = sqlx::query_as_with(&q, pms).fetch_all(&ctx.db).await?;
Ok(Json(json! {{ "updated": res.len() }}))
}
}
}

View file

@ -1,2 +1,3 @@
pub mod bulk;
pub mod private;
pub mod system;

View file

@ -4,9 +4,10 @@ use fred::interfaces::*;
use libpk::state::ShardState;
use pk_macros::api_endpoint;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::collections::HashMap;
#[allow(dead_code)]
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {

View file

@ -1,11 +1,11 @@
use axum::{extract::State, response::IntoResponse, Extension, Json};
use axum::{Extension, Json, extract::State, response::IntoResponse};
use pk_macros::api_endpoint;
use serde_json::{json, Value};
use serde_json::{Value, json};
use sqlx::Postgres;
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use crate::{auth::AuthState, error::fail, ApiContext};
use crate::{ApiContext, auth::AuthState, error::fail};
#[api_endpoint]
pub async fn get_system_settings(

View file

@ -2,6 +2,7 @@ use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use pluralkit_models::ValidationError;
use std::fmt;
// todo: model parse errors
@ -11,6 +12,8 @@ pub struct PKError {
pub json_code: i32,
pub message: &'static str,
pub errors: Vec<ValidationError>,
pub inner: Option<anyhow::Error>,
}
@ -30,6 +33,21 @@ impl Clone for PKError {
json_code: self.json_code,
message: self.message,
inner: None,
errors: self.errors.clone(),
}
}
}
// can't `impl From<Vec<ValidationError>>`
// because "upstream crate may add a new impl" >:(
impl PKError {
pub fn from_validation_errors(errs: Vec<ValidationError>) -> Self {
Self {
message: "Error parsing JSON model",
json_code: 40001,
errors: errs,
response_code: StatusCode::BAD_REQUEST,
inner: None,
}
}
}
@ -50,14 +68,19 @@ impl IntoResponse for PKError {
if let Some(inner) = self.inner {
tracing::error!(?inner, "error returned from handler");
}
crate::util::json_err(
self.response_code,
serde_json::to_string(&serde_json::json!({
let json = if self.errors.len() > 0 {
serde_json::json!({
"message": self.message,
"code": self.json_code,
}))
.unwrap(),
)
"errors": self.errors,
})
} else {
serde_json::json!({
"message": self.message,
"code": self.json_code,
})
};
crate::util::json_err(self.response_code, serde_json::to_string(&json).unwrap())
}
}
@ -78,9 +101,17 @@ macro_rules! define_error {
json_code: $json_code,
message: $message,
inner: None,
errors: vec![],
};
};
}
define_error! { GENERIC_AUTH_ERROR, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" }
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }
define_error! { NOT_OWN_MEMBER, StatusCode::FORBIDDEN, 30006, "Target member is not part of your system." }
define_error! { NOT_OWN_GROUP, StatusCode::FORBIDDEN, 30007, "Target group is not part of your system." }
define_error! { TARGET_MEMBER_NOT_FOUND, StatusCode::BAD_REQUEST, 40010, "Target member not found." }
define_error! { TARGET_GROUP_NOT_FOUND, StatusCode::BAD_REQUEST, 40011, "Target group not found." }

View file

@ -1,16 +1,14 @@
#![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use axum::{
Extension, Router,
body::Body,
extract::{Request as ExtractRequest, State},
http::Uri,
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
Extension, Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
client::legacy::{Client, connect::HttpConnector},
rt::TokioExecutor,
};
use tracing::info;
@ -117,6 +115,8 @@ fn router(ctx: ApiContext) -> Router {
.route("/v2/messages/{message_id}", get(rproxy))
.route("/v2/bulk", post(endpoints::bulk::bulk))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
.route("/private/discord/callback", post(rproxy))

View file

@ -10,7 +10,7 @@ use subtle::ConstantTimeEq;
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
use crate::{ApiContext, util::json_err};
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let mut authed_system_id: Option<i32> = None;

View file

@ -2,7 +2,7 @@ use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use tracing::{Instrument, Level, info, span, warn};
use crate::{auth::AuthState, util::header_or_unknown};

View file

@ -6,11 +6,11 @@ use axum::{
routing::url_params::UrlParams,
};
use sqlx::{types::Uuid, Postgres};
use sqlx::{Postgres, types::Uuid};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
use crate::{ApiContext, util::json_err};
use pluralkit_models::PKSystem;
// move this somewhere else
@ -31,7 +31,7 @@ pub async fn params(State(ctx): State<ApiContext>, mut req: Request, next: Next)
StatusCode::BAD_REQUEST,
r#"{"message":"400: Bad Request","code": 0}"#.to_string(),
)
.into()
.into();
}
};

View file

@ -3,7 +3,7 @@ use axum::{
http::{HeaderValue, StatusCode},
response::IntoResponse,
};
use serde_json::{json, to_string, Value};
use serde_json::{Value, json, to_string};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {

View file

@ -1,7 +1,7 @@
[package]
name = "avatars"
version = "0.1.0"
edition = "2021"
edition = "2024"
[[bin]]
name = "avatar_cleanup"

View file

@ -8,10 +8,10 @@ use anyhow::Context;
use axum::extract::State;
use axum::routing::get;
use axum::{
Json, Router,
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
use libpk::_config::AvatarsConfig;
use libpk::db::repository::avatars as db;
@ -153,7 +153,7 @@ async fn verify(
)
.await?;
let encoded = process::process_async(result.data, req.kind).await?;
process::process_async(result.data, req.kind).await?;
Ok(())
}

View file

@ -4,7 +4,7 @@ use std::io::Cursor;
use std::time::Instant;
use tracing::{debug, error, info, instrument};
use crate::{hash::Hash, ImageKind, PKAvatarError};
use crate::{ImageKind, PKAvatarError, hash::Hash};
const MAX_DIMENSION: u32 = 4000;

View file

@ -62,7 +62,7 @@ pub async fn pull(
let size = match response.content_length() {
None => return Err(PKAvatarError::MissingHeader("Content-Length")),
Some(size) if size > MAX_SIZE => {
return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE))
return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE));
}
Some(size) => size,
};
@ -162,7 +162,7 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
attachment_id: 0,
filename: "".to_string(),
full_url: url.to_string(),
})
});
}
_ => anyhow::bail!("not a discord cdn url"),
}

View file

@ -1,7 +1,7 @@
[package]
name = "dispatch"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
anyhow = { workspace = true }

View file

@ -1,7 +1,7 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
use tracing::{Instrument, Level, info, span, warn};
// log any requests that take longer than 2 seconds
// todo: change as necessary

View file

@ -5,17 +5,16 @@ use hickory_client::{
rr::{DNSClass, Name, RData, RecordType},
udp::UdpClientStream,
};
use reqwest::{redirect::Policy, StatusCode};
use reqwest::{StatusCode, redirect::Policy};
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use tokio::{net::UdpSocket, sync::RwLock};
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;
use tracing::{debug, error};
use axum::{extract::State, http::Uri, routing::post, Json, Router};
use axum::{Json, Router, extract::State, http::Uri, routing::post};
mod logger;
@ -128,7 +127,7 @@ async fn dispatch(
match res {
Ok(res) if res.status() != 200 => {
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
return DispatchResponse::InvalidResponseCode(res.status()).to_string();
}
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");

View file

@ -1,7 +1,7 @@
[package]
name = "gateway"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
anyhow = { workspace = true }

View file

@ -1,18 +1,18 @@
use axum::{
Router,
extract::{ConnectInfo, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
Router,
};
use libpk::runtime_config::RuntimeConfig;
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::{marker::ChannelMarker, Id};
use twilight_model::id::{Id, marker::ChannelMarker};
use crate::{
discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
cache::{DM_PERMISSIONS, DiscordCache, dm_channel},
gateway::cluster_config,
shard_state::ShardStateManager,
},

View file

@ -4,18 +4,18 @@ use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
InMemoryCache, ResourceType,
model::CachedMember,
permission::{MemberRoles, RootError},
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_gateway::Event;
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
Id,
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
},
};
use twilight_util::permission_calculator::PermissionCalculator;

View file

@ -6,17 +6,17 @@ use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tracing::{error, info, warn};
use twilight_gateway::{
create_iterator, CloseFrame, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
CloseFrame, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId, create_iterator,
};
use twilight_model::gateway::{
Intents,
payload::outgoing::update_presence::UpdatePresencePayload,
presence::{Activity, ActivityType, Status},
Intents,
};
use crate::{
discord::identify_queue::{self, RedisQueue},
RUNTIME_CONFIG_KEY_EVENT_TARGET,
discord::identify_queue::{self, RedisQueue},
};
use super::cache::DiscordCache;

View file

@ -3,7 +3,7 @@
// - interaction: (custom_id where not_includes "help-menu")
use std::{
collections::{hash_map::Entry, HashMap},
collections::{HashMap, hash_map::Entry},
net::{IpAddr, SocketAddr},
time::Duration,
};
@ -15,8 +15,8 @@ use twilight_gateway::Event;
use twilight_model::{
application::interaction::InteractionData,
id::{
marker::{ChannelMarker, MessageMarker, UserMarker},
Id,
marker::{ChannelMarker, MessageMarker, UserMarker},
},
};
@ -103,7 +103,13 @@ impl EventAwaiter {
}
}
}
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
info!(
"ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions",
Instant::now().duration_since(now).as_micros(),
counts.0,
counts.1,
counts.2
);
}
}

View file

@ -4,7 +4,7 @@ use axum::{
extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response,
};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use tracing::{Instrument, Level, info, span, warn};
// log any requests that take longer than 2 seconds
// todo: change as necessary

View file

@ -1,4 +1,3 @@
#![feature(let_chains)]
#![feature(if_let_guard)]
#![feature(duration_constructors)]
@ -10,7 +9,7 @@ use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
use reqwest::{ClientBuilder, StatusCode};
use std::{sync::Arc, time::Duration, vec::Vec};
use tokio::{
signal::unix::{signal, SignalKind},
signal::unix::{SignalKind, signal},
sync::mpsc::channel,
task::JoinSet,
};

View file

@ -1,7 +1,7 @@
[package]
name = "gdpr_worker"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
libpk = { path = "../libpk" }

View file

@ -1,12 +1,10 @@
#![feature(let_chains)]
use sqlx::prelude::FromRow;
use std::{sync::Arc, time::Duration};
use tracing::{error, info, warn};
use twilight_http::api_error::{ApiError, GeneralApiError};
use twilight_model::id::{
marker::{ChannelMarker, MessageMarker},
Id,
marker::{ChannelMarker, MessageMarker},
};
// create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null);

View file

@ -1,7 +1,7 @@
[package]
name = "libpk"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
anyhow = { workspace = true }

View file

@ -3,7 +3,7 @@ use lazy_static::lazy_static;
use serde::Deserialize;
use std::sync::Arc;
use twilight_model::id::{marker::UserMarker, Id};
use twilight_model::id::{Id, marker::UserMarker};
#[derive(Clone, Deserialize, Debug)]
pub struct ClusterSettings {
@ -151,11 +151,11 @@ lazy_static! {
// hacks
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
unsafe { std::env::set_var("pluralkit__discord__cluster__node_id", var); }
}
if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap());
unsafe { std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap()); }
}
Arc::new(Config::builder()

View file

@ -52,7 +52,7 @@ pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow:
pub async fn pop_queue(
pool: &PgPool,
) -> anyhow::Result<Option<(Transaction<Postgres>, ImageQueueEntry)>> {
) -> anyhow::Result<Option<(Transaction<'_, Postgres>, ImageQueueEntry)>> {
let mut tx = pool.begin().await?;
let res: Option<ImageQueueEntry> = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *")
.fetch_optional(&mut *tx).await?;

View file

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use sqlx::{
types::chrono::{DateTime, Utc},
FromRow,
types::chrono::{DateTime, Utc},
};
use uuid::Uuid;

View file

@ -1,9 +1,8 @@
#![feature(let_chains)]
use std::net::SocketAddr;
use metrics_exporter_prometheus::PrometheusBuilder;
use sentry::IntoDsn;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
use sentry_tracing::event_from_event;

View file

@ -1,7 +1,7 @@
[package]
name = "pk_macros"
version = "0.1.0"
edition = "2021"
edition = "2024"
[lib]
proc-macro = true

View file

@ -1,7 +1,7 @@
use quote::quote;
use syn::{parse_macro_input, FnArg, ItemFn, Pat};
use syn::{FnArg, ItemFn, Pat, parse_macro_input};
fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
fn _pretty_print(ts: &proc_macro2::TokenStream) -> String {
let file = syn::parse_file(&ts.to_string()).unwrap();
prettyplease::unparse(&file)
}

View file

@ -1,6 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Expr, Ident, Meta, Type};
use syn::{DeriveInput, Expr, Ident, Meta, Type, parse_macro_input};
#[derive(Clone, Debug)]
enum ElemPatchability {
@ -85,8 +85,14 @@ fn parse_field(field: syn::Field) -> ModelField {
panic!("must have json name to be publicly patchable");
}
if f.json.is_some() && f.is_privacy {
panic!("cannot set custom json name for privacy field");
if f.is_privacy && f.json.is_none() {
f.json = Some(syn::Expr::Lit(syn::ExprLit {
attrs: vec![],
lit: syn::Lit::Str(syn::LitStr::new(
f.name.clone().to_string().as_str(),
proc_macro2::Span::call_site(),
)),
}))
}
f
@ -122,17 +128,17 @@ pub fn macro_impl(
let fields: Vec<ModelField> = fields
.iter()
.filter(|f| !matches!(f.patch, ElemPatchability::None))
.filter(|f| f.is_privacy || !matches!(f.patch, ElemPatchability::None))
.cloned()
.collect();
let patch_fields = mk_patch_fields(fields.clone());
let patch_from_json = mk_patch_from_json(fields.clone());
let patch_validate = mk_patch_validate(fields.clone());
let patch_validate_bulk = mk_patch_validate_bulk(fields.clone());
let patch_to_json = mk_patch_to_json(fields.clone());
let patch_to_sql = mk_patch_to_sql(fields.clone());
return quote! {
let code = quote! {
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct #tname {
#tfields
@ -146,31 +152,42 @@ pub fn macro_impl(
#to_json
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
pub struct #patchable_name {
#patch_fields
errors: Vec<crate::ValidationError>,
}
impl #patchable_name {
pub fn from_json(input: String) -> Self {
#patch_from_json
}
pub fn validate(self) -> bool {
pub fn validate(&mut self) {
#patch_validate
}
pub fn errors(&self) -> Vec<crate::ValidationError> {
self.errors.clone()
}
pub fn validate_bulk(&mut self) {
#patch_validate_bulk
}
pub fn to_sql(self) -> sea_query::UpdateStatement {
// sea_query::Query::update()
#patch_to_sql
use sea_query::types::*;
let mut patch = &mut sea_query::Query::update();
#patch_to_sql
patch.clone()
}
pub fn to_json(self) -> serde_json::Value {
#patch_to_json
}
}
}
.into();
};
// panic!("{:#?}", code.to_string());
return code.into();
}
fn mk_tfields(fields: Vec<ModelField>) -> TokenStream {
@ -225,7 +242,7 @@ fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
.filter_map(|f| {
if f.is_privacy {
let tname = f.name.clone();
let tnamestr = f.name.clone().to_string();
let tnamestr = f.json.clone();
Some(quote! {
#tnamestr: self.#tname,
})
@ -280,13 +297,48 @@ fn mk_patch_fields(fields: Vec<ModelField>) -> TokenStream {
.collect()
}
fn mk_patch_validate(_fields: Vec<ModelField>) -> TokenStream {
quote! { true }
}
fn mk_patch_from_json(_fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_patch_to_sql(_fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
fn mk_patch_validate_bulk(fields: Vec<ModelField>) -> TokenStream {
// iterate over all nullable patchable fields other than privacy
// add an error if any field is set to a value other than null
fields
.iter()
.map(|f| {
if let syn::Type::Path(path) = &f.ty && let Some(inner) = path.path.segments.last() && inner.ident != "Option" {
return quote! {};
}
let name = f.name.clone();
if matches!(f.patch, ElemPatchability::Public) {
let json = f.json.clone().unwrap();
quote! {
if let Some(val) = self.#name.clone() && val.is_some() {
self.errors.push(ValidationError::simple(#json, "Only null values are supported in bulk endpoint"));
}
}
} else {
quote! {}
}
})
.collect()
}
fn mk_patch_to_sql(fields: Vec<ModelField>) -> TokenStream {
fields
.iter()
.filter_map(|f| {
if !matches!(f.patch, ElemPatchability::None) || f.is_privacy {
let name = f.name.clone();
let column = f.name.to_string();
Some(quote! {
if let Some(value) = self.#name {
patch = patch.value(#column, value);
}
})
} else {
None
}
})
.collect()
}
fn mk_patch_to_json(_fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }

View file

@ -1,7 +1,7 @@
[package]
name = "migrate"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
libpk = { path = "../libpk" }

View file

@ -1,5 +1,3 @@
#![feature(let_chains)]
use tracing::info;
include!(concat!(env!("OUT_DIR"), "/data.rs"));

View file

@ -1,12 +1,12 @@
[package]
name = "pluralkit_models"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
chrono = { workspace = true, features = ["serde"] }
pk_macros = { path = "../macros" }
sea-query = "0.32.1"
sea-query = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true, features = ["preserve_order"] }
# in theory we want to default-features = false for sqlx

132
crates/models/src/group.rs Normal file
View file

@ -0,0 +1,132 @@
use pk_macros::pk_model;
use chrono::{DateTime, Utc};
use serde::Deserialize;
use serde_json::Value;
use uuid::Uuid;
use crate::{PrivacyLevel, SystemId, ValidationError};
// todo: fix
pub type GroupId = i32;
#[pk_model]
struct Group {
id: GroupId,
#[json = "hid"]
#[private_patchable]
hid: String,
#[json = "uuid"]
uuid: Uuid,
// TODO fix
#[json = "system"]
system: SystemId,
#[json = "name"]
#[privacy = name_privacy]
#[patchable]
name: String,
#[json = "display_name"]
#[patchable]
display_name: Option<String>,
#[json = "color"]
#[patchable]
color: Option<String>,
#[json = "icon"]
#[patchable]
icon: Option<String>,
#[json = "banner_image"]
#[patchable]
banner_image: Option<String>,
#[json = "description"]
#[privacy = description_privacy]
#[patchable]
description: Option<String>,
#[json = "created"]
created: DateTime<Utc>,
#[privacy]
name_privacy: PrivacyLevel,
#[privacy]
description_privacy: PrivacyLevel,
#[privacy]
banner_privacy: PrivacyLevel,
#[privacy]
icon_privacy: PrivacyLevel,
#[privacy]
list_privacy: PrivacyLevel,
#[privacy]
metadata_privacy: PrivacyLevel,
#[privacy]
visibility: PrivacyLevel,
}
impl<'de> Deserialize<'de> for PKGroupPatch {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut patch: PKGroupPatch = Default::default();
let value: Value = Value::deserialize(deserializer)?;
if let Some(v) = value.get("name") {
if let Some(name) = v.as_str() {
patch.name = Some(name.to_string());
} else if v.is_null() {
patch.errors.push(ValidationError::simple(
"name",
"Group name cannot be set to null.",
));
}
}
macro_rules! parse_string_simple {
($k:expr) => {
match value.get($k) {
None => None,
Some(Value::Null) => Some(None),
Some(Value::String(s)) => Some(Some(s.clone())),
_ => {
patch.errors.push(ValidationError::new($k));
None
}
}
};
}
patch.display_name = parse_string_simple!("display_name");
patch.description = parse_string_simple!("description");
patch.icon = parse_string_simple!("icon");
patch.banner_image = parse_string_simple!("banner");
patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase()));
if let Some(privacy) = value.get("privacy").and_then(Value::as_object) {
macro_rules! parse_privacy {
($v:expr) => {
match privacy.get($v) {
None => None,
Some(Value::Null) => Some(PrivacyLevel::Private),
Some(Value::String(s)) if s == "" || s == "private" => {
Some(PrivacyLevel::Private)
}
Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public),
_ => {
patch.errors.push(ValidationError::new($v));
None
}
}
};
}
patch.name_privacy = parse_privacy!("name_privacy");
patch.description_privacy = parse_privacy!("description_privacy");
patch.banner_privacy = parse_privacy!("banner_privacy");
patch.icon_privacy = parse_privacy!("icon_privacy");
patch.list_privacy = parse_privacy!("list_privacy");
patch.metadata_privacy = parse_privacy!("metadata_privacy");
patch.visibility = parse_privacy!("visibility");
}
Ok(patch)
}
}

View file

@ -9,6 +9,8 @@ macro_rules! model {
model!(system);
model!(system_config);
model!(member);
model!(group);
#[derive(serde::Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
@ -18,7 +20,7 @@ pub enum PrivacyLevel {
}
// this sucks, put it somewhere else
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
use sqlx::{Database, Decode, Postgres, Type, postgres::PgTypeInfo};
use std::error::Error;
_util::fake_enum_impls!(PrivacyLevel);
@ -31,3 +33,30 @@ impl From<i32> for PrivacyLevel {
}
}
}
impl From<PrivacyLevel> for sea_query::Value {
fn from(level: PrivacyLevel) -> sea_query::Value {
match level {
PrivacyLevel::Public => sea_query::Value::Int(Some(1)),
PrivacyLevel::Private => sea_query::Value::Int(Some(2)),
}
}
}
#[derive(serde::Serialize, Debug, Clone)]
pub enum ValidationError {
Simple { key: String, value: String },
}
impl ValidationError {
fn new(key: &str) -> Self {
Self::simple(key, "is invalid")
}
fn simple(key: &str, value: &str) -> Self {
Self::Simple {
key: key.to_string(),
value: value.to_string(),
}
}
}

208
crates/models/src/member.rs Normal file
View file

@ -0,0 +1,208 @@
use pk_macros::pk_model;
use chrono::NaiveDateTime;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use uuid::Uuid;
use crate::{PrivacyLevel, SystemId, ValidationError};
// todo: fix
pub type MemberId = i32;
#[derive(Clone, Debug, Serialize, Deserialize, sqlx::Type)]
#[sqlx(type_name = "proxy_tag")]
pub struct ProxyTag {
pub prefix: Option<String>,
pub suffix: Option<String>,
}
#[pk_model]
struct Member {
id: MemberId,
#[json = "hid"]
#[private_patchable]
hid: String,
#[json = "uuid"]
uuid: Uuid,
// TODO fix
#[json = "system"]
system: SystemId,
#[json = "color"]
#[patchable]
color: Option<String>,
#[json = "webhook_avatar_url"]
#[patchable]
webhook_avatar_url: Option<String>,
#[json = "avatar_url"]
#[patchable]
avatar_url: Option<String>,
#[json = "banner_image"]
#[patchable]
banner_image: Option<String>,
#[json = "name"]
#[privacy = name_privacy]
#[patchable]
name: String,
#[json = "display_name"]
#[patchable]
display_name: Option<String>,
#[json = "birthday"]
#[patchable]
birthday: Option<String>,
#[json = "pronouns"]
#[privacy = pronoun_privacy]
#[patchable]
pronouns: Option<String>,
#[json = "description"]
#[privacy = description_privacy]
#[patchable]
description: Option<String>,
#[json = "proxy_tags"]
// #[patchable]
proxy_tags: Vec<ProxyTag>,
#[json = "keep_proxy"]
#[patchable]
keep_proxy: bool,
#[json = "tts"]
#[patchable]
tts: bool,
#[json = "created"]
created: NaiveDateTime,
#[json = "message_count"]
#[private_patchable]
message_count: i32,
#[json = "last_message_timestamp"]
#[private_patchable]
last_message_timestamp: Option<NaiveDateTime>,
#[json = "allow_autoproxy"]
#[patchable]
allow_autoproxy: bool,
#[privacy]
#[json = "visibility"]
member_visibility: PrivacyLevel,
#[privacy]
description_privacy: PrivacyLevel,
#[privacy]
banner_privacy: PrivacyLevel,
#[privacy]
avatar_privacy: PrivacyLevel,
#[privacy]
name_privacy: PrivacyLevel,
#[privacy]
birthday_privacy: PrivacyLevel,
#[privacy]
pronoun_privacy: PrivacyLevel,
#[privacy]
metadata_privacy: PrivacyLevel,
#[privacy]
proxy_privacy: PrivacyLevel,
}
impl<'de> Deserialize<'de> for PKMemberPatch {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut patch: PKMemberPatch = Default::default();
let value: Value = Value::deserialize(deserializer)?;
if let Some(v) = value.get("name") {
if let Some(name) = v.as_str() {
patch.name = Some(name.to_string());
} else if v.is_null() {
patch.errors.push(ValidationError::simple(
"name",
"Member name cannot be set to null.",
));
}
}
macro_rules! parse_string_simple {
($k:expr) => {
match value.get($k) {
None => None,
Some(Value::Null) => Some(None),
Some(Value::String(s)) => Some(Some(s.clone())),
_ => {
patch.errors.push(ValidationError::new($k));
None
}
}
};
}
patch.color = parse_string_simple!("color").map(|v| v.map(|t| t.to_lowercase()));
patch.display_name = parse_string_simple!("display_name");
patch.avatar_url = parse_string_simple!("avatar_url");
patch.banner_image = parse_string_simple!("banner");
patch.birthday = parse_string_simple!("birthday"); // fix
patch.pronouns = parse_string_simple!("pronouns");
patch.description = parse_string_simple!("description");
if let Some(keep_proxy) = value.get("keep_proxy").and_then(Value::as_bool) {
patch.keep_proxy = Some(keep_proxy);
}
if let Some(tts) = value.get("tts").and_then(Value::as_bool) {
patch.tts = Some(tts);
}
// todo: legacy import handling
// todo: fix proxy_tag type in sea_query
// if let Some(proxy_tags) = value.get("proxy_tags").and_then(Value::as_array) {
// patch.proxy_tags = Some(
// proxy_tags
// .iter()
// .filter_map(|tag| {
// tag.as_object().map(|tag_obj| {
// let prefix = tag_obj
// .get("prefix")
// .and_then(Value::as_str)
// .map(|s| s.to_string());
// let suffix = tag_obj
// .get("suffix")
// .and_then(Value::as_str)
// .map(|s| s.to_string());
// ProxyTag { prefix, suffix }
// })
// })
// .collect(),
// )
// }
if let Some(privacy) = value.get("privacy").and_then(Value::as_object) {
macro_rules! parse_privacy {
($v:expr) => {
match privacy.get($v) {
None => None,
Some(Value::Null) => Some(PrivacyLevel::Private),
Some(Value::String(s)) if s == "" || s == "private" => {
Some(PrivacyLevel::Private)
}
Some(Value::String(s)) if s == "public" => Some(PrivacyLevel::Public),
_ => {
patch.errors.push(ValidationError::new($v));
None
}
}
};
}
patch.member_visibility = parse_privacy!("visibility");
patch.name_privacy = parse_privacy!("name_privacy");
patch.description_privacy = parse_privacy!("description_privacy");
patch.banner_privacy = parse_privacy!("banner_privacy");
patch.avatar_privacy = parse_privacy!("avatar_privacy");
patch.birthday_privacy = parse_privacy!("birthday_privacy");
patch.pronoun_privacy = parse_privacy!("pronoun_privacy");
patch.proxy_privacy = parse_privacy!("proxy_privacy");
patch.metadata_privacy = parse_privacy!("metadata_privacy");
}
Ok(patch)
}
}

View file

@ -1,7 +1,7 @@
[package]
name = "scheduled_tasks"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
libpk = { path = "../libpk" }

View file

@ -154,7 +154,7 @@
</CardHeader>
<CardBody>
<p>If you've lost access to your discord account, you can retrieve your token here.</p>
<p>Send a direct message to a staff member (a helper, moderator or developer <a href="https://discord.gg/PczBt78">in the support server</a>), they can recover your system with this token.</p>
<p>Ask in the #bot-support channel <a href="https://discord.gg/PczBt78">of the support server</a> for a staff member to DM you, they can recover your system with this token. <b>Do not post the token in the channel.</b></p>
<Button color="danger" on:click={() => revealToken()}>Reveal token</Button>
{#if showToken}
<Row>

View file

@ -72,9 +72,7 @@
programs.nixfmt.enable = true;
};
nci.toolchainConfig = {
channel = "nightly";
};
nci.toolchainConfig = ./rust-toolchain.toml;
nci.projects."pluralkit-services" = {
path = ./.;
export = false;

3
rust-toolchain.toml Normal file
View file

@ -0,0 +1,3 @@
[toolchain]
channel = "nightly-2025-08-22"
components = ["rust-src", "rustfmt", "rust-analyzer"]