mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
feat(api): improve auth middleware
This commit is contained in:
parent
50900ee640
commit
c56fd36023
6 changed files with 87 additions and 75 deletions
22
crates/api/src/auth.rs
Normal file
22
crates/api/src/auth.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthState {
|
||||
system_id: Option<i32>,
|
||||
app_id: Option<i32>,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
|
||||
Self { system_id, app_id }
|
||||
}
|
||||
|
||||
pub fn system_id(&self) -> Option<i32> {
|
||||
self.system_id
|
||||
}
|
||||
|
||||
pub fn app_id(&self) -> Option<i32> {
|
||||
self.app_id
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,13 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
http::{Response, StatusCode, Uri},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
Extension, Router,
|
||||
};
|
||||
use hyper_util::{
|
||||
client::legacy::{connect::HttpConnector, Client},
|
||||
|
|
@ -14,6 +15,7 @@ use hyper_util::{
|
|||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
mod auth;
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod middleware;
|
||||
|
|
@ -29,6 +31,7 @@ pub struct ApiContext {
|
|||
}
|
||||
|
||||
async fn rproxy(
|
||||
Extension(auth): Extension<AuthState>,
|
||||
State(ctx): State<ApiContext>,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Result<Response<Body>, StatusCode> {
|
||||
|
|
@ -43,6 +46,19 @@ async fn rproxy(
|
|||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
let headers = req.headers_mut();
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
if let Some(sid) = auth.system_id() {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
|
||||
if let Some(aid) = auth.app_id() {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
|
||||
Ok(ctx
|
||||
.rproxy_client
|
||||
.request(req)
|
||||
|
|
@ -118,11 +134,12 @@ fn router(ctx: ApiContext) -> Router {
|
|||
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::cors))
|
||||
.layer(axum::middleware::from_fn(middleware::logger))
|
||||
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth))
|
||||
|
||||
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
|
||||
|
||||
.with_state(ctx)
|
||||
|
|
|
|||
|
|
@ -4,27 +4,19 @@ use axum::{
|
|||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
use tracing::error;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
use crate::{util::json_err, ApiContext};
|
||||
|
||||
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||
|
||||
// todo: auth should pass down models in request context
|
||||
// not numerical ids in headers
|
||||
|
||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||
let headers = request.headers_mut();
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
|
||||
let mut authed_system_id: Option<i32> = None;
|
||||
let mut authed_app_id: Option<i32> = None;
|
||||
|
||||
// fetch user authorization
|
||||
if let Some(system_auth_header) = headers
|
||||
if let Some(system_auth_header) = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
|
|
@ -45,7 +37,8 @@ pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: N
|
|||
|
||||
// fetch app authorization
|
||||
// todo: actually fetch it from db
|
||||
if let Some(app_auth_header) = headers
|
||||
if let Some(app_auth_header) = req
|
||||
.headers()
|
||||
.get("x-pluralkit-app")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
|
|
@ -62,29 +55,8 @@ pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: N
|
|||
authed_app_id = Some(1);
|
||||
}
|
||||
|
||||
// add headers for ratelimiter / dotnet-api
|
||||
{
|
||||
let headers = request.headers_mut();
|
||||
if let Some(sid) = authed_system_id {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
if let Some(aid) = authed_app_id {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
}
|
||||
req.extensions_mut()
|
||||
.insert(AuthState::new(authed_system_id, authed_app_id));
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
|
||||
// add headers for logger module (ugh)
|
||||
{
|
||||
let headers = response.headers_mut();
|
||||
if let Some(sid) = authed_system_id {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
if let Some(aid) = authed_app_id {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
next.run(req).await
|
||||
}
|
||||
|
|
@ -4,10 +4,7 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
|
|||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
use crate::{
|
||||
middleware::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
|
||||
util::header_or_unknown,
|
||||
};
|
||||
use crate::{auth::AuthState, util::header_or_unknown};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
|
|
@ -19,13 +16,18 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let auth = extensions
|
||||
.get::<AuthState>()
|
||||
.expect("should always have AuthState");
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_span = span!(
|
||||
|
|
@ -38,24 +40,18 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut response = next.run(request).instrument(request_span).await;
|
||||
let response = next.run(request).instrument(request_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
let (system_id, app_id) = {
|
||||
let headers = response.headers_mut();
|
||||
(
|
||||
headers
|
||||
.remove(INTERNAL_SYSTEMID_HEADER)
|
||||
.map(|h| h.to_str().ok().map(|v| v.to_string()))
|
||||
.flatten()
|
||||
.unwrap_or("none".to_string()),
|
||||
headers
|
||||
.remove(INTERNAL_APPID_HEADER)
|
||||
.map(|h| h.to_str().ok().map(|v| v.to_string()))
|
||||
.flatten()
|
||||
.unwrap_or("none".to_string()),
|
||||
)
|
||||
};
|
||||
let system_id = auth
|
||||
.system_id()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or("none".to_string());
|
||||
|
||||
let app_id = auth
|
||||
.app_id()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or("none".to_string());
|
||||
|
||||
counter!(
|
||||
"pluralkit_api_requests",
|
||||
|
|
|
|||
|
|
@ -9,5 +9,4 @@ pub use ignore_invalid_routes::ignore_invalid_routes;
|
|||
|
||||
pub mod ratelimit;
|
||||
|
||||
mod authnz;
|
||||
pub use authnz::authnz;
|
||||
pub mod auth;
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
|
|||
use metrics::counter;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
|
||||
use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
use crate::{
|
||||
auth::AuthState,
|
||||
util::{header_or_unknown, json_err},
|
||||
};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
|
|
@ -105,23 +106,28 @@ pub async fn do_request_ratelimited(
|
|||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER));
|
||||
let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let auth = extensions
|
||||
.get::<AuthState>()
|
||||
.expect("should always have AuthState");
|
||||
|
||||
// looks like this chooses the tokens/sec by app_id or endpoint
|
||||
// then chooses the key by system_id or source_ip
|
||||
// todo: key should probably be chosen by app_id when it's present
|
||||
// todo: make x-ratelimit-scope actually meaningful
|
||||
|
||||
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||
let rlimit = if authenticated_app_id == "1" {
|
||||
let rlimit = if let Some(app_id) = auth.app_id()
|
||||
&& app_id == 1
|
||||
{
|
||||
RatelimitType::TempCustom
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
|
|
@ -133,12 +139,12 @@ pub async fn do_request_ratelimited(
|
|||
|
||||
let rl_key = format!(
|
||||
"{}:{}",
|
||||
if authenticated_system_id != "unknown"
|
||||
if let Some(system_id) = auth.system_id()
|
||||
&& matches!(rlimit, RatelimitType::GenericUpdate)
|
||||
{
|
||||
authenticated_system_id
|
||||
system_id.to_string()
|
||||
} else {
|
||||
source_ip
|
||||
source_ip.to_string()
|
||||
},
|
||||
rlimit.key()
|
||||
);
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue