chore: reorganize rust crates

This commit is contained in:
alyssa 2025-01-02 00:50:36 +00:00
parent 357122a892
commit 16ce67e02c
58 changed files with 6 additions and 13 deletions

145
crates/libpk/src/_config.rs Normal file
View file

@ -0,0 +1,145 @@
use config::Config;
use lazy_static::lazy_static;
use serde::Deserialize;
use std::sync::Arc;
use twilight_model::id::{marker::UserMarker, Id};
#[derive(Clone, Deserialize, Debug)]
pub struct ClusterSettings {
pub node_id: u32,
pub total_shards: u32,
pub total_nodes: u32,
}
#[derive(Deserialize, Debug)]
pub struct DiscordConfig {
pub client_id: Id<UserMarker>,
pub bot_token: String,
pub client_secret: String,
pub max_concurrency: u32,
#[serde(default)]
pub cluster: Option<ClusterSettings>,
pub api_base_url: Option<String>,
#[serde(default = "_default_api_addr")]
pub cache_api_addr: String,
}
#[derive(Deserialize, Debug)]
pub struct DatabaseConfig {
pub(crate) data_db_uri: String,
pub(crate) data_db_max_connections: Option<u32>,
pub(crate) data_db_min_connections: Option<u32>,
pub(crate) messages_db_uri: Option<String>,
pub(crate) stats_db_uri: Option<String>,
pub(crate) db_password: Option<String>,
pub data_redis_addr: String,
}
fn _default_api_addr() -> String {
"[::]:5000".to_string()
}
#[derive(Deserialize, Clone, Debug)]
pub struct ApiConfig {
#[serde(default = "_default_api_addr")]
pub addr: String,
#[serde(default)]
pub ratelimit_redis_addr: Option<String>,
pub remote_url: String,
#[serde(default)]
pub temp_token2: Option<String>,
}
#[derive(Deserialize, Clone, Debug)]
pub struct AvatarsConfig {
pub s3: S3Config,
pub cdn_url: String,
#[serde(default = "_default_api_addr")]
pub bind_addr: String,
#[serde(default)]
pub migrate_worker_count: u32,
#[serde(default)]
pub cloudflare_zone_id: Option<String>,
#[serde(default)]
pub cloudflare_token: Option<String>,
}
#[derive(Deserialize, Clone, Debug)]
pub struct S3Config {
pub bucket: String,
pub application_id: String,
pub application_key: String,
pub endpoint: String,
}
#[derive(Deserialize, Debug)]
pub struct ScheduledTasksConfig {
pub set_guild_count: bool,
pub expected_gateway_count: usize,
pub gateway_url: String,
}
fn _metrics_default() -> bool {
false
}
fn _json_log_default() -> bool {
false
}
#[derive(Deserialize, Debug)]
pub struct PKConfig {
pub db: DatabaseConfig,
#[serde(default)]
pub discord: Option<DiscordConfig>,
#[serde(default)]
pub api: Option<ApiConfig>,
#[serde(default)]
pub avatars: Option<AvatarsConfig>,
#[serde(default)]
pub scheduled_tasks: Option<ScheduledTasksConfig>,
#[serde(default = "_metrics_default")]
pub run_metrics_server: bool,
#[serde(default = "_json_log_default")]
pub(crate) json_log: bool,
#[serde(default)]
pub sentry_url: Option<String>,
}
impl PKConfig {
pub fn api(self) -> ApiConfig {
self.api.expect("missing api config")
}
pub fn discord_config(self) -> DiscordConfig {
self.discord.expect("missing discord config")
}
}
// todo: consider passing this down instead of making it global
// especially since we have optional discord/api/avatars/etc config
lazy_static! {
#[derive(Debug)]
pub static ref CONFIG: Arc<PKConfig> = {
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
}
Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
.build().unwrap()
.try_deserialize::<PKConfig>().unwrap())
};
}

View file

