diff --git a/crates/api/src/endpoints/private.rs b/crates/api/src/endpoints/private.rs index 2116e3c5..cbc7cf41 100644 --- a/crates/api/src/endpoints/private.rs +++ b/crates/api/src/endpoints/private.rs @@ -83,8 +83,8 @@ pub async fn discord_callback( .expect("error making client"); let reqbody = serde_urlencoded::to_string(&CallbackDiscordData { - client_id: config.discord.as_ref().unwrap().client_id.get().to_string(), - client_secret: config.discord.as_ref().unwrap().client_secret.clone(), + client_id: config.discord().client_id.get().to_string(), + client_secret: config.discord().client_secret.clone(), grant_type: "authorization_code".to_string(), redirect_uri: request_data.redirect_domain, // change this! code: request_data.code, diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index a6e2680a..a29f21ad 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -11,7 +11,8 @@ use hyper_util::{ client::legacy::{Client, connect::HttpConnector}, rt::TokioExecutor, }; -use tracing::info; +use libpk::config; +use tracing::{info, warn}; use pk_macros::api_endpoint; @@ -128,7 +129,15 @@ fn router(ctx: ApiContext) -> Router { .route("/v2/members/{member_id}/oembed.json", get(rproxy)) .route("/v2/groups/{group_id}/oembed.json", get(rproxy)) - .layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks + .layer(axum::middleware::from_fn_with_state( + if config.api().use_ratelimiter { + Some(ctx.redis.clone()) + } else { + warn!("running without request rate limiting!"); + None + }, + middleware::ratelimit::do_request_ratelimited) + ) .layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes)) .layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::params::params)) .layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth)) @@ -146,14 +155,7 @@ async fn main() -> anyhow::Result<()> { let db = libpk::db::init_data_db().await?; let redis = libpk::db::init_redis().await?; - let rproxy_uri = Uri::from_static( - &libpk::config - .api - .as_ref() - .expect("missing api config") - .remote_url, - ) - .to_string(); + let rproxy_uri = Uri::from_static(&libpk::config.api().remote_url).to_string(); let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) .build(HttpConnector::new()); @@ -167,12 +169,7 @@ async fn main() -> anyhow::Result<()> { let app = router(ctx); - let addr: &str = libpk::config - .api - .as_ref() - .expect("missing api config") - .addr - .as_ref(); + let addr: &str = libpk::config.api().addr.as_ref(); let listener = tokio::net::TcpListener::bind(addr).await?; info!("listening on {}", addr); diff --git a/crates/api/src/middleware/auth.rs b/crates/api/src/middleware/auth.rs index 8487e932..5a47e2b2 100644 --- a/crates/api/src/middleware/auth.rs +++ b/crates/api/src/middleware/auth.rs @@ -44,12 +44,7 @@ pub async fn auth(State(ctx): State, mut req: Request, next: Next) - .get("x-pluralkit-app") .map(|h| h.to_str().ok()) .flatten() - && let Some(config_token2) = libpk::config - .api - .as_ref() - .expect("missing api config") - .temp_token2 - .as_ref() + && let Some(config_token2) = libpk::config.api().temp_token2.as_ref() && app_auth_header .as_bytes() .ct_eq(config_token2.as_bytes()) diff --git a/crates/api/src/middleware/ratelimit.rs b/crates/api/src/middleware/ratelimit.rs index 1638ecc9..e8ac7976 100644 --- a/crates/api/src/middleware/ratelimit.rs +++ b/crates/api/src/middleware/ratelimit.rs @@ -3,12 +3,12 @@ use std::time::{Duration, SystemTime}; use axum::{ extract::{MatchedPath, Request, State}, http::{HeaderValue, Method, StatusCode}, - middleware::{FromFnLayer, Next}, + middleware::Next, response::Response, }; -use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash}; +use fred::{clients::RedisPool, prelude::LuaInterface, util::sha1_hash}; use metrics::counter; -use tracing::{debug, error, info, warn}; +use tracing::{debug, error, info}; use crate::{ auth::AuthState, @@ -21,40 +21,6 @@ lazy_static::lazy_static! { static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT); } -// this is awful but it works -pub fn ratelimiter(f: F) -> FromFnLayer, T> { - let redis = libpk::config - .api - .as_ref() - .expect("missing api config") - .ratelimit_redis_addr - .as_ref() - .map(|val| { - // todo: this should probably use the global pool - let r = RedisPool::new( - fred::types::RedisConfig::from_url_centralized(val.as_ref()) - .expect("redis url is invalid"), - None, - None, - Some(Default::default()), - 10, - ) - .expect("failed to connect to redis"); - - let handle = r.connect(); - - tokio::spawn(async move { handle }); - - r - }); - - if redis.is_none() { - warn!("running without request rate limiting!"); - } - - axum::middleware::from_fn_with_state(redis, f) -} - enum RatelimitType { GenericGet, GenericUpdate, diff --git a/crates/app-commands/src/main.rs b/crates/app-commands/src/main.rs index a03f5b70..93fc4b4e 100644 --- a/crates/app-commands/src/main.rs +++ b/crates/app-commands/src/main.rs @@ -1,30 +1,14 @@ -use twilight_model::{ - application::command::{Command, CommandType}, - guild::IntegrationApplication, -}; +use twilight_model::application::command::CommandType; use twilight_util::builder::command::CommandBuilder; #[libpk::main] async fn main() -> anyhow::Result<()> { let discord = twilight_http::Client::builder() - .token( - libpk::config - .discord - .as_ref() - .expect("missing discord config") - .bot_token - .clone(), - ) + .token(libpk::config.discord().bot_token.clone()) .build(); let interaction = discord.interaction(twilight_model::id::Id::new( - libpk::config - .discord - .as_ref() - .expect("missing discord config") - .client_id - .clone() - .get(), + libpk::config.discord().client_id.clone().get(), )); let commands = vec![ diff --git a/crates/avatars/src/cleanup.rs b/crates/avatars/src/cleanup.rs index 48b25f98..ebff3960 100644 --- a/crates/avatars/src/cleanup.rs +++ b/crates/avatars/src/cleanup.rs @@ -6,10 +6,7 @@ use tracing::{error, info}; #[libpk::main] async fn main() -> anyhow::Result<()> { - let config = libpk::config - .avatars - .as_ref() - .expect("missing avatar service config"); + let config = libpk::config.avatars(); let bucket = { let region = s3::Region::Custom { @@ -83,10 +80,7 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc) -> anyhow::Res } let image_data = image_data.unwrap(); - let config = libpk::config - .avatars - .as_ref() - .expect("missing avatar service config"); + let config = libpk::config.avatars(); let path = image_data .url diff --git a/crates/avatars/src/main.rs b/crates/avatars/src/main.rs index df80ac82..d976ae52 100644 --- a/crates/avatars/src/main.rs +++ b/crates/avatars/src/main.rs @@ -172,10 +172,7 @@ pub struct AppState { #[libpk::main] async fn main() -> anyhow::Result<()> { - let config = libpk::config - .avatars - .as_ref() - .expect("missing avatar service config"); + let config = libpk::config.avatars(); let bucket = { let region = s3::Region::Custom { diff --git a/crates/gateway/src/api.rs b/crates/gateway/src/api.rs index aa2d069e..0fbb88c1 100644 --- a/crates/gateway/src/api.rs +++ b/crates/gateway/src/api.rs @@ -45,7 +45,7 @@ pub async fn run_server(cache: Arc, shard_state: Arc>, Path(guild_id): Path| async move { - match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) { + match cache.0.member(Id::new(guild_id), libpk::config.discord().client_id) { Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()), None => status_code(StatusCode::NOT_FOUND, "".to_string()), } @@ -54,7 +54,7 @@ pub async fn run_server(cache: Arc, shard_state: Arc>, Path(guild_id): Path| async move { - match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await { + match cache.guild_permissions(Id::new(guild_id), libpk::config.discord().client_id).await { Ok(val) => { status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()) }, @@ -122,7 +122,7 @@ pub async fn run_server(cache: Arc, shard_state: Arc status_code(StatusCode::FOUND, to_string(&val).unwrap()), Err(err) => { error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions"); @@ -219,7 +219,7 @@ pub async fn run_server(cache: Arc, shard_state: Arc()).await?; diff --git a/crates/gateway/src/discord/cache.rs b/crates/gateway/src/discord/cache.rs index cc538d08..9cd5c112 100644 --- a/crates/gateway/src/discord/cache.rs +++ b/crates/gateway/src/discord/cache.rs @@ -91,22 +91,10 @@ fn member_to_cached_member(item: Member, id: Id) -> CachedMember { } pub fn new() -> DiscordCache { - let mut client_builder = twilight_http::Client::builder().token( - libpk::config - .discord - .as_ref() - .expect("missing discord config") - .bot_token - .clone(), - ); + let mut client_builder = + twilight_http::Client::builder().token(libpk::config.discord().bot_token.clone()); - if let Some(base_url) = libpk::config - .discord - .as_ref() - .expect("missing discord config") - .api_base_url - .clone() - { + if let Some(base_url) = libpk::config.discord().api_base_url.clone() { client_builder = client_builder.proxy(base_url, true).ratelimiter(None); } @@ -268,13 +256,7 @@ impl DiscordCache { return Ok(Permissions::all()); } - let member = if user_id - == libpk::config - .discord - .as_ref() - .expect("missing discord config") - .client_id - { + let member = if user_id == libpk::config.discord().client_id { self.0 .member(guild_id, user_id) .ok_or(format_err!("self member not found"))? @@ -340,13 +322,7 @@ impl DiscordCache { return Ok(Permissions::all()); } - let member = if user_id - == libpk::config - .discord - .as_ref() - .expect("missing discord config") - .client_id - { + let member = if user_id == libpk::config.discord().client_id { self.0 .member(guild_id, user_id) .ok_or_else(|| { diff --git a/crates/gateway/src/discord/gateway.rs b/crates/gateway/src/discord/gateway.rs index c4d0483e..10e1e86b 100644 --- a/crates/gateway/src/discord/gateway.rs +++ b/crates/gateway/src/discord/gateway.rs @@ -23,9 +23,7 @@ use super::cache::DiscordCache; pub fn cluster_config() -> ClusterSettings { libpk::config - .discord - .as_ref() - .expect("missing discord config") + .discord() .cluster .clone() .unwrap_or(libpk::_config::ClusterSettings { @@ -63,28 +61,15 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result RedisQueue { RedisQueue { redis, - concurrency: libpk::config - .discord - .as_ref() - .expect("missing discord config") - .max_concurrency, + concurrency: libpk::config.discord().max_concurrency, } } diff --git a/crates/gateway/src/main.rs b/crates/gateway/src/main.rs index 3ac7be21..5e93d35e 100644 --- a/crates/gateway/src/main.rs +++ b/crates/gateway/src/main.rs @@ -41,13 +41,7 @@ async fn main() -> anyhow::Result<()> { ); // hacky, but needed for selfhost for now - if let Some(target) = libpk::config - .discord - .as_ref() - .unwrap() - .gateway_target - .clone() - { + if let Some(target) = libpk::config.discord().gateway_target.clone() { runtime_config .set(RUNTIME_CONFIG_KEY_EVENT_TARGET.to_string(), target) .await?; @@ -237,12 +231,7 @@ async fn main() -> anyhow::Result<()> { } async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) { - let prefix = libpk::config - .discord - .as_ref() - .expect("missing discord config") - .bot_prefix_for_gateway - .clone(); + let prefix = libpk::config.discord().bot_prefix_for_gateway.clone(); println!("{prefix}"); diff --git a/crates/gdpr_worker/src/main.rs b/crates/gdpr_worker/src/main.rs index bcedbedd..ef9a4cf0 100644 --- a/crates/gdpr_worker/src/main.rs +++ b/crates/gdpr_worker/src/main.rs @@ -14,23 +14,10 @@ async fn main() -> anyhow::Result<()> { let db = libpk::db::init_messages_db().await?; let mut client_builder = twilight_http::Client::builder() - .token( - libpk::config - .discord - .as_ref() - .expect("missing discord config") - .bot_token - .clone(), - ) + .token(libpk::config.discord().bot_token.clone()) .timeout(Duration::from_secs(30)); - if let Some(base_url) = libpk::config - .discord - .as_ref() - .expect("missing discord config") - .api_base_url - .clone() - { + if let Some(base_url) = libpk::config.discord().api_base_url.clone() { client_builder = client_builder.proxy(base_url, true).ratelimiter(None); } diff --git a/crates/libpk/src/_config.rs b/crates/libpk/src/_config.rs index f21d9adf..ec76e27a 100644 --- a/crates/libpk/src/_config.rs +++ b/crates/libpk/src/_config.rs @@ -56,7 +56,7 @@ pub struct ApiConfig { pub addr: String, #[serde(default)] - pub ratelimit_redis_addr: Option, + pub use_ratelimiter: bool, pub remote_url: String, @@ -109,11 +109,11 @@ pub struct PKConfig { pub db: DatabaseConfig, #[serde(default)] - pub discord: Option, + discord: Option, #[serde(default)] - pub api: Option, + api: Option, #[serde(default)] - pub avatars: Option, + avatars: Option, #[serde(default)] pub scheduled_tasks: Option, @@ -134,12 +134,18 @@ pub struct PKConfig { } impl PKConfig { - pub fn api(self) -> ApiConfig { - self.api.expect("missing api config") + pub fn api(&self) -> &ApiConfig { + self.api.as_ref().expect("missing api config") } - pub fn discord_config(self) -> DiscordConfig { - self.discord.expect("missing discord config") + pub fn discord(&self) -> &DiscordConfig { + self.discord.as_ref().expect("missing discord config") + } + + pub fn avatars(&self) -> &AvatarsConfig { + self.avatars + .as_ref() + .expect("missing avatar service config") } } diff --git a/crates/scheduled_tasks/src/main.rs b/crates/scheduled_tasks/src/main.rs index 805246f1..c78f3d16 100644 --- a/crates/scheduled_tasks/src/main.rs +++ b/crates/scheduled_tasks/src/main.rs @@ -22,22 +22,10 @@ pub struct AppCtx { #[libpk::main] async fn main() -> anyhow::Result<()> { - let mut client_builder = twilight_http::Client::builder().token( - libpk::config - .discord - .as_ref() - .expect("missing discord config") - .bot_token - .clone(), - ); + let mut client_builder = + twilight_http::Client::builder().token(libpk::config.discord().bot_token.clone()); - if let Some(base_url) = libpk::config - .discord - .as_ref() - .expect("missing discord config") - .api_base_url - .clone() - { + if let Some(base_url) = libpk::config.discord().api_base_url.clone() { client_builder = client_builder.proxy(base_url, true).ratelimiter(None); }