mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-06 13:57:54 +00:00
chore: reorganize rust crates
This commit is contained in:
parent
357122a892
commit
16ce67e02c
58 changed files with 6 additions and 13 deletions
45
crates/api/src/middleware/authnz.rs
Normal file
45
crates/api/src/middleware/authnz.rs
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::HeaderValue,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::ApiContext;
|
||||
|
||||
use super::logger::DID_AUTHENTICATE_HEADER;
|
||||
|
||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||
let headers = request.headers_mut();
|
||||
headers.remove("x-pluralkit-systemid");
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten();
|
||||
let mut authenticated = false;
|
||||
if let Some(auth_header) = auth_header {
|
||||
if let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
None
|
||||
}
|
||||
}
|
||||
{
|
||||
headers.append(
|
||||
"x-pluralkit-systemid",
|
||||
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
|
||||
);
|
||||
authenticated = true;
|
||||
}
|
||||
}
|
||||
let mut response = next.run(request).await;
|
||||
if authenticated {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
|
||||
}
|
||||
response
|
||||
}
|
||||
28
crates/api/src/middleware/cors.rs
Normal file
28
crates/api/src/middleware/cors.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
use axum::{
|
||||
extract::Request,
|
||||
http::{HeaderMap, HeaderValue, Method, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn add_cors_headers(headers: &mut HeaderMap) {
|
||||
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
|
||||
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
|
||||
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
|
||||
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
|
||||
}
|
||||
|
||||
pub async fn cors(request: Request, next: Next) -> Response {
|
||||
let mut response = if request.method() == Method::OPTIONS {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
next.run(request).await
|
||||
};
|
||||
|
||||
add_cors_headers(response.headers_mut());
|
||||
|
||||
response
|
||||
}
|
||||
64
crates/api/src/middleware/ignore_invalid_routes.rs
Normal file
64
crates/api/src/middleware/ignore_invalid_routes.rs
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
use axum::{
|
||||
extract::MatchedPath,
|
||||
extract::Request,
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
|
||||
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
|
||||
path.starts_with("/v2/s/")
|
||||
|| path.starts_with("/v2/m/")
|
||||
|| path.starts_with("/v2/a/")
|
||||
|| path.starts_with("/v2/msg/")
|
||||
|| path == "/v2/s"
|
||||
|| path == "/v2/m"
|
||||
}
|
||||
|
||||
pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
|
||||
let path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
if request.uri().path().starts_with("/v1") {
|
||||
(
|
||||
StatusCode::GONE,
|
||||
r#"{"message":"Unsupported API version","code":0}"#,
|
||||
)
|
||||
.into_response()
|
||||
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"Invalid path for API version","code":0}"#,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
|
||||
else if !request.uri().clone().path().starts_with("/v2")
|
||||
&& !request.uri().clone().path().starts_with("/private")
|
||||
{
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"Unsupported API version","code":0}"#,
|
||||
)
|
||||
.into_response();
|
||||
} else if path == "unknown" {
|
||||
// current prod api responds with 404 with empty body to invalid endpoints
|
||||
// just doing that here as well but i'm not sure if it's the correct behaviour
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
|
||||
// but "unknown" isn't really a valid user-agent
|
||||
else if user_agent == "unknown" {
|
||||
// please set a valid user-agent
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
} else {
|
||||
next.run(request).await
|
||||
}
|
||||
}
|
||||
98
crates/api/src/middleware/logger.rs
Normal file
98
crates/api/src/middleware/logger.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
|
||||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
remote_ip,
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
user_agent
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut response = next.run(request).instrument(request_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
let authenticated = {
|
||||
let headers = response.headers_mut();
|
||||
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
|
||||
headers.remove(DID_AUTHENTICATE_HEADER);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
counter!(
|
||||
"pluralkit_api_requests",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"pluralkit_api_requests_bucket",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
endpoint,
|
||||
elapsed
|
||||
);
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
counter!(
|
||||
"pluralkit_api_slow_requests_count",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
13
crates/api/src/middleware/mod.rs
Normal file
13
crates/api/src/middleware/mod.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
mod cors;
|
||||
pub use cors::cors;
|
||||
|
||||
mod logger;
|
||||
pub use logger::logger;
|
||||
|
||||
mod ignore_invalid_routes;
|
||||
pub use ignore_invalid_routes::ignore_invalid_routes;
|
||||
|
||||
pub mod ratelimit;
|
||||
|
||||
mod authnz;
|
||||
pub use authnz::authnz;
|
||||
61
crates/api/src/middleware/ratelimit.lua
Normal file
61
crates/api/src/middleware/ratelimit.lua
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
-- this script has side-effects, so it requires replicate commands mode
|
||||
-- redis.replicate_commands()
|
||||
|
||||
local rate_limit_key = KEYS[1]
|
||||
local rate = ARGV[1]
|
||||
local period = ARGV[2]
|
||||
local cost = tonumber(ARGV[3])
|
||||
|
||||
local burst = rate
|
||||
|
||||
local emission_interval = period / rate
|
||||
local increment = emission_interval * cost
|
||||
local burst_offset = emission_interval * burst
|
||||
|
||||
-- redis returns time as an array containing two integers: seconds of the epoch
|
||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
|
||||
-- convert them to a floating point number. the resulting number is 16 digits,
|
||||
-- bordering on the limits of a 64-bit double-precision floating point number.
|
||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
|
||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
|
||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
|
||||
local jan_1_2017 = 1483228800
|
||||
local now = redis.call("TIME")
|
||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
|
||||
|
||||
local tat = redis.call("GET", rate_limit_key)
|
||||
|
||||
if not tat then
|
||||
tat = now
|
||||
else
|
||||
tat = tonumber(tat)
|
||||
end
|
||||
|
||||
tat = math.max(tat, now)
|
||||
|
||||
local new_tat = tat + increment
|
||||
local allow_at = new_tat - burst_offset
|
||||
|
||||
local diff = now - allow_at
|
||||
local remaining = diff / emission_interval
|
||||
|
||||
if remaining < 0 then
|
||||
local reset_after = tat - now
|
||||
local retry_after = diff * -1
|
||||
return {
|
||||
0, -- remaining
|
||||
tostring(retry_after),
|
||||
reset_after,
|
||||
}
|
||||
end
|
||||
|
||||
local reset_after = new_tat - now
|
||||
if reset_after > 0 then
|
||||
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
|
||||
end
|
||||
local retry_after = -1
|
||||
return {
|
||||
remaining,
|
||||
tostring(retry_after),
|
||||
reset_after
|
||||
}
|
||||
238
crates/api/src/middleware/ratelimit.rs
Normal file
238
crates/api/src/middleware/ratelimit.rs
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use axum::{
|
||||
extract::{MatchedPath, Request, State},
|
||||
http::{HeaderValue, Method, StatusCode},
|
||||
middleware::{FromFnLayer, Next},
|
||||
response::Response,
|
||||
};
|
||||
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
|
||||
use metrics::counter;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
|
||||
}
|
||||
|
||||
// this is awful but it works
|
||||
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
||||
let redis = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.ratelimit_redis_addr
|
||||
.as_ref()
|
||||
.map(|val| {
|
||||
// todo: this should probably use the global pool
|
||||
let r = RedisPool::new(
|
||||
fred::types::RedisConfig::from_url_centralized(val.as_ref())
|
||||
.expect("redis url is invalid"),
|
||||
None,
|
||||
None,
|
||||
Some(Default::default()),
|
||||
10,
|
||||
)
|
||||
.expect("failed to connect to redis");
|
||||
|
||||
let handle = r.connect();
|
||||
|
||||
tokio::spawn(async move { handle });
|
||||
|
||||
let rscript = r.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Ok(()) = rscript.wait_for_connect().await {
|
||||
match rscript
|
||||
.script_load::<String, String>(LUA_SCRIPT.to_string())
|
||||
.await
|
||||
{
|
||||
Ok(_) => info!("connected to redis for request rate limiting"),
|
||||
Err(err) => error!("could not load redis script: {}", err),
|
||||
}
|
||||
} else {
|
||||
error!("could not wait for connection to load redis script!");
|
||||
}
|
||||
});
|
||||
|
||||
r
|
||||
});
|
||||
|
||||
if redis.is_none() {
|
||||
warn!("running without request rate limiting!");
|
||||
}
|
||||
|
||||
axum::middleware::from_fn_with_state(redis, f)
|
||||
}
|
||||
|
||||
enum RatelimitType {
|
||||
GenericGet,
|
||||
GenericUpdate,
|
||||
Message,
|
||||
TempCustom,
|
||||
}
|
||||
|
||||
impl RatelimitType {
|
||||
fn key(&self) -> String {
|
||||
match self {
|
||||
RatelimitType::GenericGet => "generic_get",
|
||||
RatelimitType::GenericUpdate => "generic_update",
|
||||
RatelimitType::Message => "message",
|
||||
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn rate(&self) -> i32 {
|
||||
match self {
|
||||
RatelimitType::GenericGet => 10,
|
||||
RatelimitType::GenericUpdate => 3,
|
||||
RatelimitType::Message => 10,
|
||||
RatelimitType::TempCustom => 20,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_request_ratelimited(
|
||||
State(redis): State<Option<RedisPool>>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
|
||||
|
||||
// https://github.com/rust-lang/rust/issues/53667
|
||||
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
||||
{
|
||||
if let Some(token2) = &libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
{
|
||||
if header.to_str().unwrap_or("invalid") == token2 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let rlimit = if is_temp_token2 {
|
||||
RatelimitType::TempCustom
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
} else if request.method() == Method::GET {
|
||||
RatelimitType::GenericGet
|
||||
} else {
|
||||
RatelimitType::GenericUpdate
|
||||
};
|
||||
|
||||
let rl_key = format!(
|
||||
"{}:{}",
|
||||
if authenticated_system_id != "unknown"
|
||||
&& matches!(rlimit, RatelimitType::GenericUpdate)
|
||||
{
|
||||
authenticated_system_id
|
||||
} else {
|
||||
source_ip
|
||||
},
|
||||
rlimit.key()
|
||||
);
|
||||
|
||||
let period = 1; // seconds
|
||||
let cost = 1; // todo: update this for group member endpoints
|
||||
|
||||
// local rate_limit_key = KEYS[1]
|
||||
// local rate = ARGV[1]
|
||||
// local period = ARGV[2]
|
||||
// return {remaining, tostring(retry_after), reset_after}
|
||||
|
||||
// todo: check if error is script not found and reload script
|
||||
let resp = redis
|
||||
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
|
||||
LUA_SCRIPT_SHA.to_string(),
|
||||
vec![rl_key.clone()],
|
||||
vec![rlimit.rate(), period, cost],
|
||||
)
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok((remaining, retry_after, reset_after)) => {
|
||||
// redis's lua doesn't support returning floats
|
||||
let retry_after: f64 = retry_after
|
||||
.parse()
|
||||
.expect("got something that isn't a f64 from redis");
|
||||
|
||||
let mut response = if remaining > 0 {
|
||||
next.run(request).await
|
||||
} else {
|
||||
let retry_after = (retry_after * 1_000_f64).ceil() as u64;
|
||||
debug!("ratelimited request from {rl_key}, retry_after={retry_after}",);
|
||||
counter!("pk_http_requests_ratelimited").increment(1);
|
||||
json_err(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!(
|
||||
r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#,
|
||||
rlimit.key(),
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let reset_time = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(reset_after))
|
||||
.expect("invalid timestamp")
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("invalid duration")
|
||||
.as_secs();
|
||||
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
"X-RateLimit-Scope",
|
||||
HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Limit",
|
||||
HeaderValue::from_str(format!("{}", rlimit.rate()).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Remaining",
|
||||
HeaderValue::from_str(format!("{}", remaining).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Reset",
|
||||
HeaderValue::from_str(format!("{}", reset_time).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
|
||||
return response;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("error getting ratelimit info: {}", err);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue