feat(api): improve auth middleware

This commit is contained in:
alyssa 2025-05-17 20:39:29 +00:00
parent 50900ee640
commit c56fd36023
6 changed files with 87 additions and 75 deletions

View file

@ -10,9 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
use metrics::counter;
use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use crate::{
auth::AuthState,
util::{header_or_unknown, json_err},
};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
@ -105,23 +106,28 @@ pub async fn do_request_ratelimited(
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER));
let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER));
let endpoint = request
.extensions()
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
// looks like this chooses the tokens/sec by app_id or endpoint
// then chooses the key by system_id or source_ip
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if authenticated_app_id == "1" {
let rlimit = if let Some(app_id) = auth.app_id()
&& app_id == 1
{
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
@ -133,12 +139,12 @@ pub async fn do_request_ratelimited(
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
if let Some(system_id) = auth.system_id()
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
system_id.to_string()
} else {
source_ip
source_ip.to_string()
},
rlimit.key()
);