PluralKit/services/api/src/middleware/ratelimit.rs

223 lines
7.4 KiB
Rust
Raw Normal View History

2023-02-15 19:27:36 -05:00
use std::time::{Duration, SystemTime};
use axum::{
2024-08-04 07:48:16 +09:00
extract::{MatchedPath, Request, State},
http::{HeaderValue, Method, StatusCode},
2023-02-15 19:27:36 -05:00
middleware::{FromFnLayer, Next},
response::Response,
};
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
2023-03-18 23:06:55 -04:00
use metrics::increment_counter;
use tracing::{debug, error, info, warn};
2023-02-15 19:27:36 -05:00
use crate::util::{header_or_unknown, json_err};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
lazy_static::lazy_static! {
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
}
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
let redis = libpk::config.api.ratelimit_redis_addr.as_ref().map(|val| {
let r = fred::pool::RedisPool::new(
fred::types::RedisConfig::from_url_centralized(val.as_ref())
.expect("redis url is invalid"),
10,
)
.expect("failed to connect to redis");
let handle = r.connect(Some(ReconnectPolicy::default()));
tokio::spawn(async move { handle });
let rscript = r.clone();
tokio::spawn(async move {
if let Ok(()) = rscript.wait_for_connect().await {
match rscript.script_load(LUA_SCRIPT).await {
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
}
} else {
error!("could not wait for connection to load redis script!");
}
});
r
});
if redis.is_none() {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
}
2024-08-04 07:48:16 +09:00
enum RatelimitType {
GenericGet,
GenericUpdate,
Message,
TempCustom,
}
impl RatelimitType {
fn key(&self) -> String {
match self {
RatelimitType::GenericGet => "generic_get",
RatelimitType::GenericUpdate => "generic_update",
RatelimitType::Message => "message",
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
}
.to_string()
}
fn rate(&self) -> i32 {
match self {
RatelimitType::GenericGet => 10,
RatelimitType::GenericUpdate => 3,
RatelimitType::Message => 10,
RatelimitType::TempCustom => 20,
}
}
}
pub async fn do_request_ratelimited(
2023-02-15 19:27:36 -05:00
State(redis): State<Option<RedisPool>>,
request: Request,
next: Next,
2023-02-15 19:27:36 -05:00
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
2024-08-04 07:48:16 +09:00
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
2023-02-15 19:27:36 -05:00
2023-03-19 11:18:09 -04:00
// https://github.com/rust-lang/rust/issues/53667
2024-08-04 07:48:16 +09:00
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
2023-02-15 19:27:36 -05:00
{
2023-03-19 11:18:09 -04:00
if let Some(token2) = &libpk::config.api.temp_token2 {
if header.to_str().unwrap_or("invalid") == token2 {
2024-08-04 07:48:16 +09:00
true
2023-03-19 11:18:09 -04:00
} else {
2024-08-04 07:48:16 +09:00
false
2023-03-19 11:18:09 -04:00
}
2023-02-15 19:27:36 -05:00
} else {
2024-08-04 07:48:16 +09:00
false
2023-02-15 19:27:36 -05:00
}
} else {
2024-08-04 07:48:16 +09:00
false
2023-02-15 19:27:36 -05:00
};
2024-08-04 07:48:16 +09:00
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
} else if request.method() == Method::GET {
RatelimitType::GenericGet
} else {
RatelimitType::GenericUpdate
};
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
} else {
source_ip
},
rlimit.key()
);
2023-02-15 19:27:36 -05:00
let burst = 5;
let period = 1; // seconds
// local rate_limit_key = KEYS[1]
// local burst = ARGV[1]
// local rate = ARGV[2]
// local period = ARGV[3]
2023-03-18 20:44:50 -04:00
// return {remaining, tostring(retry_after), reset_after}
2023-02-15 19:27:36 -05:00
let resp = redis
2024-08-04 07:48:16 +09:00
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
LUA_SCRIPT_SHA.to_string(),
vec![rl_key.clone()],
vec![burst, rlimit.rate(), period],
2023-02-15 19:27:36 -05:00
)
.await;
match resp {
Ok((mut remaining, retry_after, reset_after)) => {
2023-03-18 20:44:50 -04:00
// redis's lua doesn't support returning floats
let retry_after: f64 = retry_after
.parse()
.expect("got something that isn't a f64 from redis");
2023-02-15 19:27:36 -05:00
let mut response = if remaining > 0 {
next.run(request).await
} else {
2023-03-18 23:06:55 -04:00
let retry_after = (retry_after * 1_000_f64).ceil() as u64;
2024-08-04 07:48:16 +09:00
debug!("ratelimited request from {rl_key}, retry_after={retry_after}",);
2023-03-18 23:06:55 -04:00
increment_counter!("pk_http_requests_ratelimited");
2023-02-15 19:27:36 -05:00
json_err(
StatusCode::TOO_MANY_REQUESTS,
format!(
2024-08-04 07:48:16 +09:00
r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#,
rlimit.key(),
2023-02-15 19:27:36 -05:00
),
)
};
// the redis script puts burst in remaining for ??? some reason
2024-08-04 07:48:16 +09:00
remaining -= burst - rlimit.rate();
2023-02-15 19:27:36 -05:00
let reset_time = SystemTime::now()
.checked_add(Duration::from_secs(reset_after))
.expect("invalid timestamp")
.duration_since(std::time::UNIX_EPOCH)
.expect("invalid duration")
.as_secs();
let headers = response.headers_mut();
2024-08-04 07:48:16 +09:00
headers.insert(
"X-RateLimit-Scope",
HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"),
);
2023-02-15 19:27:36 -05:00
headers.insert(
"X-RateLimit-Limit",
2024-08-04 07:48:16 +09:00
HeaderValue::from_str(format!("{}", rlimit.rate()).as_str())
2023-02-15 19:27:36 -05:00
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(format!("{}", remaining).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(format!("{}", reset_time).as_str())
.expect("invalid header value"),
);
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
);
}
}
}
next.run(request).await
}