feat(api): improve auth middleware

This commit is contained in:
alyssa 2025-05-17 20:39:29 +00:00
parent 50900ee640
commit c56fd36023
6 changed files with 87 additions and 75 deletions

22
crates/api/src/auth.rs Normal file
View file

@ -0,0 +1,22 @@
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
#[derive(Clone)]
pub struct AuthState {
system_id: Option<i32>,
app_id: Option<i32>,
}
impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
Self { system_id, app_id }
}
pub fn system_id(&self) -> Option<i32> {
self.system_id
}
pub fn app_id(&self) -> Option<i32> {
self.app_id
}
}

View file

@ -1,12 +1,13 @@
#![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
routing::{delete, get, patch, post},
Router,
Extension, Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
@ -14,6 +15,7 @@ use hyper_util::{
};
use tracing::{error, info};
mod auth;
mod endpoints;
mod error;
mod middleware;
@ -29,6 +31,7 @@ pub struct ApiContext {
}
async fn rproxy(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
@ -43,6 +46,19 @@ async fn rproxy(
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(ctx
.rproxy_client
.request(req)
@ -118,11 +134,12 @@ fn router(ctx: ApiContext) -> Router {
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.layer(axum::middleware::from_fn(middleware::logger))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth))
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
.with_state(ctx)

View file

@ -4,27 +4,19 @@ use axum::{
middleware::Next,
response::Response,
};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
// todo: auth should pass down models in request context
// not numerical ids in headers
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
let headers = request.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<i32> = None;
// fetch user authorization
if let Some(system_auth_header) = headers
if let Some(system_auth_header) = req
.headers()
.get("authorization")
.map(|h| h.to_str().ok())
.flatten()
@ -45,7 +37,8 @@ pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: N
// fetch app authorization
// todo: actually fetch it from db
if let Some(app_auth_header) = headers
if let Some(app_auth_header) = req
.headers()
.get("x-pluralkit-app")
.map(|h| h.to_str().ok())
.flatten()
@ -62,29 +55,8 @@ pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: N
authed_app_id = Some(1);
}
// add headers for ratelimiter / dotnet-api
{
let headers = request.headers_mut();
if let Some(sid) = authed_system_id {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = authed_app_id {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
}
req.extensions_mut()
.insert(AuthState::new(authed_system_id, authed_app_id));
let mut response = next.run(request).await;
// add headers for logger module (ugh)
{
let headers = response.headers_mut();
if let Some(sid) = authed_system_id {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = authed_app_id {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
}
response
next.run(req).await
}

View file

@ -4,10 +4,7 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use crate::{
middleware::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
util::header_or_unknown,
};
use crate::{auth::AuthState, util::header_or_unknown};
// log any requests that take longer than 2 seconds
// todo: change as necessary
@ -19,13 +16,18 @@ pub async fn logger(request: Request, next: Next) -> Response {
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let endpoint = request
.extensions()
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
let uri = request.uri().clone();
let request_span = span!(
@ -38,24 +40,18 @@ pub async fn logger(request: Request, next: Next) -> Response {
);
let start = Instant::now();
let mut response = next.run(request).instrument(request_span).await;
let response = next.run(request).instrument(request_span).await;
let elapsed = start.elapsed().as_millis();
let (system_id, app_id) = {
let headers = response.headers_mut();
(
headers
.remove(INTERNAL_SYSTEMID_HEADER)
.map(|h| h.to_str().ok().map(|v| v.to_string()))
.flatten()
.unwrap_or("none".to_string()),
headers
.remove(INTERNAL_APPID_HEADER)
.map(|h| h.to_str().ok().map(|v| v.to_string()))
.flatten()
.unwrap_or("none".to_string()),
)
};
let system_id = auth
.system_id()
.map(|v| v.to_string())
.unwrap_or("none".to_string());
let app_id = auth
.app_id()
.map(|v| v.to_string())
.unwrap_or("none".to_string());
counter!(
"pluralkit_api_requests",

View file

@ -9,5 +9,4 @@ pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod ratelimit;
mod authnz;
pub use authnz::authnz;
pub mod auth;

View file

@ -10,9 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
use metrics::counter;
use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use crate::{
auth::AuthState,
util::{header_or_unknown, json_err},
};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
@ -105,23 +106,28 @@ pub async fn do_request_ratelimited(
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER));
let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER));
let endpoint = request
.extensions()
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
// looks like this chooses the tokens/sec by app_id or endpoint
// then chooses the key by system_id or source_ip
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if authenticated_app_id == "1" {
let rlimit = if let Some(app_id) = auth.app_id()
&& app_id == 1
{
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
@ -133,12 +139,12 @@ pub async fn do_request_ratelimited(
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
if let Some(system_id) = auth.system_id()
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
system_id.to_string()
} else {
source_ip
source_ip.to_string()
},
rlimit.key()
);