@ -0,0 +1,96 @@
use fred::clients::RedisPool;
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
use std::str::FromStr;
use tracing::info;
pub mod repository;
pub mod types;
pub async fn init_redis() -> anyhow::Result<RedisPool> {
info!("connecting to redis");
let redis = RedisPool::new(
fred::types::RedisConfig::from_url_centralized(crate::config.db.data_redis_addr.as_ref())
.expect("redis url is invalid"),
None,
None,
Some(Default::default()),
10,
)?;
let redis_handle = redis.connect_pool();
tokio::spawn(async move { redis_handle });
Ok(redis)
}
pub async fn init_data_db() -> anyhow::Result<PgPool> {
info!("connecting to database");
let mut options = PgConnectOptions::from_str(&crate::config.db.data_db_uri)?;
if let Some(password) = crate::config.db.db_password.clone() {
options = options.password(&password);
}
let mut pool = PgPoolOptions::new();
if let Some(max_conns) = crate::config.db.data_db_max_connections {
pool = pool.max_connections(max_conns);
}
if let Some(min_conns) = crate::config.db.data_db_min_connections {
pool = pool.min_connections(min_conns);
}
Ok(pool.connect_with(options).await?)
}
pub async fn init_messages_db() -> anyhow::Result<PgPool> {
info!("connecting to messages database");
let mut options = PgConnectOptions::from_str(
&crate::config
.db
.messages_db_uri
.as_ref()
.expect("missing messages db uri"),
)?;
if let Some(password) = crate::config.db.db_password.clone() {
options = options.password(&password);
}
let mut pool = PgPoolOptions::new();
if let Some(max_conns) = crate::config.db.data_db_max_connections {
pool = pool.max_connections(max_conns);
}
if let Some(min_conns) = crate::config.db.data_db_min_connections {
pool = pool.min_connections(min_conns);
}
Ok(pool.connect_with(options).await?)
}
pub async fn init_stats_db() -> anyhow::Result<PgPool> {
info!("connecting to stats database");
let mut options = PgConnectOptions::from_str(
&crate::config
.db
.stats_db_uri
.as_ref()
.expect("missing messages db uri"),
)?;
if let Some(password) = crate::config.db.db_password.clone() {
options = options.password(&password);
}
Ok(PgPoolOptions::new()
.max_connections(1)
.min_connections(1)
.connect_with(options)
.await?)
}

View file

@ -0,0 +1,20 @@
pub async fn legacy_token_auth(
pool: &sqlx::postgres::PgPool,
token: &str,
) -> anyhow::Result<Option<i32>> {
let mut system: Vec<LegacyTokenDbResponse> =
sqlx::query_as("select id from systems where token = $1")
.bind(token)
.fetch_all(pool)
.await?;
Ok(if let Some(system) = system.pop() {
Some(system.id)
} else {
None
})
}
#[derive(sqlx::FromRow)]
struct LegacyTokenDbResponse {
id: i32,
}

View file

@ -0,0 +1,111 @@
use sqlx::{PgPool, Postgres, Transaction};
use crate::db::types::avatars::*;
pub async fn get_by_id(pool: &PgPool, id: String) -> anyhow::Result<Option<ImageMeta>> {
Ok(sqlx::query_as("select * from images where id = $1")
.bind(id)
.fetch_optional(pool)
.await?)
}
pub async fn get_by_original_url(
pool: &PgPool,
original_url: &str,
) -> anyhow::Result<Option<ImageMeta>> {
Ok(
sqlx::query_as("select * from images where original_url = $1")
.bind(original_url)
.fetch_optional(pool)
.await?,
)
}
pub async fn get_by_attachment_id(
pool: &PgPool,
attachment_id: u64,
) -> anyhow::Result<Option<ImageMeta>> {
Ok(
sqlx::query_as("select * from images where original_attachment_id = $1")
.bind(attachment_id as i64)
.fetch_optional(pool)
.await?,
)
}
pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow::Result<()> {
sqlx::query(
r#"
delete from image_cleanup_jobs
where id in (
select id from images
where original_attachment_id = $1
)
"#,
)
.bind(attachment_id as i64)
.execute(pool)
.await?;
Ok(())
}
pub async fn pop_queue(
pool: &PgPool,
) -> anyhow::Result<Option<(Transaction<Postgres>, ImageQueueEntry)>> {
let mut tx = pool.begin().await?;
let res: Option<ImageQueueEntry> = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *")
.fetch_optional(&mut *tx).await?;
Ok(res.map(|x| (tx, x)))
}
pub async fn get_queue_length(pool: &PgPool) -> anyhow::Result<i64> {
Ok(sqlx::query_scalar("select count(*) from image_queue")
.fetch_one(pool)
.await?)
}
pub async fn get_stats(pool: &PgPool) -> anyhow::Result<Stats> {
Ok(sqlx::query_as(
"select count(*) as total_images, sum(file_size) as total_file_size from images",
)
.fetch_one(pool)
.await?)
}
pub async fn add_image(pool: &PgPool, meta: ImageMeta) -> anyhow::Result<bool> {
let kind_str = match meta.kind {
ImageKind::Avatar => "avatar",
ImageKind::Banner => "banner",
};
let res = sqlx::query("insert into images (id, url, content_type, original_url, file_size, width, height, original_file_size, original_type, original_attachment_id, kind, uploaded_by_account, uploaded_by_system, uploaded_at) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, (now() at time zone 'utc')) on conflict (id) do nothing")
.bind(meta.id)
.bind(meta.url)
.bind(meta.content_type)
.bind(meta.original_url)
.bind(meta.file_size)
.bind(meta.width)
.bind(meta.height)
.bind(meta.original_file_size)
.bind(meta.original_type)
.bind(meta.original_attachment_id)
.bind(kind_str)
.bind(meta.uploaded_by_account)
.bind(meta.uploaded_by_system)
.execute(pool).await?;
Ok(res.rows_affected() > 0)
}
pub async fn push_queue(
conn: &mut sqlx::PgConnection,
url: &str,
kind: ImageKind,
) -> anyhow::Result<()> {
sqlx::query("insert into image_queue (url, kind) values ($1, $2)")
.bind(url)
.bind(kind)
.execute(conn)
.await?;
Ok(())
}

