mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
feat(api): update rust deps, move /private/meta endpoint to rust-api
This commit is contained in:
parent
f14c421e23
commit
e415c6704f
20 changed files with 1835 additions and 244 deletions
1655
Cargo.lock
generated
1655
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -6,9 +6,16 @@ members = [
|
|||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1"
|
||||
axum = "0.7.5"
|
||||
fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] }
|
||||
lazy_static = "1.4.0"
|
||||
metrics = "0.20.1"
|
||||
serde = "1.0.152"
|
||||
serde_json = "1.0.117"
|
||||
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "chrono", "macros"] }
|
||||
tokio = { version = "1.25.0", features = ["full"] }
|
||||
tracing = "0.1.37"
|
||||
|
||||
prost = "0.12"
|
||||
prost-types = "0.12"
|
||||
prost-build = "0.12"
|
||||
|
|
|
|||
|
|
@ -6,12 +6,20 @@ edition = "2021"
|
|||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
config = "0.13.3"
|
||||
fred = { workspace = true }
|
||||
gethostname = "0.4.1"
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
metrics-exporter-prometheus = { version = "0.11.0", default-features = false, features = ["tokio", "http-listener", "tracing"] }
|
||||
serde = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-gelf = "0.7.1"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
|
||||
|
||||
prost = { workspace = true }
|
||||
prost-types = { workspace = true }
|
||||
|
||||
[build-dependencies]
|
||||
prost-build = { workspace = true }
|
||||
|
|
|
|||
8
lib/libpk/build.rs
Normal file
8
lib/libpk/build.rs
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
use std::io::Result;
|
||||
|
||||
fn main() -> Result<()> {
|
||||
prost_build::Config::new()
|
||||
.type_attribute(".ShardState", "#[derive(serde::Serialize)]")
|
||||
.compile_protos(&["../../proto/state.proto"], &["../../proto/"])?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -12,9 +12,11 @@ pub struct DiscordConfig {
|
|||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DatabaseConfig {
|
||||
pub(crate) _data_db_uri: String,
|
||||
pub(crate) _messages_db_uri: String,
|
||||
pub(crate) _db_password: Option<String>,
|
||||
pub(crate) data_db_uri: String,
|
||||
pub(crate) data_db_max_connections: Option<u32>,
|
||||
pub(crate) data_db_min_connections: Option<u32>,
|
||||
// pub(crate) _messages_db_uri: String,
|
||||
pub(crate) db_password: Option<String>,
|
||||
pub data_redis_addr: String,
|
||||
}
|
||||
|
||||
|
|
@ -42,6 +44,8 @@ fn _metrics_default() -> bool {
|
|||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PKConfig {
|
||||
pub db: DatabaseConfig,
|
||||
|
||||
pub discord: DiscordConfig,
|
||||
pub api: ApiConfig,
|
||||
|
||||
|
|
|
|||
42
lib/libpk/src/db/mod.rs
Normal file
42
lib/libpk/src/db/mod.rs
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
use fred::pool::RedisPool;
|
||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
||||
use std::str::FromStr;
|
||||
use tracing::info;
|
||||
|
||||
pub mod repository;
|
||||
|
||||
pub async fn init_redis() -> anyhow::Result<RedisPool> {
|
||||
info!("connecting to redis");
|
||||
let redis = fred::pool::RedisPool::new(
|
||||
fred::types::RedisConfig::from_url_centralized(crate::config.db.data_redis_addr.as_ref())
|
||||
.expect("redis url is invalid"),
|
||||
10,
|
||||
)?;
|
||||
|
||||
let redis_handle = redis.connect(Some(fred::types::ReconnectPolicy::default()));
|
||||
tokio::spawn(async move { redis_handle });
|
||||
|
||||
Ok(redis)
|
||||
}
|
||||
|
||||
pub async fn init_data_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(&crate::config.db.data_db_uri)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
let mut pool = PgPoolOptions::new();
|
||||
|
||||
if let Some(max_conns) = crate::config.db.data_db_max_connections {
|
||||
pool = pool.max_connections(max_conns);
|
||||
}
|
||||
|
||||
if let Some(min_conns) = crate::config.db.data_db_min_connections {
|
||||
pool = pool.min_connections(min_conns);
|
||||
}
|
||||
|
||||
Ok(pool.connect_with(options).await?)
|
||||
}
|
||||
2
lib/libpk/src/db/repository/mod.rs
Normal file
2
lib/libpk/src/db/repository/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
mod stats;
|
||||
pub use stats::*;
|
||||
13
lib/libpk/src/db/repository/stats.rs
Normal file
13
lib/libpk/src/db/repository/stats.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
pub async fn get_stats(pool: &sqlx::postgres::PgPool) -> anyhow::Result<Counts> {
|
||||
let counts: Counts = sqlx::query_as("select * from info").fetch_one(pool).await?;
|
||||
Ok(counts)
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, sqlx::FromRow)]
|
||||
pub struct Counts {
|
||||
pub system_count: i64,
|
||||
pub member_count: i64,
|
||||
pub group_count: i64,
|
||||
pub switch_count: i64,
|
||||
pub message_count: i64,
|
||||
}
|
||||
|
|
@ -2,7 +2,10 @@ use gethostname::gethostname;
|
|||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Registry};
|
||||
|
||||
mod _config;
|
||||
pub mod db;
|
||||
pub mod proto;
|
||||
|
||||
pub mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
pub fn init_logging(component: &str) -> anyhow::Result<()> {
|
||||
|
|
|
|||
1
lib/libpk/src/proto.rs
Normal file
1
lib/libpk/src/proto.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
include!(concat!(env!("OUT_DIR"), "/_.rs"));
|
||||
|
|
@ -5,13 +5,19 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = "0.6.4"
|
||||
axum = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
http = "0.2.8"
|
||||
hyper-reverse-proxy = "0.5.1"
|
||||
hyper = { version = "1.3.1", features = ["http1"] }
|
||||
hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"] }
|
||||
lazy_static = { workspace = true }
|
||||
libpk = { path = "../../lib/libpk" }
|
||||
metrics = { workspace = true }
|
||||
prost = { workspace = true }
|
||||
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.5.2", features = ["catch-panic"] }
|
||||
tracing = { workspace = true }
|
||||
|
|
|
|||
1
services/api/src/endpoints/mod.rs
Normal file
1
services/api/src/endpoints/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod private;
|
||||
53
services/api/src/endpoints/private.rs
Normal file
53
services/api/src/endpoints/private.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
use crate::ApiContext;
|
||||
use axum::{extract::State, response::Json};
|
||||
use fred::interfaces::*;
|
||||
use libpk::proto::ShardState;
|
||||
use prost::Message;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ClusterStats {
|
||||
pub guild_count: i32,
|
||||
pub channel_count: i32,
|
||||
}
|
||||
|
||||
pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
|
||||
let shard_status = ctx
|
||||
.redis
|
||||
.hgetall::<HashMap<String, Vec<u8>>, &str>("pluralkit:shardstatus")
|
||||
.await
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(|v| ShardState::decode(v.as_slice()).unwrap())
|
||||
.collect::<Vec<ShardState>>();
|
||||
|
||||
let cluster_stats = ctx
|
||||
.redis
|
||||
.hgetall::<HashMap<String, String>, &str>("pluralkit:cluster_stats")
|
||||
.await
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(|v| serde_json::from_str(v).unwrap())
|
||||
.collect::<Vec<ClusterStats>>();
|
||||
|
||||
let db_stats = libpk::db::repository::get_stats(&ctx.db).await.unwrap();
|
||||
|
||||
let guild_count: i32 = cluster_stats.iter().map(|v| v.guild_count).sum();
|
||||
let channel_count: i32 = cluster_stats.iter().map(|v| v.channel_count).sum();
|
||||
|
||||
Json(json!({
|
||||
"shards": shard_status,
|
||||
"stats": {
|
||||
"system_count": db_stats.system_count,
|
||||
"member_count": db_stats.member_count,
|
||||
"group_count": db_stats.group_count,
|
||||
"switch_count": db_stats.switch_count,
|
||||
"message_count": db_stats.message_count,
|
||||
"guild_count": guild_count,
|
||||
"channel_count": channel_count,
|
||||
}
|
||||
}))
|
||||
}
|
||||
29
services/api/src/error.rs
Normal file
29
services/api/src/error.rs
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
use axum::http::StatusCode;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PKError {
|
||||
pub response_code: StatusCode,
|
||||
pub json_code: i32,
|
||||
pub message: &'static str,
|
||||
}
|
||||
|
||||
impl fmt::Display for PKError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PKError {}
|
||||
|
||||
macro_rules! define_error {
|
||||
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
|
||||
const $name: PKError = PKError {
|
||||
response_code: $response_code,
|
||||
json_code: $json_code,
|
||||
message: $message,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
|
||||
|
|
@ -1,13 +1,58 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
http::{Response, StatusCode, Uri},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
};
|
||||
use tracing::info;
|
||||
use hyper_util::{
|
||||
client::legacy::{connect::HttpConnector, Client},
|
||||
rt::TokioExecutor,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod middleware;
|
||||
mod util;
|
||||
|
||||
// this function is manually formatted for easier legibility of routes
|
||||
#[derive(Clone)]
|
||||
pub struct ApiContext {
|
||||
pub db: sqlx::postgres::PgPool,
|
||||
pub redis: fred::pool::RedisPool,
|
||||
|
||||
rproxy_uri: String,
|
||||
rproxy_client: Client<HttpConnector, Body>,
|
||||
}
|
||||
|
||||
async fn rproxy(
|
||||
State(ctx): State<ApiContext>,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Result<Response<Body>, StatusCode> {
|
||||
let path = req.uri().path();
|
||||
let path_query = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or(path);
|
||||
|
||||
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
|
||||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
Ok(ctx
|
||||
.rproxy_client
|
||||
.request(req)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?
|
||||
.into_response())
|
||||
}
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
|
|
@ -15,75 +60,93 @@ async fn main() -> anyhow::Result<()> {
|
|||
libpk::init_metrics()?;
|
||||
info!("hello world");
|
||||
|
||||
let db = libpk::db::init_data_db().await?;
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let rproxy_uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
|
||||
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
|
||||
let ctx = ApiContext {
|
||||
db,
|
||||
redis,
|
||||
|
||||
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
|
||||
rproxy_client,
|
||||
};
|
||||
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
let app = Router::new()
|
||||
.route("/v2/systems/:system_id", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/settings", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/settings", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", get(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/members", get(util::rproxy))
|
||||
.route("/v2/members", post(util::rproxy))
|
||||
.route("/v2/members/:member_id", get(util::rproxy))
|
||||
.route("/v2/members/:member_id", patch(util::rproxy))
|
||||
.route("/v2/members/:member_id", delete(util::rproxy))
|
||||
.route("/v2/systems/:system_id/members", get(rproxy))
|
||||
.route("/v2/members", post(rproxy))
|
||||
.route("/v2/members/:member_id", get(rproxy))
|
||||
.route("/v2/members/:member_id", patch(rproxy))
|
||||
.route("/v2/members/:member_id", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/groups", get(util::rproxy))
|
||||
.route("/v2/groups", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id", get(util::rproxy))
|
||||
.route("/v2/groups/:group_id", patch(util::rproxy))
|
||||
.route("/v2/groups/:group_id", delete(util::rproxy))
|
||||
.route("/v2/systems/:system_id/groups", get(rproxy))
|
||||
.route("/v2/groups", post(rproxy))
|
||||
.route("/v2/groups/:group_id", get(rproxy))
|
||||
.route("/v2/groups/:group_id", patch(rproxy))
|
||||
.route("/v2/groups/:group_id", delete(rproxy))
|
||||
|
||||
.route("/v2/groups/:group_id/members", get(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members/add", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members/remove", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members/overwrite", post(util::rproxy))
|
||||
.route("/v2/groups/:group_id/members", get(rproxy))
|
||||
.route("/v2/groups/:group_id/members/add", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/remove", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/groups", get(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups/add", post(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups/remove", post(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups/overwrite", post(util::rproxy))
|
||||
.route("/v2/members/:member_id/groups", get(rproxy))
|
||||
.route("/v2/members/:member_id/groups/add", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/remove", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches", post(util::rproxy))
|
||||
.route("/v2/systems/:system_id/fronters", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches", post(rproxy))
|
||||
.route("/v2/systems/:system_id/fronters", get(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", delete(util::rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", get(util::rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", patch(util::rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/autoproxy", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/autoproxy", patch(util::rproxy))
|
||||
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
|
||||
|
||||
.route("/v2/messages/:message_id", get(util::rproxy))
|
||||
.route("/v2/messages/:message_id", get(rproxy))
|
||||
|
||||
.route("/private/meta", get(util::rproxy))
|
||||
.route("/private/bulk_privacy/member", post(util::rproxy))
|
||||
.route("/private/bulk_privacy/group", post(util::rproxy))
|
||||
.route("/private/discord/callback", post(util::rproxy))
|
||||
.route("/private/meta", get(endpoints::private::meta))
|
||||
.route("/private/bulk_privacy/member", post(rproxy))
|
||||
.route("/private/bulk_privacy/group", post(rproxy))
|
||||
.route("/private/discord/callback", post(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/oembed.json", get(util::rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(util::rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(util::rproxy))
|
||||
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
|
||||
|
||||
.layer(axum::middleware::from_fn(middleware::logger))
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::cors))
|
||||
|
||||
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
|
||||
|
||||
.with_state(ctx)
|
||||
|
||||
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }));
|
||||
|
||||
let addr: &str = libpk::config.api.addr.as_ref();
|
||||
axum::Server::bind(&addr.parse()?)
|
||||
.serve(app.into_make_service())
|
||||
.await?;
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
use axum::{
|
||||
http::{HeaderMap, HeaderValue, Method, Request, StatusCode},
|
||||
extract::Request,
|
||||
http::{HeaderMap, HeaderValue, Method, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
|
@ -14,7 +15,7 @@ fn add_cors_headers(headers: &mut HeaderMap) {
|
|||
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
|
||||
}
|
||||
|
||||
pub async fn cors<B>(request: Request<B>, next: Next<B>) -> Response {
|
||||
pub async fn cors(request: Request, next: Next) -> Response {
|
||||
let mut response = if request.method() == Method::OPTIONS {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
use axum::{
|
||||
extract::MatchedPath,
|
||||
http::{Request, StatusCode},
|
||||
extract::Request,
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
|
@ -16,7 +17,7 @@ fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
|
|||
|| path == "/v2/m"
|
||||
}
|
||||
|
||||
pub async fn ignore_invalid_routes<B>(request: Request<B>, next: Next<B>) -> Response {
|
||||
pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
|
||||
let path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, http::Request, middleware::Next, response::Response};
|
||||
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
|
||||
use metrics::histogram;
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
|
|
@ -10,7 +10,7 @@ use crate::util::header_or_unknown;
|
|||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub async fn logger<B>(request: Request<B>, next: Next<B>) -> Response {
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let request_id = header_or_unknown(request.headers().get("Fly-Request-Id"));
|
||||
|
|
|
|||
|
|
@ -1,13 +1,12 @@
|
|||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use axum::http::{HeaderValue, StatusCode};
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::Request,
|
||||
extract::{Request, State},
|
||||
middleware::{FromFnLayer, Next},
|
||||
response::Response,
|
||||
};
|
||||
use fred::{pool::RedisPool, prelude::LuaInterface, types::ReconnectPolicy, util::sha1_hash};
|
||||
use http::{HeaderValue, StatusCode};
|
||||
use metrics::increment_counter;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
|
|
@ -55,10 +54,10 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
|||
axum::middleware::from_fn_with_state(redis, f)
|
||||
}
|
||||
|
||||
pub async fn do_request_ratelimited<B>(
|
||||
pub async fn do_request_ratelimited(
|
||||
State(redis): State<Option<RedisPool>>,
|
||||
request: Request<B>,
|
||||
next: Next<B>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
use crate::error::PKError;
|
||||
use axum::{
|
||||
body::Body,
|
||||
http::{HeaderValue, Request, Response, StatusCode, Uri},
|
||||
http::{HeaderValue, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde_json::{json, to_string, Value};
|
||||
use tracing::error;
|
||||
|
||||
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
|
||||
|
|
@ -19,21 +20,41 @@ pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn rproxy(req: Request<Body>) -> Response<Body> {
|
||||
let uri = Uri::from_static(&libpk::config.api.remote_url).to_string();
|
||||
|
||||
match hyper_reverse_proxy::call("0.0.0.0".parse().unwrap(), &uri[..uri.len() - 1], req).await {
|
||||
Ok(response) => response,
|
||||
Err(error) => {
|
||||
error!("error proxying request: {:?}", error);
|
||||
Response::builder()
|
||||
.status(StatusCode::INTERNAL_SERVER_ERROR)
|
||||
.body(Body::empty())
|
||||
.unwrap()
|
||||
}
|
||||
pub fn wrapper<F>(handler: F) -> impl Fn() -> axum::response::Response
|
||||
where
|
||||
F: Fn() -> anyhow::Result<Value>,
|
||||
{
|
||||
move || match handler() {
|
||||
Ok(v) => (StatusCode::OK, to_string(&v).unwrap()).into_response(),
|
||||
Err(error) => match error.downcast_ref::<PKError>() {
|
||||
Some(pkerror) => json_err(
|
||||
pkerror.response_code,
|
||||
to_string(&json!({ "message": pkerror.message, "code": pkerror.json_code }))
|
||||
.unwrap(),
|
||||
),
|
||||
None => {
|
||||
error!(
|
||||
"error in handler {}: {:#?}",
|
||||
std::any::type_name::<F>(),
|
||||
error
|
||||
);
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
error!("caught panic from handler: {:#?}", err);
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||
let mut response = (code, text).into_response();
|
||||
let headers = response.headers_mut();
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue