chore: reorganize rust crates

This commit is contained in:
alyssa 2025-01-02 00:50:36 +00:00
parent 357122a892
commit 16ce67e02c
58 changed files with 6 additions and 13 deletions

View file

@ -0,0 +1 @@
pub mod private;

View file

@ -0,0 +1,203 @@
use crate::ApiContext;
use axum::{extract::State, response::Json};
use fred::interfaces::*;
use libpk::state::ShardState;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {
pub guild_count: i32,
pub channel_count: i32,
}
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
let mut shard_status = ctx
.redis
.hgetall::<HashMap<String, String>, &str>("pluralkit:shardstatus")
.await
.unwrap()
.values()
.map(|v| serde_json::from_str(v).expect("could not deserialize shard"))
.collect::<Vec<ShardState>>();
shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id));
Json(json!({
"shards": shard_status,
}))
}
pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
let cluster_stats = ctx
.redis
.hgetall::<HashMap<String, String>, &str>("pluralkit:cluster_stats")
.await
.unwrap()
.values()
.map(|v| serde_json::from_str(v).unwrap())
.collect::<Vec<ClusterStats>>();
let db_stats = libpk::db::repository::get_stats(&ctx.db).await.unwrap();
let guild_count: i32 = cluster_stats.iter().map(|v| v.guild_count).sum();
let channel_count: i32 = cluster_stats.iter().map(|v| v.channel_count).sum();
Json(json!({
"system_count": db_stats.system_count,
"member_count": db_stats.member_count,
"group_count": db_stats.group_count,
"switch_count": db_stats.switch_count,
"message_count": db_stats.message_count,
"guild_count": guild_count,
"channel_count": channel_count,
}))
}
use std::time::Duration;
use crate::util::json_err;
use axum::{
extract,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use libpk::config;
use pluralkit_models::{PKSystem, PKSystemConfig};
use reqwest::ClientBuilder;
#[derive(serde::Deserialize, Debug)]
pub struct CallbackRequestData {
redirect_domain: String,
code: String,
// state: String,
}
#[derive(serde::Serialize)]
struct CallbackDiscordData {
client_id: String,
client_secret: String,
grant_type: String,
redirect_uri: String,
code: String,
}
pub async fn discord_callback(
State(ctx): State<ApiContext>,
extract::Json(request_data): extract::Json<CallbackRequestData>,
) -> Response {
let client = ClientBuilder::new()
.connect_timeout(Duration::from_secs(3))
.timeout(Duration::from_secs(3))
.build()
.expect("error making client");
let reqbody = serde_urlencoded::to_string(&CallbackDiscordData {
client_id: config.discord.as_ref().unwrap().client_id.get().to_string(),
client_secret: config.discord.as_ref().unwrap().client_secret.clone(),
grant_type: "authorization_code".to_string(),
redirect_uri: request_data.redirect_domain, // change this!
code: request_data.code,
})
.expect("could not serialize");
let discord_resp = client
.post("https://discord.com/api/v10/oauth2/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(reqbody)
.send()
.await
.expect("failed to request discord");
let Value::Object(discord_data) = discord_resp
.json::<Value>()
.await
.expect("failed to deserialize discord response as json")
else {
panic!("discord response is not an object")
};
if !discord_data.contains_key("access_token") {
return json_err(
StatusCode::BAD_REQUEST,
format!(
"{{\"error\":\"{}\"\"}}",
discord_data
.get("error_description")
.expect("missing error_description from discord")
.to_string()
),
);
};
let token = format!(
"Bearer {}",
discord_data
.get("access_token")
.expect("missing access_token")
.as_str()
.unwrap()
);
let discord_client = twilight_http::Client::new(token);
let user = discord_client
.current_user()
.await
.expect("failed to get current user from discord")
.model()
.await
.expect("failed to parse user model from discord");
let system: Option<PKSystem> = sqlx::query_as(
r#"
select systems.*
from accounts
left join systems on accounts.system = systems.id
where accounts.uid = $1
"#,
)
.bind(user.id.get() as i64)
.fetch_optional(&ctx.db)
.await
.expect("failed to query");
if system.is_none() {
return json_err(
StatusCode::BAD_REQUEST,
"user does not have a system registered".to_string(),
);
}
let system = system.unwrap();
let system_config: Option<PKSystemConfig> = sqlx::query_as(
r#"
select * from system_config where system = $1
"#,
)
.bind(system.id)
.fetch_optional(&ctx.db)
.await
.expect("failed to query");
let system_config = system_config.unwrap();
// create dashboard token for system
let token = system.clone().token;
(
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"system": system.to_json(),
"config": system_config.to_json(),
"user": user,
"token": token,
}))
.expect("should not error"),
)
.into_response()
}

