mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-06 22:07:55 +00:00
chore: reorganize rust crates
This commit is contained in:
parent
357122a892
commit
16ce67e02c
58 changed files with 6 additions and 13 deletions
|
|
@ -1,24 +0,0 @@
|
|||
[package]
|
||||
name = "libpk"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
sentry = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true}
|
||||
twilight-model = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
config = "0.14.0"
|
||||
json-subscriber = { version = "0.2.2", features = ["env-filter"] }
|
||||
metrics-exporter-prometheus = { version = "0.15.3", default-features = false, features = ["tokio", "http-listener", "tracing"] }
|
||||
|
|
@ -1,145 +0,0 @@
|
|||
use config::Config;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use twilight_model::id::{marker::UserMarker, Id};
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
pub struct ClusterSettings {
|
||||
pub node_id: u32,
|
||||
pub total_shards: u32,
|
||||
pub total_nodes: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DiscordConfig {
|
||||
pub client_id: Id<UserMarker>,
|
||||
pub bot_token: String,
|
||||
pub client_secret: String,
|
||||
pub max_concurrency: u32,
|
||||
#[serde(default)]
|
||||
pub cluster: Option<ClusterSettings>,
|
||||
pub api_base_url: Option<String>,
|
||||
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub cache_api_addr: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DatabaseConfig {
|
||||
pub(crate) data_db_uri: String,
|
||||
pub(crate) data_db_max_connections: Option<u32>,
|
||||
pub(crate) data_db_min_connections: Option<u32>,
|
||||
pub(crate) messages_db_uri: Option<String>,
|
||||
pub(crate) stats_db_uri: Option<String>,
|
||||
pub(crate) db_password: Option<String>,
|
||||
pub data_redis_addr: String,
|
||||
}
|
||||
|
||||
fn _default_api_addr() -> String {
|
||||
"[::]:5000".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct ApiConfig {
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub ratelimit_redis_addr: Option<String>,
|
||||
|
||||
pub remote_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub temp_token2: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct AvatarsConfig {
|
||||
pub s3: S3Config,
|
||||
pub cdn_url: String,
|
||||
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub bind_addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub migrate_worker_count: u32,
|
||||
|
||||
#[serde(default)]
|
||||
pub cloudflare_zone_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub cloudflare_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct S3Config {
|
||||
pub bucket: String,
|
||||
pub application_id: String,
|
||||
pub application_key: String,
|
||||
pub endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ScheduledTasksConfig {
|
||||
pub set_guild_count: bool,
|
||||
pub expected_gateway_count: usize,
|
||||
pub gateway_url: String,
|
||||
}
|
||||
|
||||
fn _metrics_default() -> bool {
|
||||
false
|
||||
}
|
||||
fn _json_log_default() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PKConfig {
|
||||
pub db: DatabaseConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub discord: Option<DiscordConfig>,
|
||||
#[serde(default)]
|
||||
pub api: Option<ApiConfig>,
|
||||
#[serde(default)]
|
||||
pub avatars: Option<AvatarsConfig>,
|
||||
#[serde(default)]
|
||||
pub scheduled_tasks: Option<ScheduledTasksConfig>,
|
||||
|
||||
#[serde(default = "_metrics_default")]
|
||||
pub run_metrics_server: bool,
|
||||
|
||||
#[serde(default = "_json_log_default")]
|
||||
pub(crate) json_log: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub sentry_url: Option<String>,
|
||||
}
|
||||
|
||||
impl PKConfig {
|
||||
pub fn api(self) -> ApiConfig {
|
||||
self.api.expect("missing api config")
|
||||
}
|
||||
|
||||
pub fn discord_config(self) -> DiscordConfig {
|
||||
self.discord.expect("missing discord config")
|
||||
}
|
||||
}
|
||||
|
||||
// todo: consider passing this down instead of making it global
|
||||
// especially since we have optional discord/api/avatars/etc config
|
||||
lazy_static! {
|
||||
#[derive(Debug)]
|
||||
pub static ref CONFIG: Arc<PKConfig> = {
|
||||
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
|
||||
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
|
||||
std::env::set_var("pluralkit__discord__cluster__node_id", var);
|
||||
}
|
||||
|
||||
Arc::new(Config::builder()
|
||||
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
|
||||
.build().unwrap()
|
||||
.try_deserialize::<PKConfig>().unwrap())
|
||||
};
|
||||
}
|
||||
|
|
@ -1,96 +0,0 @@
|
|||
use fred::clients::RedisPool;
|
||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
||||
use std::str::FromStr;
|
||||
use tracing::info;
|
||||
|
||||
pub mod repository;
|
||||
pub mod types;
|
||||
|
||||
pub async fn init_redis() -> anyhow::Result<RedisPool> {
|
||||
info!("connecting to redis");
|
||||
let redis = RedisPool::new(
|
||||
fred::types::RedisConfig::from_url_centralized(crate::config.db.data_redis_addr.as_ref())
|
||||
.expect("redis url is invalid"),
|
||||
None,
|
||||
None,
|
||||
Some(Default::default()),
|
||||
10,
|
||||
)?;
|
||||
|
||||
let redis_handle = redis.connect_pool();
|
||||
tokio::spawn(async move { redis_handle });
|
||||
|
||||
Ok(redis)
|
||||
}
|
||||
|
||||
pub async fn init_data_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(&crate::config.db.data_db_uri)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
let mut pool = PgPoolOptions::new();
|
||||
|
||||
if let Some(max_conns) = crate::config.db.data_db_max_connections {
|
||||
pool = pool.max_connections(max_conns);
|
||||
}
|
||||
|
||||
if let Some(min_conns) = crate::config.db.data_db_min_connections {
|
||||
pool = pool.min_connections(min_conns);
|
||||
}
|
||||
|
||||
Ok(pool.connect_with(options).await?)
|
||||
}
|
||||
|
||||
pub async fn init_messages_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to messages database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(
|
||||
&crate::config
|
||||
.db
|
||||
.messages_db_uri
|
||||
.as_ref()
|
||||
.expect("missing messages db uri"),
|
||||
)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
let mut pool = PgPoolOptions::new();
|
||||
|
||||
if let Some(max_conns) = crate::config.db.data_db_max_connections {
|
||||
pool = pool.max_connections(max_conns);
|
||||
}
|
||||
|
||||
if let Some(min_conns) = crate::config.db.data_db_min_connections {
|
||||
pool = pool.min_connections(min_conns);
|
||||
}
|
||||
|
||||
Ok(pool.connect_with(options).await?)
|
||||
}
|
||||
|
||||
pub async fn init_stats_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to stats database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(
|
||||
&crate::config
|
||||
.db
|
||||
.stats_db_uri
|
||||
.as_ref()
|
||||
.expect("missing messages db uri"),
|
||||
)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
Ok(PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.min_connections(1)
|
||||
.connect_with(options)
|
||||
.await?)
|
||||
}
|
||||
|
|
@ -1,20 +0,0 @@
|
|||
pub async fn legacy_token_auth(
|
||||
pool: &sqlx::postgres::PgPool,
|
||||
token: &str,
|
||||
) -> anyhow::Result<Option<i32>> {
|
||||
let mut system: Vec<LegacyTokenDbResponse> =
|
||||
sqlx::query_as("select id from systems where token = $1")
|
||||
.bind(token)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(if let Some(system) = system.pop() {
|
||||
Some(system.id)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct LegacyTokenDbResponse {
|
||||
id: i32,
|
||||
}
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
|
||||
use crate::db::types::avatars::*;
|
||||
|
||||
pub async fn get_by_id(pool: &PgPool, id: String) -> anyhow::Result<Option<ImageMeta>> {
|
||||
Ok(sqlx::query_as("select * from images where id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_by_original_url(
|
||||
pool: &PgPool,
|
||||
original_url: &str,
|
||||
) -> anyhow::Result<Option<ImageMeta>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from images where original_url = $1")
|
||||
.bind(original_url)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn get_by_attachment_id(
|
||||
pool: &PgPool,
|
||||
attachment_id: u64,
|
||||
) -> anyhow::Result<Option<ImageMeta>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from images where original_attachment_id = $1")
|
||||
.bind(attachment_id as i64)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
delete from image_cleanup_jobs
|
||||
where id in (
|
||||
select id from images
|
||||
where original_attachment_id = $1
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(attachment_id as i64)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn pop_queue(
|
||||
pool: &PgPool,
|
||||
) -> anyhow::Result<Option<(Transaction<Postgres>, ImageQueueEntry)>> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let res: Option<ImageQueueEntry> = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *")
|
||||
.fetch_optional(&mut *tx).await?;
|
||||
Ok(res.map(|x| (tx, x)))
|
||||
}
|
||||
|
||||
pub async fn get_queue_length(pool: &PgPool) -> anyhow::Result<i64> {
|
||||
Ok(sqlx::query_scalar("select count(*) from image_queue")
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_stats(pool: &PgPool) -> anyhow::Result<Stats> {
|
||||
Ok(sqlx::query_as(
|
||||
"select count(*) as total_images, sum(file_size) as total_file_size from images",
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn add_image(pool: &PgPool, meta: ImageMeta) -> anyhow::Result<bool> {
|
||||
let kind_str = match meta.kind {
|
||||
ImageKind::Avatar => "avatar",
|
||||
ImageKind::Banner => "banner",
|
||||
};
|
||||
|
||||
let res = sqlx::query("insert into images (id, url, content_type, original_url, file_size, width, height, original_file_size, original_type, original_attachment_id, kind, uploaded_by_account, uploaded_by_system, uploaded_at) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, (now() at time zone 'utc')) on conflict (id) do nothing")
|
||||
.bind(meta.id)
|
||||
.bind(meta.url)
|
||||
.bind(meta.content_type)
|
||||
.bind(meta.original_url)
|
||||
.bind(meta.file_size)
|
||||
.bind(meta.width)
|
||||
.bind(meta.height)
|
||||
.bind(meta.original_file_size)
|
||||
.bind(meta.original_type)
|
||||
.bind(meta.original_attachment_id)
|
||||
.bind(kind_str)
|
||||
.bind(meta.uploaded_by_account)
|
||||
.bind(meta.uploaded_by_system)
|
||||
.execute(pool).await?;
|
||||
Ok(res.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn push_queue(
|
||||
conn: &mut sqlx::PgConnection,
|
||||
url: &str,
|
||||
kind: ImageKind,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query("insert into image_queue (url, kind) values ($1, $2)")
|
||||
.bind(url)
|
||||
.bind(kind)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1,7 +0,0 @@
|
|||
mod stats;
|
||||
pub use stats::*;
|
||||
|
||||
pub mod avatars;
|
||||
|
||||
mod auth;
|
||||
pub use auth::*;
|
||||
|
|
@ -1,26 +0,0 @@
|
|||
pub async fn get_stats(pool: &sqlx::postgres::PgPool) -> anyhow::Result<Counts> {
|
||||
let counts: Counts = sqlx::query_as("select * from info").fetch_one(pool).await?;
|
||||
Ok(counts)
|
||||
}
|
||||
|
||||
pub async fn insert_stats(
|
||||
pool: &sqlx::postgres::PgPool,
|
||||
table: &str,
|
||||
value: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
// danger sql injection
|
||||
sqlx::query(format!("insert into {table} values (now(), $1)").as_str())
|
||||
.bind(value)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, sqlx::FromRow)]
|
||||
pub struct Counts {
|
||||
pub system_count: i64,
|
||||
pub member_count: i64,
|
||||
pub group_count: i64,
|
||||
pub switch_count: i64,
|
||||
pub message_count: i64,
|
||||
}
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use time::OffsetDateTime;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(FromRow)]
|
||||
pub struct ImageMeta {
|
||||
pub id: String,
|
||||
pub kind: ImageKind,
|
||||
pub content_type: String,
|
||||
pub url: String,
|
||||
pub file_size: i32,
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
pub uploaded_at: Option<OffsetDateTime>,
|
||||
|
||||
pub original_url: Option<String>,
|
||||
pub original_attachment_id: Option<i64>,
|
||||
pub original_file_size: Option<i32>,
|
||||
pub original_type: Option<String>,
|
||||
pub uploaded_by_account: Option<i64>,
|
||||
pub uploaded_by_system: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct Stats {
|
||||
pub total_images: i64,
|
||||
pub total_file_size: i64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Debug, sqlx::Type, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[sqlx(rename_all = "snake_case", type_name = "text")]
|
||||
pub enum ImageKind {
|
||||
Avatar,
|
||||
Banner,
|
||||
}
|
||||
|
||||
impl ImageKind {
|
||||
pub fn size(&self) -> (u32, u32) {
|
||||
match self {
|
||||
Self::Avatar => (512, 512),
|
||||
Self::Banner => (1024, 1024),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromRow)]
|
||||
pub struct ImageQueueEntry {
|
||||
pub itemid: i32,
|
||||
pub url: String,
|
||||
pub kind: ImageKind,
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
pub mod avatars;
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
#![feature(let_chains)]
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use sentry::IntoDsn;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
pub mod db;
|
||||
pub mod state;
|
||||
pub mod util;
|
||||
|
||||
pub mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
// functions in this file are only used by the main function below
|
||||
|
||||
pub fn init_logging(component: &str) -> anyhow::Result<()> {
|
||||
if config.json_log {
|
||||
let mut layer = json_subscriber::layer();
|
||||
layer.inner_layer_mut().add_static_field(
|
||||
"component",
|
||||
serde_json::Value::String(component.to_string()),
|
||||
);
|
||||
tracing_subscriber::registry()
|
||||
.with(layer)
|
||||
.with(EnvFilter::from_default_env())
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_metrics() -> anyhow::Result<()> {
|
||||
if config.run_metrics_server {
|
||||
PrometheusBuilder::new()
|
||||
.with_http_listener("[::]:9000".parse::<SocketAddr>().unwrap())
|
||||
.install()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_sentry() -> sentry::ClientInitGuard {
|
||||
sentry::init(sentry::ClientOptions {
|
||||
dsn: config
|
||||
.sentry_url
|
||||
.clone()
|
||||
.map(|u| u.into_dsn().unwrap())
|
||||
.flatten(),
|
||||
release: sentry::release_name!(),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! main {
|
||||
($component:expr) => {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
// we might also be able to use env!("CARGO_CRATE_NAME") here
|
||||
libpk::init_logging($component)?;
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
if let Err(err) = libpk::init_metrics() {
|
||||
tracing::error!("failed to init metrics collector: {err}");
|
||||
};
|
||||
tracing::info!("hello world");
|
||||
if let Err(err) = real_main().await {
|
||||
tracing::error!("failed to run service: {err}");
|
||||
};
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
@ -1,12 +0,0 @@
|
|||
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
|
||||
pub struct ShardState {
|
||||
pub shard_id: i32,
|
||||
pub up: bool,
|
||||
pub disconnection_count: i32,
|
||||
/// milliseconds
|
||||
pub latency: i32,
|
||||
/// unix timestamp
|
||||
pub last_heartbeat: i32,
|
||||
pub last_connection: i32,
|
||||
pub cluster_id: Option<i32>,
|
||||
}
|
||||
|
|
@ -1 +0,0 @@
|
|||
pub mod redis;
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
use fred::error::RedisError;
|
||||
|
||||
pub trait RedisErrorExt<T> {
|
||||
fn to_option_or_error(self) -> Result<Option<T>, RedisError>;
|
||||
}
|
||||
|
||||
impl<T> RedisErrorExt<T> for Result<T, RedisError> {
|
||||
fn to_option_or_error(self) -> Result<Option<T>, RedisError> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(error) if error.is_not_found() => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
[package]
|
||||
name = "model_macros"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
syn = "2.0"
|
||||
|
||||
|
|
@ -1,259 +0,0 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, DeriveInput, Expr, Ident, Meta, Type};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum ElemPatchability {
|
||||
None,
|
||||
Private,
|
||||
Public,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ModelField {
|
||||
name: Ident,
|
||||
ty: Type,
|
||||
patch: ElemPatchability,
|
||||
json: Option<Expr>,
|
||||
is_privacy: bool,
|
||||
default: Option<Expr>,
|
||||
}
|
||||
|
||||
fn parse_field(field: syn::Field) -> ModelField {
|
||||
let mut f = ModelField {
|
||||
name: field.ident.expect("field missing ident"),
|
||||
ty: field.ty,
|
||||
patch: ElemPatchability::None,
|
||||
json: None,
|
||||
is_privacy: false,
|
||||
default: None,
|
||||
};
|
||||
|
||||
for attr in field.attrs.iter() {
|
||||
match &attr.meta {
|
||||
Meta::Path(path) => {
|
||||
let ident = path.segments[0].ident.to_string();
|
||||
match ident.as_str() {
|
||||
"private_patchable" => match f.patch {
|
||||
ElemPatchability::None => {
|
||||
f.patch = ElemPatchability::Private;
|
||||
}
|
||||
_ => {
|
||||
panic!("cannot have multiple patch tags on same field");
|
||||
}
|
||||
},
|
||||
"patchable" => match f.patch {
|
||||
ElemPatchability::None => {
|
||||
f.patch = ElemPatchability::Public;
|
||||
}
|
||||
_ => {
|
||||
panic!("cannot have multiple patch tags on same field");
|
||||
}
|
||||
},
|
||||
"privacy" => f.is_privacy = true,
|
||||
_ => panic!("unknown attribute"),
|
||||
}
|
||||
}
|
||||
Meta::NameValue(nv) => match nv.path.segments[0].ident.to_string().as_str() {
|
||||
"json" => {
|
||||
if f.json.is_some() {
|
||||
panic!("cannot set json multiple times for same field");
|
||||
}
|
||||
f.json = Some(nv.value.clone());
|
||||
}
|
||||
"default" => {
|
||||
if f.default.is_some() {
|
||||
panic!("cannot set default multiple times for same field");
|
||||
}
|
||||
f.default = Some(nv.value.clone());
|
||||
}
|
||||
_ => panic!("unknown attribute"),
|
||||
},
|
||||
Meta::List(_) => panic!("unknown attribute"),
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(f.patch, ElemPatchability::Public) && f.json.is_none() {
|
||||
panic!("must have json name to be publicly patchable");
|
||||
}
|
||||
|
||||
if f.json.is_some() && f.is_privacy {
|
||||
panic!("cannot set custom json name for privacy field");
|
||||
}
|
||||
|
||||
f
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn pk_model(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
let ast = parse_macro_input!(input as DeriveInput);
|
||||
let model_type = match ast.data {
|
||||
syn::Data::Struct(struct_data) => struct_data,
|
||||
_ => panic!("pk_model can only be used on a struct"),
|
||||
};
|
||||
|
||||
let tname = Ident::new(&format!("PK{}", ast.ident), Span::call_site());
|
||||
let patchable_name = Ident::new(&format!("PK{}Patch", ast.ident), Span::call_site());
|
||||
|
||||
let fields = if let syn::Fields::Named(fields) = model_type.fields {
|
||||
fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|f| parse_field(f.clone()))
|
||||
.collect::<Vec<ModelField>>()
|
||||
} else {
|
||||
panic!("fields of a struct must be named");
|
||||
};
|
||||
|
||||
// println!("{}: {:#?}", tname, fields);
|
||||
|
||||
let tfields = mk_tfields(fields.clone());
|
||||
let from_json = mk_tfrom_json(fields.clone());
|
||||
let from_sql = mk_tfrom_sql(fields.clone());
|
||||
let to_json = mk_tto_json(fields.clone());
|
||||
|
||||
let fields: Vec<ModelField> = fields
|
||||
.iter()
|
||||
.filter(|f| !matches!(f.patch, ElemPatchability::None))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let patch_fields = mk_patch_fields(fields.clone());
|
||||
let patch_from_json = mk_patch_from_json(fields.clone());
|
||||
let patch_validate = mk_patch_validate(fields.clone());
|
||||
let patch_to_json = mk_patch_to_json(fields.clone());
|
||||
let patch_to_sql = mk_patch_to_sql(fields.clone());
|
||||
|
||||
return quote! {
|
||||
#[derive(sqlx::FromRow, Debug, Clone)]
|
||||
pub struct #tname {
|
||||
#tfields
|
||||
}
|
||||
|
||||
impl #tname {
|
||||
pub fn from_json(input: String) -> Self {
|
||||
#from_json
|
||||
}
|
||||
|
||||
pub fn to_json(self) -> serde_json::Value {
|
||||
#to_json
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct #patchable_name {
|
||||
#patch_fields
|
||||
}
|
||||
|
||||
impl #patchable_name {
|
||||
pub fn from_json(input: String) -> Self {
|
||||
#patch_from_json
|
||||
}
|
||||
|
||||
pub fn validate(self) -> bool {
|
||||
#patch_validate
|
||||
}
|
||||
|
||||
pub fn to_sql(self) -> sea_query::UpdateStatement {
|
||||
// sea_query::Query::update()
|
||||
#patch_to_sql
|
||||
}
|
||||
|
||||
pub fn to_json(self) -> serde_json::Value {
|
||||
#patch_to_json
|
||||
}
|
||||
}
|
||||
}
|
||||
.into();
|
||||
}
|
||||
|
||||
fn mk_tfields(fields: Vec<ModelField>) -> TokenStream {
|
||||
fields
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let name = f.name.clone();
|
||||
let ty = f.ty.clone();
|
||||
quote! {
|
||||
pub #name: #ty,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
fn mk_tfrom_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_tfrom_sql(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
// todo: check privacy access
|
||||
let fielddefs: TokenStream = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
f.json.as_ref().map(|v| {
|
||||
let tname = f.name.clone();
|
||||
if let Some(default) = f.default.as_ref() {
|
||||
quote! {
|
||||
#v: self.#tname.unwrap_or(#default),
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#v: self.#tname,
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let privacyfielddefs: TokenStream = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
if f.is_privacy {
|
||||
let tname = f.name.clone();
|
||||
let tnamestr = f.name.clone().to_string();
|
||||
Some(quote! {
|
||||
#tnamestr: self.#tname,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
quote! {
|
||||
serde_json::json!({
|
||||
#fielddefs
|
||||
"privacy": {
|
||||
#privacyfielddefs
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mk_patch_fields(fields: Vec<ModelField>) -> TokenStream {
|
||||
fields
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let name = f.name.clone();
|
||||
let ty = f.ty.clone();
|
||||
quote! {
|
||||
pub #name: Option<#ty>,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
fn mk_patch_validate(_fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { true }
|
||||
}
|
||||
fn mk_patch_from_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_patch_to_sql(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_patch_to_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
[package]
|
||||
name = "pluralkit_models"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
model_macros = { path = "../model_macros" }
|
||||
sea-query = "0.32.1"
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true, features = ["preserve_order"] }
|
||||
sqlx = { workspace = true, default-features = false, features = ["chrono"] }
|
||||
uuid = { workspace = true }
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
// postgres enums created in c# pluralkit implementations are "fake", i.e. they
|
||||
// are actually ints in the database rather than postgres enums, because dapper
|
||||
// does not support postgres enums
|
||||
// here, we add some impls to support this kind of enum in sqlx
|
||||
// there is probably a better way to do this, but works for now.
|
||||
// note: caller needs to implement From<i32> for their type
|
||||
macro_rules! fake_enum_impls {
|
||||
($n:ident) => {
|
||||
impl Type<Postgres> for $n {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::with_name("INT4")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<$n> for i32 {
|
||||
fn from(enum_value: $n) -> Self {
|
||||
enum_value as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, DB: Database> Decode<'r, DB> for $n
|
||||
where
|
||||
i32: Decode<'r, DB>,
|
||||
{
|
||||
fn decode(
|
||||
value: <DB as Database>::ValueRef<'r>,
|
||||
) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
|
||||
let value = <i32 as Decode<DB>>::decode(value)?;
|
||||
Ok(Self::from(value))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use fake_enum_impls;
|
||||
|
|
@ -1,11 +0,0 @@
|
|||
mod _util;
|
||||
|
||||
macro_rules! model {
|
||||
($n:ident) => {
|
||||
mod $n;
|
||||
pub use $n::*;
|
||||
};
|
||||
}
|
||||
|
||||
model!(system);
|
||||
model!(system_config);
|
||||
|
|
@ -1,80 +0,0 @@
|
|||
use std::error::Error;
|
||||
|
||||
use model_macros::pk_model;
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::_util::fake_enum_impls;
|
||||
|
||||
// todo: fix this
|
||||
pub type SystemId = i32;
|
||||
|
||||
// todo: move this
|
||||
#[derive(serde::Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PrivacyLevel {
|
||||
Public,
|
||||
Private,
|
||||
}
|
||||
|
||||
fake_enum_impls!(PrivacyLevel);
|
||||
|
||||
impl From<i32> for PrivacyLevel {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
1 => PrivacyLevel::Public,
|
||||
2 => PrivacyLevel::Private,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pk_model]
|
||||
struct System {
|
||||
id: SystemId,
|
||||
#[json = "id"]
|
||||
#[private_patchable]
|
||||
hid: String,
|
||||
#[json = "uuid"]
|
||||
uuid: Uuid,
|
||||
#[json = "name"]
|
||||
name: Option<String>,
|
||||
#[json = "description"]
|
||||
description: Option<String>,
|
||||
#[json = "tag"]
|
||||
tag: Option<String>,
|
||||
#[json = "pronouns"]
|
||||
pronouns: Option<String>,
|
||||
#[json = "avatar_url"]
|
||||
avatar_url: Option<String>,
|
||||
#[json = "banner_image"]
|
||||
banner_image: Option<String>,
|
||||
#[json = "color"]
|
||||
color: Option<String>,
|
||||
token: Option<String>,
|
||||
#[json = "webhook_url"]
|
||||
webhook_url: Option<String>,
|
||||
webhook_token: Option<String>,
|
||||
#[json = "created"]
|
||||
created: NaiveDateTime,
|
||||
#[privacy]
|
||||
name_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
avatar_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
description_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
banner_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
member_list_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
front_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
front_history_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
group_list_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
pronoun_privacy: PrivacyLevel,
|
||||
}
|
||||
|
|
@ -1,89 +0,0 @@
|
|||
use model_macros::pk_model;
|
||||
|
||||
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{SystemId, _util::fake_enum_impls};
|
||||
|
||||
pub const DEFAULT_MEMBER_LIMIT: i32 = 1000;
|
||||
pub const DEFAULT_GROUP_LIMIT: i32 = 250;
|
||||
|
||||
#[derive(serde::Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum HidPadFormat {
|
||||
#[serde(rename = "off")]
|
||||
None,
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
fake_enum_impls!(HidPadFormat);
|
||||
|
||||
impl From<i32> for HidPadFormat {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
0 => HidPadFormat::None,
|
||||
1 => HidPadFormat::Left,
|
||||
2 => HidPadFormat::Right,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ProxySwitchAction {
|
||||
Off,
|
||||
New,
|
||||
Add,
|
||||
}
|
||||
fake_enum_impls!(ProxySwitchAction);
|
||||
|
||||
impl From<i32> for ProxySwitchAction {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
0 => ProxySwitchAction::Off,
|
||||
1 => ProxySwitchAction::New,
|
||||
2 => ProxySwitchAction::Add,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pk_model]
|
||||
struct SystemConfig {
|
||||
system: SystemId,
|
||||
#[json = "timezone"]
|
||||
ui_tz: String,
|
||||
#[json = "pings_enabled"]
|
||||
pings_enabled: bool,
|
||||
#[json = "latch_timeout"]
|
||||
latch_timeout: Option<i32>,
|
||||
#[json = "member_default_private"]
|
||||
member_default_private: bool,
|
||||
#[json = "group_default_private"]
|
||||
group_default_private: bool,
|
||||
#[json = "show_private_info"]
|
||||
show_private_info: bool,
|
||||
#[json = "member_limit"]
|
||||
#[default = DEFAULT_MEMBER_LIMIT]
|
||||
member_limit_override: Option<i32>,
|
||||
#[json = "group_limit"]
|
||||
#[default = DEFAULT_GROUP_LIMIT]
|
||||
group_limit_override: Option<i32>,
|
||||
#[json = "case_sensitive_proxy_tags"]
|
||||
case_sensitive_proxy_tags: bool,
|
||||
#[json = "proxy_error_message_enabled"]
|
||||
proxy_error_message_enabled: bool,
|
||||
#[json = "hid_display_split"]
|
||||
hid_display_split: bool,
|
||||
#[json = "hid_display_caps"]
|
||||
hid_display_caps: bool,
|
||||
#[json = "hid_list_padding"]
|
||||
hid_list_padding: HidPadFormat,
|
||||
#[json = "proxy_switch"]
|
||||
proxy_switch: ProxySwitchAction,
|
||||
#[json = "name_format"]
|
||||
name_format: String,
|
||||
#[json = "description_templates"]
|
||||
description_templates: Vec<String>,
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue