feat(api): update rust deps, move /private/meta endpoint to rust-api

This commit is contained in:
alyssa 2024-06-16 21:56:14 +09:00
parent f14c421e23
commit e415c6704f
20 changed files with 1835 additions and 244 deletions

View file

@ -5,13 +5,19 @@ edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = "0.6.4"
axum = { workspace = true }
fred = { workspace = true }
http = "0.2.8"
hyper-reverse-proxy = "0.5.1"
hyper = { version = "1.3.1", features = ["http1"] }
hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"] }
lazy_static = { workspace = true }
libpk = { path = "../../lib/libpk" }
metrics = { workspace = true }
prost = { workspace = true }
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["catch-panic"] }
tracing = { workspace = true }

View file

@ -0,0 +1 @@
pub mod private;

View file

@ -0,0 +1,53 @@
use crate::ApiContext;
use axum::{extract::State, response::Json};
use fred::interfaces::*;
use libpk::proto::ShardState;
use prost::Message;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {
pub guild_count: i32,
pub channel_count: i32,
}
pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
let shard_status = ctx
.redis
.hgetall::<HashMap<String, Vec<u8>>, &str>("pluralkit:shardstatus")
.await
.unwrap()
.values()
.map(|v| ShardState::decode(v.as_slice()).unwrap())
.collect::<Vec<ShardState>>();
let cluster_stats = ctx
.redis
.hgetall::<HashMap<String, String>, &str>("pluralkit:cluster_stats")
.await
.unwrap()
.values()
.map(|v| serde_json::from_str(v).unwrap())
.collect::<Vec<ClusterStats>>();
let db_stats = libpk::db::repository::get_stats(&ctx.db).await.unwrap();
let guild_count: i32 = cluster_stats.iter().map(|v| v.guild_count).sum();
let channel_count: i32 = cluster_stats.iter().map(|v| v.channel_count).sum();
Json(json!({
"shards": shard_status,
"stats": {
"system_count": db_stats.system_count,
"member_count": db_stats.member_count,
"group_count": db_stats.group_count,
"switch_count": db_stats.switch_count,
"message_count": db_stats.message_count,
"guild_count": guild_count,
"channel_count": channel_count,
}
}))
}

29
services/api/src/error.rs Normal file
View file

@ -0,0 +1,29 @@
use axum::http::StatusCode;
use std::fmt;
#[derive(Debug)]
pub struct PKError {
pub response_code: StatusCode,
pub json_code: i32,
pub message: &'static str,
}
impl fmt::Display for PKError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for PKError {}
macro_rules! define_error {
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
const $name: PKError = PKError {
response_code: $response_code,
json_code: $json_code,
message: $message,
};
};
}
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }

View file

