feat(api): new ratelimit handling

This commit is contained in:
alyssa 2024-08-04 07:48:16 +09:00
parent cfde105e19
commit e23528383f
3 changed files with 91 additions and 20 deletions

View file

@ -7,6 +7,7 @@ local rate = ARGV[2]
local period = ARGV[3]
-- we're only ever asking for 1 request at a time
-- todo: this is no longer true
local cost = 1 --local cost = tonumber(ARGV[4])
local emission_interval = period / rate

View file

@ -1,8 +1,8 @@
use std::time::{Duration, SystemTime};
use axum::http::{HeaderValue, StatusCode};
use axum::{
extract::{Request, State},
extract::{MatchedPath, Request, State},
http::{HeaderValue, Method, StatusCode},
middleware::{FromFnLayer, Next},
response::Response,
};
@ -54,6 +54,34 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
axum::middleware::from_fn_with_state(redis, f)
}
enum RatelimitType {
GenericGet,
GenericUpdate,
Message,
TempCustom,
}
impl RatelimitType {
fn key(&self) -> String {
match self {
RatelimitType::GenericGet => "generic_get",
RatelimitType::GenericUpdate => "generic_update",
RatelimitType::Message => "message",
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
}
.to_string()
}
fn rate(&self) -> i32 {
match self {
RatelimitType::GenericGet => 10,
RatelimitType::GenericUpdate => 3,
RatelimitType::Message => 10,
RatelimitType::TempCustom => 20,
}
}
}
pub async fn do_request_ratelimited(
State(redis): State<Option<RedisPool>>,
request: Request,
@ -62,40 +90,66 @@ pub async fn do_request_ratelimited(
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
// https://github.com/rust-lang/rust/issues/53667
let (rl_key, rate) = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if let Some(token2) = &libpk::config.api.temp_token2 {
if header.to_str().unwrap_or("invalid") == token2 {
("token2", 20)
true
} else {
(source_ip, 2)
false
}
} else {
(source_ip, 2)
false
}
} else {
(source_ip, 2)
false
};
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
} else if request.method() == Method::GET {
RatelimitType::GenericGet
} else {
RatelimitType::GenericUpdate
};
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
} else {
source_ip
},
rlimit.key()
);
let burst = 5;
let period = 1; // seconds
// todo: make this static
// though even if it's not static, it's probably cheaper than sending the entire script to redis every time
let scriptsha = sha1_hash(&LUA_SCRIPT);
// local rate_limit_key = KEYS[1]
// local burst = ARGV[1]
// local rate = ARGV[2]
// local period = ARGV[3]
// return {remaining, tostring(retry_after), reset_after}
let resp = redis
.evalsha::<(i32, String, u64), String, Vec<&str>, Vec<i32>>(
scriptsha,
vec![rl_key],
vec![burst, rate, period],
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
LUA_SCRIPT_SHA.to_string(),
vec![rl_key.clone()],
vec![burst, rlimit.rate(), period],
)
.await;
@ -110,18 +164,19 @@ pub async fn do_request_ratelimited(
next.run(request).await
} else {
let retry_after = (retry_after * 1_000_f64).ceil() as u64;
debug!("ratelimited request from {rl_key}, retry_after={retry_after}");
debug!("ratelimited request from {rl_key}, retry_after={retry_after}",);
increment_counter!("pk_http_requests_ratelimited");
json_err(
StatusCode::TOO_MANY_REQUESTS,
format!(
r#"{{"message":"429: too many requests","retry_after":{retry_after},"code":0}}"#,
r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#,
rlimit.key(),
),
)
};
// the redis script puts burst in remaining for ??? some reason
remaining -= burst - rate;
remaining -= burst - rlimit.rate();
let reset_time = SystemTime::now()
.checked_add(Duration::from_secs(reset_after))
@ -131,9 +186,13 @@ pub async fn do_request_ratelimited(
.as_secs();
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Scope",
HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(format!("{}", rate).as_str())
HeaderValue::from_str(format!("{}", rlimit.rate()).as_str())
.expect("invalid header value"),
);
headers.insert(