View file

@ -0,0 +1,7 @@
mod stats;
pub use stats::*;
pub mod avatars;
mod auth;
pub use auth::*;

View file

@ -0,0 +1,26 @@
pub async fn get_stats(pool: &sqlx::postgres::PgPool) -> anyhow::Result<Counts> {
let counts: Counts = sqlx::query_as("select * from info").fetch_one(pool).await?;
Ok(counts)
}
pub async fn insert_stats(
pool: &sqlx::postgres::PgPool,
table: &str,
value: i64,
) -> anyhow::Result<()> {
// danger sql injection
sqlx::query(format!("insert into {table} values (now(), $1)").as_str())
.bind(value)
.execute(pool)
.await?;
Ok(())
}
#[derive(serde::Serialize, sqlx::FromRow)]
pub struct Counts {
pub system_count: i64,
pub member_count: i64,
pub group_count: i64,
pub switch_count: i64,
pub message_count: i64,
}

View file

@ -0,0 +1,53 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(FromRow)]
pub struct ImageMeta {
pub id: String,
pub kind: ImageKind,
pub content_type: String,
pub url: String,
pub file_size: i32,
pub width: i32,
pub height: i32,
pub uploaded_at: Option<OffsetDateTime>,
pub original_url: Option<String>,
pub original_attachment_id: Option<i64>,
pub original_file_size: Option<i32>,
pub original_type: Option<String>,
pub uploaded_by_account: Option<i64>,
pub uploaded_by_system: Option<Uuid>,
}
#[derive(FromRow, Serialize)]
pub struct Stats {
pub total_images: i64,
pub total_file_size: i64,
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug, sqlx::Type, PartialEq)]
#[serde(rename_all = "snake_case")]
#[sqlx(rename_all = "snake_case", type_name = "text")]
pub enum ImageKind {
Avatar,
Banner,
}
impl ImageKind {
pub fn size(&self) -> (u32, u32) {
match self {
Self::Avatar => (512, 512),
Self::Banner => (1024, 1024),
}
}
}
#[derive(FromRow)]
pub struct ImageQueueEntry {
pub itemid: i32,
pub url: String,
pub kind: ImageKind,
}

View file

@ -0,0 +1 @@
pub mod avatars;

81
crates/libpk/src/lib.rs Normal file
View file

@ -0,0 +1,81 @@
#![feature(let_chains)]
use std::net::SocketAddr;
use metrics_exporter_prometheus::PrometheusBuilder;
use sentry::IntoDsn;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
pub mod db;
pub mod state;
pub mod util;
pub mod _config;
pub use crate::_config::CONFIG as config;
// functions in this file are only used by the main function below
pub fn init_logging(component: &str) -> anyhow::Result<()> {
if config.json_log {
let mut layer = json_subscriber::layer();
layer.inner_layer_mut().add_static_field(
"component",
serde_json::Value::String(component.to_string()),
);
tracing_subscriber::registry()
.with(layer)
.with(EnvFilter::from_default_env())
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.init();
}
Ok(())
}
pub fn init_metrics() -> anyhow::Result<()> {
if config.run_metrics_server {
PrometheusBuilder::new()
.with_http_listener("[::]:9000".parse::<SocketAddr>().unwrap())
.install()?;
}
Ok(())
}
pub fn init_sentry() -> sentry::ClientInitGuard {
sentry::init(sentry::ClientOptions {
dsn: config
.sentry_url
.clone()
.map(|u| u.into_dsn().unwrap())
.flatten(),
release: sentry::release_name!(),
..Default::default()
})
}
#[macro_export]
macro_rules! main {
($component:expr) => {
fn main() -> anyhow::Result<()> {
let _sentry_guard = libpk::init_sentry();
// we might also be able to use env!("CARGO_CRATE_NAME") here
libpk::init_logging($component)?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
if let Err(err) = libpk::init_metrics() {
tracing::error!("failed to init metrics collector: {err}");
};
tracing::info!("hello world");
if let Err(err) = real_main().await {
tracing::error!("failed to run service: {err}");
};
});
Ok(())
}
};
}

12
crates/libpk/src/state.rs Normal file
View file

@ -0,0 +1,12 @@
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
pub struct ShardState {
pub shard_id: i32,
pub up: bool,
pub disconnection_count: i32,
/// milliseconds
pub latency: i32,
/// unix timestamp
pub last_heartbeat: i32,
pub last_connection: i32,
pub cluster_id: Option<i32>,
}

View file

@ -0,0 +1 @@
pub mod redis;

View file

@ -0,0 +1,15 @@
use fred::error::RedisError;
pub trait RedisErrorExt<T> {
fn to_option_or_error(self) -> Result<Option<T>, RedisError>;
}
impl<T> RedisErrorExt<T> for Result<T, RedisError> {
fn to_option_or_error(self) -> Result<Option<T>, RedisError> {
match self {
Ok(v) => Ok(Some(v)),
Err(error) if error.is_not_found() => Ok(None),
Err(error) => Err(error),
}
}
}