@ -1,13 +1,58 @@
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
routing::{delete, get, patch, post},
Router,
};
use tracing::info;
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use tracing::{error, info};
mod endpoints;
mod error;
mod middleware;
mod util;
// this function is manually formatted for easier legibility of routes
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::pool::RedisPool,
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
}
async fn rproxy(
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
Ok(ctx
.rproxy_client
.request(req)
.await
.map_err(|err| {
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
StatusCode::BAD_GATEWAY
})?
.into_response())
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
#[tokio::main]
async fn main() -> anyhow::Result<()> {
@ -15,75 +60,93 @@ async fn main() -> anyhow::Result<()> {
libpk::init_metrics()?;
info!("hello world");
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let rproxy_uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let ctx = ApiContext {
db,
redis,
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
};
// processed upside down (???) so we have to put middleware at the end
let app = Router::new()
.route("/v2/systems/:system_id", get(util::rproxy))
.route("/v2/systems/:system_id", patch(util::rproxy))
.route("/v2/systems/:system_id/settings", get(util::rproxy))
.route("/v2/systems/:system_id/settings", patch(util::rproxy))
.route("/v2/systems/:system_id", get(rproxy))
.route("/v2/systems/:system_id", patch(rproxy))
.route("/v2/systems/:system_id/settings", get(rproxy))
.route("/v2/systems/:system_id/settings", patch(rproxy))
.route("/v2/systems/:system_id/members", get(util::rproxy))
.route("/v2/members", post(util::rproxy))
.route("/v2/members/:member_id", get(util::rproxy))
.route("/v2/members/:member_id", patch(util::rproxy))
.route("/v2/members/:member_id", delete(util::rproxy))
.route("/v2/systems/:system_id/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/:member_id", get(rproxy))
.route("/v2/members/:member_id", patch(rproxy))
.route("/v2/members/:member_id", delete(rproxy))
.route("/v2/systems/:system_id/groups", get(util::rproxy))
.route("/v2/groups", post(util::rproxy))
.route("/v2/groups/:group_id", get(util::rproxy))
.route("/v2/groups/:group_id", patch(util::rproxy))
.route("/v2/groups/:group_id", delete(util::rproxy))
.route("/v2/systems/:system_id/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/:group_id", get(rproxy))
.route("/v2/groups/:group_id", patch(rproxy))
.route("/v2/groups/:group_id", delete(rproxy))
.route("/v2/groups/:group_id/members", get(util::rproxy))
.route("/v2/groups/:group_id/members/add", post(util::rproxy))
.route("/v2/groups/:group_id/members/remove", post(util::rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(util::rproxy))
.route("/v2/groups/:group_id/members", get(rproxy))
.route("/v2/groups/:group_id/members/add", post(rproxy))
.route("/v2/groups/:group_id/members/remove", post(rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
.route("/v2/members/:member_id/groups", get(util::rproxy))
.route("/v2/members/:member_id/groups/add", post(util::rproxy))
.route("/v2/members/:member_id/groups/remove", post(util::rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(util::rproxy))
.route("/v2/members/:member_id/groups", get(rproxy))
.route("/v2/members/:member_id/groups/add", post(rproxy))
.route("/v2/members/:member_id/groups/remove", post(rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
.route("/v2/systems/:system_id/switches", get(util::rproxy))
.route("/v2/systems/:system_id/switches", post(util::rproxy))
.route("/v2/systems/:system_id/fronters", get(util::rproxy))
.route("/v2/systems/:system_id/switches", get(rproxy))
.route("/v2/systems/:system_id/switches", post(rproxy))
.route("/v2/systems/:system_id/fronters", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(util::rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(util::rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(util::rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(util::rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(util::rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
.route("/v2/systems/:system_id/autoproxy", get(util::rproxy))
.route("/v2/systems/:system_id/autoproxy", patch(util::rproxy))
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
.route("/v2/messages/:message_id", get(util::rproxy))
.route("/v2/messages/:message_id", get(rproxy))
.route("/private/meta", get(util::rproxy))
.route("/private/bulk_privacy/member", post(util::rproxy))
.route("/private/bulk_privacy/group", post(util::rproxy))
.route("/private/discord/callback", post(util::rproxy))
.route("/private/meta", get(endpoints::private::meta))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
.route("/private/discord/callback", post(rproxy))
.route("/v2/systems/:system_id/oembed.json", get(util::rproxy))
.route("/v2/members/:member_id/oembed.json", get(util::rproxy))
.route("/v2/groups/:group_id/oembed.json", get(util::rproxy))
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.layer(axum::middleware::from_fn(middleware::logger))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
.with_state(ctx)
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }));
let addr: &str = libpk::config.api.addr.as_ref();
axum::Server::bind(&addr.parse()?)
.serve(app.into_make_service())
.await?;
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -1,5 +1,6 @@
use axum::{
http::{HeaderMap, HeaderValue, Method, Request, StatusCode},
extract::Request,
http::{HeaderMap, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
@ -14,7 +15,7 @@ fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}
pub async fn cors<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn cors(request: Request, next: Next) -> Response {
let mut response = if request.method() == Method::OPTIONS {
StatusCode::OK.into_response()
} else {

View file

@ -1,6 +1,7 @@
use axum::{
extract::MatchedPath,
http::{Request, StatusCode},
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
@ -16,7 +17,7 @@ fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
|| path == "/v2/m"
}
pub async fn ignore_invalid_routes<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
let path = request
.extensions()
.get::<MatchedPath>()

View file

@ -1,6 +1,6 @@
use std::time::Instant;
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::Response};
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use metrics::histogram;
use tracing::{info, span, warn, Instrument, Level};
@ -10,7 +10,7 @@ use crate::util::header_or_unknown;
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger<B>(request: Request<B>, next: Next<B>) -> Response {
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let request_id = header_or_unknown(request.headers().get("Fly-Request-Id"));

View file

@ -1,13 +1,12 @@
use std::time::{Duration, SystemTime};
use axum::http::{HeaderValue, StatusCode};
use axum::{
extract::State,
http::Request,
extract::{Request, State},
middleware::{FromFnLayer, Next},
response::Response,
};
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
use http::{HeaderValue, StatusCode};
use metrics::increment_counter;
use tracing::{debug, error, info, warn};
@ -55,10 +54,10 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
axum::middleware::from_fn_with_state(redis, f)
}
pub async fn do_request_ratelimited<B>(
pub async fn do_request_ratelimited(
State(redis): State<Option<RedisPool>>,
request: Request<B>,
next: Next<B>,
request: Request,
next: Next,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();

View file

@ -1,8 +1,9 @@
use crate::error::PKError;
use axum::{
body::Body,
http::{HeaderValue, Request, Response, StatusCode, Uri},
http::{HeaderValue, StatusCode},
response::IntoResponse,
};
use serde_json::{json, to_string, Value};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
@ -19,21 +20,41 @@ pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
}
}
pub async fn rproxy(req: Request<Body>) -> Response<Body> {
let uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
match hyper_reverse_proxy::call("0.0.0.0".parse().unwrap(), &uri[..uri.len() - 1], req).await {
Ok(response) => response,
Err(error) => {
error!("error proxying request: {:?}", error);
Response::builder()
.status(StatusCode::INTERNAL_SERVER_ERROR)
.body(Body::empty())
.unwrap()
}
pub fn wrapper<F>(handler: F) -> impl Fn() -> axum::response::Response
where
F: Fn() -> anyhow::Result<Value>,
{
move || match handler() {
Ok(v) => (StatusCode::OK, to_string(&v).unwrap()).into_response(),
Err(error) => match error.downcast_ref::<PKError>() {
Some(pkerror) => json_err(
pkerror.response_code,
to_string(&json!({ "message": pkerror.message, "code": pkerror.json_code }))
.unwrap(),
),
None => {
error!(
"error in handler {}: {:#?}",
std::any::type_name::<F>(),
error
);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
},
}
}
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!("caught panic from handler: {:#?}", err);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();