mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-10 15:57:53 +00:00
[WIP] feat: scoped api keys
This commit is contained in:
parent
e7ee593a85
commit
06cb160f95
45 changed files with 1264 additions and 154 deletions
|
|
@ -8,12 +8,15 @@ use axum::{
|
|||
};
|
||||
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
|
||||
use metrics::counter;
|
||||
use sqlx::Postgres;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
ApiContext,
|
||||
auth::AuthState,
|
||||
util::{header_or_unknown, json_err},
|
||||
};
|
||||
use pluralkit_models::PKExternalApp;
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
|
|
@ -22,7 +25,10 @@ lazy_static::lazy_static! {
|
|||
}
|
||||
|
||||
// this is awful but it works
|
||||
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
||||
pub fn ratelimiter<F, T>(
|
||||
ctx: ApiContext,
|
||||
f: F,
|
||||
) -> FromFnLayer<F, (ApiContext, Option<RedisPool>), T> {
|
||||
let redis = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
|
|
@ -52,14 +58,14 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
|||
warn!("running without request rate limiting!");
|
||||
}
|
||||
|
||||
axum::middleware::from_fn_with_state(redis, f)
|
||||
axum::middleware::from_fn_with_state((ctx, redis), f)
|
||||
}
|
||||
|
||||
enum RatelimitType {
|
||||
GenericGet,
|
||||
GenericUpdate,
|
||||
Message,
|
||||
TempCustom,
|
||||
AppCustom(i32),
|
||||
}
|
||||
|
||||
impl RatelimitType {
|
||||
|
|
@ -68,7 +74,7 @@ impl RatelimitType {
|
|||
RatelimitType::GenericGet => "generic_get",
|
||||
RatelimitType::GenericUpdate => "generic_update",
|
||||
RatelimitType::Message => "message",
|
||||
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
|
||||
RatelimitType::AppCustom(_) => "app_custom",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
|
@ -78,21 +84,41 @@ impl RatelimitType {
|
|||
RatelimitType::GenericGet => 10,
|
||||
RatelimitType::GenericUpdate => 3,
|
||||
RatelimitType::Message => 10,
|
||||
RatelimitType::TempCustom => 20,
|
||||
RatelimitType::AppCustom(n) => *n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_request_ratelimited(
|
||||
State(redis): State<Option<RedisPool>>,
|
||||
State((ctx, redis)): State<(ApiContext, Option<RedisPool>)>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
if headers.get("x-pluralkit-internal").is_some() {
|
||||
// bypass ratelimiting entirely for internal requests
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let extensions = request.extensions().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
|
||||
let mut app_rate: Option<i32> = None;
|
||||
if let Some(app_header) = request.headers().clone().get("x-pluralkit-app") {
|
||||
let app_token = app_header.to_str().unwrap_or("invalid");
|
||||
if app_token.starts_with("pkap2:")
|
||||
&& let Some(app) = sqlx::query_as::<Postgres, PKExternalApp>(
|
||||
"select * from external_apps where api_rl_token = $1",
|
||||
)
|
||||
.bind(&app_token[6..])
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
.expect("failed to query external app in postgres")
|
||||
{
|
||||
app_rate = Some(app.api_rl_rate.expect("external app has no api_rl_rate"));
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
|
|
@ -109,11 +135,8 @@ pub async fn do_request_ratelimited(
|
|||
// todo: key should probably be chosen by app_id when it's present
|
||||
// todo: make x-ratelimit-scope actually meaningful
|
||||
|
||||
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||
let rlimit = if let Some(app_id) = auth.app_id()
|
||||
&& app_id == 1
|
||||
{
|
||||
RatelimitType::TempCustom
|
||||
let rlimit = if let Some(r) = app_rate {
|
||||
RatelimitType::AppCustom(r)
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
} else if request.method() == Method::GET {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue