mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
chore: reorganize rust crates
This commit is contained in:
parent
357122a892
commit
16ce67e02c
58 changed files with 6 additions and 13 deletions
28
crates/api/Cargo.toml
Normal file
28
crates/api/Cargo.toml
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
[package]
|
||||
name = "api"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
pluralkit_models = { path = "../models" }
|
||||
libpk = { path = "../libpk" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
|
||||
hyper = { version = "1.3.1", features = ["http1"] }
|
||||
hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"] }
|
||||
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
|
||||
serde_urlencoded = "0.7.1"
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.5.2", features = ["catch-panic"] }
|
||||
1
crates/api/src/endpoints/mod.rs
Normal file
1
crates/api/src/endpoints/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod private;
|
||||
203
crates/api/src/endpoints/private.rs
Normal file
203
crates/api/src/endpoints/private.rs
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
use crate::ApiContext;
|
||||
use axum::{extract::State, response::Json};
|
||||
use fred::interfaces::*;
|
||||
use libpk::state::ShardState;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ClusterStats {
|
||||
pub guild_count: i32,
|
||||
pub channel_count: i32,
|
||||
}
|
||||
|
||||
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
|
||||
let mut shard_status = ctx
|
||||
.redis
|
||||
.hgetall::<HashMap<String, String>, &str>("pluralkit:shardstatus")
|
||||
.await
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(|v| serde_json::from_str(v).expect("could not deserialize shard"))
|
||||
.collect::<Vec<ShardState>>();
|
||||
|
||||
shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id));
|
||||
|
||||
Json(json!({
|
||||
"shards": shard_status,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
|
||||
let cluster_stats = ctx
|
||||
.redis
|
||||
.hgetall::<HashMap<String, String>, &str>("pluralkit:cluster_stats")
|
||||
.await
|
||||
.unwrap()
|
||||
.values()
|
||||
.map(|v| serde_json::from_str(v).unwrap())
|
||||
.collect::<Vec<ClusterStats>>();
|
||||
|
||||
let db_stats = libpk::db::repository::get_stats(&ctx.db).await.unwrap();
|
||||
|
||||
let guild_count: i32 = cluster_stats.iter().map(|v| v.guild_count).sum();
|
||||
let channel_count: i32 = cluster_stats.iter().map(|v| v.channel_count).sum();
|
||||
|
||||
Json(json!({
|
||||
"system_count": db_stats.system_count,
|
||||
"member_count": db_stats.member_count,
|
||||
"group_count": db_stats.group_count,
|
||||
"switch_count": db_stats.switch_count,
|
||||
"message_count": db_stats.message_count,
|
||||
"guild_count": guild_count,
|
||||
"channel_count": channel_count,
|
||||
}))
|
||||
}
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::util::json_err;
|
||||
use axum::{
|
||||
extract,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use libpk::config;
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig};
|
||||
use reqwest::ClientBuilder;
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
pub struct CallbackRequestData {
|
||||
redirect_domain: String,
|
||||
code: String,
|
||||
// state: String,
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize)]
|
||||
struct CallbackDiscordData {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
grant_type: String,
|
||||
redirect_uri: String,
|
||||
code: String,
|
||||
}
|
||||
|
||||
pub async fn discord_callback(
|
||||
State(ctx): State<ApiContext>,
|
||||
extract::Json(request_data): extract::Json<CallbackRequestData>,
|
||||
) -> Response {
|
||||
let client = ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(3))
|
||||
.timeout(Duration::from_secs(3))
|
||||
.build()
|
||||
.expect("error making client");
|
||||
|
||||
let reqbody = serde_urlencoded::to_string(&CallbackDiscordData {
|
||||
client_id: config.discord.as_ref().unwrap().client_id.get().to_string(),
|
||||
client_secret: config.discord.as_ref().unwrap().client_secret.clone(),
|
||||
grant_type: "authorization_code".to_string(),
|
||||
redirect_uri: request_data.redirect_domain, // change this!
|
||||
code: request_data.code,
|
||||
})
|
||||
.expect("could not serialize");
|
||||
|
||||
let discord_resp = client
|
||||
.post("https://discord.com/api/v10/oauth2/token")
|
||||
.header("content-type", "application/x-www-form-urlencoded")
|
||||
.body(reqbody)
|
||||
.send()
|
||||
.await
|
||||
.expect("failed to request discord");
|
||||
|
||||
let Value::Object(discord_data) = discord_resp
|
||||
.json::<Value>()
|
||||
.await
|
||||
.expect("failed to deserialize discord response as json")
|
||||
else {
|
||||
panic!("discord response is not an object")
|
||||
};
|
||||
|
||||
if !discord_data.contains_key("access_token") {
|
||||
return json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!(
|
||||
"{{\"error\":\"{}\"\"}}",
|
||||
discord_data
|
||||
.get("error_description")
|
||||
.expect("missing error_description from discord")
|
||||
.to_string()
|
||||
),
|
||||
);
|
||||
};
|
||||
|
||||
let token = format!(
|
||||
"Bearer {}",
|
||||
discord_data
|
||||
.get("access_token")
|
||||
.expect("missing access_token")
|
||||
.as_str()
|
||||
.unwrap()
|
||||
);
|
||||
|
||||
let discord_client = twilight_http::Client::new(token);
|
||||
|
||||
let user = discord_client
|
||||
.current_user()
|
||||
.await
|
||||
.expect("failed to get current user from discord")
|
||||
.model()
|
||||
.await
|
||||
.expect("failed to parse user model from discord");
|
||||
|
||||
let system: Option<PKSystem> = sqlx::query_as(
|
||||
r#"
|
||||
select systems.*
|
||||
from accounts
|
||||
left join systems on accounts.system = systems.id
|
||||
where accounts.uid = $1
|
||||
"#,
|
||||
)
|
||||
.bind(user.id.get() as i64)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
.expect("failed to query");
|
||||
|
||||
if system.is_none() {
|
||||
return json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"user does not have a system registered".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let system = system.unwrap();
|
||||
|
||||
let system_config: Option<PKSystemConfig> = sqlx::query_as(
|
||||
r#"
|
||||
select * from system_config where system = $1
|
||||
"#,
|
||||
)
|
||||
.bind(system.id)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
.expect("failed to query");
|
||||
|
||||
let system_config = system_config.unwrap();
|
||||
|
||||
// create dashboard token for system
|
||||
|
||||
let token = system.clone().token;
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"system": system.to_json(),
|
||||
"config": system_config.to_json(),
|
||||
"user": user,
|
||||
"token": token,
|
||||
}))
|
||||
.expect("should not error"),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
29
crates/api/src/error.rs
Normal file
29
crates/api/src/error.rs
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
use axum::http::StatusCode;
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PKError {
|
||||
pub response_code: StatusCode,
|
||||
pub json_code: i32,
|
||||
pub message: &'static str,
|
||||
}
|
||||
|
||||
impl fmt::Display for PKError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for PKError {}
|
||||
|
||||
macro_rules! define_error {
|
||||
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
|
||||
const $name: PKError = PKError {
|
||||
response_code: $response_code,
|
||||
json_code: $json_code,
|
||||
message: $message,
|
||||
};
|
||||
};
|
||||
}
|
||||
|
||||
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
|
||||
169
crates/api/src/main.rs
Normal file
169
crates/api/src/main.rs
Normal file
|
|
@ -0,0 +1,169 @@
|
|||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
http::{Response, StatusCode, Uri},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
};
|
||||
use hyper_util::{
|
||||
client::legacy::{connect::HttpConnector, Client},
|
||||
rt::TokioExecutor,
|
||||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod middleware;
|
||||
mod util;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiContext {
|
||||
pub db: sqlx::postgres::PgPool,
|
||||
pub redis: fred::clients::RedisPool,
|
||||
|
||||
rproxy_uri: String,
|
||||
rproxy_client: Client<HttpConnector, Body>,
|
||||
}
|
||||
|
||||
async fn rproxy(
|
||||
State(ctx): State<ApiContext>,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Result<Response<Body>, StatusCode> {
|
||||
let path = req.uri().path();
|
||||
let path_query = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or(path);
|
||||
|
||||
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
|
||||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
Ok(ctx
|
||||
.rproxy_client
|
||||
.request(req)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?
|
||||
.into_response())
|
||||
}
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
fn router(ctx: ApiContext) -> Router {
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
Router::new()
|
||||
.route("/v2/systems/:system_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", get(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/members", get(rproxy))
|
||||
.route("/v2/members", post(rproxy))
|
||||
.route("/v2/members/:member_id", get(rproxy))
|
||||
.route("/v2/members/:member_id", patch(rproxy))
|
||||
.route("/v2/members/:member_id", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/groups", get(rproxy))
|
||||
.route("/v2/groups", post(rproxy))
|
||||
.route("/v2/groups/:group_id", get(rproxy))
|
||||
.route("/v2/groups/:group_id", patch(rproxy))
|
||||
.route("/v2/groups/:group_id", delete(rproxy))
|
||||
|
||||
.route("/v2/groups/:group_id/members", get(rproxy))
|
||||
.route("/v2/groups/:group_id/members/add", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/remove", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/groups", get(rproxy))
|
||||
.route("/v2/members/:member_id/groups/add", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/remove", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches", post(rproxy))
|
||||
.route("/v2/systems/:system_id/fronters", get(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
|
||||
|
||||
.route("/v2/messages/:message_id", get(rproxy))
|
||||
|
||||
.route("/private/bulk_privacy/member", post(rproxy))
|
||||
.route("/private/bulk_privacy/group", post(rproxy))
|
||||
.route("/private/discord/callback", post(rproxy))
|
||||
.route("/private/discord/callback2", post(endpoints::private::discord_callback))
|
||||
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
|
||||
.route("/private/stats", get(endpoints::private::meta))
|
||||
|
||||
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::cors))
|
||||
.layer(axum::middleware::from_fn(middleware::logger))
|
||||
|
||||
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
|
||||
|
||||
.with_state(ctx)
|
||||
|
||||
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
|
||||
}
|
||||
|
||||
libpk::main!("api");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_data_db().await?;
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let rproxy_uri = Uri::from_static(
|
||||
&libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.remote_url,
|
||||
)
|
||||
.to_string();
|
||||
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
|
||||
let ctx = ApiContext {
|
||||
db,
|
||||
redis,
|
||||
|
||||
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
|
||||
rproxy_client,
|
||||
};
|
||||
|
||||
let app = router(ctx);
|
||||
|
||||
let addr: &str = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.addr
|
||||
.as_ref();
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
45
crates/api/src/middleware/authnz.rs
Normal file
45
crates/api/src/middleware/authnz.rs
Normal file
|
|
@ -0,0 +1,45 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::HeaderValue,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::ApiContext;
|
||||
|
||||
use super::logger::DID_AUTHENTICATE_HEADER;
|
||||
|
||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||
let headers = request.headers_mut();
|
||||
headers.remove("x-pluralkit-systemid");
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten();
|
||||
let mut authenticated = false;
|
||||
if let Some(auth_header) = auth_header {
|
||||
if let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
None
|
||||
}
|
||||
}
|
||||
{
|
||||
headers.append(
|
||||
"x-pluralkit-systemid",
|
||||
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
|
||||
);
|
||||
authenticated = true;
|
||||
}
|
||||
}
|
||||
let mut response = next.run(request).await;
|
||||
if authenticated {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
|
||||
}
|
||||
response
|
||||
}
|
||||
28
crates/api/src/middleware/cors.rs
Normal file
28
crates/api/src/middleware/cors.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
use axum::{
|
||||
extract::Request,
|
||||
http::{HeaderMap, HeaderValue, Method, StatusCode},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
#[rustfmt::skip]
|
||||
fn add_cors_headers(headers: &mut HeaderMap) {
|
||||
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
|
||||
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
|
||||
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
|
||||
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
|
||||
}
|
||||
|
||||
pub async fn cors(request: Request, next: Next) -> Response {
|
||||
let mut response = if request.method() == Method::OPTIONS {
|
||||
StatusCode::OK.into_response()
|
||||
} else {
|
||||
next.run(request).await
|
||||
};
|
||||
|
||||
add_cors_headers(response.headers_mut());
|
||||
|
||||
response
|
||||
}
|
||||
64
crates/api/src/middleware/ignore_invalid_routes.rs
Normal file
64
crates/api/src/middleware/ignore_invalid_routes.rs
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
use axum::{
|
||||
extract::MatchedPath,
|
||||
extract::Request,
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
|
||||
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
|
||||
path.starts_with("/v2/s/")
|
||||
|| path.starts_with("/v2/m/")
|
||||
|| path.starts_with("/v2/a/")
|
||||
|| path.starts_with("/v2/msg/")
|
||||
|| path == "/v2/s"
|
||||
|| path == "/v2/m"
|
||||
}
|
||||
|
||||
pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
|
||||
let path = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
if request.uri().path().starts_with("/v1") {
|
||||
(
|
||||
StatusCode::GONE,
|
||||
r#"{"message":"Unsupported API version","code":0}"#,
|
||||
)
|
||||
.into_response()
|
||||
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
|
||||
(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"Invalid path for API version","code":0}"#,
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
|
||||
else if !request.uri().clone().path().starts_with("/v2")
|
||||
&& !request.uri().clone().path().starts_with("/private")
|
||||
{
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"Unsupported API version","code":0}"#,
|
||||
)
|
||||
.into_response();
|
||||
} else if path == "unknown" {
|
||||
// current prod api responds with 404 with empty body to invalid endpoints
|
||||
// just doing that here as well but i'm not sure if it's the correct behaviour
|
||||
return StatusCode::NOT_FOUND.into_response();
|
||||
}
|
||||
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
|
||||
// but "unknown" isn't really a valid user-agent
|
||||
else if user_agent == "unknown" {
|
||||
// please set a valid user-agent
|
||||
return StatusCode::BAD_REQUEST.into_response();
|
||||
} else {
|
||||
next.run(request).await
|
||||
}
|
||||
}
|
||||
98
crates/api/src/middleware/logger.rs
Normal file
98
crates/api/src/middleware/logger.rs
Normal file
|
|
@ -0,0 +1,98 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
|
||||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
remote_ip,
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
user_agent
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut response = next.run(request).instrument(request_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
let authenticated = {
|
||||
let headers = response.headers_mut();
|
||||
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
|
||||
headers.remove(DID_AUTHENTICATE_HEADER);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
counter!(
|
||||
"pluralkit_api_requests",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"pluralkit_api_requests_bucket",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
endpoint,
|
||||
elapsed
|
||||
);
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
counter!(
|
||||
"pluralkit_api_slow_requests_count",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
13
crates/api/src/middleware/mod.rs
Normal file
13
crates/api/src/middleware/mod.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
mod cors;
|
||||
pub use cors::cors;
|
||||
|
||||
mod logger;
|
||||
pub use logger::logger;
|
||||
|
||||
mod ignore_invalid_routes;
|
||||
pub use ignore_invalid_routes::ignore_invalid_routes;
|
||||
|
||||
pub mod ratelimit;
|
||||
|
||||
mod authnz;
|
||||
pub use authnz::authnz;
|
||||
61
crates/api/src/middleware/ratelimit.lua
Normal file
61
crates/api/src/middleware/ratelimit.lua
Normal file
|
|
@ -0,0 +1,61 @@
|
|||
-- this script has side-effects, so it requires replicate commands mode
|
||||
-- redis.replicate_commands()
|
||||
|
||||
local rate_limit_key = KEYS[1]
|
||||
local rate = ARGV[1]
|
||||
local period = ARGV[2]
|
||||
local cost = tonumber(ARGV[3])
|
||||
|
||||
local burst = rate
|
||||
|
||||
local emission_interval = period / rate
|
||||
local increment = emission_interval * cost
|
||||
local burst_offset = emission_interval * burst
|
||||
|
||||
-- redis returns time as an array containing two integers: seconds of the epoch
|
||||
-- time (10 digits) and microseconds (6 digits). for convenience we need to
|
||||
-- convert them to a floating point number. the resulting number is 16 digits,
|
||||
-- bordering on the limits of a 64-bit double-precision floating point number.
|
||||
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
|
||||
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
|
||||
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
|
||||
local jan_1_2017 = 1483228800
|
||||
local now = redis.call("TIME")
|
||||
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
|
||||
|
||||
local tat = redis.call("GET", rate_limit_key)
|
||||
|
||||
if not tat then
|
||||
tat = now
|
||||
else
|
||||
tat = tonumber(tat)
|
||||
end
|
||||
|
||||
tat = math.max(tat, now)
|
||||
|
||||
local new_tat = tat + increment
|
||||
local allow_at = new_tat - burst_offset
|
||||
|
||||
local diff = now - allow_at
|
||||
local remaining = diff / emission_interval
|
||||
|
||||
if remaining < 0 then
|
||||
local reset_after = tat - now
|
||||
local retry_after = diff * -1
|
||||
return {
|
||||
0, -- remaining
|
||||
tostring(retry_after),
|
||||
reset_after,
|
||||
}
|
||||
end
|
||||
|
||||
local reset_after = new_tat - now
|
||||
if reset_after > 0 then
|
||||
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
|
||||
end
|
||||
local retry_after = -1
|
||||
return {
|
||||
remaining,
|
||||
tostring(retry_after),
|
||||
reset_after
|
||||
}
|
||||
238
crates/api/src/middleware/ratelimit.rs
Normal file
238
crates/api/src/middleware/ratelimit.rs
Normal file
|
|
@ -0,0 +1,238 @@
|
|||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use axum::{
|
||||
extract::{MatchedPath, Request, State},
|
||||
http::{HeaderValue, Method, StatusCode},
|
||||
middleware::{FromFnLayer, Next},
|
||||
response::Response,
|
||||
};
|
||||
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
|
||||
use metrics::counter;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
|
||||
}
|
||||
|
||||
// this is awful but it works
|
||||
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
||||
let redis = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.ratelimit_redis_addr
|
||||
.as_ref()
|
||||
.map(|val| {
|
||||
// todo: this should probably use the global pool
|
||||
let r = RedisPool::new(
|
||||
fred::types::RedisConfig::from_url_centralized(val.as_ref())
|
||||
.expect("redis url is invalid"),
|
||||
None,
|
||||
None,
|
||||
Some(Default::default()),
|
||||
10,
|
||||
)
|
||||
.expect("failed to connect to redis");
|
||||
|
||||
let handle = r.connect();
|
||||
|
||||
tokio::spawn(async move { handle });
|
||||
|
||||
let rscript = r.clone();
|
||||
tokio::spawn(async move {
|
||||
if let Ok(()) = rscript.wait_for_connect().await {
|
||||
match rscript
|
||||
.script_load::<String, String>(LUA_SCRIPT.to_string())
|
||||
.await
|
||||
{
|
||||
Ok(_) => info!("connected to redis for request rate limiting"),
|
||||
Err(err) => error!("could not load redis script: {}", err),
|
||||
}
|
||||
} else {
|
||||
error!("could not wait for connection to load redis script!");
|
||||
}
|
||||
});
|
||||
|
||||
r
|
||||
});
|
||||
|
||||
if redis.is_none() {
|
||||
warn!("running without request rate limiting!");
|
||||
}
|
||||
|
||||
axum::middleware::from_fn_with_state(redis, f)
|
||||
}
|
||||
|
||||
enum RatelimitType {
|
||||
GenericGet,
|
||||
GenericUpdate,
|
||||
Message,
|
||||
TempCustom,
|
||||
}
|
||||
|
||||
impl RatelimitType {
|
||||
fn key(&self) -> String {
|
||||
match self {
|
||||
RatelimitType::GenericGet => "generic_get",
|
||||
RatelimitType::GenericUpdate => "generic_update",
|
||||
RatelimitType::Message => "message",
|
||||
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
||||
fn rate(&self) -> i32 {
|
||||
match self {
|
||||
RatelimitType::GenericGet => 10,
|
||||
RatelimitType::GenericUpdate => 3,
|
||||
RatelimitType::Message => 10,
|
||||
RatelimitType::TempCustom => 20,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_request_ratelimited(
|
||||
State(redis): State<Option<RedisPool>>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
|
||||
|
||||
// https://github.com/rust-lang/rust/issues/53667
|
||||
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
||||
{
|
||||
if let Some(token2) = &libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
{
|
||||
if header.to_str().unwrap_or("invalid") == token2 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let rlimit = if is_temp_token2 {
|
||||
RatelimitType::TempCustom
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
} else if request.method() == Method::GET {
|
||||
RatelimitType::GenericGet
|
||||
} else {
|
||||
RatelimitType::GenericUpdate
|
||||
};
|
||||
|
||||
let rl_key = format!(
|
||||
"{}:{}",
|
||||
if authenticated_system_id != "unknown"
|
||||
&& matches!(rlimit, RatelimitType::GenericUpdate)
|
||||
{
|
||||
authenticated_system_id
|
||||
} else {
|
||||
source_ip
|
||||
},
|
||||
rlimit.key()
|
||||
);
|
||||
|
||||
let period = 1; // seconds
|
||||
let cost = 1; // todo: update this for group member endpoints
|
||||
|
||||
// local rate_limit_key = KEYS[1]
|
||||
// local rate = ARGV[1]
|
||||
// local period = ARGV[2]
|
||||
// return {remaining, tostring(retry_after), reset_after}
|
||||
|
||||
// todo: check if error is script not found and reload script
|
||||
let resp = redis
|
||||
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
|
||||
LUA_SCRIPT_SHA.to_string(),
|
||||
vec![rl_key.clone()],
|
||||
vec![rlimit.rate(), period, cost],
|
||||
)
|
||||
.await;
|
||||
|
||||
match resp {
|
||||
Ok((remaining, retry_after, reset_after)) => {
|
||||
// redis's lua doesn't support returning floats
|
||||
let retry_after: f64 = retry_after
|
||||
.parse()
|
||||
.expect("got something that isn't a f64 from redis");
|
||||
|
||||
let mut response = if remaining > 0 {
|
||||
next.run(request).await
|
||||
} else {
|
||||
let retry_after = (retry_after * 1_000_f64).ceil() as u64;
|
||||
debug!("ratelimited request from {rl_key}, retry_after={retry_after}",);
|
||||
counter!("pk_http_requests_ratelimited").increment(1);
|
||||
json_err(
|
||||
StatusCode::TOO_MANY_REQUESTS,
|
||||
format!(
|
||||
r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#,
|
||||
rlimit.key(),
|
||||
),
|
||||
)
|
||||
};
|
||||
|
||||
let reset_time = SystemTime::now()
|
||||
.checked_add(Duration::from_secs(reset_after))
|
||||
.expect("invalid timestamp")
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.expect("invalid duration")
|
||||
.as_secs();
|
||||
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
"X-RateLimit-Scope",
|
||||
HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Limit",
|
||||
HeaderValue::from_str(format!("{}", rlimit.rate()).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Remaining",
|
||||
HeaderValue::from_str(format!("{}", remaining).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
headers.insert(
|
||||
"X-RateLimit-Reset",
|
||||
HeaderValue::from_str(format!("{}", reset_time).as_str())
|
||||
.expect("invalid header value"),
|
||||
);
|
||||
|
||||
return response;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("error getting ratelimit info: {}", err);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(request).await
|
||||
}
|
||||
63
crates/api/src/util.rs
Normal file
63
crates/api/src/util.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
use crate::error::PKError;
|
||||
use axum::{
|
||||
http::{HeaderValue, StatusCode},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use serde_json::{json, to_string, Value};
|
||||
use tracing::error;
|
||||
|
||||
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
|
||||
if let Some(value) = header {
|
||||
match value.to_str() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
error!("failed to parse header value {:#?}: {:#?}", value, err);
|
||||
"failed to parse"
|
||||
}
|
||||
}
|
||||
} else {
|
||||
"unknown"
|
||||
}
|
||||
}
|
||||
|
||||
pub fn wrapper<F>(handler: F) -> impl Fn() -> axum::response::Response
|
||||
where
|
||||
F: Fn() -> anyhow::Result<Value>,
|
||||
{
|
||||
move || match handler() {
|
||||
Ok(v) => (StatusCode::OK, to_string(&v).unwrap()).into_response(),
|
||||
Err(error) => match error.downcast_ref::<PKError>() {
|
||||
Some(pkerror) => json_err(
|
||||
pkerror.response_code,
|
||||
to_string(&json!({ "message": pkerror.message, "code": pkerror.json_code }))
|
||||
.unwrap(),
|
||||
),
|
||||
None => {
|
||||
error!(
|
||||
"error in handler {}: {:#?}",
|
||||
std::any::type_name::<F>(),
|
||||
error
|
||||
);
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
)
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
error!("caught panic from handler: {:#?}", err);
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||
let mut response = (code, text).into_response();
|
||||
let headers = response.headers_mut();
|
||||
headers.insert("content-type", HeaderValue::from_static("application/json"));
|
||||
response
|
||||
}
|
||||
30
crates/avatars/Cargo.toml
Normal file
30
crates/avatars/Cargo.toml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
[package]
|
||||
name = "avatars"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[[bin]]
|
||||
name = "avatar_cleanup"
|
||||
path = "src/cleanup.rs"
|
||||
|
||||
[dependencies]
|
||||
libpk = { path = "../libpk" }
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
data-encoding = "2.5.0"
|
||||
gif = "0.13.1"
|
||||
image = { version = "0.24.8", default-features = false, features = ["gif", "jpeg", "png", "webp", "tiff"] }
|
||||
form_urlencoded = "1.2.1"
|
||||
rust-s3 = { version = "0.33.0", default-features = false, features = ["tokio-rustls-tls"] }
|
||||
sha2 = "0.10.8"
|
||||
thiserror = "1.0.56"
|
||||
webp = "0.2.6"
|
||||
146
crates/avatars/src/cleanup.rs
Normal file
146
crates/avatars/src/cleanup.rs
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
use anyhow::Context;
|
||||
use reqwest::{ClientBuilder, StatusCode};
|
||||
use sqlx::prelude::FromRow;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tracing::{error, info};
|
||||
|
||||
libpk::main!("avatar_cleanup");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
.expect("missing avatar service config");
|
||||
|
||||
let bucket = {
|
||||
let region = s3::Region::Custom {
|
||||
region: "s3".to_string(),
|
||||
endpoint: config.s3.endpoint.to_string(),
|
||||
};
|
||||
|
||||
let credentials = s3::creds::Credentials::new(
|
||||
Some(&config.s3.application_id),
|
||||
Some(&config.s3.application_key),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let bucket = s3::Bucket::new(&config.s3.bucket, region, credentials)?;
|
||||
|
||||
Arc::new(bucket)
|
||||
};
|
||||
|
||||
let pool = libpk::db::init_data_db().await?;
|
||||
|
||||
loop {
|
||||
// no infinite loops
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
match cleanup_job(pool.clone(), bucket.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
error!("failed to run avatar cleanup job: {}", err);
|
||||
// sentry
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromRow)]
|
||||
struct CleanupJobEntry {
|
||||
id: String,
|
||||
}
|
||||
|
||||
async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Result<()> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
let image_id: Option<CleanupJobEntry> = sqlx::query_as(
|
||||
r#"
|
||||
select id from image_cleanup_jobs
|
||||
where ts < now() - interval '1 day'
|
||||
for update skip locked limit 1;"#,
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
if image_id.is_none() {
|
||||
info!("no job to run, sleeping for 1 minute");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
|
||||
return Ok(());
|
||||
}
|
||||
let image_id = image_id.unwrap().id;
|
||||
info!("got image {image_id}, cleaning up...");
|
||||
|
||||
let image_data = libpk::db::repository::avatars::get_by_id(&pool, image_id.clone()).await?;
|
||||
if image_data.is_none() {
|
||||
info!("image {image_id} was already deleted, skipping");
|
||||
sqlx::query("delete from image_cleanup_jobs where id = $1")
|
||||
.bind(image_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
let image_data = image_data.unwrap();
|
||||
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
.expect("missing avatar service config");
|
||||
|
||||
let path = image_data
|
||||
.url
|
||||
.strip_prefix(config.cdn_url.as_str())
|
||||
.unwrap();
|
||||
|
||||
let s3_resp = bucket.delete_object(path).await?;
|
||||
match s3_resp.status_code() {
|
||||
204 => {
|
||||
info!("successfully deleted image {image_id} from s3");
|
||||
}
|
||||
_ => {
|
||||
anyhow::bail!("s3 returned bad error code {}", s3_resp.status_code());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(zone_id) = config.cloudflare_zone_id.as_ref() {
|
||||
let client = ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(3))
|
||||
.timeout(Duration::from_secs(3))
|
||||
.build()
|
||||
.context("error making client")?;
|
||||
|
||||
let cf_resp = client
|
||||
.post(format!(
|
||||
"https://api.cloudflare.com/client/v4/zones/{zone_id}/purge_cache"
|
||||
))
|
||||
.header(
|
||||
"Authorization",
|
||||
format!("Bearer {}", config.cloudflare_token.as_ref().unwrap()),
|
||||
)
|
||||
.body(format!(r#"{{"files":["{}"]}}"#, image_data.url))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
match cf_resp.status() {
|
||||
StatusCode::OK => {
|
||||
info!(
|
||||
"successfully purged url {} from cloudflare cache",
|
||||
image_data.url
|
||||
);
|
||||
}
|
||||
_ => {
|
||||
let status = cf_resp.status();
|
||||
tracing::info!("raw response from cloudflare: {:#?}", cf_resp.text().await?);
|
||||
anyhow::bail!("cloudflare returned bad error code {}", status);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sqlx::query("delete from images where id = $1")
|
||||
.bind(image_id.clone())
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
21
crates/avatars/src/hash.rs
Normal file
21
crates/avatars/src/hash.rs
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
use std::fmt::Display;
|
||||
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Hash([u8; 32]);
|
||||
|
||||
impl Hash {
|
||||
pub fn sha256(data: &[u8]) -> Hash {
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(data);
|
||||
Hash(hasher.finalize().into())
|
||||
}
|
||||
}
|
||||
|
||||
impl Display for Hash {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let encoding = data_encoding::BASE32_NOPAD;
|
||||
write!(f, "{}", encoding.encode(&self.0[..16]).to_lowercase())
|
||||
}
|
||||
}
|
||||
26
crates/avatars/src/init.sql
Normal file
26
crates/avatars/src/init.sql
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
create table if not exists images
|
||||
(
|
||||
id text primary key,
|
||||
url text not null,
|
||||
original_url text,
|
||||
original_file_size int,
|
||||
original_type text,
|
||||
original_attachment_id bigint,
|
||||
file_size int not null,
|
||||
width int not null,
|
||||
height int not null,
|
||||
kind text not null,
|
||||
uploaded_at timestamptz not null,
|
||||
uploaded_by_account bigint
|
||||
);
|
||||
|
||||
create index if not exists images_original_url_idx on images (original_url);
|
||||
create index if not exists images_original_attachment_id_idx on images (original_attachment_id);
|
||||
create index if not exists images_uploaded_by_account_idx on images (uploaded_by_account);
|
||||
|
||||
create table if not exists image_queue (itemid serial primary key, url text not null, kind text not null);
|
||||
|
||||
alter table images add column if not exists uploaded_by_system uuid;
|
||||
alter table images add column if not exists content_type text default 'image/webp';
|
||||
|
||||
create table image_cleanup_jobs(id text references images(id) on delete cascade);
|
||||
257
crates/avatars/src/main.rs
Normal file
257
crates/avatars/src/main.rs
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
mod hash;
|
||||
mod migrate;
|
||||
mod process;
|
||||
mod pull;
|
||||
mod store;
|
||||
|
||||
use anyhow::Context;
|
||||
use axum::extract::State;
|
||||
use axum::routing::get;
|
||||
use axum::{
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::post,
|
||||
Json, Router,
|
||||
};
|
||||
use libpk::_config::AvatarsConfig;
|
||||
use libpk::db::repository::avatars as db;
|
||||
use libpk::db::types::avatars::*;
|
||||
use reqwest::{Client, ClientBuilder};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum PKAvatarError {
|
||||
// todo: split off into logical groups (cdn/url error, image format error, etc)
|
||||
#[error("invalid cdn url")]
|
||||
InvalidCdnUrl,
|
||||
|
||||
#[error("discord cdn responded with status code: {0}")]
|
||||
BadCdnResponse(reqwest::StatusCode),
|
||||
|
||||
#[error("network error: {0}")]
|
||||
NetworkError(reqwest::Error),
|
||||
|
||||
#[error("response is missing header: {0}")]
|
||||
MissingHeader(&'static str),
|
||||
|
||||
#[error("unsupported content type: {0}")]
|
||||
UnsupportedContentType(String),
|
||||
|
||||
#[error("image file size too large ({0} > {1})")]
|
||||
ImageFileSizeTooLarge(u64, u64),
|
||||
|
||||
#[error("unsupported image format: {0:?}")]
|
||||
UnsupportedImageFormat(image::ImageFormat),
|
||||
|
||||
#[error("could not detect image format")]
|
||||
UnknownImageFormat,
|
||||
|
||||
#[error("original image dimensions too large: {0:?} > {1:?}")]
|
||||
ImageDimensionsTooLarge((u32, u32), (u32, u32)),
|
||||
|
||||
#[error("could not decode image, is it corrupted?")]
|
||||
ImageFormatError(#[from] image::ImageError),
|
||||
|
||||
#[error("unknown error")]
|
||||
InternalError(#[from] anyhow::Error),
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PullRequest {
|
||||
url: String,
|
||||
kind: ImageKind,
|
||||
uploaded_by: Option<u64>, // should be String? serde makes this hard :/
|
||||
system_id: Option<Uuid>,
|
||||
|
||||
#[serde(default)]
|
||||
force: bool,
|
||||
}
|
||||
|
||||
#[derive(Serialize)]
|
||||
pub struct PullResponse {
|
||||
url: String,
|
||||
new: bool,
|
||||
}
|
||||
|
||||
async fn pull(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<PullRequest>,
|
||||
) -> Result<Json<PullResponse>, PKAvatarError> {
|
||||
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
||||
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||
|
||||
if !req.force {
|
||||
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||
// remove any pending image cleanup
|
||||
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
|
||||
return Ok(Json(PullResponse {
|
||||
url: existing.url,
|
||||
new: false,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
let result = crate::pull::pull(state.pull_client, &parsed).await?;
|
||||
|
||||
let original_file_size = result.data.len();
|
||||
let encoded = process::process_async(result.data, req.kind).await?;
|
||||
|
||||
let store_res = crate::store::store(&state.bucket, &encoded).await?;
|
||||
let final_url = format!("{}{}", state.config.cdn_url, store_res.path);
|
||||
let is_new = db::add_image(
|
||||
&state.pool,
|
||||
ImageMeta {
|
||||
id: store_res.id,
|
||||
url: final_url.clone(),
|
||||
content_type: encoded.format.mime_type().to_string(),
|
||||
original_url: Some(parsed.full_url),
|
||||
original_type: Some(result.content_type),
|
||||
original_file_size: Some(original_file_size as i32),
|
||||
original_attachment_id: Some(parsed.attachment_id as i64),
|
||||
file_size: encoded.data.len() as i32,
|
||||
width: encoded.width as i32,
|
||||
height: encoded.height as i32,
|
||||
kind: req.kind,
|
||||
uploaded_at: None,
|
||||
uploaded_by_account: req.uploaded_by.map(|x| x as i64),
|
||||
uploaded_by_system: req.system_id,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Json(PullResponse {
|
||||
url: final_url,
|
||||
new: is_new,
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn stats(State(state): State<AppState>) -> Result<Json<Stats>, PKAvatarError> {
|
||||
Ok(Json(db::get_stats(&state.pool).await?))
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
bucket: Arc<s3::Bucket>,
|
||||
pull_client: Arc<Client>,
|
||||
pool: PgPool,
|
||||
config: Arc<AvatarsConfig>,
|
||||
}
|
||||
|
||||
libpk::main!("avatars");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
.expect("missing avatar service config");
|
||||
|
||||
let bucket = {
|
||||
let region = s3::Region::Custom {
|
||||
region: "s3".to_string(),
|
||||
endpoint: config.s3.endpoint.to_string(),
|
||||
};
|
||||
|
||||
let credentials = s3::creds::Credentials::new(
|
||||
Some(&config.s3.application_id),
|
||||
Some(&config.s3.application_key),
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let bucket = s3::Bucket::new(&config.s3.bucket, region, credentials)?;
|
||||
|
||||
Arc::new(bucket)
|
||||
};
|
||||
|
||||
let pull_client = Arc::new(
|
||||
ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(3))
|
||||
.timeout(Duration::from_secs(3))
|
||||
.user_agent("PluralKit-Avatars/0.1")
|
||||
.build()
|
||||
.context("error making client")?,
|
||||
);
|
||||
|
||||
let pool = libpk::db::init_data_db().await?;
|
||||
|
||||
let state = AppState {
|
||||
bucket,
|
||||
pull_client,
|
||||
pool,
|
||||
config: Arc::new(config.clone()),
|
||||
};
|
||||
|
||||
// migrations are done, disable this
|
||||
// migrate::spawn_migrate_workers(Arc::new(state.clone()), state.config.migrate_worker_count);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/pull", post(pull))
|
||||
.route("/stats", get(stats))
|
||||
.with_state(state);
|
||||
|
||||
let host = &config.bind_addr;
|
||||
info!("starting server on {}!", host);
|
||||
let listener = tokio::net::TcpListener::bind(host).await.unwrap();
|
||||
axum::serve(listener, app).await.unwrap();
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct AppError(anyhow::Error);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
error!("error handling request: {}", self.0);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: self.0.to_string(),
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for PKAvatarError {
|
||||
fn into_response(self) -> Response {
|
||||
let status_code = match self {
|
||||
PKAvatarError::InternalError(_) | PKAvatarError::NetworkError(_) => {
|
||||
StatusCode::INTERNAL_SERVER_ERROR
|
||||
}
|
||||
_ => StatusCode::BAD_REQUEST,
|
||||
};
|
||||
|
||||
// print inner error if otherwise hidden
|
||||
error!("error: {}", self.source().unwrap_or(&self));
|
||||
|
||||
(
|
||||
status_code,
|
||||
Json(ErrorResponse {
|
||||
error: self.to_string(),
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> From<E> for AppError
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
146
crates/avatars/src/migrate.rs
Normal file
146
crates/avatars/src/migrate.rs
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
use crate::pull::parse_url;
|
||||
use crate::{db, process, AppState, PKAvatarError};
|
||||
use libpk::db::types::avatars::{ImageMeta, ImageQueueEntry};
|
||||
use reqwest::StatusCode;
|
||||
use std::error::Error;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use time::Instant;
|
||||
use tokio::sync::Semaphore;
|
||||
use tracing::{error, info, instrument, warn};
|
||||
|
||||
static PROCESS_SEMAPHORE: Semaphore = Semaphore::const_new(100);
|
||||
|
||||
pub async fn handle_item_inner(
|
||||
state: &AppState,
|
||||
item: &ImageQueueEntry,
|
||||
) -> Result<(), PKAvatarError> {
|
||||
let parsed = parse_url(&item.url).map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||
|
||||
if let Some(_) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||
info!(
|
||||
"attachment {} already migrated, skipping",
|
||||
parsed.attachment_id
|
||||
);
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let pulled = crate::pull::pull(state.pull_client.clone(), &parsed).await?;
|
||||
let data_len = pulled.data.len();
|
||||
|
||||
let encoded = {
|
||||
// Trying to reduce CPU load/potentially blocking the worker by adding a bottleneck on parallel encodes
|
||||
// no semaphore on the main api though, that one should ideally be low latency
|
||||
// todo: configurable?
|
||||
let time_before_semaphore = Instant::now();
|
||||
let permit = PROCESS_SEMAPHORE
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| PKAvatarError::InternalError(e.into()))?;
|
||||
let time_after_semaphore = Instant::now();
|
||||
let semaphore_time = time_after_semaphore - time_before_semaphore;
|
||||
if semaphore_time.whole_milliseconds() > 100 {
|
||||
warn!(
|
||||
"waited more than {} ms for process semaphore",
|
||||
semaphore_time.whole_milliseconds()
|
||||
);
|
||||
}
|
||||
|
||||
let encoded = process::process_async(pulled.data, item.kind).await?;
|
||||
drop(permit);
|
||||
encoded
|
||||
};
|
||||
let store_res = crate::store::store(&state.bucket, &encoded).await?;
|
||||
let final_url = format!("{}{}", state.config.cdn_url, store_res.path);
|
||||
|
||||
db::add_image(
|
||||
&state.pool,
|
||||
ImageMeta {
|
||||
id: store_res.id,
|
||||
url: final_url.clone(),
|
||||
content_type: encoded.format.mime_type().to_string(),
|
||||
original_url: Some(parsed.full_url),
|
||||
original_type: Some(pulled.content_type),
|
||||
original_file_size: Some(data_len as i32),
|
||||
original_attachment_id: Some(parsed.attachment_id as i64),
|
||||
file_size: encoded.data.len() as i32,
|
||||
width: encoded.width as i32,
|
||||
height: encoded.height as i32,
|
||||
kind: item.kind,
|
||||
uploaded_at: None,
|
||||
uploaded_by_account: None,
|
||||
uploaded_by_system: None,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
info!(
|
||||
"migrated {} ({}k -> {}k)",
|
||||
final_url,
|
||||
data_len,
|
||||
encoded.data.len()
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn handle_item(state: &AppState) -> Result<(), PKAvatarError> {
|
||||
// let queue_length = db::get_queue_length(&state.pool).await?;
|
||||
// info!("migrate queue length: {}", queue_length);
|
||||
|
||||
if let Some((mut tx, item)) = db::pop_queue(&state.pool).await? {
|
||||
match handle_item_inner(state, &item).await {
|
||||
Ok(_) => {
|
||||
tx.commit().await.map_err(Into::<anyhow::Error>::into)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(
|
||||
// Errors that mean the image can't be migrated and doesn't need to be retried
|
||||
e @ (PKAvatarError::ImageDimensionsTooLarge(_, _)
|
||||
| PKAvatarError::UnknownImageFormat
|
||||
| PKAvatarError::UnsupportedImageFormat(_)
|
||||
| PKAvatarError::UnsupportedContentType(_)
|
||||
| PKAvatarError::ImageFileSizeTooLarge(_, _)
|
||||
| PKAvatarError::InvalidCdnUrl
|
||||
| PKAvatarError::BadCdnResponse(StatusCode::NOT_FOUND | StatusCode::FORBIDDEN)),
|
||||
) => {
|
||||
warn!("error migrating {}, skipping: {}", item.url, e);
|
||||
tx.commit().await.map_err(Into::<anyhow::Error>::into)?;
|
||||
Ok(())
|
||||
}
|
||||
Err(e @ PKAvatarError::ImageFormatError(_)) => {
|
||||
// will add this item back to the end of the queue
|
||||
db::push_queue(&mut *tx, &item.url, item.kind).await?;
|
||||
tx.commit().await.map_err(Into::<anyhow::Error>::into)?;
|
||||
Err(e)
|
||||
}
|
||||
Err(e) => Err(e),
|
||||
}
|
||||
} else {
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[instrument(skip(state))]
|
||||
pub async fn worker(worker_id: u32, state: Arc<AppState>) {
|
||||
info!("spawned migrate worker with id {}", worker_id);
|
||||
loop {
|
||||
match handle_item(&state).await {
|
||||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"error in migrate worker {}: {}",
|
||||
worker_id,
|
||||
e.source().unwrap_or(&e)
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn spawn_migrate_workers(state: Arc<AppState>, count: u32) {
|
||||
for i in 0..count {
|
||||
tokio::spawn(worker(i, state.clone()));
|
||||
}
|
||||
}
|
||||
257
crates/avatars/src/process.rs
Normal file
257
crates/avatars/src/process.rs
Normal file
|
|
@ -0,0 +1,257 @@
|
|||
use image::{DynamicImage, ImageFormat};
|
||||
use std::borrow::Cow;
|
||||
use std::io::Cursor;
|
||||
use time::Instant;
|
||||
use tracing::{debug, error, info, instrument};
|
||||
|
||||
use crate::{hash::Hash, ImageKind, PKAvatarError};
|
||||
|
||||
const MAX_DIMENSION: u32 = 4000;
|
||||
|
||||
pub struct ProcessOutput {
|
||||
pub width: u32,
|
||||
pub height: u32,
|
||||
pub hash: Hash,
|
||||
pub format: ProcessedFormat,
|
||||
pub data: Vec<u8>,
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub enum ProcessedFormat {
|
||||
Webp,
|
||||
Gif,
|
||||
}
|
||||
|
||||
impl ProcessedFormat {
|
||||
pub fn mime_type(&self) -> &'static str {
|
||||
match self {
|
||||
ProcessedFormat::Gif => "image/gif",
|
||||
ProcessedFormat::Webp => "image/webp",
|
||||
}
|
||||
}
|
||||
|
||||
pub fn extension(&self) -> &'static str {
|
||||
match self {
|
||||
ProcessedFormat::Webp => "webp",
|
||||
ProcessedFormat::Gif => "gif",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Moving Vec<u8> in here since the thread needs ownership of it now, it's fine, don't need it after
|
||||
pub async fn process_async(data: Vec<u8>, kind: ImageKind) -> Result<ProcessOutput, PKAvatarError> {
|
||||
tokio::task::spawn_blocking(move || process(&data, kind))
|
||||
.await
|
||||
.map_err(|je| PKAvatarError::InternalError(je.into()))?
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub fn process(data: &[u8], kind: ImageKind) -> Result<ProcessOutput, PKAvatarError> {
|
||||
let time_before = Instant::now();
|
||||
let reader = reader_for(data);
|
||||
match reader.format() {
|
||||
Some(ImageFormat::Png | ImageFormat::WebP | ImageFormat::Jpeg | ImageFormat::Tiff) => {} // ok :)
|
||||
Some(ImageFormat::Gif) => {
|
||||
// animated gifs will need to be handled totally differently
|
||||
// so split off processing here and come back if it's not applicable
|
||||
// (non-banner gifs + 1-frame animated gifs still need to be webp'd)
|
||||
if let Some(output) = process_gif(data, kind)? {
|
||||
return Ok(output);
|
||||
}
|
||||
}
|
||||
Some(other) => return Err(PKAvatarError::UnsupportedImageFormat(other)),
|
||||
None => return Err(PKAvatarError::UnknownImageFormat),
|
||||
}
|
||||
|
||||
// want to check dimensions *before* decoding so we don't accidentally end up with a memory bomb
|
||||
// eg. a 16000x16000 png file is only 31kb and expands to almost a gig of memory
|
||||
let (width, height) = assert_dimensions(reader.into_dimensions()?)?;
|
||||
|
||||
// need to make a new reader??? why can't it just use the same one. reduce duplication?
|
||||
let reader = reader_for(data);
|
||||
|
||||
let time_after_parse = Instant::now();
|
||||
|
||||
// apparently `image` sometimes decodes webp images wrong/weird.
|
||||
// see: https://discord.com/channels/466707357099884544/667795132971614229/1209925940835262464
|
||||
// instead, for webp, we use libwebp itself to decode, as well.
|
||||
// (pls no cve)
|
||||
let image = if reader.format() == Some(ImageFormat::WebP) {
|
||||
let webp_image = webp::Decoder::new(data).decode().ok_or_else(|| {
|
||||
PKAvatarError::InternalError(anyhow::anyhow!("webp decode failed").into())
|
||||
})?;
|
||||
webp_image.to_image()
|
||||
} else {
|
||||
reader.decode().map_err(|e| {
|
||||
// print the ugly error, return the nice error
|
||||
error!("error decoding image: {}", e);
|
||||
PKAvatarError::ImageFormatError(e)
|
||||
})?
|
||||
};
|
||||
|
||||
let time_after_decode = Instant::now();
|
||||
let image = resize(image, kind);
|
||||
let time_after_resize = Instant::now();
|
||||
|
||||
let encoded = encode(image);
|
||||
let time_after = Instant::now();
|
||||
|
||||
info!(
|
||||
"{}: lossy size {}K (parse: {} ms, decode: {} ms, resize: {} ms, encode: {} ms)",
|
||||
encoded.hash,
|
||||
encoded.data.len() / 1024,
|
||||
(time_after_parse - time_before).whole_milliseconds(),
|
||||
(time_after_decode - time_after_parse).whole_milliseconds(),
|
||||
(time_after_resize - time_after_decode).whole_milliseconds(),
|
||||
(time_after - time_after_resize).whole_milliseconds(),
|
||||
);
|
||||
|
||||
debug!(
|
||||
"processed image {}: {} bytes, {}x{} -> {} bytes, {}x{}",
|
||||
encoded.hash,
|
||||
data.len(),
|
||||
width,
|
||||
height,
|
||||
encoded.data.len(),
|
||||
encoded.width,
|
||||
encoded.height
|
||||
);
|
||||
Ok(encoded)
|
||||
}
|
||||
|
||||
fn assert_dimensions((width, height): (u32, u32)) -> Result<(u32, u32), PKAvatarError> {
|
||||
if width > MAX_DIMENSION || height > MAX_DIMENSION {
|
||||
return Err(PKAvatarError::ImageDimensionsTooLarge(
|
||||
(width, height),
|
||||
(MAX_DIMENSION, MAX_DIMENSION),
|
||||
));
|
||||
}
|
||||
return Ok((width, height));
|
||||
}
|
||||
fn process_gif(input_data: &[u8], kind: ImageKind) -> Result<Option<ProcessOutput>, PKAvatarError> {
|
||||
// gifs only supported for banners
|
||||
if kind != ImageKind::Banner {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
// and we can't rescale gifs (i tried :/) so the max size is the real limit
|
||||
if kind != ImageKind::Banner {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let reader = gif::Decoder::new(Cursor::new(input_data)).map_err(Into::<anyhow::Error>::into)?;
|
||||
let (max_width, max_height) = kind.size();
|
||||
if reader.width() as u32 > max_width || reader.height() as u32 > max_height {
|
||||
return Err(PKAvatarError::ImageDimensionsTooLarge(
|
||||
(reader.width() as u32, reader.height() as u32),
|
||||
(max_width, max_height),
|
||||
));
|
||||
}
|
||||
Ok(process_gif_inner(reader).map_err(Into::<anyhow::Error>::into)?)
|
||||
}
|
||||
|
||||
fn process_gif_inner(
|
||||
mut reader: gif::Decoder<Cursor<&[u8]>>,
|
||||
) -> Result<Option<ProcessOutput>, anyhow::Error> {
|
||||
let time_before = Instant::now();
|
||||
|
||||
let (width, height) = (reader.width(), reader.height());
|
||||
|
||||
let mut writer = gif::Encoder::new(
|
||||
Vec::new(),
|
||||
width as u16,
|
||||
height as u16,
|
||||
reader.global_palette().unwrap_or(&[]),
|
||||
)?;
|
||||
writer.set_repeat(reader.repeat())?;
|
||||
|
||||
let mut frame_buf = Vec::new();
|
||||
|
||||
let mut frame_count = 0;
|
||||
while let Some(frame) = reader.next_frame_info()? {
|
||||
let mut frame = frame.clone();
|
||||
assert_dimensions((frame.width as u32, frame.height as u32))?;
|
||||
frame_buf.clear();
|
||||
frame_buf.resize(reader.buffer_size(), 0);
|
||||
reader.read_into_buffer(&mut frame_buf)?;
|
||||
frame.buffer = Cow::Borrowed(&frame_buf);
|
||||
|
||||
frame.make_lzw_pre_encoded();
|
||||
writer.write_lzw_pre_encoded_frame(&frame)?;
|
||||
frame_count += 1;
|
||||
}
|
||||
|
||||
if frame_count == 1 {
|
||||
// If there's only one frame, then this doesn't need to be a gif. webp it
|
||||
// (unfortunately we can't tell if there's only one frame until after the first frame's been decoded...)
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let data = writer.into_inner()?;
|
||||
let time_after = Instant::now();
|
||||
|
||||
let hash = Hash::sha256(&data);
|
||||
|
||||
let original_data = reader.into_inner();
|
||||
info!(
|
||||
"processed gif {}: {}K -> {}K ({} ms, frames: {})",
|
||||
hash,
|
||||
original_data.buffer().len() / 1024,
|
||||
data.len() / 1024,
|
||||
(time_after - time_before).whole_milliseconds(),
|
||||
frame_count
|
||||
);
|
||||
|
||||
Ok(Some(ProcessOutput {
|
||||
data,
|
||||
format: ProcessedFormat::Gif,
|
||||
hash,
|
||||
width: width as u32,
|
||||
height: height as u32,
|
||||
}))
|
||||
}
|
||||
|
||||
fn reader_for(data: &[u8]) -> image::io::Reader<Cursor<&[u8]>> {
|
||||
image::io::Reader::new(Cursor::new(data))
|
||||
.with_guessed_format()
|
||||
.expect("cursor i/o is infallible")
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
fn resize(image: DynamicImage, kind: ImageKind) -> DynamicImage {
|
||||
let (target_width, target_height) = kind.size();
|
||||
if image.width() <= target_width && image.height() <= target_height {
|
||||
// don't resize if already smaller
|
||||
return image;
|
||||
}
|
||||
|
||||
// todo: best filter?
|
||||
let resized = image.resize(
|
||||
target_width,
|
||||
target_height,
|
||||
image::imageops::FilterType::Lanczos3,
|
||||
);
|
||||
return resized;
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
// can't believe this is infallible
|
||||
fn encode(image: DynamicImage) -> ProcessOutput {
|
||||
let (width, height) = (image.width(), image.height());
|
||||
let image_buf = image.to_rgba8();
|
||||
|
||||
let encoded_lossy = webp::Encoder::new(&*image_buf, webp::PixelLayout::Rgba, width, height)
|
||||
.encode_simple(false, 90.0)
|
||||
.expect("encode should be infallible")
|
||||
.to_vec();
|
||||
|
||||
let hash = Hash::sha256(&encoded_lossy);
|
||||
|
||||
ProcessOutput {
|
||||
data: encoded_lossy,
|
||||
format: ProcessedFormat::Webp,
|
||||
hash,
|
||||
width,
|
||||
height,
|
||||
}
|
||||
}
|
||||
166
crates/avatars/src/pull.rs
Normal file
166
crates/avatars/src/pull.rs
Normal file
|
|
@ -0,0 +1,166 @@
|
|||
use std::time::Duration;
|
||||
use std::{str::FromStr, sync::Arc};
|
||||
|
||||
use crate::PKAvatarError;
|
||||
use anyhow::Context;
|
||||
use reqwest::{Client, ClientBuilder, StatusCode, Url};
|
||||
use time::Instant;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
const MAX_SIZE: u64 = 8 * 1024 * 1024;
|
||||
|
||||
pub struct PullResult {
|
||||
pub data: Vec<u8>,
|
||||
pub content_type: String,
|
||||
pub last_modified: Option<String>,
|
||||
}
|
||||
|
||||
#[instrument(skip_all)]
|
||||
pub async fn pull(
|
||||
client: Arc<Client>,
|
||||
parsed_url: &ParsedUrl,
|
||||
) -> Result<PullResult, PKAvatarError> {
|
||||
let time_before = Instant::now();
|
||||
let mut trimmed_url = trim_url_query(&parsed_url.full_url)?;
|
||||
if trimmed_url.host_str() == Some("media.discordapp.net") {
|
||||
trimmed_url
|
||||
.set_host(Some("cdn.discordapp.com"))
|
||||
.expect("set_host should not fail");
|
||||
}
|
||||
let response = client.get(trimmed_url.clone()).send().await.map_err(|e| {
|
||||
error!("network error for {}: {}", parsed_url.full_url, e);
|
||||
PKAvatarError::NetworkError(e)
|
||||
})?;
|
||||
let time_after_headers = Instant::now();
|
||||
let status = response.status();
|
||||
|
||||
if status != StatusCode::OK {
|
||||
return Err(PKAvatarError::BadCdnResponse(status));
|
||||
}
|
||||
|
||||
let size = match response.content_length() {
|
||||
None => return Err(PKAvatarError::MissingHeader("Content-Length")),
|
||||
Some(size) if size > MAX_SIZE => {
|
||||
return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE))
|
||||
}
|
||||
Some(size) => size,
|
||||
};
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|x| x.to_str().ok()) // invalid (non-unicode) header = missing, why not
|
||||
.map(|mime| mime.split(';').next().unwrap_or("")) // cut off at ;
|
||||
.ok_or(PKAvatarError::MissingHeader("Content-Type"))?
|
||||
.to_owned();
|
||||
let mime = match content_type.as_str() {
|
||||
mime @ ("image/jpeg" | "image/png" | "image/gif" | "image/webp" | "image/tiff") => mime,
|
||||
_ => return Err(PKAvatarError::UnsupportedContentType(content_type)),
|
||||
};
|
||||
|
||||
let last_modified = response
|
||||
.headers()
|
||||
.get(reqwest::header::LAST_MODIFIED)
|
||||
.and_then(|x| x.to_str().ok())
|
||||
.map(|x| x.to_string());
|
||||
|
||||
let body = response.bytes().await.map_err(|e| {
|
||||
error!("network error for {}: {}", parsed_url.full_url, e);
|
||||
PKAvatarError::NetworkError(e)
|
||||
})?;
|
||||
if body.len() != size as usize {
|
||||
// ???does this ever happen?
|
||||
return Err(PKAvatarError::InternalError(anyhow::anyhow!(
|
||||
"server responded with wrong length"
|
||||
)));
|
||||
}
|
||||
let time_after_body = Instant::now();
|
||||
|
||||
let headers_time = time_after_headers - time_before;
|
||||
let body_time = time_after_body - time_after_headers;
|
||||
|
||||
// can't do dynamic log level lmao
|
||||
if status != StatusCode::OK {
|
||||
tracing::warn!(
|
||||
"{}: {} (headers: {}ms, body: {}ms)",
|
||||
status,
|
||||
&trimmed_url,
|
||||
headers_time.whole_milliseconds(),
|
||||
body_time.whole_milliseconds()
|
||||
);
|
||||
} else {
|
||||
tracing::info!(
|
||||
"{}: {} (headers: {}ms, body: {}ms)",
|
||||
status,
|
||||
&trimmed_url,
|
||||
headers_time.whole_milliseconds(),
|
||||
body_time.whole_milliseconds()
|
||||
);
|
||||
};
|
||||
|
||||
Ok(PullResult {
|
||||
data: body.to_vec(),
|
||||
content_type: mime.to_string(),
|
||||
last_modified,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ParsedUrl {
|
||||
pub channel_id: u64,
|
||||
pub attachment_id: u64,
|
||||
pub filename: String,
|
||||
pub full_url: String,
|
||||
}
|
||||
|
||||
pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
|
||||
// todo: should this return PKAvatarError::InvalidCdnUrl?
|
||||
let url = Url::from_str(url).context("invalid url")?;
|
||||
|
||||
match (url.scheme(), url.domain()) {
|
||||
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
|
||||
_ => anyhow::bail!("not a discord cdn url"),
|
||||
}
|
||||
|
||||
match url
|
||||
.path_segments()
|
||||
.map(|x| x.collect::<Vec<_>>())
|
||||
.as_deref()
|
||||
{
|
||||
Some([_, channel_id, attachment_id, filename]) => {
|
||||
let channel_id = u64::from_str(channel_id).context("invalid channel id")?;
|
||||
let attachment_id = u64::from_str(attachment_id).context("invalid channel id")?;
|
||||
|
||||
Ok(ParsedUrl {
|
||||
channel_id,
|
||||
attachment_id,
|
||||
filename: filename.to_string(),
|
||||
full_url: url.to_string(),
|
||||
})
|
||||
}
|
||||
_ => anyhow::bail!("invaild discord cdn url"),
|
||||
}
|
||||
}
|
||||
|
||||
fn trim_url_query(url: &str) -> anyhow::Result<Url> {
|
||||
let mut parsed = Url::parse(url)?;
|
||||
|
||||
let mut qs = form_urlencoded::Serializer::new(String::new());
|
||||
for (key, value) in parsed.query_pairs() {
|
||||
match key.as_ref() {
|
||||
"ex" | "is" | "hm" => {
|
||||
qs.append_pair(key.as_ref(), value.as_ref());
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
let new_query = qs.finish();
|
||||
parsed.set_query(if new_query.len() > 0 {
|
||||
Some(&new_query)
|
||||
} else {
|
||||
None
|
||||
});
|
||||
|
||||
Ok(parsed)
|
||||
}
|
||||
60
crates/avatars/src/store.rs
Normal file
60
crates/avatars/src/store.rs
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
use crate::process::ProcessOutput;
|
||||
use tracing::error;
|
||||
|
||||
pub struct StoreResult {
|
||||
pub id: String,
|
||||
pub path: String,
|
||||
}
|
||||
|
||||
pub async fn store(bucket: &s3::Bucket, res: &ProcessOutput) -> anyhow::Result<StoreResult> {
|
||||
// errors here are all going to be internal
|
||||
let encoded_hash = res.hash.to_string();
|
||||
let path = format!(
|
||||
"images/{}/{}.{}",
|
||||
&encoded_hash[..2],
|
||||
&encoded_hash[2..],
|
||||
res.format.extension()
|
||||
);
|
||||
|
||||
// todo: something better than these retries
|
||||
let mut retry_count = 0;
|
||||
loop {
|
||||
if retry_count == 2 {
|
||||
tokio::time::sleep(tokio::time::Duration::new(2, 0)).await;
|
||||
}
|
||||
if retry_count > 2 {
|
||||
anyhow::bail!("error uploading image to cdn, too many retries") // nicer user-facing error?
|
||||
}
|
||||
retry_count += 1;
|
||||
|
||||
let resp = bucket
|
||||
.put_object_with_content_type(&path, &res.data, res.format.mime_type())
|
||||
.await?;
|
||||
match resp.status_code() {
|
||||
200 => {
|
||||
tracing::debug!("uploaded image to {}", &path);
|
||||
|
||||
return Ok(StoreResult {
|
||||
id: encoded_hash,
|
||||
path,
|
||||
});
|
||||
}
|
||||
500 | 503 => {
|
||||
tracing::warn!(
|
||||
"got 503 uploading image to {} ({}), retrying... (try {}/3)",
|
||||
&path,
|
||||
resp.as_str()?,
|
||||
retry_count
|
||||
);
|
||||
continue;
|
||||
}
|
||||
_ => {
|
||||
error!(
|
||||
"storage backend responded status code {}",
|
||||
resp.status_code()
|
||||
);
|
||||
anyhow::bail!("error uploading image to cdn") // nicer user-facing error?
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
16
crates/dispatch/Cargo.toml
Normal file
16
crates/dispatch/Cargo.toml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "dispatch"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
hickory-client = "0.24.1"
|
||||
52
crates/dispatch/src/logger.rs
Normal file
52
crates/dispatch/src/logger.rs
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_id_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).instrument(request_id_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
uri.path(),
|
||||
elapsed
|
||||
);
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
192
crates/dispatch/src/main.rs
Normal file
192
crates/dispatch/src/main.rs
Normal file
|
|
@ -0,0 +1,192 @@
|
|||
#![feature(ip)]
|
||||
|
||||
use hickory_client::{
|
||||
client::{AsyncClient, ClientHandle},
|
||||
rr::{DNSClass, Name, RData, RecordType},
|
||||
udp::UdpClientStream,
|
||||
};
|
||||
use reqwest::{redirect::Policy, StatusCode};
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{net::UdpSocket, sync::RwLock};
|
||||
use tracing::{debug, error, info};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use axum::{extract::State, http::Uri, routing::post, Json, Router};
|
||||
|
||||
mod logger;
|
||||
|
||||
// this package does not currently use libpk
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.json()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
info!("hello world");
|
||||
|
||||
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
|
||||
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
|
||||
let (client, bg) = AsyncClient::connect(stream).await?;
|
||||
tokio::spawn(bg);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", post(dispatch))
|
||||
.with_state(Arc::new(RwLock::new(DNSClient(client))))
|
||||
.layer(axum::middleware::from_fn(logger::logger));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct DispatchRequest {
|
||||
auth: String,
|
||||
url: String,
|
||||
payload: String,
|
||||
test: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum DispatchResponse {
|
||||
OK,
|
||||
BadData,
|
||||
ResolveFailed,
|
||||
NoIPs,
|
||||
InvalidIP,
|
||||
FetchFailed,
|
||||
InvalidResponseCode(StatusCode),
|
||||
TestFailed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DispatchResponse {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
async fn dispatch(
|
||||
// not entirely sure if this RwLock is the right way to do it
|
||||
State(dns): State<Arc<RwLock<DNSClient>>>,
|
||||
Json(req): Json<DispatchRequest>,
|
||||
) -> String {
|
||||
// todo: fix
|
||||
if req.auth != std::env::var("HTTP_AUTH_TOKEN").unwrap() {
|
||||
return "".to_string();
|
||||
}
|
||||
|
||||
let uri = match req.url.parse::<Uri>() {
|
||||
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
|
||||
Err(error) => {
|
||||
error!(?error, "failed to parse uri {}", req.url);
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
_ => {
|
||||
error!("uri {} is invalid", req.url);
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
};
|
||||
let ips = {
|
||||
let mut dns = dns.write().await;
|
||||
match dns.resolve(uri.host().unwrap().to_string()).await {
|
||||
Ok(v) => v,
|
||||
Err(error) => {
|
||||
error!(?error, "failed to resolve");
|
||||
return DispatchResponse::ResolveFailed.to_string();
|
||||
}
|
||||
}
|
||||
};
|
||||
if ips.iter().any(|ip| !ip.is_global()) {
|
||||
return DispatchResponse::InvalidIP.to_string();
|
||||
}
|
||||
|
||||
if ips.len() == 0 {
|
||||
return DispatchResponse::NoIPs.to_string();
|
||||
}
|
||||
|
||||
let ips: Vec<SocketAddr> = ips
|
||||
.iter()
|
||||
.map(|ip| SocketAddr::V4(SocketAddrV4::new(*ip, 443)))
|
||||
.collect();
|
||||
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.user_agent("PluralKit Dispatch (https://pluralkit.me/api/dispatch/)")
|
||||
.redirect(Policy::none())
|
||||
.timeout(Duration::from_secs(10))
|
||||
.http1_only()
|
||||
.use_rustls_tls()
|
||||
.https_only(true)
|
||||
.resolve_to_addrs(uri.host().unwrap(), &ips)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let res = client
|
||||
.post(req.url.clone())
|
||||
.header("content-type", "application/json")
|
||||
.body(req.payload)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(res) if res.status() != 200 => {
|
||||
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, url = req.url.clone(), "failed to fetch");
|
||||
return DispatchResponse::FetchFailed.to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if let Some(test) = req.test {
|
||||
let test_res = client
|
||||
.post(req.url.clone())
|
||||
.header("content-type", "application/json")
|
||||
.body(test)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match test_res {
|
||||
Ok(res) if res.status() != 401 => return DispatchResponse::TestFailed.to_string(),
|
||||
Err(error) => {
|
||||
error!(?error, url = req.url.clone(), "failed to fetch");
|
||||
return DispatchResponse::FetchFailed.to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
DispatchResponse::OK.to_string()
|
||||
}
|
||||
|
||||
struct DNSClient(AsyncClient);
|
||||
|
||||
impl DNSClient {
|
||||
async fn resolve(&mut self, host: String) -> anyhow::Result<Vec<Ipv4Addr>> {
|
||||
let resp = self
|
||||
.0
|
||||
.query(Name::from_ascii(host)?, DNSClass::IN, RecordType::A)
|
||||
.await?;
|
||||
|
||||
debug!("got dns response: {resp:?}");
|
||||
|
||||
Ok(resp
|
||||
.answers()
|
||||
.iter()
|
||||
.filter_map(|ans| {
|
||||
if let Some(RData::A(val)) = ans.data() {
|
||||
Some(val.0)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
27
crates/gateway/Cargo.toml
Normal file
27
crates/gateway/Cargo.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "gateway"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
libpk = { path = "../libpk" }
|
||||
metrics = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
signal-hook = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
twilight-gateway = { workspace = true }
|
||||
twilight-cache-inmemory = { workspace = true }
|
||||
twilight-util = { workspace = true }
|
||||
twilight-model = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
|
||||
serde_variant = "0.1.3"
|
||||
183
crates/gateway/src/cache_api.rs
Normal file
183
crates/gateway/src/cache_api.rs
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use serde_json::{json, to_string};
|
||||
use tracing::{error, info};
|
||||
use twilight_model::id::Id;
|
||||
|
||||
use crate::discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn status_code(code: StatusCode, body: String) -> Response {
|
||||
(code, body).into_response()
|
||||
}
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/guilds/:guild_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild(Id::new(guild_id)) {
|
||||
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/members/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
|
||||
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
|
||||
Ok(val) => {
|
||||
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
|
||||
},
|
||||
Err(err) => {
|
||||
error!(?err, ?guild_id, "failed to get own guild member permissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/:user_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
|
||||
Err(err) => {
|
||||
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/channels",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
|
||||
Some(channels) => channels.to_owned(),
|
||||
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
};
|
||||
|
||||
let mut channels = Vec::new();
|
||||
for id in channel_ids {
|
||||
match cache.0.channel(id) {
|
||||
Some(channel) => channels.push(channel.to_owned()),
|
||||
None => {
|
||||
tracing::error!(
|
||||
channel_id = id.get(),
|
||||
"referenced channel {} from guild {} not found in cache",
|
||||
id.get(), guild_id,
|
||||
);
|
||||
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
|
||||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
|
||||
}
|
||||
match cache.0.channel(Id::new(channel_id)) {
|
||||
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string())
|
||||
}
|
||||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
|
||||
}
|
||||
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
|
||||
Err(err) => {
|
||||
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
|
||||
get(|| async { "todo" }),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/last_message",
|
||||
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/roles",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
|
||||
Some(roles) => roles.to_owned(),
|
||||
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
};
|
||||
|
||||
let mut roles = Vec::new();
|
||||
for id in role_ids {
|
||||
match cache.0.role(id) {
|
||||
Some(role) => roles.push(role.value().resource().to_owned()),
|
||||
None => {
|
||||
tracing::error!(
|
||||
role_id = id.get(),
|
||||
"referenced role {} from guild {} not found in cache",
|
||||
id.get(), guild_id,
|
||||
);
|
||||
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
|
||||
})
|
||||
)
|
||||
|
||||
.route("/stats", get(|State(cache): State<Arc<DiscordCache>>| async move {
|
||||
let cluster = cluster_config();
|
||||
let has_been_up = cache.2.read().await.len() as u32 == if cluster.total_shards > 16 {16} else {cluster.total_shards};
|
||||
let stats = cache.0.stats();
|
||||
let stats = json!({
|
||||
"guild_count": stats.guilds(),
|
||||
"channel_count": stats.channels(),
|
||||
// just put this here until prom stats
|
||||
"unavailable_guild_count": stats.unavailable_guilds(),
|
||||
"up": has_been_up,
|
||||
});
|
||||
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
|
||||
}))
|
||||
|
||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||
.with_state(cache);
|
||||
|
||||
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
368
crates/gateway/src/discord/cache.rs
Normal file
368
crates/gateway/src/discord/cache.rs
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
use anyhow::format_err;
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use twilight_cache_inmemory::{
|
||||
model::CachedMember,
|
||||
permission::{MemberRoles, RootError},
|
||||
traits::CacheableChannel,
|
||||
InMemoryCache, ResourceType,
|
||||
};
|
||||
use twilight_model::{
|
||||
channel::{Channel, ChannelType},
|
||||
guild::{Guild, Member, Permissions},
|
||||
id::{
|
||||
marker::{ChannelMarker, GuildMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
use twilight_util::permission_calculator::PermissionCalculator;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
|
||||
| Permissions::SEND_MESSAGES
|
||||
| Permissions::READ_MESSAGE_HISTORY
|
||||
| Permissions::ADD_REACTIONS
|
||||
| Permissions::ATTACH_FILES
|
||||
| Permissions::EMBED_LINKS
|
||||
| Permissions::USE_EXTERNAL_EMOJIS
|
||||
| Permissions::CONNECT
|
||||
| Permissions::SPEAK
|
||||
| Permissions::USE_VAD;
|
||||
}
|
||||
|
||||
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
|
||||
Channel {
|
||||
id,
|
||||
kind: ChannelType::Private,
|
||||
|
||||
application_id: None,
|
||||
applied_tags: None,
|
||||
available_tags: None,
|
||||
bitrate: None,
|
||||
default_auto_archive_duration: None,
|
||||
default_forum_layout: None,
|
||||
default_reaction_emoji: None,
|
||||
default_sort_order: None,
|
||||
default_thread_rate_limit_per_user: None,
|
||||
flags: None,
|
||||
guild_id: None,
|
||||
icon: None,
|
||||
invitable: None,
|
||||
last_message_id: None,
|
||||
last_pin_timestamp: None,
|
||||
managed: None,
|
||||
member: None,
|
||||
member_count: None,
|
||||
message_count: None,
|
||||
name: None,
|
||||
newly_created: None,
|
||||
nsfw: None,
|
||||
owner_id: None,
|
||||
parent_id: None,
|
||||
permission_overwrites: None,
|
||||
position: None,
|
||||
rate_limit_per_user: None,
|
||||
recipients: None,
|
||||
rtc_region: None,
|
||||
thread_metadata: None,
|
||||
topic: None,
|
||||
user_limit: None,
|
||||
video_quality_mode: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
|
||||
CachedMember {
|
||||
avatar: item.avatar,
|
||||
communication_disabled_until: item.communication_disabled_until,
|
||||
deaf: Some(item.deaf),
|
||||
flags: item.flags,
|
||||
joined_at: item.joined_at,
|
||||
mute: Some(item.mute),
|
||||
nick: item.nick,
|
||||
premium_since: item.premium_since,
|
||||
roles: item.roles,
|
||||
pending: false,
|
||||
user_id: id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> DiscordCache {
|
||||
let mut client_builder = twilight_http::Client::builder().token(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.clone(),
|
||||
);
|
||||
|
||||
if let Some(base_url) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.api_base_url
|
||||
.clone()
|
||||
{
|
||||
client_builder = client_builder.proxy(base_url, true);
|
||||
}
|
||||
|
||||
let client = Arc::new(client_builder.build());
|
||||
|
||||
let cache = Arc::new(
|
||||
InMemoryCache::builder()
|
||||
.resource_types(
|
||||
ResourceType::GUILD
|
||||
| ResourceType::CHANNEL
|
||||
| ResourceType::ROLE
|
||||
| ResourceType::USER_CURRENT
|
||||
| ResourceType::MEMBER_CURRENT,
|
||||
)
|
||||
.message_cache_size(0)
|
||||
.build(),
|
||||
);
|
||||
|
||||
DiscordCache(cache, client, RwLock::new(Vec::new()))
|
||||
}
|
||||
|
||||
pub struct DiscordCache(
|
||||
pub Arc<InMemoryCache>,
|
||||
pub Arc<twilight_http::Client>,
|
||||
pub RwLock<Vec<u32>>,
|
||||
);
|
||||
|
||||
impl DiscordCache {
|
||||
pub async fn guild_permissions(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
if self
|
||||
.0
|
||||
.guild(guild_id)
|
||||
.ok_or_else(|| format_err!("guild not found"))?
|
||||
.owner_id()
|
||||
== user_id
|
||||
{
|
||||
return Ok(Permissions::all());
|
||||
}
|
||||
|
||||
let member = if user_id
|
||||
== libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.client_id
|
||||
{
|
||||
self.0
|
||||
.member(guild_id, user_id)
|
||||
.ok_or(format_err!("self member not found"))?
|
||||
.value()
|
||||
.to_owned()
|
||||
} else {
|
||||
member_to_cached_member(
|
||||
self.1
|
||||
.guild_member(guild_id, user_id)
|
||||
.await?
|
||||
.model()
|
||||
.await?,
|
||||
user_id,
|
||||
)
|
||||
};
|
||||
|
||||
let MemberRoles { assigned, everyone } = self
|
||||
.0
|
||||
.permissions()
|
||||
.member_roles(guild_id, &member)
|
||||
.map_err(RootError::from_member_roles)?;
|
||||
let calculator =
|
||||
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
|
||||
|
||||
let permissions = calculator.root();
|
||||
|
||||
Ok(self
|
||||
.0
|
||||
.permissions()
|
||||
.disable_member_communication(&member, permissions))
|
||||
}
|
||||
|
||||
pub async fn channel_permissions(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
let channel = self
|
||||
.0
|
||||
.channel(channel_id)
|
||||
.ok_or(format_err!("channel not found"))?;
|
||||
|
||||
if channel.value().guild_id.is_none() {
|
||||
return Ok(*DM_PERMISSIONS);
|
||||
}
|
||||
|
||||
let guild_id = channel.value().guild_id.unwrap();
|
||||
|
||||
if self
|
||||
.0
|
||||
.guild(guild_id)
|
||||
.ok_or_else(|| {
|
||||
tracing::error!(
|
||||
channel_id = channel_id.get(),
|
||||
guild_id = guild_id.get(),
|
||||
"referenced guild from cached channel {channel_id} not found in cache"
|
||||
);
|
||||
format_err!("internal cache error")
|
||||
})?
|
||||
.owner_id()
|
||||
== user_id
|
||||
{
|
||||
return Ok(Permissions::all());
|
||||
}
|
||||
|
||||
let member = if user_id
|
||||
== libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.client_id
|
||||
{
|
||||
self.0
|
||||
.member(guild_id, user_id)
|
||||
.ok_or_else(|| {
|
||||
tracing::error!(
|
||||
guild_id = guild_id.get(),
|
||||
"self member for cached guild {guild_id} not found in cache"
|
||||
);
|
||||
format_err!("internal cache error")
|
||||
})?
|
||||
.value()
|
||||
.to_owned()
|
||||
} else {
|
||||
member_to_cached_member(
|
||||
self.1
|
||||
.guild_member(guild_id, user_id)
|
||||
.await?
|
||||
.model()
|
||||
.await?,
|
||||
user_id,
|
||||
)
|
||||
};
|
||||
|
||||
let MemberRoles { assigned, everyone } = self
|
||||
.0
|
||||
.permissions()
|
||||
.member_roles(guild_id, &member)
|
||||
.map_err(RootError::from_member_roles)?;
|
||||
|
||||
let overwrites = match channel.kind {
|
||||
ChannelType::AnnouncementThread
|
||||
| ChannelType::PrivateThread
|
||||
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
|
||||
_ => channel
|
||||
.value()
|
||||
.permission_overwrites()
|
||||
.unwrap_or_default()
|
||||
.to_vec(),
|
||||
};
|
||||
|
||||
let calculator =
|
||||
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
|
||||
|
||||
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
|
||||
|
||||
Ok(self
|
||||
.0
|
||||
.permissions()
|
||||
.disable_member_communication(&member, permissions))
|
||||
}
|
||||
|
||||
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
|
||||
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
|
||||
self.0.guild(id).map(|guild| {
|
||||
let channels = self
|
||||
.0
|
||||
.guild_channels(id)
|
||||
.map(|reference| {
|
||||
reference
|
||||
.iter()
|
||||
.filter_map(|channel_id| {
|
||||
let channel = self.0.channel(*channel_id)?;
|
||||
|
||||
if channel.kind.is_thread() {
|
||||
None
|
||||
} else {
|
||||
Some(channel.value().clone())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let roles = self
|
||||
.0
|
||||
.guild_roles(id)
|
||||
.map(|reference| {
|
||||
reference
|
||||
.iter()
|
||||
.filter_map(|role_id| {
|
||||
Some(self.0.role(*role_id)?.value().resource().clone())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Guild {
|
||||
afk_channel_id: guild.afk_channel_id(),
|
||||
afk_timeout: guild.afk_timeout(),
|
||||
application_id: guild.application_id(),
|
||||
approximate_member_count: None, // Only present in with_counts HTTP endpoint
|
||||
banner: guild.banner().map(ToOwned::to_owned),
|
||||
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
|
||||
channels,
|
||||
default_message_notifications: guild.default_message_notifications(),
|
||||
description: guild.description().map(ToString::to_string),
|
||||
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
|
||||
emojis: vec![],
|
||||
explicit_content_filter: guild.explicit_content_filter(),
|
||||
features: guild.features().cloned().collect(),
|
||||
icon: guild.icon().map(ToOwned::to_owned),
|
||||
id: guild.id(),
|
||||
joined_at: guild.joined_at(),
|
||||
large: guild.large(),
|
||||
max_members: guild.max_members(),
|
||||
max_presences: guild.max_presences(),
|
||||
max_video_channel_users: guild.max_video_channel_users(),
|
||||
member_count: guild.member_count(),
|
||||
members: vec![],
|
||||
mfa_level: guild.mfa_level(),
|
||||
name: guild.name().to_string(),
|
||||
nsfw_level: guild.nsfw_level(),
|
||||
owner_id: guild.owner_id(),
|
||||
owner: guild.owner(),
|
||||
permissions: guild.permissions(),
|
||||
public_updates_channel_id: guild.public_updates_channel_id(),
|
||||
preferred_locale: guild.preferred_locale().to_string(),
|
||||
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
|
||||
premium_subscription_count: guild.premium_subscription_count(),
|
||||
premium_tier: guild.premium_tier(),
|
||||
presences: vec![],
|
||||
roles,
|
||||
rules_channel_id: guild.rules_channel_id(),
|
||||
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
|
||||
splash: guild.splash().map(ToOwned::to_owned),
|
||||
stage_instances: vec![],
|
||||
stickers: vec![],
|
||||
system_channel_flags: guild.system_channel_flags(),
|
||||
system_channel_id: guild.system_channel_id(),
|
||||
threads: vec![],
|
||||
unavailable: false,
|
||||
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
|
||||
verification_level: guild.verification_level(),
|
||||
voice_states: vec![],
|
||||
widget_channel_id: guild.widget_channel_id(),
|
||||
widget_enabled: guild.widget_enabled(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
200
crates/gateway/src/discord/gateway.rs
Normal file
200
crates/gateway/src/discord/gateway.rs
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
use futures::StreamExt;
|
||||
use libpk::_config::ClusterSettings;
|
||||
use metrics::counter;
|
||||
use std::sync::{mpsc::Sender, Arc};
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_gateway::{
|
||||
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
|
||||
};
|
||||
use twilight_model::gateway::{
|
||||
payload::outgoing::update_presence::UpdatePresencePayload,
|
||||
presence::{Activity, ActivityType, Status},
|
||||
Intents,
|
||||
};
|
||||
|
||||
use crate::discord::identify_queue::{self, RedisQueue};
|
||||
|
||||
use super::{cache::DiscordCache, shard_state::ShardStateManager};
|
||||
|
||||
pub fn cluster_config() -> ClusterSettings {
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.cluster
|
||||
.clone()
|
||||
.unwrap_or(libpk::_config::ClusterSettings {
|
||||
node_id: 0,
|
||||
total_shards: 1,
|
||||
total_nodes: 1,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
|
||||
let intents = Intents::GUILDS
|
||||
| Intents::DIRECT_MESSAGES
|
||||
| Intents::DIRECT_MESSAGE_REACTIONS
|
||||
| Intents::GUILD_MESSAGES
|
||||
| Intents::GUILD_MESSAGE_REACTIONS
|
||||
| Intents::MESSAGE_CONTENT;
|
||||
|
||||
let queue = identify_queue::new(redis);
|
||||
|
||||
let cluster_settings = cluster_config();
|
||||
|
||||
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
|
||||
warn!("we have less than 16 shards, assuming single gateway process");
|
||||
(0, (cluster_settings.total_shards - 1).into())
|
||||
} else {
|
||||
(
|
||||
(cluster_settings.node_id * 16).into(),
|
||||
(((cluster_settings.node_id + 1) * 16) - 1).into(),
|
||||
)
|
||||
};
|
||||
|
||||
let shards = create_iterator(
|
||||
start_shard..end_shard + 1,
|
||||
cluster_settings.total_shards,
|
||||
ConfigBuilder::new(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.to_owned(),
|
||||
intents,
|
||||
)
|
||||
.presence(presence("pk;help", false))
|
||||
.queue(queue.clone())
|
||||
.build(),
|
||||
|_, builder| builder.build(),
|
||||
);
|
||||
|
||||
let mut shards_vec = Vec::new();
|
||||
shards_vec.extend(shards);
|
||||
|
||||
Ok(shards_vec)
|
||||
}
|
||||
|
||||
pub async fn runner(
|
||||
mut shard: Shard<RedisQueue>,
|
||||
_tx: Sender<(ShardId, String)>,
|
||||
shard_state: ShardStateManager,
|
||||
cache: Arc<DiscordCache>,
|
||||
) {
|
||||
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
|
||||
info!("waiting for events");
|
||||
while let Some(item) = shard.next().await {
|
||||
let raw_event = match item {
|
||||
Ok(evt) => match evt {
|
||||
Message::Close(frame) => {
|
||||
info!(
|
||||
"shard {} closed: {}",
|
||||
shard.id().number(),
|
||||
if let Some(close) = frame {
|
||||
format!("{} ({})", close.code, close.reason)
|
||||
} else {
|
||||
"unknown".to_string()
|
||||
}
|
||||
);
|
||||
if let Err(error) = shard_state.socket_closed(shard.id().number()).await {
|
||||
error!("failed to update shard state for socket closure: {error}");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Message::Text(text) => text,
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let event = match twilight_gateway::parse(raw_event.clone(), EventTypeFlags::all()) {
|
||||
Ok(Some(parsed)) => Event::from(parsed),
|
||||
Ok(None) => {
|
||||
// we received an event type unknown to twilight
|
||||
// that's fine, we probably don't need it anyway
|
||||
continue;
|
||||
}
|
||||
Err(error) => {
|
||||
error!(
|
||||
"shard {} failed to parse gateway event: {}",
|
||||
shard.id().number(),
|
||||
error
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// log the event in metrics
|
||||
// event_type * shard_id is too many labels and prometheus fails to query it
|
||||
// so we split it into two metrics
|
||||
counter!(
|
||||
"pluralkit_gateway_events_type",
|
||||
"event_type" => serde_variant::to_variant_name(&event.kind()).unwrap(),
|
||||
)
|
||||
.increment(1);
|
||||
counter!(
|
||||
"pluralkit_gateway_events_shard",
|
||||
"shard_id" => shard.id().number().to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
// update shard state and discord cache
|
||||
if let Err(error) = shard_state
|
||||
.handle_event(shard.id().number(), event.clone())
|
||||
.await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state");
|
||||
}
|
||||
// need to do heartbeat separately, to get the latency
|
||||
if let Event::GatewayHeartbeatAck = event
|
||||
&& let Err(error) = shard_state
|
||||
.heartbeated(shard.id().number(), shard.latency())
|
||||
.await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state for latency");
|
||||
}
|
||||
|
||||
if let Event::Ready(_) = event {
|
||||
if !cache.2.read().await.contains(&shard.id().number()) {
|
||||
cache.2.write().await.push(shard.id().number());
|
||||
}
|
||||
}
|
||||
cache.0.update(&event);
|
||||
|
||||
// okay, we've handled the event internally, let's send it to consumers
|
||||
// tx.send((shard.id(), raw_event)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
|
||||
UpdatePresencePayload {
|
||||
activities: vec![Activity {
|
||||
application_id: None,
|
||||
assets: None,
|
||||
buttons: vec![],
|
||||
created_at: None,
|
||||
details: None,
|
||||
id: None,
|
||||
state: None,
|
||||
url: None,
|
||||
emoji: None,
|
||||
flags: None,
|
||||
instance: None,
|
||||
kind: ActivityType::Playing,
|
||||
name: status.to_string(),
|
||||
party: None,
|
||||
secrets: None,
|
||||
timestamps: None,
|
||||
}],
|
||||
afk: false,
|
||||
since: None,
|
||||
status: if going_away {
|
||||
Status::Idle
|
||||
} else {
|
||||
Status::Online
|
||||
},
|
||||
}
|
||||
}
|
||||
88
crates/gateway/src/discord/identify_queue.rs
Normal file
88
crates/gateway/src/discord/identify_queue.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
use fred::{
|
||||
clients::RedisPool,
|
||||
error::RedisError,
|
||||
interfaces::KeysInterface,
|
||||
types::{Expiration, SetOptions},
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::queue::Queue;
|
||||
|
||||
pub fn new(redis: RedisPool) -> RedisQueue {
|
||||
RedisQueue {
|
||||
redis,
|
||||
concurrency: libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.max_concurrency,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisQueue {
|
||||
pub redis: RedisPool,
|
||||
pub concurrency: u32,
|
||||
}
|
||||
|
||||
impl Debug for RedisQueue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RedisQueue")
|
||||
.field("concurrency", &self.concurrency)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Queue for RedisQueue {
|
||||
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(request_inner(
|
||||
self.redis.clone(),
|
||||
self.concurrency,
|
||||
shard_id,
|
||||
tx,
|
||||
));
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
const EXPIRY: i64 = 6;
|
||||
const RETRY_INTERVAL: u64 = 500;
|
||||
|
||||
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
|
||||
let bucket = shard_id % concurrency;
|
||||
let key = format!("pluralkit:identify:{}", bucket);
|
||||
|
||||
info!(shard_id, bucket, "waiting for allowance...");
|
||||
loop {
|
||||
let done: Result<Option<String>, RedisError> = redis
|
||||
.set(
|
||||
key.to_string(),
|
||||
"1",
|
||||
Some(Expiration::EX(EXPIRY)),
|
||||
Some(SetOptions::NX),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
match done {
|
||||
Ok(Some(_)) => {
|
||||
info!(shard_id, bucket, "got allowance!");
|
||||
// if this fails, it's probably already doing something else
|
||||
let _ = tx.send(());
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
|
||||
}
|
||||
}
|
||||
4
crates/gateway/src/discord/mod.rs
Normal file
4
crates/gateway/src/discord/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod cache;
|
||||
pub mod gateway;
|
||||
pub mod identify_queue;
|
||||
pub mod shard_state;
|
||||
91
crates/gateway/src/discord/shard_state.rs
Normal file
91
crates/gateway/src/discord/shard_state.rs
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
use fred::{clients::RedisPool, interfaces::HashesInterface};
|
||||
use metrics::{counter, gauge};
|
||||
use tracing::info;
|
||||
use twilight_gateway::{Event, Latency};
|
||||
|
||||
use libpk::{state::*, util::redis::*};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShardStateManager {
|
||||
redis: RedisPool,
|
||||
}
|
||||
|
||||
pub fn new(redis: RedisPool) -> ShardStateManager {
|
||||
ShardStateManager { redis }
|
||||
}
|
||||
|
||||
impl ShardStateManager {
|
||||
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
|
||||
match event {
|
||||
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
|
||||
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
|
||||
let data: Option<String> = self
|
||||
.redis
|
||||
.hget("pluralkit:shardstatus", shard_id)
|
||||
.await
|
||||
.to_option_or_error()?;
|
||||
match data {
|
||||
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
|
||||
None => Ok(ShardState::default()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset::<(), &str, (String, String)>(
|
||||
"pluralkit:shardstatus",
|
||||
(
|
||||
shard_id.to_string(),
|
||||
serde_json::to_string(&info).expect("could not serialize shard"),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"shard {} {}",
|
||||
shard_id,
|
||||
if resumed { "resumed" } else { "ready" }
|
||||
);
|
||||
counter!(
|
||||
"pluralkit_gateway_shard_reconnect",
|
||||
"shard_id" => shard_id.to_string(),
|
||||
"resumed" => resumed.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
gauge!("pluralkit_gateway_shard_up").increment(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.up = true;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
gauge!("pluralkit_gateway_shard_up").decrement(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.up = false;
|
||||
info.disconnection_count += 1;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.up = true;
|
||||
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.latency = latency
|
||||
.recent()
|
||||
.first()
|
||||
.map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
72
crates/gateway/src/logger.rs
Normal file
72
crates/gateway/src/logger.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{
|
||||
extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response,
|
||||
};
|
||||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_id_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).instrument(request_id_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
counter!(
|
||||
"pluralkit_gateway_cache_api_requests",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"pluralkit_gateway_cache_api_requests_bucket",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
if response.status() != StatusCode::FOUND {
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
uri.path(),
|
||||
elapsed
|
||||
);
|
||||
}
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
140
crates/gateway/src/main.rs
Normal file
140
crates/gateway/src/main.rs
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
#![feature(let_chains)]
|
||||
#![feature(if_let_guard)]
|
||||
|
||||
use chrono::Timelike;
|
||||
use fred::{clients::RedisPool, interfaces::*};
|
||||
use signal_hook::{
|
||||
consts::{SIGINT, SIGTERM},
|
||||
iterator::Signals,
|
||||
};
|
||||
use std::{
|
||||
sync::{mpsc::channel, Arc},
|
||||
time::Duration,
|
||||
vec::Vec,
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{info, warn};
|
||||
use twilight_gateway::{MessageSender, ShardId};
|
||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||
|
||||
mod cache_api;
|
||||
mod discord;
|
||||
mod logger;
|
||||
|
||||
libpk::main!("gateway");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let (shutdown_tx, shutdown_rx) = channel::<()>();
|
||||
let shutdown_tx = Arc::new(shutdown_tx);
|
||||
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let shard_state = discord::shard_state::new(redis.clone());
|
||||
let cache = Arc::new(discord::cache::new());
|
||||
|
||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||
|
||||
let (event_tx, _event_rx) = channel();
|
||||
|
||||
let mut senders = Vec::new();
|
||||
let mut signal_senders = Vec::new();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for shard in shards {
|
||||
senders.push((shard.id(), shard.sender()));
|
||||
signal_senders.push(shard.sender());
|
||||
set.spawn(tokio::spawn(discord::gateway::runner(
|
||||
shard,
|
||||
event_tx.clone(),
|
||||
shard_state.clone(),
|
||||
cache.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
set.spawn(tokio::spawn(
|
||||
async move { scheduled_task(redis, senders).await },
|
||||
));
|
||||
|
||||
// todo: probably don't do it this way
|
||||
let api_shutdown_tx = shutdown_tx.clone();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
match cache_api::run_server(cache).await {
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to serve cache api");
|
||||
let _ = api_shutdown_tx.send(());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}));
|
||||
|
||||
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
|
||||
|
||||
set.spawn(tokio::spawn(async move {
|
||||
for sig in signals.forever() {
|
||||
info!("received signal {:?}", sig);
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
}));
|
||||
|
||||
let _ = shutdown_rx.recv();
|
||||
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
set.abort_all();
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
(60 - chrono::offset::Utc::now().second()).into(),
|
||||
))
|
||||
.await;
|
||||
info!("running per-minute scheduled tasks");
|
||||
|
||||
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
|
||||
Ok(val) => Some(val),
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "failed to fetch bot status from redis");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence(
|
||||
if let Some(status) = status {
|
||||
format!("pk;help | {}", status)
|
||||
} else {
|
||||
"pk;help".to_string()
|
||||
}
|
||||
.as_str(),
|
||||
false,
|
||||
),
|
||||
};
|
||||
|
||||
for sender in senders.iter() {
|
||||
match sender.1.command(&presence) {
|
||||
Err(error) => {
|
||||
warn!(?error, "could not update presence on shard {}", sender.0)
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
24
crates/libpk/Cargo.toml
Normal file
24
crates/libpk/Cargo.toml
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
[package]
|
||||
name = "libpk"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
sentry = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
time = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true}
|
||||
twilight-model = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
|
||||
config = "0.14.0"
|
||||
json-subscriber = { version = "0.2.2", features = ["env-filter"] }
|
||||
metrics-exporter-prometheus = { version = "0.15.3", default-features = false, features = ["tokio", "http-listener", "tracing"] }
|
||||
145
crates/libpk/src/_config.rs
Normal file
145
crates/libpk/src/_config.rs
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
use config::Config;
|
||||
use lazy_static::lazy_static;
|
||||
use serde::Deserialize;
|
||||
use std::sync::Arc;
|
||||
|
||||
use twilight_model::id::{marker::UserMarker, Id};
|
||||
|
||||
#[derive(Clone, Deserialize, Debug)]
|
||||
pub struct ClusterSettings {
|
||||
pub node_id: u32,
|
||||
pub total_shards: u32,
|
||||
pub total_nodes: u32,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DiscordConfig {
|
||||
pub client_id: Id<UserMarker>,
|
||||
pub bot_token: String,
|
||||
pub client_secret: String,
|
||||
pub max_concurrency: u32,
|
||||
#[serde(default)]
|
||||
pub cluster: Option<ClusterSettings>,
|
||||
pub api_base_url: Option<String>,
|
||||
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub cache_api_addr: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DatabaseConfig {
|
||||
pub(crate) data_db_uri: String,
|
||||
pub(crate) data_db_max_connections: Option<u32>,
|
||||
pub(crate) data_db_min_connections: Option<u32>,
|
||||
pub(crate) messages_db_uri: Option<String>,
|
||||
pub(crate) stats_db_uri: Option<String>,
|
||||
pub(crate) db_password: Option<String>,
|
||||
pub data_redis_addr: String,
|
||||
}
|
||||
|
||||
fn _default_api_addr() -> String {
|
||||
"[::]:5000".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct ApiConfig {
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub ratelimit_redis_addr: Option<String>,
|
||||
|
||||
pub remote_url: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub temp_token2: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct AvatarsConfig {
|
||||
pub s3: S3Config,
|
||||
pub cdn_url: String,
|
||||
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub bind_addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub migrate_worker_count: u32,
|
||||
|
||||
#[serde(default)]
|
||||
pub cloudflare_zone_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub cloudflare_token: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct S3Config {
|
||||
pub bucket: String,
|
||||
pub application_id: String,
|
||||
pub application_key: String,
|
||||
pub endpoint: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct ScheduledTasksConfig {
|
||||
pub set_guild_count: bool,
|
||||
pub expected_gateway_count: usize,
|
||||
pub gateway_url: String,
|
||||
}
|
||||
|
||||
fn _metrics_default() -> bool {
|
||||
false
|
||||
}
|
||||
fn _json_log_default() -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct PKConfig {
|
||||
pub db: DatabaseConfig,
|
||||
|
||||
#[serde(default)]
|
||||
pub discord: Option<DiscordConfig>,
|
||||
#[serde(default)]
|
||||
pub api: Option<ApiConfig>,
|
||||
#[serde(default)]
|
||||
pub avatars: Option<AvatarsConfig>,
|
||||
#[serde(default)]
|
||||
pub scheduled_tasks: Option<ScheduledTasksConfig>,
|
||||
|
||||
#[serde(default = "_metrics_default")]
|
||||
pub run_metrics_server: bool,
|
||||
|
||||
#[serde(default = "_json_log_default")]
|
||||
pub(crate) json_log: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub sentry_url: Option<String>,
|
||||
}
|
||||
|
||||
impl PKConfig {
|
||||
pub fn api(self) -> ApiConfig {
|
||||
self.api.expect("missing api config")
|
||||
}
|
||||
|
||||
pub fn discord_config(self) -> DiscordConfig {
|
||||
self.discord.expect("missing discord config")
|
||||
}
|
||||
}
|
||||
|
||||
// todo: consider passing this down instead of making it global
|
||||
// especially since we have optional discord/api/avatars/etc config
|
||||
lazy_static! {
|
||||
#[derive(Debug)]
|
||||
pub static ref CONFIG: Arc<PKConfig> = {
|
||||
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
|
||||
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
|
||||
std::env::set_var("pluralkit__discord__cluster__node_id", var);
|
||||
}
|
||||
|
||||
Arc::new(Config::builder()
|
||||
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
|
||||
.build().unwrap()
|
||||
.try_deserialize::<PKConfig>().unwrap())
|
||||
};
|
||||
}
|
||||
96
crates/libpk/src/db/mod.rs
Normal file
96
crates/libpk/src/db/mod.rs
Normal file
|
|
@ -0,0 +1,96 @@
|
|||
use fred::clients::RedisPool;
|
||||
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
|
||||
use std::str::FromStr;
|
||||
use tracing::info;
|
||||
|
||||
pub mod repository;
|
||||
pub mod types;
|
||||
|
||||
pub async fn init_redis() -> anyhow::Result<RedisPool> {
|
||||
info!("connecting to redis");
|
||||
let redis = RedisPool::new(
|
||||
fred::types::RedisConfig::from_url_centralized(crate::config.db.data_redis_addr.as_ref())
|
||||
.expect("redis url is invalid"),
|
||||
None,
|
||||
None,
|
||||
Some(Default::default()),
|
||||
10,
|
||||
)?;
|
||||
|
||||
let redis_handle = redis.connect_pool();
|
||||
tokio::spawn(async move { redis_handle });
|
||||
|
||||
Ok(redis)
|
||||
}
|
||||
|
||||
pub async fn init_data_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(&crate::config.db.data_db_uri)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
let mut pool = PgPoolOptions::new();
|
||||
|
||||
if let Some(max_conns) = crate::config.db.data_db_max_connections {
|
||||
pool = pool.max_connections(max_conns);
|
||||
}
|
||||
|
||||
if let Some(min_conns) = crate::config.db.data_db_min_connections {
|
||||
pool = pool.min_connections(min_conns);
|
||||
}
|
||||
|
||||
Ok(pool.connect_with(options).await?)
|
||||
}
|
||||
|
||||
pub async fn init_messages_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to messages database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(
|
||||
&crate::config
|
||||
.db
|
||||
.messages_db_uri
|
||||
.as_ref()
|
||||
.expect("missing messages db uri"),
|
||||
)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
let mut pool = PgPoolOptions::new();
|
||||
|
||||
if let Some(max_conns) = crate::config.db.data_db_max_connections {
|
||||
pool = pool.max_connections(max_conns);
|
||||
}
|
||||
|
||||
if let Some(min_conns) = crate::config.db.data_db_min_connections {
|
||||
pool = pool.min_connections(min_conns);
|
||||
}
|
||||
|
||||
Ok(pool.connect_with(options).await?)
|
||||
}
|
||||
|
||||
pub async fn init_stats_db() -> anyhow::Result<PgPool> {
|
||||
info!("connecting to stats database");
|
||||
|
||||
let mut options = PgConnectOptions::from_str(
|
||||
&crate::config
|
||||
.db
|
||||
.stats_db_uri
|
||||
.as_ref()
|
||||
.expect("missing messages db uri"),
|
||||
)?;
|
||||
|
||||
if let Some(password) = crate::config.db.db_password.clone() {
|
||||
options = options.password(&password);
|
||||
}
|
||||
|
||||
Ok(PgPoolOptions::new()
|
||||
.max_connections(1)
|
||||
.min_connections(1)
|
||||
.connect_with(options)
|
||||
.await?)
|
||||
}
|
||||
20
crates/libpk/src/db/repository/auth.rs
Normal file
20
crates/libpk/src/db/repository/auth.rs
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
pub async fn legacy_token_auth(
|
||||
pool: &sqlx::postgres::PgPool,
|
||||
token: &str,
|
||||
) -> anyhow::Result<Option<i32>> {
|
||||
let mut system: Vec<LegacyTokenDbResponse> =
|
||||
sqlx::query_as("select id from systems where token = $1")
|
||||
.bind(token)
|
||||
.fetch_all(pool)
|
||||
.await?;
|
||||
Ok(if let Some(system) = system.pop() {
|
||||
Some(system.id)
|
||||
} else {
|
||||
None
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct LegacyTokenDbResponse {
|
||||
id: i32,
|
||||
}
|
||||
111
crates/libpk/src/db/repository/avatars.rs
Normal file
111
crates/libpk/src/db/repository/avatars.rs
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
|
||||
use crate::db::types::avatars::*;
|
||||
|
||||
pub async fn get_by_id(pool: &PgPool, id: String) -> anyhow::Result<Option<ImageMeta>> {
|
||||
Ok(sqlx::query_as("select * from images where id = $1")
|
||||
.bind(id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_by_original_url(
|
||||
pool: &PgPool,
|
||||
original_url: &str,
|
||||
) -> anyhow::Result<Option<ImageMeta>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from images where original_url = $1")
|
||||
.bind(original_url)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn get_by_attachment_id(
|
||||
pool: &PgPool,
|
||||
attachment_id: u64,
|
||||
) -> anyhow::Result<Option<ImageMeta>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from images where original_attachment_id = $1")
|
||||
.bind(attachment_id as i64)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow::Result<()> {
|
||||
sqlx::query(
|
||||
r#"
|
||||
delete from image_cleanup_jobs
|
||||
where id in (
|
||||
select id from images
|
||||
where original_attachment_id = $1
|
||||
)
|
||||
"#,
|
||||
)
|
||||
.bind(attachment_id as i64)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn pop_queue(
|
||||
pool: &PgPool,
|
||||
) -> anyhow::Result<Option<(Transaction<Postgres>, ImageQueueEntry)>> {
|
||||
let mut tx = pool.begin().await?;
|
||||
let res: Option<ImageQueueEntry> = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *")
|
||||
.fetch_optional(&mut *tx).await?;
|
||||
Ok(res.map(|x| (tx, x)))
|
||||
}
|
||||
|
||||
pub async fn get_queue_length(pool: &PgPool) -> anyhow::Result<i64> {
|
||||
Ok(sqlx::query_scalar("select count(*) from image_queue")
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_stats(pool: &PgPool) -> anyhow::Result<Stats> {
|
||||
Ok(sqlx::query_as(
|
||||
"select count(*) as total_images, sum(file_size) as total_file_size from images",
|
||||
)
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn add_image(pool: &PgPool, meta: ImageMeta) -> anyhow::Result<bool> {
|
||||
let kind_str = match meta.kind {
|
||||
ImageKind::Avatar => "avatar",
|
||||
ImageKind::Banner => "banner",
|
||||
};
|
||||
|
||||
let res = sqlx::query("insert into images (id, url, content_type, original_url, file_size, width, height, original_file_size, original_type, original_attachment_id, kind, uploaded_by_account, uploaded_by_system, uploaded_at) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, (now() at time zone 'utc')) on conflict (id) do nothing")
|
||||
.bind(meta.id)
|
||||
.bind(meta.url)
|
||||
.bind(meta.content_type)
|
||||
.bind(meta.original_url)
|
||||
.bind(meta.file_size)
|
||||
.bind(meta.width)
|
||||
.bind(meta.height)
|
||||
.bind(meta.original_file_size)
|
||||
.bind(meta.original_type)
|
||||
.bind(meta.original_attachment_id)
|
||||
.bind(kind_str)
|
||||
.bind(meta.uploaded_by_account)
|
||||
.bind(meta.uploaded_by_system)
|
||||
.execute(pool).await?;
|
||||
Ok(res.rows_affected() > 0)
|
||||
}
|
||||
|
||||
pub async fn push_queue(
|
||||
conn: &mut sqlx::PgConnection,
|
||||
url: &str,
|
||||
kind: ImageKind,
|
||||
) -> anyhow::Result<()> {
|
||||
sqlx::query("insert into image_queue (url, kind) values ($1, $2)")
|
||||
.bind(url)
|
||||
.bind(kind)
|
||||
.execute(conn)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
7
crates/libpk/src/db/repository/mod.rs
Normal file
7
crates/libpk/src/db/repository/mod.rs
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
mod stats;
|
||||
pub use stats::*;
|
||||
|
||||
pub mod avatars;
|
||||
|
||||
mod auth;
|
||||
pub use auth::*;
|
||||
26
crates/libpk/src/db/repository/stats.rs
Normal file
26
crates/libpk/src/db/repository/stats.rs
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
pub async fn get_stats(pool: &sqlx::postgres::PgPool) -> anyhow::Result<Counts> {
|
||||
let counts: Counts = sqlx::query_as("select * from info").fetch_one(pool).await?;
|
||||
Ok(counts)
|
||||
}
|
||||
|
||||
pub async fn insert_stats(
|
||||
pool: &sqlx::postgres::PgPool,
|
||||
table: &str,
|
||||
value: i64,
|
||||
) -> anyhow::Result<()> {
|
||||
// danger sql injection
|
||||
sqlx::query(format!("insert into {table} values (now(), $1)").as_str())
|
||||
.bind(value)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, sqlx::FromRow)]
|
||||
pub struct Counts {
|
||||
pub system_count: i64,
|
||||
pub member_count: i64,
|
||||
pub group_count: i64,
|
||||
pub switch_count: i64,
|
||||
pub message_count: i64,
|
||||
}
|
||||
53
crates/libpk/src/db/types/avatars.rs
Normal file
53
crates/libpk/src/db/types/avatars.rs
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::FromRow;
|
||||
use time::OffsetDateTime;
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(FromRow)]
|
||||
pub struct ImageMeta {
|
||||
pub id: String,
|
||||
pub kind: ImageKind,
|
||||
pub content_type: String,
|
||||
pub url: String,
|
||||
pub file_size: i32,
|
||||
pub width: i32,
|
||||
pub height: i32,
|
||||
pub uploaded_at: Option<OffsetDateTime>,
|
||||
|
||||
pub original_url: Option<String>,
|
||||
pub original_attachment_id: Option<i64>,
|
||||
pub original_file_size: Option<i32>,
|
||||
pub original_type: Option<String>,
|
||||
pub uploaded_by_account: Option<i64>,
|
||||
pub uploaded_by_system: Option<Uuid>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct Stats {
|
||||
pub total_images: i64,
|
||||
pub total_file_size: i64,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy, Debug, sqlx::Type, PartialEq)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
#[sqlx(rename_all = "snake_case", type_name = "text")]
|
||||
pub enum ImageKind {
|
||||
Avatar,
|
||||
Banner,
|
||||
}
|
||||
|
||||
impl ImageKind {
|
||||
pub fn size(&self) -> (u32, u32) {
|
||||
match self {
|
||||
Self::Avatar => (512, 512),
|
||||
Self::Banner => (1024, 1024),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromRow)]
|
||||
pub struct ImageQueueEntry {
|
||||
pub itemid: i32,
|
||||
pub url: String,
|
||||
pub kind: ImageKind,
|
||||
}
|
||||
1
crates/libpk/src/db/types/mod.rs
Normal file
1
crates/libpk/src/db/types/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod avatars;
|
||||
81
crates/libpk/src/lib.rs
Normal file
81
crates/libpk/src/lib.rs
Normal file
|
|
@ -0,0 +1,81 @@
|
|||
#![feature(let_chains)]
|
||||
use std::net::SocketAddr;
|
||||
|
||||
use metrics_exporter_prometheus::PrometheusBuilder;
|
||||
use sentry::IntoDsn;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
pub mod db;
|
||||
pub mod state;
|
||||
pub mod util;
|
||||
|
||||
pub mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
// functions in this file are only used by the main function below
|
||||
|
||||
pub fn init_logging(component: &str) -> anyhow::Result<()> {
|
||||
if config.json_log {
|
||||
let mut layer = json_subscriber::layer();
|
||||
layer.inner_layer_mut().add_static_field(
|
||||
"component",
|
||||
serde_json::Value::String(component.to_string()),
|
||||
);
|
||||
tracing_subscriber::registry()
|
||||
.with(layer)
|
||||
.with(EnvFilter::from_default_env())
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_metrics() -> anyhow::Result<()> {
|
||||
if config.run_metrics_server {
|
||||
PrometheusBuilder::new()
|
||||
.with_http_listener("[::]:9000".parse::<SocketAddr>().unwrap())
|
||||
.install()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_sentry() -> sentry::ClientInitGuard {
|
||||
sentry::init(sentry::ClientOptions {
|
||||
dsn: config
|
||||
.sentry_url
|
||||
.clone()
|
||||
.map(|u| u.into_dsn().unwrap())
|
||||
.flatten(),
|
||||
release: sentry::release_name!(),
|
||||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! main {
|
||||
($component:expr) => {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
// we might also be able to use env!("CARGO_CRATE_NAME") here
|
||||
libpk::init_logging($component)?;
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
if let Err(err) = libpk::init_metrics() {
|
||||
tracing::error!("failed to init metrics collector: {err}");
|
||||
};
|
||||
tracing::info!("hello world");
|
||||
if let Err(err) = real_main().await {
|
||||
tracing::error!("failed to run service: {err}");
|
||||
};
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
12
crates/libpk/src/state.rs
Normal file
12
crates/libpk/src/state.rs
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
|
||||
pub struct ShardState {
|
||||
pub shard_id: i32,
|
||||
pub up: bool,
|
||||
pub disconnection_count: i32,
|
||||
/// milliseconds
|
||||
pub latency: i32,
|
||||
/// unix timestamp
|
||||
pub last_heartbeat: i32,
|
||||
pub last_connection: i32,
|
||||
pub cluster_id: Option<i32>,
|
||||
}
|
||||
1
crates/libpk/src/util/mod.rs
Normal file
1
crates/libpk/src/util/mod.rs
Normal file
|
|
@ -0,0 +1 @@
|
|||
pub mod redis;
|
||||
15
crates/libpk/src/util/redis.rs
Normal file
15
crates/libpk/src/util/redis.rs
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
use fred::error::RedisError;
|
||||
|
||||
pub trait RedisErrorExt<T> {
|
||||
fn to_option_or_error(self) -> Result<Option<T>, RedisError>;
|
||||
}
|
||||
|
||||
impl<T> RedisErrorExt<T> for Result<T, RedisError> {
|
||||
fn to_option_or_error(self) -> Result<Option<T>, RedisError> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(error) if error.is_not_found() => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
13
crates/model_macros/Cargo.toml
Normal file
13
crates/model_macros/Cargo.toml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "model_macros"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[lib]
|
||||
proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
quote = "1.0"
|
||||
proc-macro2 = "1.0"
|
||||
syn = "2.0"
|
||||
|
||||
259
crates/model_macros/src/lib.rs
Normal file
259
crates/model_macros/src/lib.rs
Normal file
|
|
@ -0,0 +1,259 @@
|
|||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::quote;
|
||||
use syn::{parse_macro_input, DeriveInput, Expr, Ident, Meta, Type};
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum ElemPatchability {
|
||||
None,
|
||||
Private,
|
||||
Public,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
struct ModelField {
|
||||
name: Ident,
|
||||
ty: Type,
|
||||
patch: ElemPatchability,
|
||||
json: Option<Expr>,
|
||||
is_privacy: bool,
|
||||
default: Option<Expr>,
|
||||
}
|
||||
|
||||
fn parse_field(field: syn::Field) -> ModelField {
|
||||
let mut f = ModelField {
|
||||
name: field.ident.expect("field missing ident"),
|
||||
ty: field.ty,
|
||||
patch: ElemPatchability::None,
|
||||
json: None,
|
||||
is_privacy: false,
|
||||
default: None,
|
||||
};
|
||||
|
||||
for attr in field.attrs.iter() {
|
||||
match &attr.meta {
|
||||
Meta::Path(path) => {
|
||||
let ident = path.segments[0].ident.to_string();
|
||||
match ident.as_str() {
|
||||
"private_patchable" => match f.patch {
|
||||
ElemPatchability::None => {
|
||||
f.patch = ElemPatchability::Private;
|
||||
}
|
||||
_ => {
|
||||
panic!("cannot have multiple patch tags on same field");
|
||||
}
|
||||
},
|
||||
"patchable" => match f.patch {
|
||||
ElemPatchability::None => {
|
||||
f.patch = ElemPatchability::Public;
|
||||
}
|
||||
_ => {
|
||||
panic!("cannot have multiple patch tags on same field");
|
||||
}
|
||||
},
|
||||
"privacy" => f.is_privacy = true,
|
||||
_ => panic!("unknown attribute"),
|
||||
}
|
||||
}
|
||||
Meta::NameValue(nv) => match nv.path.segments[0].ident.to_string().as_str() {
|
||||
"json" => {
|
||||
if f.json.is_some() {
|
||||
panic!("cannot set json multiple times for same field");
|
||||
}
|
||||
f.json = Some(nv.value.clone());
|
||||
}
|
||||
"default" => {
|
||||
if f.default.is_some() {
|
||||
panic!("cannot set default multiple times for same field");
|
||||
}
|
||||
f.default = Some(nv.value.clone());
|
||||
}
|
||||
_ => panic!("unknown attribute"),
|
||||
},
|
||||
Meta::List(_) => panic!("unknown attribute"),
|
||||
}
|
||||
}
|
||||
|
||||
if matches!(f.patch, ElemPatchability::Public) && f.json.is_none() {
|
||||
panic!("must have json name to be publicly patchable");
|
||||
}
|
||||
|
||||
if f.json.is_some() && f.is_privacy {
|
||||
panic!("cannot set custom json name for privacy field");
|
||||
}
|
||||
|
||||
f
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn pk_model(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
let ast = parse_macro_input!(input as DeriveInput);
|
||||
let model_type = match ast.data {
|
||||
syn::Data::Struct(struct_data) => struct_data,
|
||||
_ => panic!("pk_model can only be used on a struct"),
|
||||
};
|
||||
|
||||
let tname = Ident::new(&format!("PK{}", ast.ident), Span::call_site());
|
||||
let patchable_name = Ident::new(&format!("PK{}Patch", ast.ident), Span::call_site());
|
||||
|
||||
let fields = if let syn::Fields::Named(fields) = model_type.fields {
|
||||
fields
|
||||
.named
|
||||
.iter()
|
||||
.map(|f| parse_field(f.clone()))
|
||||
.collect::<Vec<ModelField>>()
|
||||
} else {
|
||||
panic!("fields of a struct must be named");
|
||||
};
|
||||
|
||||
// println!("{}: {:#?}", tname, fields);
|
||||
|
||||
let tfields = mk_tfields(fields.clone());
|
||||
let from_json = mk_tfrom_json(fields.clone());
|
||||
let from_sql = mk_tfrom_sql(fields.clone());
|
||||
let to_json = mk_tto_json(fields.clone());
|
||||
|
||||
let fields: Vec<ModelField> = fields
|
||||
.iter()
|
||||
.filter(|f| !matches!(f.patch, ElemPatchability::None))
|
||||
.cloned()
|
||||
.collect();
|
||||
|
||||
let patch_fields = mk_patch_fields(fields.clone());
|
||||
let patch_from_json = mk_patch_from_json(fields.clone());
|
||||
let patch_validate = mk_patch_validate(fields.clone());
|
||||
let patch_to_json = mk_patch_to_json(fields.clone());
|
||||
let patch_to_sql = mk_patch_to_sql(fields.clone());
|
||||
|
||||
return quote! {
|
||||
#[derive(sqlx::FromRow, Debug, Clone)]
|
||||
pub struct #tname {
|
||||
#tfields
|
||||
}
|
||||
|
||||
impl #tname {
|
||||
pub fn from_json(input: String) -> Self {
|
||||
#from_json
|
||||
}
|
||||
|
||||
pub fn to_json(self) -> serde_json::Value {
|
||||
#to_json
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct #patchable_name {
|
||||
#patch_fields
|
||||
}
|
||||
|
||||
impl #patchable_name {
|
||||
pub fn from_json(input: String) -> Self {
|
||||
#patch_from_json
|
||||
}
|
||||
|
||||
pub fn validate(self) -> bool {
|
||||
#patch_validate
|
||||
}
|
||||
|
||||
pub fn to_sql(self) -> sea_query::UpdateStatement {
|
||||
// sea_query::Query::update()
|
||||
#patch_to_sql
|
||||
}
|
||||
|
||||
pub fn to_json(self) -> serde_json::Value {
|
||||
#patch_to_json
|
||||
}
|
||||
}
|
||||
}
|
||||
.into();
|
||||
}
|
||||
|
||||
fn mk_tfields(fields: Vec<ModelField>) -> TokenStream {
|
||||
fields
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let name = f.name.clone();
|
||||
let ty = f.ty.clone();
|
||||
quote! {
|
||||
pub #name: #ty,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
fn mk_tfrom_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_tfrom_sql(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
// todo: check privacy access
|
||||
let fielddefs: TokenStream = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
f.json.as_ref().map(|v| {
|
||||
let tname = f.name.clone();
|
||||
if let Some(default) = f.default.as_ref() {
|
||||
quote! {
|
||||
#v: self.#tname.unwrap_or(#default),
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#v: self.#tname,
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
let privacyfielddefs: TokenStream = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
if f.is_privacy {
|
||||
let tname = f.name.clone();
|
||||
let tnamestr = f.name.clone().to_string();
|
||||
Some(quote! {
|
||||
#tnamestr: self.#tname,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
quote! {
|
||||
serde_json::json!({
|
||||
#fielddefs
|
||||
"privacy": {
|
||||
#privacyfielddefs
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn mk_patch_fields(fields: Vec<ModelField>) -> TokenStream {
|
||||
fields
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let name = f.name.clone();
|
||||
let ty = f.ty.clone();
|
||||
quote! {
|
||||
pub #name: Option<#ty>,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
fn mk_patch_validate(_fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { true }
|
||||
}
|
||||
fn mk_patch_from_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_patch_to_sql(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_patch_to_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
quote! { unimplemented!(); }
|
||||
}
|
||||
13
crates/models/Cargo.toml
Normal file
13
crates/models/Cargo.toml
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
[package]
|
||||
name = "pluralkit_models"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
model_macros = { path = "../model_macros" }
|
||||
sea-query = "0.32.1"
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true, features = ["preserve_order"] }
|
||||
sqlx = { workspace = true, default-features = false, features = ["chrono"] }
|
||||
uuid = { workspace = true }
|
||||
35
crates/models/src/_util.rs
Normal file
35
crates/models/src/_util.rs
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
// postgres enums created in c# pluralkit implementations are "fake", i.e. they
|
||||
// are actually ints in the database rather than postgres enums, because dapper
|
||||
// does not support postgres enums
|
||||
// here, we add some impls to support this kind of enum in sqlx
|
||||
// there is probably a better way to do this, but works for now.
|
||||
// note: caller needs to implement From<i32> for their type
|
||||
macro_rules! fake_enum_impls {
|
||||
($n:ident) => {
|
||||
impl Type<Postgres> for $n {
|
||||
fn type_info() -> PgTypeInfo {
|
||||
PgTypeInfo::with_name("INT4")
|
||||
}
|
||||
}
|
||||
|
||||
impl From<$n> for i32 {
|
||||
fn from(enum_value: $n) -> Self {
|
||||
enum_value as i32
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, DB: Database> Decode<'r, DB> for $n
|
||||
where
|
||||
i32: Decode<'r, DB>,
|
||||
{
|
||||
fn decode(
|
||||
value: <DB as Database>::ValueRef<'r>,
|
||||
) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
|
||||
let value = <i32 as Decode<DB>>::decode(value)?;
|
||||
Ok(Self::from(value))
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
pub(crate) use fake_enum_impls;
|
||||
11
crates/models/src/lib.rs
Normal file
11
crates/models/src/lib.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
mod _util;
|
||||
|
||||
macro_rules! model {
|
||||
($n:ident) => {
|
||||
mod $n;
|
||||
pub use $n::*;
|
||||
};
|
||||
}
|
||||
|
||||
model!(system);
|
||||
model!(system_config);
|
||||
80
crates/models/src/system.rs
Normal file
80
crates/models/src/system.rs
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
use std::error::Error;
|
||||
|
||||
use model_macros::pk_model;
|
||||
|
||||
use chrono::NaiveDateTime;
|
||||
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::_util::fake_enum_impls;
|
||||
|
||||
// todo: fix this
|
||||
pub type SystemId = i32;
|
||||
|
||||
// todo: move this
|
||||
#[derive(serde::Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PrivacyLevel {
|
||||
Public,
|
||||
Private,
|
||||
}
|
||||
|
||||
fake_enum_impls!(PrivacyLevel);
|
||||
|
||||
impl From<i32> for PrivacyLevel {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
1 => PrivacyLevel::Public,
|
||||
2 => PrivacyLevel::Private,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pk_model]
|
||||
struct System {
|
||||
id: SystemId,
|
||||
#[json = "id"]
|
||||
#[private_patchable]
|
||||
hid: String,
|
||||
#[json = "uuid"]
|
||||
uuid: Uuid,
|
||||
#[json = "name"]
|
||||
name: Option<String>,
|
||||
#[json = "description"]
|
||||
description: Option<String>,
|
||||
#[json = "tag"]
|
||||
tag: Option<String>,
|
||||
#[json = "pronouns"]
|
||||
pronouns: Option<String>,
|
||||
#[json = "avatar_url"]
|
||||
avatar_url: Option<String>,
|
||||
#[json = "banner_image"]
|
||||
banner_image: Option<String>,
|
||||
#[json = "color"]
|
||||
color: Option<String>,
|
||||
token: Option<String>,
|
||||
#[json = "webhook_url"]
|
||||
webhook_url: Option<String>,
|
||||
webhook_token: Option<String>,
|
||||
#[json = "created"]
|
||||
created: NaiveDateTime,
|
||||
#[privacy]
|
||||
name_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
avatar_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
description_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
banner_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
member_list_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
front_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
front_history_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
group_list_privacy: PrivacyLevel,
|
||||
#[privacy]
|
||||
pronoun_privacy: PrivacyLevel,
|
||||
}
|
||||
89
crates/models/src/system_config.rs
Normal file
89
crates/models/src/system_config.rs
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
use model_macros::pk_model;
|
||||
|
||||
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
|
||||
use std::error::Error;
|
||||
|
||||
use crate::{SystemId, _util::fake_enum_impls};
|
||||
|
||||
pub const DEFAULT_MEMBER_LIMIT: i32 = 1000;
|
||||
pub const DEFAULT_GROUP_LIMIT: i32 = 250;
|
||||
|
||||
#[derive(serde::Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum HidPadFormat {
|
||||
#[serde(rename = "off")]
|
||||
None,
|
||||
Left,
|
||||
Right,
|
||||
}
|
||||
fake_enum_impls!(HidPadFormat);
|
||||
|
||||
impl From<i32> for HidPadFormat {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
0 => HidPadFormat::None,
|
||||
1 => HidPadFormat::Left,
|
||||
2 => HidPadFormat::Right,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(serde::Serialize, Debug, Clone)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
enum ProxySwitchAction {
|
||||
Off,
|
||||
New,
|
||||
Add,
|
||||
}
|
||||
fake_enum_impls!(ProxySwitchAction);
|
||||
|
||||
impl From<i32> for ProxySwitchAction {
|
||||
fn from(value: i32) -> Self {
|
||||
match value {
|
||||
0 => ProxySwitchAction::Off,
|
||||
1 => ProxySwitchAction::New,
|
||||
2 => ProxySwitchAction::Add,
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[pk_model]
|
||||
struct SystemConfig {
|
||||
system: SystemId,
|
||||
#[json = "timezone"]
|
||||
ui_tz: String,
|
||||
#[json = "pings_enabled"]
|
||||
pings_enabled: bool,
|
||||
#[json = "latch_timeout"]
|
||||
latch_timeout: Option<i32>,
|
||||
#[json = "member_default_private"]
|
||||
member_default_private: bool,
|
||||
#[json = "group_default_private"]
|
||||
group_default_private: bool,
|
||||
#[json = "show_private_info"]
|
||||
show_private_info: bool,
|
||||
#[json = "member_limit"]
|
||||
#[default = DEFAULT_MEMBER_LIMIT]
|
||||
member_limit_override: Option<i32>,
|
||||
#[json = "group_limit"]
|
||||
#[default = DEFAULT_GROUP_LIMIT]
|
||||
group_limit_override: Option<i32>,
|
||||
#[json = "case_sensitive_proxy_tags"]
|
||||
case_sensitive_proxy_tags: bool,
|
||||
#[json = "proxy_error_message_enabled"]
|
||||
proxy_error_message_enabled: bool,
|
||||
#[json = "hid_display_split"]
|
||||
hid_display_split: bool,
|
||||
#[json = "hid_display_caps"]
|
||||
hid_display_caps: bool,
|
||||
#[json = "hid_list_padding"]
|
||||
hid_list_padding: HidPadFormat,
|
||||
#[json = "proxy_switch"]
|
||||
proxy_switch: ProxySwitchAction,
|
||||
#[json = "name_format"]
|
||||
name_format: String,
|
||||
#[json = "description_templates"]
|
||||
description_templates: Vec<String>,
|
||||
}
|
||||
20
crates/scheduled_tasks/Cargo.toml
Normal file
20
crates/scheduled_tasks/Cargo.toml
Normal file
|
|
@ -0,0 +1,20 @@
|
|||
[package]
|
||||
name = "scheduled_tasks"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
libpk = { path = "../libpk" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
croner = "2.1.0"
|
||||
num-format = "0.4.4"
|
||||
91
crates/scheduled_tasks/src/main.rs
Normal file
91
crates/scheduled_tasks/src/main.rs
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
use chrono::Utc;
|
||||
use croner::Cron;
|
||||
use fred::prelude::RedisPool;
|
||||
use sqlx::PgPool;
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{debug, error, info};
|
||||
|
||||
mod tasks;
|
||||
use tasks::*;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppCtx {
|
||||
pub data: PgPool,
|
||||
pub messages: PgPool,
|
||||
pub stats: PgPool,
|
||||
pub redis: RedisPool,
|
||||
}
|
||||
|
||||
libpk::main!("scheduled_tasks");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let ctx = AppCtx {
|
||||
data: libpk::db::init_data_db().await?,
|
||||
messages: libpk::db::init_messages_db().await?,
|
||||
stats: libpk::db::init_stats_db().await?,
|
||||
redis: libpk::db::init_redis().await?,
|
||||
};
|
||||
|
||||
info!("starting scheduled tasks runner");
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
|
||||
// i couldn't be bothered to figure out the types of passing in an async
|
||||
// function to another function... so macro it is
|
||||
macro_rules! doforever {
|
||||
($cron:expr, $desc:expr, $fn:ident) => {
|
||||
let ctx = ctx.clone();
|
||||
let cron = Cron::new($cron)
|
||||
.with_seconds_optional()
|
||||
.parse()
|
||||
.expect("invalid cron");
|
||||
set.spawn(tokio::spawn(async move {
|
||||
loop {
|
||||
let ctx = ctx.clone();
|
||||
let next_iter_time = cron.find_next_occurrence(&Utc::now(), false).unwrap();
|
||||
debug!("next execution of {} at {:?}", $desc, next_iter_time);
|
||||
let dur = next_iter_time - Utc::now();
|
||||
tokio::time::sleep(dur.to_std().unwrap()).await;
|
||||
|
||||
info!("running {}", $desc);
|
||||
let before = std::time::Instant::now();
|
||||
if let Err(error) = $fn(ctx).await {
|
||||
error!("failed to run {}: {}", $desc, error);
|
||||
// sentry
|
||||
}
|
||||
let duration = before.elapsed();
|
||||
info!("ran {} in {duration:?}", $desc);
|
||||
// add prometheus log
|
||||
}
|
||||
}))
|
||||
};
|
||||
}
|
||||
|
||||
// every 10 seconds
|
||||
doforever!(
|
||||
"0,10,20,30,40,50 * * * * *",
|
||||
"prometheus updater",
|
||||
update_prometheus
|
||||
);
|
||||
// every minute
|
||||
doforever!("* * * * *", "database stats updater", update_db_meta);
|
||||
// every 10 minutes
|
||||
doforever!(
|
||||
"0,10,20,30,40,50 * * * *",
|
||||
"message stats updater",
|
||||
update_db_message_meta
|
||||
);
|
||||
// every minute
|
||||
doforever!("* * * * *", "discord stats updater", update_discord_stats);
|
||||
// on :00 and :30
|
||||
doforever!(
|
||||
"0,30 * * * *",
|
||||
"queue deleted image cleanup job",
|
||||
queue_deleted_image_cleanup
|
||||
);
|
||||
|
||||
set.join_next()
|
||||
.await
|
||||
.ok_or(anyhow::anyhow!("could not join_next"))???;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
151
crates/scheduled_tasks/src/tasks.rs
Normal file
151
crates/scheduled_tasks/src/tasks.rs
Normal file
|
|
@ -0,0 +1,151 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use anyhow::anyhow;
|
||||
use fred::prelude::KeysInterface;
|
||||
use libpk::{
|
||||
config,
|
||||
db::repository::{get_stats, insert_stats},
|
||||
};
|
||||
use metrics::gauge;
|
||||
use num_format::{Locale, ToFormattedString};
|
||||
use reqwest::ClientBuilder;
|
||||
use sqlx::Executor;
|
||||
|
||||
use crate::AppCtx;
|
||||
|
||||
pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> {
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct Count {
|
||||
count: i64,
|
||||
}
|
||||
let count: Count = sqlx::query_as("select count(*) from image_cleanup_jobs")
|
||||
.fetch_one(&ctx.data)
|
||||
.await?;
|
||||
|
||||
gauge!("pluralkit_image_cleanup_queue_length").set(count.count as f64);
|
||||
|
||||
// todo: remaining shard session_start_limit
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_db_meta(ctx: AppCtx) -> anyhow::Result<()> {
|
||||
ctx.data
|
||||
.execute(
|
||||
r#"
|
||||
update info set
|
||||
system_count = (select count(*) from systems),
|
||||
member_count = (select count(*) from systems),
|
||||
group_count = (select count(*) from systems),
|
||||
switch_count = (select count(*) from systems)
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
let new_stats = get_stats(&ctx.data).await?;
|
||||
insert_stats(&ctx.stats, "systems", new_stats.system_count).await?;
|
||||
insert_stats(&ctx.stats, "members", new_stats.member_count).await?;
|
||||
insert_stats(&ctx.stats, "groups", new_stats.group_count).await?;
|
||||
insert_stats(&ctx.stats, "switches", new_stats.switch_count).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_db_message_meta(ctx: AppCtx) -> anyhow::Result<()> {
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct MessageCount {
|
||||
count: i64,
|
||||
}
|
||||
let message_count: MessageCount = sqlx::query_as("select count(*) from messages")
|
||||
.fetch_one(&ctx.messages)
|
||||
.await?;
|
||||
sqlx::query("update info set message_count = $1")
|
||||
.bind(message_count.count)
|
||||
.execute(&ctx.data)
|
||||
.await?;
|
||||
insert_stats(&ctx.stats, "messages", message_count.count).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn update_discord_stats(ctx: AppCtx) -> anyhow::Result<()> {
|
||||
let client = ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(3))
|
||||
.timeout(Duration::from_secs(3))
|
||||
.build()
|
||||
.expect("error making client");
|
||||
|
||||
let cfg = config
|
||||
.scheduled_tasks
|
||||
.as_ref()
|
||||
.expect("missing scheduled_tasks config");
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
struct GatewayStatus {
|
||||
up: bool,
|
||||
guild_count: i64,
|
||||
channel_count: i64,
|
||||
}
|
||||
|
||||
let mut guild_count = 0;
|
||||
let mut channel_count = 0;
|
||||
|
||||
for idx in 0..=cfg.expected_gateway_count {
|
||||
let res = client
|
||||
.get(format!("http://cluster{idx}.{}/stats", cfg.gateway_url))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let stat: GatewayStatus = res.json().await?;
|
||||
|
||||
if !stat.up {
|
||||
return Err(anyhow!("cluster {idx} is not up"));
|
||||
}
|
||||
|
||||
guild_count += stat.guild_count;
|
||||
channel_count += stat.channel_count;
|
||||
}
|
||||
|
||||
insert_stats(&ctx.stats, "guilds", guild_count).await?;
|
||||
insert_stats(&ctx.stats, "channels", channel_count).await?;
|
||||
|
||||
if cfg.set_guild_count {
|
||||
ctx.redis
|
||||
.set::<(), &str, String>(
|
||||
"pluralkit:botstatus",
|
||||
format!(
|
||||
"in {} servers",
|
||||
guild_count.to_formatted_string(&Locale::en)
|
||||
),
|
||||
None,
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn queue_deleted_image_cleanup(ctx: AppCtx) -> anyhow::Result<()> {
|
||||
// todo: we want to delete immediately when system is deleted, but after a
|
||||
// delay if member is deleted
|
||||
ctx.data
|
||||
.execute(
|
||||
r#"
|
||||
insert into image_cleanup_jobs
|
||||
select id, now() from images where
|
||||
not exists (select from image_cleanup_jobs j where j.id = images.id)
|
||||
and not exists (select from systems where avatar_url = images.url)
|
||||
and not exists (select from systems where banner_image = images.url)
|
||||
and not exists (select from system_guild where avatar_url = images.url)
|
||||
|
||||
and not exists (select from members where avatar_url = images.url)
|
||||
and not exists (select from members where banner_image = images.url)
|
||||
and not exists (select from members where webhook_avatar_url = images.url)
|
||||
and not exists (select from member_guild where avatar_url = images.url)
|
||||
|
||||
and not exists (select from groups where icon = images.url)
|
||||
and not exists (select from groups where banner_image = images.url);
|
||||
"#,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue