use std::time::{Duration, SystemTime}; use axum::{ extract::{MatchedPath, Request, State}, http::{HeaderValue, Method, StatusCode}, middleware::{FromFnLayer, Next}, response::Response, }; use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash}; use metrics::increment_counter; use tracing::{debug, error, info, warn}; use crate::util::{header_or_unknown, json_err}; const LUA_SCRIPT: &str = include_str!("ratelimit.lua"); lazy_static::lazy_static! { static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT); } // this is awful but it works pub fn ratelimiter(f: F) -> FromFnLayer, T> { let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| { let r = fred::pool::RedisPool::new( fred::types::RedisConfig::from_url_centralized(val.as_ref()) .expect("redis url is invalid"), 10, ) .expect("failed to connect to redis"); let handle = r.connect(Some(ReconnectPolicy::default())); tokio::spawn(async move { handle }); let rscript = r.clone(); tokio::spawn(async move { if let Ok(()) = rscript.wait_for_connect().await { match rscript.script_load(LUA_SCRIPT).await { Ok(_) => info!("connected to redis for request rate limiting"), Err(err) => error!("could not load redis script: {}", err), } } else { error!("could not wait for connection to load redis script!"); } }); r }); if redis.is_none() { warn!("running without request rate limiting!"); } axum::middleware::from_fn_with_state(redis, f) } enum RatelimitType { GenericGet, GenericUpdate, Message, TempCustom, } impl RatelimitType { fn key(&self) -> String { match self { RatelimitType::GenericGet => "generic_get", RatelimitType::GenericUpdate => "generic_update", RatelimitType::Message => "message", RatelimitType::TempCustom => "token2", // this should be "app_custom" or something } .to_string() } fn rate(&self) -> i32 { match self { RatelimitType::GenericGet => 10, RatelimitType::GenericUpdate => 3, RatelimitType::Message => 10, RatelimitType::TempCustom => 20, } } } pub async fn do_request_ratelimited( State(redis): State>, request: Request, next: Next, ) -> Response { if let Some(redis) = redis { let headers = request.headers().clone(); let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP")); let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid")); // https://github.com/rust-lang/rust/issues/53667 let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App") { if let Some(token2) = &libpk::config.api.temp_token2 { if header.to_str().unwrap_or("invalid") == token2 { true } else { false } } else { false } } else { false }; let endpoint = request .extensions() .get::() .cloned() .map(|v| v.as_str().to_string()) .unwrap_or("unknown".to_string()); let rlimit = if is_temp_token2 { RatelimitType::TempCustom } else if endpoint == "/v2/messages/:message_id" { RatelimitType::Message } else if request.method() == Method::GET { RatelimitType::GenericGet } else { RatelimitType::GenericUpdate }; let rl_key = format!( "{}:{}", if authenticated_system_id != "unknown" && matches!(rlimit, RatelimitType::GenericUpdate) { authenticated_system_id } else { source_ip }, rlimit.key() ); let burst = 5; let period = 1; // seconds // local rate_limit_key = KEYS[1] // local burst = ARGV[1] // local rate = ARGV[2] // local period = ARGV[3] // return {remaining, tostring(retry_after), reset_after} let resp = redis .evalsha::<(i32, String, u64), String, Vec, Vec>( LUA_SCRIPT_SHA.to_string(), vec![rl_key.clone()], vec![burst, rlimit.rate(), period], ) .await; match resp { Ok((mut remaining, retry_after, reset_after)) => { // redis's lua doesn't support returning floats let retry_after: f64 = retry_after .parse() .expect("got something that isn't a f64 from redis"); let mut response = if remaining > 0 { next.run(request).await } else { let retry_after = (retry_after * 1_000_f64).ceil() as u64; debug!("ratelimited request from {rl_key}, retry_after={retry_after}",); increment_counter!("pk_http_requests_ratelimited"); json_err( StatusCode::TOO_MANY_REQUESTS, format!( r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#, rlimit.key(), ), ) }; // the redis script puts burst in remaining for ??? some reason remaining -= burst - rlimit.rate(); let reset_time = SystemTime::now() .checked_add(Duration::from_secs(reset_after)) .expect("invalid timestamp") .duration_since(std::time::UNIX_EPOCH) .expect("invalid duration") .as_secs(); let headers = response.headers_mut(); headers.insert( "X-RateLimit-Scope", HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"), ); headers.insert( "X-RateLimit-Limit", HeaderValue::from_str(format!("{}", rlimit.rate()).as_str()) .expect("invalid header value"), ); headers.insert( "X-RateLimit-Remaining", HeaderValue::from_str(format!("{}", remaining).as_str()) .expect("invalid header value"), ); headers.insert( "X-RateLimit-Reset", HeaderValue::from_str(format!("{}", reset_time).as_str()) .expect("invalid header value"), ); return response; } Err(err) => { tracing::error!("error getting ratelimit info: {}", err); return json_err( StatusCode::INTERNAL_SERVER_ERROR, r#"{"message": "500: internal server error", "code": 0}"#.to_string(), ); } } } next.run(request).await }