29
crates/api/src/error.rs Normal file
View file

@ -0,0 +1,29 @@
use axum::http::StatusCode;
use std::fmt;
#[derive(Debug)]
pub struct PKError {
pub response_code: StatusCode,
pub json_code: i32,
pub message: &'static str,
}
impl fmt::Display for PKError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for PKError {}
macro_rules! define_error {
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
const $name: PKError = PKError {
response_code: $response_code,
json_code: $json_code,
message: $message,
};
};
}
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }

169
crates/api/src/main.rs Normal file
View file

@ -0,0 +1,169 @@
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
routing::{delete, get, patch, post},
Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use tracing::{error, info};
mod endpoints;
mod error;
mod middleware;
mod util;
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::clients::RedisPool,
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
}
async fn rproxy(
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
Ok(ctx
.rproxy_client
.request(req)
.await
.map_err(|err| {
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
StatusCode::BAD_GATEWAY
})?
.into_response())
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
fn router(ctx: ApiContext) -> Router {
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/v2/systems/:system_id", get(rproxy))
.route("/v2/systems/:system_id", patch(rproxy))
.route("/v2/systems/:system_id/settings", get(rproxy))
.route("/v2/systems/:system_id/settings", patch(rproxy))
.route("/v2/systems/:system_id/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/:member_id", get(rproxy))
.route("/v2/members/:member_id", patch(rproxy))
.route("/v2/members/:member_id", delete(rproxy))
.route("/v2/systems/:system_id/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/:group_id", get(rproxy))
.route("/v2/groups/:group_id", patch(rproxy))
.route("/v2/groups/:group_id", delete(rproxy))
.route("/v2/groups/:group_id/members", get(rproxy))
.route("/v2/groups/:group_id/members/add", post(rproxy))
.route("/v2/groups/:group_id/members/remove", post(rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
.route("/v2/members/:member_id/groups", get(rproxy))
.route("/v2/members/:member_id/groups/add", post(rproxy))
.route("/v2/members/:member_id/groups/remove", post(rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
.route("/v2/systems/:system_id/switches", get(rproxy))
.route("/v2/systems/:system_id/switches", post(rproxy))
.route("/v2/systems/:system_id/fronters", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
.route("/v2/messages/:message_id", get(rproxy))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
.route("/private/discord/callback", post(rproxy))
.route("/private/discord/callback2", post(endpoints::private::discord_callback))
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.layer(axum::middleware::from_fn(middleware::logger))
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
.with_state(ctx)
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
}
libpk::main!("api");
async fn real_main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let rproxy_uri = Uri::from_static(
&libpk::config
.api
.as_ref()
.expect("missing api config")
.remote_url,
)
.to_string();
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let ctx = ApiContext {
db,
redis,
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
};
let app = router(ctx);
let addr: &str = libpk::config
.api
.as_ref()
.expect("missing api config")
.addr
.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,45 @@
use axum::{
extract::{Request, State},
http::HeaderValue,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::ApiContext;
use super::logger::DID_AUTHENTICATE_HEADER;
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
let headers = request.headers_mut();
headers.remove("x-pluralkit-systemid");
let auth_header = headers
.get("authorization")
.map(|h| h.to_str().ok())
.flatten();
let mut authenticated = false;
if let Some(auth_header) = auth_header {
if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
None
}
}
{
headers.append(
"x-pluralkit-systemid",
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
);
authenticated = true;
}
}
let mut response = next.run(request).await;
if authenticated {
response
.headers_mut()
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
}
response
}

View file

@ -0,0 +1,28 @@
use axum::{
extract::Request,
http::{HeaderMap, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[rustfmt::skip]
fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}
pub async fn cors(request: Request, next: Next) -> Response {
let mut response = if request.method() == Method::OPTIONS {
StatusCode::OK.into_response()
} else {
next.run(request).await
};
add_cors_headers(response.headers_mut());
response
}

View file

@ -0,0 +1,64 @@
use axum::{
extract::MatchedPath,
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use crate::util::header_or_unknown;
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
path.starts_with("/v2/s/")
|| path.starts_with("/v2/m/")
|| path.starts_with("/v2/a/")
|| path.starts_with("/v2/msg/")
|| path == "/v2/s"
|| path == "/v2/m"
}
pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
let path = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
if request.uri().path().starts_with("/v1") {
(
StatusCode::GONE,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response()
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
(
StatusCode::BAD_REQUEST,
r#"{"message":"Invalid path for API version","code":0}"#,
)
.into_response()
}
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
else if !request.uri().clone().path().starts_with("/v2")
&& !request.uri().clone().path().starts_with("/private")
{
return (
StatusCode::BAD_REQUEST,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response();
} else if path == "unknown" {
// current prod api responds with 404 with empty body to invalid endpoints
// just doing that here as well but i'm not sure if it's the correct behaviour
return StatusCode::NOT_FOUND.into_response();
}
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
// but "unknown" isn't really a valid user-agent
else if user_agent == "unknown" {
// please set a valid user-agent
return StatusCode::BAD_REQUEST.into_response();
} else {
next.run(request).await
}
}

View file

@ -0,0 +1,98 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use crate::util::header_or_unknown;
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_span = span!(
Level::INFO,
"request",
remote_ip,
method = method.as_str(),
endpoint = endpoint.clone(),
user_agent
);
let start = Instant::now();
let mut response = next.run(request).instrument(request_span).await;
let elapsed = start.elapsed().as_millis();
let authenticated = {
let headers = response.headers_mut();
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
headers.remove(DID_AUTHENTICATE_HEADER);
true
} else {
false
}
};
counter!(
"pluralkit_api_requests",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
)
.increment(1);
histogram!(
"pluralkit_api_requests_bucket",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
)
.record(elapsed as f64 / 1_000_f64);
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
endpoint,
elapsed
);
if elapsed > MIN_LOG_TIME {
counter!(
"pluralkit_api_slow_requests_count",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
)
.increment(1);
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

View file

@ -0,0 +1,13 @@
mod cors;
pub use cors::cors;
mod logger;
pub use logger::logger;
mod ignore_invalid_routes;
pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod ratelimit;
mod authnz;
pub use authnz::authnz;

View file

@ -0,0 +1,61 @@
-- this script has side-effects, so it requires replicate commands mode
-- redis.replicate_commands()
local rate_limit_key = KEYS[1]
local rate = ARGV[1]
local period = ARGV[2]
local cost = tonumber(ARGV[3])
local burst = rate
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- remaining
tostring(retry_after),
reset_after,
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {
remaining,
tostring(retry_after),
reset_after
}

View file

@ -0,0 +1,238 @@
use std::time::{Duration, SystemTime};
use axum::{
extract::{MatchedPath, Request, State},
http::{HeaderValue, Method, StatusCode},
middleware::{FromFnLayer, Next},
response::Response,
};
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
use metrics::counter;
use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
lazy_static::lazy_static! {
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
}
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
let redis = libpk::config
.api
.as_ref()
.expect("missing api config")
.ratelimit_redis_addr
.as_ref()
.map(|val| {
// todo: this should probably use the global pool
let r = RedisPool::new(
fred::types::RedisConfig::from_url_centralized(val.as_ref())
.expect("redis url is invalid"),
None,
None,
Some(Default::default()),
10,
)
.expect("failed to connect to redis");
let handle = r.connect();
tokio::spawn(async move { handle });
let rscript = r.clone();
tokio::spawn(async move {
if let Ok(()) = rscript.wait_for_connect().await {
match rscript
.script_load::<String, String>(LUA_SCRIPT.to_string())
.await
{
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
}
} else {
error!("could not wait for connection to load redis script!");
}
});
r
});
if redis.is_none() {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
}
enum RatelimitType {
GenericGet,
GenericUpdate,
Message,
TempCustom,
}
impl RatelimitType {
fn key(&self) -> String {
match self {
RatelimitType::GenericGet => "generic_get",
RatelimitType::GenericUpdate => "generic_update",
RatelimitType::Message => "message",
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
}
.to_string()
}
fn rate(&self) -> i32 {
match self {
RatelimitType::GenericGet => 10,
RatelimitType::GenericUpdate => 3,
RatelimitType::Message => 10,
RatelimitType::TempCustom => 20,
}
}
}
pub async fn do_request_ratelimited(
State(redis): State<Option<RedisPool>>,
request: Request,
next: Next,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
// https://github.com/rust-lang/rust/issues/53667
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if let Some(token2) = &libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
{
if header.to_str().unwrap_or("invalid") == token2 {
true
} else {
false
}
} else {
false
}
} else {
false
};
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
} else if request.method() == Method::GET {
RatelimitType::GenericGet
} else {
RatelimitType::GenericUpdate
};
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
} else {
source_ip
},
rlimit.key()
);
let period = 1; // seconds
let cost = 1; // todo: update this for group member endpoints
// local rate_limit_key = KEYS[1]
// local rate = ARGV[1]
// local period = ARGV[2]
// return {remaining, tostring(retry_after), reset_after}
// todo: check if error is script not found and reload script
let resp = redis
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
LUA_SCRIPT_SHA.to_string(),
vec![rl_key.clone()],
vec![rlimit.rate(), period, cost],
)
.await;
match resp {
Ok((remaining, retry_after, reset_after)) => {
// redis's lua doesn't support returning floats
let retry_after: f64 = retry_after
.parse()
.expect("got something that isn't a f64 from redis");
let mut response = if remaining > 0 {
next.run(request).await
} else {
let retry_after = (retry_after * 1_000_f64).ceil() as u64;
debug!("ratelimited request from {rl_key}, retry_after={retry_after}",);
counter!("pk_http_requests_ratelimited").increment(1);
json_err(
StatusCode::TOO_MANY_REQUESTS,
format!(
r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#,
rlimit.key(),
),
)
};
let reset_time = SystemTime::now()
.checked_add(Duration::from_secs(reset_after))
.expect("invalid timestamp")
.duration_since(std::time::UNIX_EPOCH)
.expect("invalid duration")
.as_secs();
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Scope",
HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(format!("{}", rlimit.rate()).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(format!("{}", remaining).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(format!("{}", reset_time).as_str())
.expect("invalid header value"),
);
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
);
}
}
}
next.run(request).await
}

63
crates/api/src/util.rs Normal file
View file

@ -0,0 +1,63 @@
use crate::error::PKError;
use axum::{
http::{HeaderValue, StatusCode},
response::IntoResponse,
};
use serde_json::{json, to_string, Value};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
if let Some(value) = header {
match value.to_str() {
Ok(v) => v,
Err(err) => {
error!("failed to parse header value {:#?}: {:#?}", value, err);
"failed to parse"
}
}
} else {
"unknown"
}
}
pub fn wrapper<F>(handler: F) -> impl Fn() -> axum::response::Response
where
F: Fn() -> anyhow::Result<Value>,
{
move || match handler() {
Ok(v) => (StatusCode::OK, to_string(&v).unwrap()).into_response(),
Err(error) => match error.downcast_ref::<PKError>() {
Some(pkerror) => json_err(
pkerror.response_code,
to_string(&json!({ "message": pkerror.message, "code": pkerror.json_code }))
.unwrap(),
),
None => {
error!(
"error in handler {}: {:#?}",
std::any::type_name::<F>(),
error
);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
},
}
}
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!("caught panic from handler: {:#?}", err);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();
headers.insert("content-type", HeaderValue::from_static("application/json"));
response
}