chore: reorganize rust crates

This commit is contained in:
alyssa 2025-01-02 00:50:36 +00:00
parent 357122a892
commit 16ce67e02c
58 changed files with 6 additions and 13 deletions

28
crates/api/Cargo.toml Normal file
View file

@ -0,0 +1,28 @@
[package]
name = "api"
version = "0.1.0"
edition = "2021"
[dependencies]
pluralkit_models = { path = "../models" }
libpk = { path = "../libpk" }
anyhow = { workspace = true }
axum = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
hyper = { version = "1.3.1", features = ["http1"] }
hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"] }
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
serde_urlencoded = "0.7.1"
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["catch-panic"] }

View file

@ -0,0 +1 @@
pub mod private;

View file

@ -0,0 +1,203 @@
use crate::ApiContext;
use axum::{extract::State, response::Json};
use fred::interfaces::*;
use libpk::state::ShardState;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {
pub guild_count: i32,
pub channel_count: i32,
}
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
let mut shard_status = ctx
.redis
.hgetall::<HashMap<String, String>, &str>("pluralkit:shardstatus")
.await
.unwrap()
.values()
.map(|v| serde_json::from_str(v).expect("could not deserialize shard"))
.collect::<Vec<ShardState>>();
shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id));
Json(json!({
"shards": shard_status,
}))
}
pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
let cluster_stats = ctx
.redis
.hgetall::<HashMap<String, String>, &str>("pluralkit:cluster_stats")
.await
.unwrap()
.values()
.map(|v| serde_json::from_str(v).unwrap())
.collect::<Vec<ClusterStats>>();
let db_stats = libpk::db::repository::get_stats(&ctx.db).await.unwrap();
let guild_count: i32 = cluster_stats.iter().map(|v| v.guild_count).sum();
let channel_count: i32 = cluster_stats.iter().map(|v| v.channel_count).sum();
Json(json!({
"system_count": db_stats.system_count,
"member_count": db_stats.member_count,
"group_count": db_stats.group_count,
"switch_count": db_stats.switch_count,
"message_count": db_stats.message_count,
"guild_count": guild_count,
"channel_count": channel_count,
}))
}
use std::time::Duration;
use crate::util::json_err;
use axum::{
extract,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use libpk::config;
use pluralkit_models::{PKSystem, PKSystemConfig};
use reqwest::ClientBuilder;
#[derive(serde::Deserialize, Debug)]
pub struct CallbackRequestData {
redirect_domain: String,
code: String,
// state: String,
}
#[derive(serde::Serialize)]
struct CallbackDiscordData {
client_id: String,
client_secret: String,
grant_type: String,
redirect_uri: String,
code: String,
}
pub async fn discord_callback(
State(ctx): State<ApiContext>,
extract::Json(request_data): extract::Json<CallbackRequestData>,
) -> Response {
let client = ClientBuilder::new()
.connect_timeout(Duration::from_secs(3))
.timeout(Duration::from_secs(3))
.build()
.expect("error making client");
let reqbody = serde_urlencoded::to_string(&CallbackDiscordData {
client_id: config.discord.as_ref().unwrap().client_id.get().to_string(),
client_secret: config.discord.as_ref().unwrap().client_secret.clone(),
grant_type: "authorization_code".to_string(),
redirect_uri: request_data.redirect_domain, // change this!
code: request_data.code,
})
.expect("could not serialize");
let discord_resp = client
.post("https://discord.com/api/v10/oauth2/token")
.header("content-type", "application/x-www-form-urlencoded")
.body(reqbody)
.send()
.await
.expect("failed to request discord");
let Value::Object(discord_data) = discord_resp
.json::<Value>()
.await
.expect("failed to deserialize discord response as json")
else {
panic!("discord response is not an object")
};
if !discord_data.contains_key("access_token") {
return json_err(
StatusCode::BAD_REQUEST,
format!(
"{{\"error\":\"{}\"\"}}",
discord_data
.get("error_description")
.expect("missing error_description from discord")
.to_string()
),
);
};
let token = format!(
"Bearer {}",
discord_data
.get("access_token")
.expect("missing access_token")
.as_str()
.unwrap()
);
let discord_client = twilight_http::Client::new(token);
let user = discord_client
.current_user()
.await
.expect("failed to get current user from discord")
.model()
.await
.expect("failed to parse user model from discord");
let system: Option<PKSystem> = sqlx::query_as(
r#"
select systems.*
from accounts
left join systems on accounts.system = systems.id
where accounts.uid = $1
"#,
)
.bind(user.id.get() as i64)
.fetch_optional(&ctx.db)
.await
.expect("failed to query");
if system.is_none() {
return json_err(
StatusCode::BAD_REQUEST,
"user does not have a system registered".to_string(),
);
}
let system = system.unwrap();
let system_config: Option<PKSystemConfig> = sqlx::query_as(
r#"
select * from system_config where system = $1
"#,
)
.bind(system.id)
.fetch_optional(&ctx.db)
.await
.expect("failed to query");
let system_config = system_config.unwrap();
// create dashboard token for system
let token = system.clone().token;
(
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"system": system.to_json(),
"config": system_config.to_json(),
"user": user,
"token": token,
}))
.expect("should not error"),
)
.into_response()
}

29
crates/api/src/error.rs Normal file
View file

@ -0,0 +1,29 @@
use axum::http::StatusCode;
use std::fmt;
#[derive(Debug)]
pub struct PKError {
pub response_code: StatusCode,
pub json_code: i32,
pub message: &'static str,
}
impl fmt::Display for PKError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{:?}", self)
}
}
impl std::error::Error for PKError {}
macro_rules! define_error {
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
const $name: PKError = PKError {
response_code: $response_code,
json_code: $json_code,
message: $message,
};
};
}
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }

169
crates/api/src/main.rs Normal file
View file

@ -0,0 +1,169 @@
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
routing::{delete, get, patch, post},
Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use tracing::{error, info};
mod endpoints;
mod error;
mod middleware;
mod util;
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::clients::RedisPool,
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
}
async fn rproxy(
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
Ok(ctx
.rproxy_client
.request(req)
.await
.map_err(|err| {
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
StatusCode::BAD_GATEWAY
})?
.into_response())
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
fn router(ctx: ApiContext) -> Router {
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/v2/systems/:system_id", get(rproxy))
.route("/v2/systems/:system_id", patch(rproxy))
.route("/v2/systems/:system_id/settings", get(rproxy))
.route("/v2/systems/:system_id/settings", patch(rproxy))
.route("/v2/systems/:system_id/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/:member_id", get(rproxy))
.route("/v2/members/:member_id", patch(rproxy))
.route("/v2/members/:member_id", delete(rproxy))
.route("/v2/systems/:system_id/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/:group_id", get(rproxy))
.route("/v2/groups/:group_id", patch(rproxy))
.route("/v2/groups/:group_id", delete(rproxy))
.route("/v2/groups/:group_id/members", get(rproxy))
.route("/v2/groups/:group_id/members/add", post(rproxy))
.route("/v2/groups/:group_id/members/remove", post(rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
.route("/v2/members/:member_id/groups", get(rproxy))
.route("/v2/members/:member_id/groups/add", post(rproxy))
.route("/v2/members/:member_id/groups/remove", post(rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
.route("/v2/systems/:system_id/switches", get(rproxy))
.route("/v2/systems/:system_id/switches", post(rproxy))
.route("/v2/systems/:system_id/fronters", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
.route("/v2/messages/:message_id", get(rproxy))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
.route("/private/discord/callback", post(rproxy))
.route("/private/discord/callback2", post(endpoints::private::discord_callback))
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.layer(axum::middleware::from_fn(middleware::logger))
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
.with_state(ctx)
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
}
libpk::main!("api");
async fn real_main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let rproxy_uri = Uri::from_static(
&libpk::config
.api
.as_ref()
.expect("missing api config")
.remote_url,
)
.to_string();
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let ctx = ApiContext {
db,
redis,
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
};
let app = router(ctx);
let addr: &str = libpk::config
.api
.as_ref()
.expect("missing api config")
.addr
.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,45 @@
use axum::{
extract::{Request, State},
http::HeaderValue,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::ApiContext;
use super::logger::DID_AUTHENTICATE_HEADER;
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
let headers = request.headers_mut();
headers.remove("x-pluralkit-systemid");
let auth_header = headers
.get("authorization")
.map(|h| h.to_str().ok())
.flatten();
let mut authenticated = false;
if let Some(auth_header) = auth_header {
if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
None
}
}
{
headers.append(
"x-pluralkit-systemid",
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
);
authenticated = true;
}
}
let mut response = next.run(request).await;
if authenticated {
response
.headers_mut()
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
}
response
}

View file

@ -0,0 +1,28 @@
use axum::{
extract::Request,
http::{HeaderMap, HeaderValue, Method, StatusCode},
middleware::Next,
response::{IntoResponse, Response},
};
#[rustfmt::skip]
fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}
pub async fn cors(request: Request, next: Next) -> Response {
let mut response = if request.method() == Method::OPTIONS {
StatusCode::OK.into_response()
} else {
next.run(request).await
};
add_cors_headers(response.headers_mut());
response
}

View file

@ -0,0 +1,64 @@
use axum::{
extract::MatchedPath,
extract::Request,
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use crate::util::header_or_unknown;
fn is_trying_to_use_v1_path_on_v2(path: &str) -> bool {
path.starts_with("/v2/s/")
|| path.starts_with("/v2/m/")
|| path.starts_with("/v2/a/")
|| path.starts_with("/v2/msg/")
|| path == "/v2/s"
|| path == "/v2/m"
}
pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
let path = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
if request.uri().path().starts_with("/v1") {
(
StatusCode::GONE,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response()
} else if is_trying_to_use_v1_path_on_v2(request.uri().path()) {
(
StatusCode::BAD_REQUEST,
r#"{"message":"Invalid path for API version","code":0}"#,
)
.into_response()
}
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
else if !request.uri().clone().path().starts_with("/v2")
&& !request.uri().clone().path().starts_with("/private")
{
return (
StatusCode::BAD_REQUEST,
r#"{"message":"Unsupported API version","code":0}"#,
)
.into_response();
} else if path == "unknown" {
// current prod api responds with 404 with empty body to invalid endpoints
// just doing that here as well but i'm not sure if it's the correct behaviour
return StatusCode::NOT_FOUND.into_response();
}
// yes, technically because of how we parse headers this will break for user-agents literally set to "unknown"
// but "unknown" isn't really a valid user-agent
else if user_agent == "unknown" {
// please set a valid user-agent
return StatusCode::BAD_REQUEST.into_response();
} else {
next.run(request).await
}
}

View file

@ -0,0 +1,98 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use crate::util::header_or_unknown;
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_span = span!(
Level::INFO,
"request",
remote_ip,
method = method.as_str(),
endpoint = endpoint.clone(),
user_agent
);
let start = Instant::now();
let mut response = next.run(request).instrument(request_span).await;
let elapsed = start.elapsed().as_millis();
let authenticated = {
let headers = response.headers_mut();
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
headers.remove(DID_AUTHENTICATE_HEADER);
true
} else {
false
}
};
counter!(
"pluralkit_api_requests",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
)
.increment(1);
histogram!(
"pluralkit_api_requests_bucket",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
)
.record(elapsed as f64 / 1_000_f64);
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
endpoint,
elapsed
);
if elapsed > MIN_LOG_TIME {
counter!(
"pluralkit_api_slow_requests_count",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
)
.increment(1);
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

View file

@ -0,0 +1,13 @@
mod cors;
pub use cors::cors;
mod logger;
pub use logger::logger;
mod ignore_invalid_routes;
pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod ratelimit;
mod authnz;
pub use authnz::authnz;

View file

@ -0,0 +1,61 @@
-- this script has side-effects, so it requires replicate commands mode
-- redis.replicate_commands()
local rate_limit_key = KEYS[1]
local rate = ARGV[1]
local period = ARGV[2]
local cost = tonumber(ARGV[3])
local burst = rate
local emission_interval = period / rate
local increment = emission_interval * cost
local burst_offset = emission_interval * burst
-- redis returns time as an array containing two integers: seconds of the epoch
-- time (10 digits) and microseconds (6 digits). for convenience we need to
-- convert them to a floating point number. the resulting number is 16 digits,
-- bordering on the limits of a 64-bit double-precision floating point number.
-- adjust the epoch to be relative to Jan 1, 2017 00:00:00 GMT to avoid floating
-- point problems. this approach is good until "now" is 2,483,228,799 (Wed, 09
-- Sep 2048 01:46:39 GMT), when the adjusted value is 16 digits.
local jan_1_2017 = 1483228800
local now = redis.call("TIME")
now = (now[1] - jan_1_2017) + (now[2] / 1000000)
local tat = redis.call("GET", rate_limit_key)
if not tat then
tat = now
else
tat = tonumber(tat)
end
tat = math.max(tat, now)
local new_tat = tat + increment
local allow_at = new_tat - burst_offset
local diff = now - allow_at
local remaining = diff / emission_interval
if remaining < 0 then
local reset_after = tat - now
local retry_after = diff * -1
return {
0, -- remaining
tostring(retry_after),
reset_after,
}
end
local reset_after = new_tat - now
if reset_after > 0 then
redis.call("SET", rate_limit_key, new_tat, "EX", math.ceil(reset_after))
end
local retry_after = -1
return {
remaining,
tostring(retry_after),
reset_after
}

View file

@ -0,0 +1,238 @@
use std::time::{Duration, SystemTime};
use axum::{
extract::{MatchedPath, Request, State},
http::{HeaderValue, Method, StatusCode},
middleware::{FromFnLayer, Next},
response::Response,
};
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
use metrics::counter;
use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
lazy_static::lazy_static! {
static ref LUA_SCRIPT_SHA: String = sha1_hash(LUA_SCRIPT);
}
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
let redis = libpk::config
.api
.as_ref()
.expect("missing api config")
.ratelimit_redis_addr
.as_ref()
.map(|val| {
// todo: this should probably use the global pool
let r = RedisPool::new(
fred::types::RedisConfig::from_url_centralized(val.as_ref())
.expect("redis url is invalid"),
None,
None,
Some(Default::default()),
10,
)
.expect("failed to connect to redis");
let handle = r.connect();
tokio::spawn(async move { handle });
let rscript = r.clone();
tokio::spawn(async move {
if let Ok(()) = rscript.wait_for_connect().await {
match rscript
.script_load::<String, String>(LUA_SCRIPT.to_string())
.await
{
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
}
} else {
error!("could not wait for connection to load redis script!");
}
});
r
});
if redis.is_none() {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
}
enum RatelimitType {
GenericGet,
GenericUpdate,
Message,
TempCustom,
}
impl RatelimitType {
fn key(&self) -> String {
match self {
RatelimitType::GenericGet => "generic_get",
RatelimitType::GenericUpdate => "generic_update",
RatelimitType::Message => "message",
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
}
.to_string()
}
fn rate(&self) -> i32 {
match self {
RatelimitType::GenericGet => 10,
RatelimitType::GenericUpdate => 3,
RatelimitType::Message => 10,
RatelimitType::TempCustom => 20,
}
}
}
pub async fn do_request_ratelimited(
State(redis): State<Option<RedisPool>>,
request: Request,
next: Next,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
// https://github.com/rust-lang/rust/issues/53667
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if let Some(token2) = &libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
{
if header.to_str().unwrap_or("invalid") == token2 {
true
} else {
false
}
} else {
false
}
} else {
false
};
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
} else if request.method() == Method::GET {
RatelimitType::GenericGet
} else {
RatelimitType::GenericUpdate
};
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
} else {
source_ip
},
rlimit.key()
);
let period = 1; // seconds
let cost = 1; // todo: update this for group member endpoints
// local rate_limit_key = KEYS[1]
// local rate = ARGV[1]
// local period = ARGV[2]
// return {remaining, tostring(retry_after), reset_after}
// todo: check if error is script not found and reload script
let resp = redis
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
LUA_SCRIPT_SHA.to_string(),
vec![rl_key.clone()],
vec![rlimit.rate(), period, cost],
)
.await;
match resp {
Ok((remaining, retry_after, reset_after)) => {
// redis's lua doesn't support returning floats
let retry_after: f64 = retry_after
.parse()
.expect("got something that isn't a f64 from redis");
let mut response = if remaining > 0 {
next.run(request).await
} else {
let retry_after = (retry_after * 1_000_f64).ceil() as u64;
debug!("ratelimited request from {rl_key}, retry_after={retry_after}",);
counter!("pk_http_requests_ratelimited").increment(1);
json_err(
StatusCode::TOO_MANY_REQUESTS,
format!(
r#"{{"message":"429: too many requests","retry_after":{retry_after},"scope":"{}","code":0}}"#,
rlimit.key(),
),
)
};
let reset_time = SystemTime::now()
.checked_add(Duration::from_secs(reset_after))
.expect("invalid timestamp")
.duration_since(std::time::UNIX_EPOCH)
.expect("invalid duration")
.as_secs();
let headers = response.headers_mut();
headers.insert(
"X-RateLimit-Scope",
HeaderValue::from_str(rlimit.key().as_str()).expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Limit",
HeaderValue::from_str(format!("{}", rlimit.rate()).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Remaining",
HeaderValue::from_str(format!("{}", remaining).as_str())
.expect("invalid header value"),
);
headers.insert(
"X-RateLimit-Reset",
HeaderValue::from_str(format!("{}", reset_time).as_str())
.expect("invalid header value"),
);
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
);
}
}
}
next.run(request).await
}

63
crates/api/src/util.rs Normal file
View file

@ -0,0 +1,63 @@
use crate::error::PKError;
use axum::{
http::{HeaderValue, StatusCode},
response::IntoResponse,
};
use serde_json::{json, to_string, Value};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
if let Some(value) = header {
match value.to_str() {
Ok(v) => v,
Err(err) => {
error!("failed to parse header value {:#?}: {:#?}", value, err);
"failed to parse"
}
}
} else {
"unknown"
}
}
pub fn wrapper<F>(handler: F) -> impl Fn() -> axum::response::Response
where
F: Fn() -> anyhow::Result<Value>,
{
move || match handler() {
Ok(v) => (StatusCode::OK, to_string(&v).unwrap()).into_response(),
Err(error) => match error.downcast_ref::<PKError>() {
Some(pkerror) => json_err(
pkerror.response_code,
to_string(&json!({ "message": pkerror.message, "code": pkerror.json_code }))
.unwrap(),
),
None => {
error!(
"error in handler {}: {:#?}",
std::any::type_name::<F>(),
error
);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
},
}
}
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!("caught panic from handler: {:#?}", err);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();
headers.insert("content-type", HeaderValue::from_static("application/json"));
response
}

30
crates/avatars/Cargo.toml Normal file
View file

@ -0,0 +1,30 @@
[package]
name = "avatars"
version = "0.1.0"
edition = "2021"
[[bin]]
name = "avatar_cleanup"
path = "src/cleanup.rs"
[dependencies]
libpk = { path = "../libpk" }
anyhow = { workspace = true }
axum = { workspace = true }
futures = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
sqlx = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
uuid = { workspace = true }
data-encoding = "2.5.0"
gif = "0.13.1"
image = { version = "0.24.8", default-features = false, features = ["gif", "jpeg", "png", "webp", "tiff"] }
form_urlencoded = "1.2.1"
rust-s3 = { version = "0.33.0", default-features = false, features = ["tokio-rustls-tls"] }
sha2 = "0.10.8"
thiserror = "1.0.56"
webp = "0.2.6"

View file

@ -0,0 +1,146 @@
use anyhow::Context;
use reqwest::{ClientBuilder, StatusCode};
use sqlx::prelude::FromRow;
use std::{sync::Arc, time::Duration};
use tracing::{error, info};
libpk::main!("avatar_cleanup");
async fn real_main() -> anyhow::Result<()> {
let config = libpk::config
.avatars
.as_ref()
.expect("missing avatar service config");
let bucket = {
let region = s3::Region::Custom {
region: "s3".to_string(),
endpoint: config.s3.endpoint.to_string(),
};
let credentials = s3::creds::Credentials::new(
Some(&config.s3.application_id),
Some(&config.s3.application_key),
None,
None,
None,
)
.unwrap();
let bucket = s3::Bucket::new(&config.s3.bucket, region, credentials)?;
Arc::new(bucket)
};
let pool = libpk::db::init_data_db().await?;
loop {
// no infinite loops
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
match cleanup_job(pool.clone(), bucket.clone()).await {
Ok(()) => {}
Err(err) => {
error!("failed to run avatar cleanup job: {}", err);
// sentry
}
}
}
}
#[derive(FromRow)]
struct CleanupJobEntry {
id: String,
}
async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Result<()> {
let mut tx = pool.begin().await?;
let image_id: Option<CleanupJobEntry> = sqlx::query_as(
r#"
select id from image_cleanup_jobs
where ts < now() - interval '1 day'
for update skip locked limit 1;"#,
)
.fetch_optional(&mut *tx)
.await?;
if image_id.is_none() {
info!("no job to run, sleeping for 1 minute");
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
return Ok(());
}
let image_id = image_id.unwrap().id;
info!("got image {image_id}, cleaning up...");
let image_data = libpk::db::repository::avatars::get_by_id(&pool, image_id.clone()).await?;
if image_data.is_none() {
info!("image {image_id} was already deleted, skipping");
sqlx::query("delete from image_cleanup_jobs where id = $1")
.bind(image_id)
.execute(&mut *tx)
.await?;
return Ok(());
}
let image_data = image_data.unwrap();
let config = libpk::config
.avatars
.as_ref()
.expect("missing avatar service config");
let path = image_data
.url
.strip_prefix(config.cdn_url.as_str())
.unwrap();
let s3_resp = bucket.delete_object(path).await?;
match s3_resp.status_code() {
204 => {
info!("successfully deleted image {image_id} from s3");
}
_ => {
anyhow::bail!("s3 returned bad error code {}", s3_resp.status_code());
}
}
if let Some(zone_id) = config.cloudflare_zone_id.as_ref() {
let client = ClientBuilder::new()
.connect_timeout(Duration::from_secs(3))
.timeout(Duration::from_secs(3))
.build()
.context("error making client")?;
let cf_resp = client
.post(format!(
"https://api.cloudflare.com/client/v4/zones/{zone_id}/purge_cache"
))
.header(
"Authorization",
format!("Bearer {}", config.cloudflare_token.as_ref().unwrap()),
)
.body(format!(r#"{{"files":["{}"]}}"#, image_data.url))
.send()
.await?;
match cf_resp.status() {
StatusCode::OK => {
info!(
"successfully purged url {} from cloudflare cache",
image_data.url
);
}
_ => {
let status = cf_resp.status();
tracing::info!("raw response from cloudflare: {:#?}", cf_resp.text().await?);
anyhow::bail!("cloudflare returned bad error code {}", status);
}
}
}
sqlx::query("delete from images where id = $1")
.bind(image_id.clone())
.execute(&mut *tx)
.await?;
tx.commit().await?;
Ok(())
}

View file

@ -0,0 +1,21 @@
use std::fmt::Display;
use sha2::{Digest, Sha256};
#[derive(Debug)]
pub struct Hash([u8; 32]);
impl Hash {
pub fn sha256(data: &[u8]) -> Hash {
let mut hasher = Sha256::new();
hasher.update(data);
Hash(hasher.finalize().into())
}
}
impl Display for Hash {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let encoding = data_encoding::BASE32_NOPAD;
write!(f, "{}", encoding.encode(&self.0[..16]).to_lowercase())
}
}

View file

@ -0,0 +1,26 @@
create table if not exists images
(
id text primary key,
url text not null,
original_url text,
original_file_size int,
original_type text,
original_attachment_id bigint,
file_size int not null,
width int not null,
height int not null,
kind text not null,
uploaded_at timestamptz not null,
uploaded_by_account bigint
);
create index if not exists images_original_url_idx on images (original_url);
create index if not exists images_original_attachment_id_idx on images (original_attachment_id);
create index if not exists images_uploaded_by_account_idx on images (uploaded_by_account);
create table if not exists image_queue (itemid serial primary key, url text not null, kind text not null);
alter table images add column if not exists uploaded_by_system uuid;
alter table images add column if not exists content_type text default 'image/webp';
create table image_cleanup_jobs(id text references images(id) on delete cascade);

257
crates/avatars/src/main.rs Normal file
View file

@ -0,0 +1,257 @@
mod hash;
mod migrate;
mod process;
mod pull;
mod store;
use anyhow::Context;
use axum::extract::State;
use axum::routing::get;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
use libpk::_config::AvatarsConfig;
use libpk::db::repository::avatars as db;
use libpk::db::types::avatars::*;
use reqwest::{Client, ClientBuilder};
use serde::{Deserialize, Serialize};
use sqlx::PgPool;
use std::error::Error;
use std::sync::Arc;
use std::time::Duration;
use thiserror::Error;
use tracing::{error, info};
use uuid::Uuid;
#[derive(Error, Debug)]
pub enum PKAvatarError {
// todo: split off into logical groups (cdn/url error, image format error, etc)
#[error("invalid cdn url")]
InvalidCdnUrl,
#[error("discord cdn responded with status code: {0}")]
BadCdnResponse(reqwest::StatusCode),
#[error("network error: {0}")]
NetworkError(reqwest::Error),
#[error("response is missing header: {0}")]
MissingHeader(&'static str),
#[error("unsupported content type: {0}")]
UnsupportedContentType(String),
#[error("image file size too large ({0} > {1})")]
ImageFileSizeTooLarge(u64, u64),
#[error("unsupported image format: {0:?}")]
UnsupportedImageFormat(image::ImageFormat),
#[error("could not detect image format")]
UnknownImageFormat,
#[error("original image dimensions too large: {0:?} > {1:?}")]
ImageDimensionsTooLarge((u32, u32), (u32, u32)),
#[error("could not decode image, is it corrupted?")]
ImageFormatError(#[from] image::ImageError),
#[error("unknown error")]
InternalError(#[from] anyhow::Error),
}
#[derive(Deserialize, Debug)]
pub struct PullRequest {
url: String,
kind: ImageKind,
uploaded_by: Option<u64>, // should be String? serde makes this hard :/
system_id: Option<Uuid>,
#[serde(default)]
force: bool,
}
#[derive(Serialize)]
pub struct PullResponse {
url: String,
new: bool,
}
async fn pull(
State(state): State<AppState>,
Json(req): Json<PullRequest>,
) -> Result<Json<PullResponse>, PKAvatarError> {
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
if !req.force {
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
// remove any pending image cleanup
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
return Ok(Json(PullResponse {
url: existing.url,
new: false,
}));
}
}
let result = crate::pull::pull(state.pull_client, &parsed).await?;
let original_file_size = result.data.len();
let encoded = process::process_async(result.data, req.kind).await?;
let store_res = crate::store::store(&state.bucket, &encoded).await?;
let final_url = format!("{}{}", state.config.cdn_url, store_res.path);
let is_new = db::add_image(
&state.pool,
ImageMeta {
id: store_res.id,
url: final_url.clone(),
content_type: encoded.format.mime_type().to_string(),
original_url: Some(parsed.full_url),
original_type: Some(result.content_type),
original_file_size: Some(original_file_size as i32),
original_attachment_id: Some(parsed.attachment_id as i64),
file_size: encoded.data.len() as i32,
width: encoded.width as i32,
height: encoded.height as i32,
kind: req.kind,
uploaded_at: None,
uploaded_by_account: req.uploaded_by.map(|x| x as i64),
uploaded_by_system: req.system_id,
},
)
.await?;
Ok(Json(PullResponse {
url: final_url,
new: is_new,
}))
}
pub async fn stats(State(state): State<AppState>) -> Result<Json<Stats>, PKAvatarError> {
Ok(Json(db::get_stats(&state.pool).await?))
}
#[derive(Clone)]
pub struct AppState {
bucket: Arc<s3::Bucket>,
pull_client: Arc<Client>,
pool: PgPool,
config: Arc<AvatarsConfig>,
}
libpk::main!("avatars");
async fn real_main() -> anyhow::Result<()> {
let config = libpk::config
.avatars
.as_ref()
.expect("missing avatar service config");
let bucket = {
let region = s3::Region::Custom {
region: "s3".to_string(),
endpoint: config.s3.endpoint.to_string(),
};
let credentials = s3::creds::Credentials::new(
Some(&config.s3.application_id),
Some(&config.s3.application_key),
None,
None,
None,
)
.unwrap();
let bucket = s3::Bucket::new(&config.s3.bucket, region, credentials)?;
Arc::new(bucket)
};
let pull_client = Arc::new(
ClientBuilder::new()
.connect_timeout(Duration::from_secs(3))
.timeout(Duration::from_secs(3))
.user_agent("PluralKit-Avatars/0.1")
.build()
.context("error making client")?,
);
let pool = libpk::db::init_data_db().await?;
let state = AppState {
bucket,
pull_client,
pool,
config: Arc::new(config.clone()),
};
// migrations are done, disable this
// migrate::spawn_migrate_workers(Arc::new(state.clone()), state.config.migrate_worker_count);
let app = Router::new()
.route("/pull", post(pull))
.route("/stats", get(stats))
.with_state(state);
let host = &config.bind_addr;
info!("starting server on {}!", host);
let listener = tokio::net::TcpListener::bind(host).await.unwrap();
axum::serve(listener, app).await.unwrap();
Ok(())
}
struct AppError(anyhow::Error);
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
error!("error handling request: {}", self.0);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: self.0.to_string(),
}),
)
.into_response()
}
}
impl IntoResponse for PKAvatarError {
fn into_response(self) -> Response {
let status_code = match self {
PKAvatarError::InternalError(_) | PKAvatarError::NetworkError(_) => {
StatusCode::INTERNAL_SERVER_ERROR
}
_ => StatusCode::BAD_REQUEST,
};
// print inner error if otherwise hidden
error!("error: {}", self.source().unwrap_or(&self));
(
status_code,
Json(ErrorResponse {
error: self.to_string(),
}),
)
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

View file

@ -0,0 +1,146 @@
use crate::pull::parse_url;
use crate::{db, process, AppState, PKAvatarError};
use libpk::db::types::avatars::{ImageMeta, ImageQueueEntry};
use reqwest::StatusCode;
use std::error::Error;
use std::sync::Arc;
use std::time::Duration;
use time::Instant;
use tokio::sync::Semaphore;
use tracing::{error, info, instrument, warn};
static PROCESS_SEMAPHORE: Semaphore = Semaphore::const_new(100);
pub async fn handle_item_inner(
state: &AppState,
item: &ImageQueueEntry,
) -> Result<(), PKAvatarError> {
let parsed = parse_url(&item.url).map_err(|_| PKAvatarError::InvalidCdnUrl)?;
if let Some(_) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
info!(
"attachment {} already migrated, skipping",
parsed.attachment_id
);
return Ok(());
}
let pulled = crate::pull::pull(state.pull_client.clone(), &parsed).await?;
let data_len = pulled.data.len();
let encoded = {
// Trying to reduce CPU load/potentially blocking the worker by adding a bottleneck on parallel encodes
// no semaphore on the main api though, that one should ideally be low latency
// todo: configurable?
let time_before_semaphore = Instant::now();
let permit = PROCESS_SEMAPHORE
.acquire()
.await
.map_err(|e| PKAvatarError::InternalError(e.into()))?;
let time_after_semaphore = Instant::now();
let semaphore_time = time_after_semaphore - time_before_semaphore;
if semaphore_time.whole_milliseconds() > 100 {
warn!(
"waited more than {} ms for process semaphore",
semaphore_time.whole_milliseconds()
);
}
let encoded = process::process_async(pulled.data, item.kind).await?;
drop(permit);
encoded
};
let store_res = crate::store::store(&state.bucket, &encoded).await?;
let final_url = format!("{}{}", state.config.cdn_url, store_res.path);
db::add_image(
&state.pool,
ImageMeta {
id: store_res.id,
url: final_url.clone(),
content_type: encoded.format.mime_type().to_string(),
original_url: Some(parsed.full_url),
original_type: Some(pulled.content_type),
original_file_size: Some(data_len as i32),
original_attachment_id: Some(parsed.attachment_id as i64),
file_size: encoded.data.len() as i32,
width: encoded.width as i32,
height: encoded.height as i32,
kind: item.kind,
uploaded_at: None,
uploaded_by_account: None,
uploaded_by_system: None,
},
)
.await?;
info!(
"migrated {} ({}k -> {}k)",
final_url,
data_len,
encoded.data.len()
);
Ok(())
}
pub async fn handle_item(state: &AppState) -> Result<(), PKAvatarError> {
// let queue_length = db::get_queue_length(&state.pool).await?;
// info!("migrate queue length: {}", queue_length);
if let Some((mut tx, item)) = db::pop_queue(&state.pool).await? {
match handle_item_inner(state, &item).await {
Ok(_) => {
tx.commit().await.map_err(Into::<anyhow::Error>::into)?;
Ok(())
}
Err(
// Errors that mean the image can't be migrated and doesn't need to be retried
e @ (PKAvatarError::ImageDimensionsTooLarge(_, _)
| PKAvatarError::UnknownImageFormat
| PKAvatarError::UnsupportedImageFormat(_)
| PKAvatarError::UnsupportedContentType(_)
| PKAvatarError::ImageFileSizeTooLarge(_, _)
| PKAvatarError::InvalidCdnUrl
| PKAvatarError::BadCdnResponse(StatusCode::NOT_FOUND | StatusCode::FORBIDDEN)),
) => {
warn!("error migrating {}, skipping: {}", item.url, e);
tx.commit().await.map_err(Into::<anyhow::Error>::into)?;
Ok(())
}
Err(e @ PKAvatarError::ImageFormatError(_)) => {
// will add this item back to the end of the queue
db::push_queue(&mut *tx, &item.url, item.kind).await?;
tx.commit().await.map_err(Into::<anyhow::Error>::into)?;
Err(e)
}
Err(e) => Err(e),
}
} else {
tokio::time::sleep(Duration::from_secs(5)).await;
Ok(())
}
}
#[instrument(skip(state))]
pub async fn worker(worker_id: u32, state: Arc<AppState>) {
info!("spawned migrate worker with id {}", worker_id);
loop {
match handle_item(&state).await {
Ok(()) => {}
Err(e) => {
error!(
"error in migrate worker {}: {}",
worker_id,
e.source().unwrap_or(&e)
);
tokio::time::sleep(Duration::from_secs(5)).await;
}
}
}
}
pub fn spawn_migrate_workers(state: Arc<AppState>, count: u32) {
for i in 0..count {
tokio::spawn(worker(i, state.clone()));
}
}

View file

@ -0,0 +1,257 @@
use image::{DynamicImage, ImageFormat};
use std::borrow::Cow;
use std::io::Cursor;
use time::Instant;
use tracing::{debug, error, info, instrument};
use crate::{hash::Hash, ImageKind, PKAvatarError};
const MAX_DIMENSION: u32 = 4000;
pub struct ProcessOutput {
pub width: u32,
pub height: u32,
pub hash: Hash,
pub format: ProcessedFormat,
pub data: Vec<u8>,
}
#[derive(Copy, Clone, Debug)]
pub enum ProcessedFormat {
Webp,
Gif,
}
impl ProcessedFormat {
pub fn mime_type(&self) -> &'static str {
match self {
ProcessedFormat::Gif => "image/gif",
ProcessedFormat::Webp => "image/webp",
}
}
pub fn extension(&self) -> &'static str {
match self {
ProcessedFormat::Webp => "webp",
ProcessedFormat::Gif => "gif",
}
}
}
// Moving Vec<u8> in here since the thread needs ownership of it now, it's fine, don't need it after
pub async fn process_async(data: Vec<u8>, kind: ImageKind) -> Result<ProcessOutput, PKAvatarError> {
tokio::task::spawn_blocking(move || process(&data, kind))
.await
.map_err(|je| PKAvatarError::InternalError(je.into()))?
}
#[instrument(skip_all)]
pub fn process(data: &[u8], kind: ImageKind) -> Result<ProcessOutput, PKAvatarError> {
let time_before = Instant::now();
let reader = reader_for(data);
match reader.format() {
Some(ImageFormat::Png | ImageFormat::WebP | ImageFormat::Jpeg | ImageFormat::Tiff) => {} // ok :)
Some(ImageFormat::Gif) => {
// animated gifs will need to be handled totally differently
// so split off processing here and come back if it's not applicable
// (non-banner gifs + 1-frame animated gifs still need to be webp'd)
if let Some(output) = process_gif(data, kind)? {
return Ok(output);
}
}
Some(other) => return Err(PKAvatarError::UnsupportedImageFormat(other)),
None => return Err(PKAvatarError::UnknownImageFormat),
}
// want to check dimensions *before* decoding so we don't accidentally end up with a memory bomb
// eg. a 16000x16000 png file is only 31kb and expands to almost a gig of memory
let (width, height) = assert_dimensions(reader.into_dimensions()?)?;
// need to make a new reader??? why can't it just use the same one. reduce duplication?
let reader = reader_for(data);
let time_after_parse = Instant::now();
// apparently `image` sometimes decodes webp images wrong/weird.
// see: https://discord.com/channels/466707357099884544/667795132971614229/1209925940835262464
// instead, for webp, we use libwebp itself to decode, as well.
// (pls no cve)
let image = if reader.format() == Some(ImageFormat::WebP) {
let webp_image = webp::Decoder::new(data).decode().ok_or_else(|| {
PKAvatarError::InternalError(anyhow::anyhow!("webp decode failed").into())
})?;
webp_image.to_image()
} else {
reader.decode().map_err(|e| {
// print the ugly error, return the nice error
error!("error decoding image: {}", e);
PKAvatarError::ImageFormatError(e)
})?
};
let time_after_decode = Instant::now();
let image = resize(image, kind);
let time_after_resize = Instant::now();
let encoded = encode(image);
let time_after = Instant::now();
info!(
"{}: lossy size {}K (parse: {} ms, decode: {} ms, resize: {} ms, encode: {} ms)",
encoded.hash,
encoded.data.len() / 1024,
(time_after_parse - time_before).whole_milliseconds(),
(time_after_decode - time_after_parse).whole_milliseconds(),
(time_after_resize - time_after_decode).whole_milliseconds(),
(time_after - time_after_resize).whole_milliseconds(),
);
debug!(
"processed image {}: {} bytes, {}x{} -> {} bytes, {}x{}",
encoded.hash,
data.len(),
width,
height,
encoded.data.len(),
encoded.width,
encoded.height
);
Ok(encoded)
}
fn assert_dimensions((width, height): (u32, u32)) -> Result<(u32, u32), PKAvatarError> {
if width > MAX_DIMENSION || height > MAX_DIMENSION {
return Err(PKAvatarError::ImageDimensionsTooLarge(
(width, height),
(MAX_DIMENSION, MAX_DIMENSION),
));
}
return Ok((width, height));
}
fn process_gif(input_data: &[u8], kind: ImageKind) -> Result<Option<ProcessOutput>, PKAvatarError> {
// gifs only supported for banners
if kind != ImageKind::Banner {
return Ok(None);
}
// and we can't rescale gifs (i tried :/) so the max size is the real limit
if kind != ImageKind::Banner {
return Ok(None);
}
let reader = gif::Decoder::new(Cursor::new(input_data)).map_err(Into::<anyhow::Error>::into)?;
let (max_width, max_height) = kind.size();
if reader.width() as u32 > max_width || reader.height() as u32 > max_height {
return Err(PKAvatarError::ImageDimensionsTooLarge(
(reader.width() as u32, reader.height() as u32),
(max_width, max_height),
));
}
Ok(process_gif_inner(reader).map_err(Into::<anyhow::Error>::into)?)
}
fn process_gif_inner(
mut reader: gif::Decoder<Cursor<&[u8]>>,
) -> Result<Option<ProcessOutput>, anyhow::Error> {
let time_before = Instant::now();
let (width, height) = (reader.width(), reader.height());
let mut writer = gif::Encoder::new(
Vec::new(),
width as u16,
height as u16,
reader.global_palette().unwrap_or(&[]),
)?;
writer.set_repeat(reader.repeat())?;
let mut frame_buf = Vec::new();
let mut frame_count = 0;
while let Some(frame) = reader.next_frame_info()? {
let mut frame = frame.clone();
assert_dimensions((frame.width as u32, frame.height as u32))?;
frame_buf.clear();
frame_buf.resize(reader.buffer_size(), 0);
reader.read_into_buffer(&mut frame_buf)?;
frame.buffer = Cow::Borrowed(&frame_buf);
frame.make_lzw_pre_encoded();
writer.write_lzw_pre_encoded_frame(&frame)?;
frame_count += 1;
}
if frame_count == 1 {
// If there's only one frame, then this doesn't need to be a gif. webp it
// (unfortunately we can't tell if there's only one frame until after the first frame's been decoded...)
return Ok(None);
}
let data = writer.into_inner()?;
let time_after = Instant::now();
let hash = Hash::sha256(&data);
let original_data = reader.into_inner();
info!(
"processed gif {}: {}K -> {}K ({} ms, frames: {})",
hash,
original_data.buffer().len() / 1024,
data.len() / 1024,
(time_after - time_before).whole_milliseconds(),
frame_count
);
Ok(Some(ProcessOutput {
data,
format: ProcessedFormat::Gif,
hash,
width: width as u32,
height: height as u32,
}))
}
fn reader_for(data: &[u8]) -> image::io::Reader<Cursor<&[u8]>> {
image::io::Reader::new(Cursor::new(data))
.with_guessed_format()
.expect("cursor i/o is infallible")
}
#[instrument(skip_all)]
fn resize(image: DynamicImage, kind: ImageKind) -> DynamicImage {
let (target_width, target_height) = kind.size();
if image.width() <= target_width && image.height() <= target_height {
// don't resize if already smaller
return image;
}
// todo: best filter?
let resized = image.resize(
target_width,
target_height,
image::imageops::FilterType::Lanczos3,
);
return resized;
}
#[instrument(skip_all)]
// can't believe this is infallible
fn encode(image: DynamicImage) -> ProcessOutput {
let (width, height) = (image.width(), image.height());
let image_buf = image.to_rgba8();
let encoded_lossy = webp::Encoder::new(&*image_buf, webp::PixelLayout::Rgba, width, height)
.encode_simple(false, 90.0)
.expect("encode should be infallible")
.to_vec();
let hash = Hash::sha256(&encoded_lossy);
ProcessOutput {
data: encoded_lossy,
format: ProcessedFormat::Webp,
hash,
width,
height,
}
}

166
crates/avatars/src/pull.rs Normal file
View file

@ -0,0 +1,166 @@
use std::time::Duration;
use std::{str::FromStr, sync::Arc};
use crate::PKAvatarError;
use anyhow::Context;
use reqwest::{Client, ClientBuilder, StatusCode, Url};
use time::Instant;
use tracing::{error, instrument};
const MAX_SIZE: u64 = 8 * 1024 * 1024;
pub struct PullResult {
pub data: Vec<u8>,
pub content_type: String,
pub last_modified: Option<String>,
}
#[instrument(skip_all)]
pub async fn pull(
client: Arc<Client>,
parsed_url: &ParsedUrl,
) -> Result<PullResult, PKAvatarError> {
let time_before = Instant::now();
let mut trimmed_url = trim_url_query(&parsed_url.full_url)?;
if trimmed_url.host_str() == Some("media.discordapp.net") {
trimmed_url
.set_host(Some("cdn.discordapp.com"))
.expect("set_host should not fail");
}
let response = client.get(trimmed_url.clone()).send().await.map_err(|e| {
error!("network error for {}: {}", parsed_url.full_url, e);
PKAvatarError::NetworkError(e)
})?;
let time_after_headers = Instant::now();
let status = response.status();
if status != StatusCode::OK {
return Err(PKAvatarError::BadCdnResponse(status));
}
let size = match response.content_length() {
None => return Err(PKAvatarError::MissingHeader("Content-Length")),
Some(size) if size > MAX_SIZE => {
return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE))
}
Some(size) => size,
};
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|x| x.to_str().ok()) // invalid (non-unicode) header = missing, why not
.map(|mime| mime.split(';').next().unwrap_or("")) // cut off at ;
.ok_or(PKAvatarError::MissingHeader("Content-Type"))?
.to_owned();
let mime = match content_type.as_str() {
mime @ ("image/jpeg" | "image/png" | "image/gif" | "image/webp" | "image/tiff") => mime,
_ => return Err(PKAvatarError::UnsupportedContentType(content_type)),
};
let last_modified = response
.headers()
.get(reqwest::header::LAST_MODIFIED)
.and_then(|x| x.to_str().ok())
.map(|x| x.to_string());
let body = response.bytes().await.map_err(|e| {
error!("network error for {}: {}", parsed_url.full_url, e);
PKAvatarError::NetworkError(e)
})?;
if body.len() != size as usize {
// ???does this ever happen?
return Err(PKAvatarError::InternalError(anyhow::anyhow!(
"server responded with wrong length"
)));
}
let time_after_body = Instant::now();
let headers_time = time_after_headers - time_before;
let body_time = time_after_body - time_after_headers;
// can't do dynamic log level lmao
if status != StatusCode::OK {
tracing::warn!(
"{}: {} (headers: {}ms, body: {}ms)",
status,
&trimmed_url,
headers_time.whole_milliseconds(),
body_time.whole_milliseconds()
);
} else {
tracing::info!(
"{}: {} (headers: {}ms, body: {}ms)",
status,
&trimmed_url,
headers_time.whole_milliseconds(),
body_time.whole_milliseconds()
);
};
Ok(PullResult {
data: body.to_vec(),
content_type: mime.to_string(),
last_modified,
})
}
#[derive(Debug)]
pub struct ParsedUrl {
pub channel_id: u64,
pub attachment_id: u64,
pub filename: String,
pub full_url: String,
}
pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
// todo: should this return PKAvatarError::InvalidCdnUrl?
let url = Url::from_str(url).context("invalid url")?;
match (url.scheme(), url.domain()) {
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
_ => anyhow::bail!("not a discord cdn url"),
}
match url
.path_segments()
.map(|x| x.collect::<Vec<_>>())
.as_deref()
{
Some([_, channel_id, attachment_id, filename]) => {
let channel_id = u64::from_str(channel_id).context("invalid channel id")?;
let attachment_id = u64::from_str(attachment_id).context("invalid channel id")?;
Ok(ParsedUrl {
channel_id,
attachment_id,
filename: filename.to_string(),
full_url: url.to_string(),
})
}
_ => anyhow::bail!("invaild discord cdn url"),
}
}
fn trim_url_query(url: &str) -> anyhow::Result<Url> {
let mut parsed = Url::parse(url)?;
let mut qs = form_urlencoded::Serializer::new(String::new());
for (key, value) in parsed.query_pairs() {
match key.as_ref() {
"ex" | "is" | "hm" => {
qs.append_pair(key.as_ref(), value.as_ref());
}
_ => {}
}
}
let new_query = qs.finish();
parsed.set_query(if new_query.len() > 0 {
Some(&new_query)
} else {
None
});
Ok(parsed)
}

View file

@ -0,0 +1,60 @@
use crate::process::ProcessOutput;
use tracing::error;
pub struct StoreResult {
pub id: String,
pub path: String,
}
pub async fn store(bucket: &s3::Bucket, res: &ProcessOutput) -> anyhow::Result<StoreResult> {
// errors here are all going to be internal
let encoded_hash = res.hash.to_string();
let path = format!(
"images/{}/{}.{}",
&encoded_hash[..2],
&encoded_hash[2..],
res.format.extension()
);
// todo: something better than these retries
let mut retry_count = 0;
loop {
if retry_count == 2 {
tokio::time::sleep(tokio::time::Duration::new(2, 0)).await;
}
if retry_count > 2 {
anyhow::bail!("error uploading image to cdn, too many retries") // nicer user-facing error?
}
retry_count += 1;
let resp = bucket
.put_object_with_content_type(&path, &res.data, res.format.mime_type())
.await?;
match resp.status_code() {
200 => {
tracing::debug!("uploaded image to {}", &path);
return Ok(StoreResult {
id: encoded_hash,
path,
});
}
500 | 503 => {
tracing::warn!(
"got 503 uploading image to {} ({}), retrying... (try {}/3)",
&path,
resp.as_str()?,
retry_count
);
continue;
}
_ => {
error!(
"storage backend responded status code {}",
resp.status_code()
);
anyhow::bail!("error uploading image to cdn") // nicer user-facing error?
}
}
}
}

View file

@ -0,0 +1,16 @@
[package]
name = "dispatch"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
hickory-client = "0.24.1"

View file

@ -0,0 +1,52 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

192
crates/dispatch/src/main.rs Normal file
View file

@ -0,0 +1,192 @@
#![feature(ip)]
use hickory_client::{
client::{AsyncClient, ClientHandle},
rr::{DNSClass, Name, RData, RecordType},
udp::UdpClientStream,
};
use reqwest::{redirect::Policy, StatusCode};
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use tokio::{net::UdpSocket, sync::RwLock};
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;
use axum::{extract::State, http::Uri, routing::post, Json, Router};
mod logger;
// this package does not currently use libpk
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.json()
.with_env_filter(EnvFilter::from_default_env())
.init();
info!("hello world");
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
let (client, bg) = AsyncClient::connect(stream).await?;
tokio::spawn(bg);
let app = Router::new()
.route("/", post(dispatch))
.with_state(Arc::new(RwLock::new(DNSClient(client))))
.layer(axum::middleware::from_fn(logger::logger));
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await?;
axum::serve(listener, app).await?;
Ok(())
}
#[derive(Debug, serde::Deserialize)]
struct DispatchRequest {
auth: String,
url: String,
payload: String,
test: Option<String>,
}
#[derive(Debug)]
enum DispatchResponse {
OK,
BadData,
ResolveFailed,
NoIPs,
InvalidIP,
FetchFailed,
InvalidResponseCode(StatusCode),
TestFailed,
}
impl std::fmt::Display for DispatchResponse {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
async fn dispatch(
// not entirely sure if this RwLock is the right way to do it
State(dns): State<Arc<RwLock<DNSClient>>>,
Json(req): Json<DispatchRequest>,
) -> String {
// todo: fix
if req.auth != std::env::var("HTTP_AUTH_TOKEN").unwrap() {
return "".to_string();
}
let uri = match req.url.parse::<Uri>() {
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
Err(error) => {
error!(?error, "failed to parse uri {}", req.url);
return DispatchResponse::BadData.to_string();
}
_ => {
error!("uri {} is invalid", req.url);
return DispatchResponse::BadData.to_string();
}
};
let ips = {
let mut dns = dns.write().await;
match dns.resolve(uri.host().unwrap().to_string()).await {
Ok(v) => v,
Err(error) => {
error!(?error, "failed to resolve");
return DispatchResponse::ResolveFailed.to_string();
}
}
};
if ips.iter().any(|ip| !ip.is_global()) {
return DispatchResponse::InvalidIP.to_string();
}
if ips.len() == 0 {
return DispatchResponse::NoIPs.to_string();
}
let ips: Vec<SocketAddr> = ips
.iter()
.map(|ip| SocketAddr::V4(SocketAddrV4::new(*ip, 443)))
.collect();
let client = reqwest::ClientBuilder::new()
.user_agent("PluralKit Dispatch (https://pluralkit.me/api/dispatch/)")
.redirect(Policy::none())
.timeout(Duration::from_secs(10))
.http1_only()
.use_rustls_tls()
.https_only(true)
.resolve_to_addrs(uri.host().unwrap(), &ips)
.build()
.unwrap();
let res = client
.post(req.url.clone())
.header("content-type", "application/json")
.body(req.payload)
.send()
.await;
match res {
Ok(res) if res.status() != 200 => {
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
}
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");
return DispatchResponse::FetchFailed.to_string();
}
_ => {}
}
if let Some(test) = req.test {
let test_res = client
.post(req.url.clone())
.header("content-type", "application/json")
.body(test)
.send()
.await;
match test_res {
Ok(res) if res.status() != 401 => return DispatchResponse::TestFailed.to_string(),
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");
return DispatchResponse::FetchFailed.to_string();
}
_ => {}
}
}
DispatchResponse::OK.to_string()
}
struct DNSClient(AsyncClient);
impl DNSClient {
async fn resolve(&mut self, host: String) -> anyhow::Result<Vec<Ipv4Addr>> {
let resp = self
.0
.query(Name::from_ascii(host)?, DNSClass::IN, RecordType::A)
.await?;
debug!("got dns response: {resp:?}");
Ok(resp
.answers()
.iter()
.filter_map(|ans| {
if let Some(RData::A(val)) = ans.data() {
Some(val.0)
} else {
None
}
})
.collect())
}
}

27
crates/gateway/Cargo.toml Normal file
View file

@ -0,0 +1,27 @@
[package]
name = "gateway"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
futures = { workspace = true }
lazy_static = { workspace = true }
libpk = { path = "../libpk" }
metrics = { workspace = true }
serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-gateway = { workspace = true }
twilight-cache-inmemory = { workspace = true }
twilight-util = { workspace = true }
twilight-model = { workspace = true }
twilight-http = { workspace = true }
serde_variant = "0.1.3"

View file

@ -0,0 +1,183 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Router,
};
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::Id;
use crate::discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
};
use std::sync::Arc;
fn status_code(code: StatusCode, body: String) -> Response {
(code, body).into_response()
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
let app = Router::new()
.route(
"/guilds/:guild_id",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild(Id::new(guild_id)) {
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/members/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
Ok(val) => {
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
},
Err(err) => {
error!(?err, ?guild_id, "failed to get own guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/permissions/:user_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
Err(err) => {
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
Some(channels) => channels.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut channels = Vec::new();
for id in channel_ids {
match cache.0.channel(id) {
Some(channel) => channels.push(channel.to_owned()),
None => {
tracing::error!(
channel_id = id.get(),
"referenced channel {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
})
)
.route(
"/guilds/:guild_id/channels/:channel_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
}
match cache.0.channel(Id::new(channel_id)) {
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string())
}
})
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
}
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
Err(err) => {
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
get(|| async { "todo" }),
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
)
.route(
"/guilds/:guild_id/roles",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
Some(roles) => roles.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut roles = Vec::new();
for id in role_ids {
match cache.0.role(id) {
Some(role) => roles.push(role.value().resource().to_owned()),
None => {
tracing::error!(
role_id = id.get(),
"referenced role {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
})
)
.route("/stats", get(|State(cache): State<Arc<DiscordCache>>| async move {
let cluster = cluster_config();
let has_been_up = cache.2.read().await.len() as u32 == if cluster.total_shards > 16 {16} else {cluster.total_shards};
let stats = cache.0.stats();
let stats = json!({
"guild_count": stats.guilds(),
"channel_count": stats.channels(),
// just put this here until prom stats
"unavailable_guild_count": stats.unavailable_guilds(),
"up": has_been_up,
});
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
}))
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,368 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
model::CachedMember,
permission::{MemberRoles, RootError},
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
Id,
},
};
use twilight_util::permission_calculator::PermissionCalculator;
lazy_static! {
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
| Permissions::SEND_MESSAGES
| Permissions::READ_MESSAGE_HISTORY
| Permissions::ADD_REACTIONS
| Permissions::ATTACH_FILES
| Permissions::EMBED_LINKS
| Permissions::USE_EXTERNAL_EMOJIS
| Permissions::CONNECT
| Permissions::SPEAK
| Permissions::USE_VAD;
}
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
Channel {
id,
kind: ChannelType::Private,
application_id: None,
applied_tags: None,
available_tags: None,
bitrate: None,
default_auto_archive_duration: None,
default_forum_layout: None,
default_reaction_emoji: None,
default_sort_order: None,
default_thread_rate_limit_per_user: None,
flags: None,
guild_id: None,
icon: None,
invitable: None,
last_message_id: None,
last_pin_timestamp: None,
managed: None,
member: None,
member_count: None,
message_count: None,
name: None,
newly_created: None,
nsfw: None,
owner_id: None,
parent_id: None,
permission_overwrites: None,
position: None,
rate_limit_per_user: None,
recipients: None,
rtc_region: None,
thread_metadata: None,
topic: None,
user_limit: None,
video_quality_mode: None,
}
}
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
CachedMember {
avatar: item.avatar,
communication_disabled_until: item.communication_disabled_until,
deaf: Some(item.deaf),
flags: item.flags,
joined_at: item.joined_at,
mute: Some(item.mute),
nick: item.nick,
premium_since: item.premium_since,
roles: item.roles,
pending: false,
user_id: id,
}
}
pub fn new() -> DiscordCache {
let mut client_builder = twilight_http::Client::builder().token(
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_token
.clone(),
);
if let Some(base_url) = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.api_base_url
.clone()
{
client_builder = client_builder.proxy(base_url, true);
}
let client = Arc::new(client_builder.build());
let cache = Arc::new(
InMemoryCache::builder()
.resource_types(
ResourceType::GUILD
| ResourceType::CHANNEL
| ResourceType::ROLE
| ResourceType::USER_CURRENT
| ResourceType::MEMBER_CURRENT,
)
.message_cache_size(0)
.build(),
);
DiscordCache(cache, client, RwLock::new(Vec::new()))
}
pub struct DiscordCache(
pub Arc<InMemoryCache>,
pub Arc<twilight_http::Client>,
pub RwLock<Vec<u32>>,
);
impl DiscordCache {
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
if self
.0
.guild(guild_id)
.ok_or_else(|| format_err!("guild not found"))?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id
== libpk::config
.discord
.as_ref()
.expect("missing discord config")
.client_id
{
self.0
.member(guild_id, user_id)
.ok_or(format_err!("self member not found"))?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.root();
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
pub async fn channel_permissions(
&self,
channel_id: Id<ChannelMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
let channel = self
.0
.channel(channel_id)
.ok_or(format_err!("channel not found"))?;
if channel.value().guild_id.is_none() {
return Ok(*DM_PERMISSIONS);
}
let guild_id = channel.value().guild_id.unwrap();
if self
.0
.guild(guild_id)
.ok_or_else(|| {
tracing::error!(
channel_id = channel_id.get(),
guild_id = guild_id.get(),
"referenced guild from cached channel {channel_id} not found in cache"
);
format_err!("internal cache error")
})?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id
== libpk::config
.discord
.as_ref()
.expect("missing discord config")
.client_id
{
self.0
.member(guild_id, user_id)
.ok_or_else(|| {
tracing::error!(
guild_id = guild_id.get(),
"self member for cached guild {guild_id} not found in cache"
);
format_err!("internal cache error")
})?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let overwrites = match channel.kind {
ChannelType::AnnouncementThread
| ChannelType::PrivateThread
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
_ => channel
.value()
.permission_overwrites()
.unwrap_or_default()
.to_vec(),
};
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
self.0.guild(id).map(|guild| {
let channels = self
.0
.guild_channels(id)
.map(|reference| {
reference
.iter()
.filter_map(|channel_id| {
let channel = self.0.channel(*channel_id)?;
if channel.kind.is_thread() {
None
} else {
Some(channel.value().clone())
}
})
.collect()
})
.unwrap_or_default();
let roles = self
.0
.guild_roles(id)
.map(|reference| {
reference
.iter()
.filter_map(|role_id| {
Some(self.0.role(*role_id)?.value().resource().clone())
})
.collect()
})
.unwrap_or_default();
Guild {
afk_channel_id: guild.afk_channel_id(),
afk_timeout: guild.afk_timeout(),
application_id: guild.application_id(),
approximate_member_count: None, // Only present in with_counts HTTP endpoint
banner: guild.banner().map(ToOwned::to_owned),
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
channels,
default_message_notifications: guild.default_message_notifications(),
description: guild.description().map(ToString::to_string),
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
emojis: vec![],
explicit_content_filter: guild.explicit_content_filter(),
features: guild.features().cloned().collect(),
icon: guild.icon().map(ToOwned::to_owned),
id: guild.id(),
joined_at: guild.joined_at(),
large: guild.large(),
max_members: guild.max_members(),
max_presences: guild.max_presences(),
max_video_channel_users: guild.max_video_channel_users(),
member_count: guild.member_count(),
members: vec![],
mfa_level: guild.mfa_level(),
name: guild.name().to_string(),
nsfw_level: guild.nsfw_level(),
owner_id: guild.owner_id(),
owner: guild.owner(),
permissions: guild.permissions(),
public_updates_channel_id: guild.public_updates_channel_id(),
preferred_locale: guild.preferred_locale().to_string(),
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
premium_subscription_count: guild.premium_subscription_count(),
premium_tier: guild.premium_tier(),
presences: vec![],
roles,
rules_channel_id: guild.rules_channel_id(),
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
splash: guild.splash().map(ToOwned::to_owned),
stage_instances: vec![],
stickers: vec![],
system_channel_flags: guild.system_channel_flags(),
system_channel_id: guild.system_channel_id(),
threads: vec![],
unavailable: false,
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
verification_level: guild.verification_level(),
voice_states: vec![],
widget_channel_id: guild.widget_channel_id(),
widget_enabled: guild.widget_enabled(),
}
})
}
}

View file

@ -0,0 +1,200 @@
use futures::StreamExt;
use libpk::_config::ClusterSettings;
use metrics::counter;
use std::sync::{mpsc::Sender, Arc};
use tracing::{error, info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
};
use twilight_model::gateway::{
payload::outgoing::update_presence::UpdatePresencePayload,
presence::{Activity, ActivityType, Status},
Intents,
};
use crate::discord::identify_queue::{self, RedisQueue};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
pub fn cluster_config() -> ClusterSettings {
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.cluster
.clone()
.unwrap_or(libpk::_config::ClusterSettings {
node_id: 0,
total_shards: 1,
total_nodes: 1,
})
}
pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
let intents = Intents::GUILDS
| Intents::DIRECT_MESSAGES
| Intents::DIRECT_MESSAGE_REACTIONS
| Intents::GUILD_MESSAGES
| Intents::GUILD_MESSAGE_REACTIONS
| Intents::MESSAGE_CONTENT;
let queue = identify_queue::new(redis);
let cluster_settings = cluster_config();
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
warn!("we have less than 16 shards, assuming single gateway process");
(0, (cluster_settings.total_shards - 1).into())
} else {
(
(cluster_settings.node_id * 16).into(),
(((cluster_settings.node_id + 1) * 16) - 1).into(),
)
};
let shards = create_iterator(
start_shard..end_shard + 1,
cluster_settings.total_shards,
ConfigBuilder::new(
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_token
.to_owned(),
intents,
)
.presence(presence("pk;help", false))
.queue(queue.clone())
.build(),
|_, builder| builder.build(),
);
let mut shards_vec = Vec::new();
shards_vec.extend(shards);
Ok(shards_vec)
}
pub async fn runner(
mut shard: Shard<RedisQueue>,
_tx: Sender<(ShardId, String)>,
shard_state: ShardStateManager,
cache: Arc<DiscordCache>,
) {
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
info!("waiting for events");
while let Some(item) = shard.next().await {
let raw_event = match item {
Ok(evt) => match evt {
Message::Close(frame) => {
info!(
"shard {} closed: {}",
shard.id().number(),
if let Some(close) = frame {
format!("{} ({})", close.code, close.reason)
} else {
"unknown".to_string()
}
);
if let Err(error) = shard_state.socket_closed(shard.id().number()).await {
error!("failed to update shard state for socket closure: {error}");
}
continue;
}
Message::Text(text) => text,
},
Err(error) => {
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
continue;
}
};
let event = match twilight_gateway::parse(raw_event.clone(), EventTypeFlags::all()) {
Ok(Some(parsed)) => Event::from(parsed),
Ok(None) => {
// we received an event type unknown to twilight
// that's fine, we probably don't need it anyway
continue;
}
Err(error) => {
error!(
"shard {} failed to parse gateway event: {}",
shard.id().number(),
error
);
continue;
}
};
// log the event in metrics
// event_type * shard_id is too many labels and prometheus fails to query it
// so we split it into two metrics
counter!(
"pluralkit_gateway_events_type",
"event_type" => serde_variant::to_variant_name(&event.kind()).unwrap(),
)
.increment(1);
counter!(
"pluralkit_gateway_events_shard",
"shard_id" => shard.id().number().to_string(),
)
.increment(1);
// update shard state and discord cache
if let Err(error) = shard_state
.handle_event(shard.id().number(), event.clone())
.await
{
tracing::warn!(?error, "error updating redis state");
}
// need to do heartbeat separately, to get the latency
if let Event::GatewayHeartbeatAck = event
&& let Err(error) = shard_state
.heartbeated(shard.id().number(), shard.latency())
.await
{
tracing::warn!(?error, "error updating redis state for latency");
}
if let Event::Ready(_) = event {
if !cache.2.read().await.contains(&shard.id().number()) {
cache.2.write().await.push(shard.id().number());
}
}
cache.0.update(&event);
// okay, we've handled the event internally, let's send it to consumers
// tx.send((shard.id(), raw_event)).unwrap();
}
}
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
UpdatePresencePayload {
activities: vec![Activity {
application_id: None,
assets: None,
buttons: vec![],
created_at: None,
details: None,
id: None,
state: None,
url: None,
emoji: None,
flags: None,
instance: None,
kind: ActivityType::Playing,
name: status.to_string(),
party: None,
secrets: None,
timestamps: None,
}],
afk: false,
since: None,
status: if going_away {
Status::Idle
} else {
Status::Online
},
}
}

View file

@ -0,0 +1,88 @@
use fred::{
clients::RedisPool,
error::RedisError,
interfaces::KeysInterface,
types::{Expiration, SetOptions},
};
use std::fmt::Debug;
use std::time::Duration;
use tokio::sync::oneshot;
use tracing::{error, info};
use twilight_gateway::queue::Queue;
pub fn new(redis: RedisPool) -> RedisQueue {
RedisQueue {
redis,
concurrency: libpk::config
.discord
.as_ref()
.expect("missing discord config")
.max_concurrency,
}
}
#[derive(Clone)]
pub struct RedisQueue {
pub redis: RedisPool,
pub concurrency: u32,
}
impl Debug for RedisQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisQueue")
.field("concurrency", &self.concurrency)
.finish()
}
}
impl Queue for RedisQueue {
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
tokio::spawn(request_inner(
self.redis.clone(),
self.concurrency,
shard_id,
tx,
));
rx
}
}
const EXPIRY: i64 = 6;
const RETRY_INTERVAL: u64 = 500;
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
let bucket = shard_id % concurrency;
let key = format!("pluralkit:identify:{}", bucket);
info!(shard_id, bucket, "waiting for allowance...");
loop {
let done: Result<Option<String>, RedisError> = redis
.set(
key.to_string(),
"1",
Some(Expiration::EX(EXPIRY)),
Some(SetOptions::NX),
false,
)
.await;
match done {
Ok(Some(_)) => {
info!(shard_id, bucket, "got allowance!");
// if this fails, it's probably already doing something else
let _ = tx.send(());
return;
}
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
}
}
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
}
}

View file

@ -0,0 +1,4 @@
pub mod cache;
pub mod gateway;
pub mod identify_queue;
pub mod shard_state;

View file

@ -0,0 +1,91 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge};
use tracing::info;
use twilight_gateway::{Event, Latency};
use libpk::{state::*, util::redis::*};
#[derive(Clone)]
pub struct ShardStateManager {
redis: RedisPool,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
_ => Ok(()),
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<String> = self
.redis
.hget("pluralkit:shardstatus", shard_id)
.await
.to_option_or_error()?;
match data {
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
None => Ok(ShardState::default()),
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
self.redis
.hset::<(), &str, (String, String)>(
"pluralkit:shardstatus",
(
shard_id.to_string(),
serde_json::to_string(&info).expect("could not serialize shard"),
),
)
.await?;
Ok(())
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!(
"shard {} {}",
shard_id,
if resumed { "resumed" } else { "ready" }
);
counter!(
"pluralkit_gateway_shard_reconnect",
"shard_id" => shard_id.to_string(),
"resumed" => resumed.to_string(),
)
.increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1);
let mut info = self.get_shard(shard_id).await?;
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let mut info = self.get_shard(shard_id).await?;
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
self.save_shard(shard_id, info).await?;
Ok(())
}
}

View file

@ -0,0 +1,72 @@
use std::time::Instant;
use axum::{
extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response,
};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
counter!(
"pluralkit_gateway_cache_api_requests",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
)
.increment(1);
histogram!(
"pluralkit_gateway_cache_api_requests_bucket",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
)
.record(elapsed as f64 / 1_000_f64);
if response.status() != StatusCode::FOUND {
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
}
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

140
crates/gateway/src/main.rs Normal file
View file

@ -0,0 +1,140 @@
#![feature(let_chains)]
#![feature(if_let_guard)]
use chrono::Timelike;
use fred::{clients::RedisPool, interfaces::*};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
};
use std::{
sync::{mpsc::channel, Arc},
time::Duration,
vec::Vec,
};
use tokio::task::JoinSet;
use tracing::{info, warn};
use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence;
mod cache_api;
mod discord;
mod logger;
libpk::main!("gateway");
async fn real_main() -> anyhow::Result<()> {
let (shutdown_tx, shutdown_rx) = channel::<()>();
let shutdown_tx = Arc::new(shutdown_tx);
let redis = libpk::db::init_redis().await?;
let shard_state = discord::shard_state::new(redis.clone());
let cache = Arc::new(discord::cache::new());
let shards = discord::gateway::create_shards(redis.clone())?;
let (event_tx, _event_rx) = channel();
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
let mut set = JoinSet::new();
for shard in shards {
senders.push((shard.id(), shard.sender()));
signal_senders.push(shard.sender());
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
cache.clone(),
)));
}
set.spawn(tokio::spawn(
async move { scheduled_task(redis, senders).await },
));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache).await {
Err(error) => {
tracing::error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
}
_ => unreachable!(),
}
}));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
set.spawn(tokio::spawn(async move {
for sig in signals.forever() {
info!("received signal {:?}", sig);
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(());
break;
}
}));
let _ = shutdown_rx.recv();
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
set.abort_all();
info!("gateway exiting, have a nice day!");
Ok(())
}
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
loop {
tokio::time::sleep(Duration::from_secs(
(60 - chrono::offset::Utc::now().second()).into(),
))
.await;
info!("running per-minute scheduled tasks");
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
Ok(val) => Some(val),
Err(error) => {
tracing::warn!(?error, "failed to fetch bot status from redis");
None
}
};
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence(
if let Some(status) = status {
format!("pk;help | {}", status)
} else {
"pk;help".to_string()
}
.as_str(),
false,
),
};
for sender in senders.iter() {
match sender.1.command(&presence) {
Err(error) => {
warn!(?error, "could not update presence on shard {}", sender.0)
}
_ => {}
};
}
}
}

24
crates/libpk/Cargo.toml Normal file
View file

@ -0,0 +1,24 @@
[package]
name = "libpk"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
sentry = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
time = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true}
twilight-model = { workspace = true }
uuid = { workspace = true }
config = "0.14.0"
json-subscriber = { version = "0.2.2", features = ["env-filter"] }
metrics-exporter-prometheus = { version = "0.15.3", default-features = false, features = ["tokio", "http-listener", "tracing"] }

145
crates/libpk/src/_config.rs Normal file
View file

@ -0,0 +1,145 @@
use config::Config;
use lazy_static::lazy_static;
use serde::Deserialize;
use std::sync::Arc;
use twilight_model::id::{marker::UserMarker, Id};
#[derive(Clone, Deserialize, Debug)]
pub struct ClusterSettings {
pub node_id: u32,
pub total_shards: u32,
pub total_nodes: u32,
}
#[derive(Deserialize, Debug)]
pub struct DiscordConfig {
pub client_id: Id<UserMarker>,
pub bot_token: String,
pub client_secret: String,
pub max_concurrency: u32,
#[serde(default)]
pub cluster: Option<ClusterSettings>,
pub api_base_url: Option<String>,
#[serde(default = "_default_api_addr")]
pub cache_api_addr: String,
}
#[derive(Deserialize, Debug)]
pub struct DatabaseConfig {
pub(crate) data_db_uri: String,
pub(crate) data_db_max_connections: Option<u32>,
pub(crate) data_db_min_connections: Option<u32>,
pub(crate) messages_db_uri: Option<String>,
pub(crate) stats_db_uri: Option<String>,
pub(crate) db_password: Option<String>,
pub data_redis_addr: String,
}
fn _default_api_addr() -> String {
"[::]:5000".to_string()
}
#[derive(Deserialize, Clone, Debug)]
pub struct ApiConfig {
#[serde(default = "_default_api_addr")]
pub addr: String,
#[serde(default)]
pub ratelimit_redis_addr: Option<String>,
pub remote_url: String,
#[serde(default)]
pub temp_token2: Option<String>,
}
#[derive(Deserialize, Clone, Debug)]
pub struct AvatarsConfig {
pub s3: S3Config,
pub cdn_url: String,
#[serde(default = "_default_api_addr")]
pub bind_addr: String,
#[serde(default)]
pub migrate_worker_count: u32,
#[serde(default)]
pub cloudflare_zone_id: Option<String>,
#[serde(default)]
pub cloudflare_token: Option<String>,
}
#[derive(Deserialize, Clone, Debug)]
pub struct S3Config {
pub bucket: String,
pub application_id: String,
pub application_key: String,
pub endpoint: String,
}
#[derive(Deserialize, Debug)]
pub struct ScheduledTasksConfig {
pub set_guild_count: bool,
pub expected_gateway_count: usize,
pub gateway_url: String,
}
fn _metrics_default() -> bool {
false
}
fn _json_log_default() -> bool {
false
}
#[derive(Deserialize, Debug)]
pub struct PKConfig {
pub db: DatabaseConfig,
#[serde(default)]
pub discord: Option<DiscordConfig>,
#[serde(default)]
pub api: Option<ApiConfig>,
#[serde(default)]
pub avatars: Option<AvatarsConfig>,
#[serde(default)]
pub scheduled_tasks: Option<ScheduledTasksConfig>,
#[serde(default = "_metrics_default")]
pub run_metrics_server: bool,
#[serde(default = "_json_log_default")]
pub(crate) json_log: bool,
#[serde(default)]
pub sentry_url: Option<String>,
}
impl PKConfig {
pub fn api(self) -> ApiConfig {
self.api.expect("missing api config")
}
pub fn discord_config(self) -> DiscordConfig {
self.discord.expect("missing discord config")
}
}
// todo: consider passing this down instead of making it global
// especially since we have optional discord/api/avatars/etc config
lazy_static! {
#[derive(Debug)]
pub static ref CONFIG: Arc<PKConfig> = {
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
}
Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
.build().unwrap()
.try_deserialize::<PKConfig>().unwrap())
};
}

View file

@ -0,0 +1,96 @@
use fred::clients::RedisPool;
use sqlx::postgres::{PgConnectOptions, PgPool, PgPoolOptions};
use std::str::FromStr;
use tracing::info;
pub mod repository;
pub mod types;
pub async fn init_redis() -> anyhow::Result<RedisPool> {
info!("connecting to redis");
let redis = RedisPool::new(
fred::types::RedisConfig::from_url_centralized(crate::config.db.data_redis_addr.as_ref())
.expect("redis url is invalid"),
None,
None,
Some(Default::default()),
10,
)?;
let redis_handle = redis.connect_pool();
tokio::spawn(async move { redis_handle });
Ok(redis)
}
pub async fn init_data_db() -> anyhow::Result<PgPool> {
info!("connecting to database");
let mut options = PgConnectOptions::from_str(&crate::config.db.data_db_uri)?;
if let Some(password) = crate::config.db.db_password.clone() {
options = options.password(&password);
}
let mut pool = PgPoolOptions::new();
if let Some(max_conns) = crate::config.db.data_db_max_connections {
pool = pool.max_connections(max_conns);
}
if let Some(min_conns) = crate::config.db.data_db_min_connections {
pool = pool.min_connections(min_conns);
}
Ok(pool.connect_with(options).await?)
}
pub async fn init_messages_db() -> anyhow::Result<PgPool> {
info!("connecting to messages database");
let mut options = PgConnectOptions::from_str(
&crate::config
.db
.messages_db_uri
.as_ref()
.expect("missing messages db uri"),
)?;
if let Some(password) = crate::config.db.db_password.clone() {
options = options.password(&password);
}
let mut pool = PgPoolOptions::new();
if let Some(max_conns) = crate::config.db.data_db_max_connections {
pool = pool.max_connections(max_conns);
}
if let Some(min_conns) = crate::config.db.data_db_min_connections {
pool = pool.min_connections(min_conns);
}
Ok(pool.connect_with(options).await?)
}
pub async fn init_stats_db() -> anyhow::Result<PgPool> {
info!("connecting to stats database");
let mut options = PgConnectOptions::from_str(
&crate::config
.db
.stats_db_uri
.as_ref()
.expect("missing messages db uri"),
)?;
if let Some(password) = crate::config.db.db_password.clone() {
options = options.password(&password);
}
Ok(PgPoolOptions::new()
.max_connections(1)
.min_connections(1)
.connect_with(options)
.await?)
}

View file

@ -0,0 +1,20 @@
pub async fn legacy_token_auth(
pool: &sqlx::postgres::PgPool,
token: &str,
) -> anyhow::Result<Option<i32>> {
let mut system: Vec<LegacyTokenDbResponse> =
sqlx::query_as("select id from systems where token = $1")
.bind(token)
.fetch_all(pool)
.await?;
Ok(if let Some(system) = system.pop() {
Some(system.id)
} else {
None
})
}
#[derive(sqlx::FromRow)]
struct LegacyTokenDbResponse {
id: i32,
}

View file

@ -0,0 +1,111 @@
use sqlx::{PgPool, Postgres, Transaction};
use crate::db::types::avatars::*;
pub async fn get_by_id(pool: &PgPool, id: String) -> anyhow::Result<Option<ImageMeta>> {
Ok(sqlx::query_as("select * from images where id = $1")
.bind(id)
.fetch_optional(pool)
.await?)
}
pub async fn get_by_original_url(
pool: &PgPool,
original_url: &str,
) -> anyhow::Result<Option<ImageMeta>> {
Ok(
sqlx::query_as("select * from images where original_url = $1")
.bind(original_url)
.fetch_optional(pool)
.await?,
)
}
pub async fn get_by_attachment_id(
pool: &PgPool,
attachment_id: u64,
) -> anyhow::Result<Option<ImageMeta>> {
Ok(
sqlx::query_as("select * from images where original_attachment_id = $1")
.bind(attachment_id as i64)
.fetch_optional(pool)
.await?,
)
}
pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow::Result<()> {
sqlx::query(
r#"
delete from image_cleanup_jobs
where id in (
select id from images
where original_attachment_id = $1
)
"#,
)
.bind(attachment_id as i64)
.execute(pool)
.await?;
Ok(())
}
pub async fn pop_queue(
pool: &PgPool,
) -> anyhow::Result<Option<(Transaction<Postgres>, ImageQueueEntry)>> {
let mut tx = pool.begin().await?;
let res: Option<ImageQueueEntry> = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *")
.fetch_optional(&mut *tx).await?;
Ok(res.map(|x| (tx, x)))
}
pub async fn get_queue_length(pool: &PgPool) -> anyhow::Result<i64> {
Ok(sqlx::query_scalar("select count(*) from image_queue")
.fetch_one(pool)
.await?)
}
pub async fn get_stats(pool: &PgPool) -> anyhow::Result<Stats> {
Ok(sqlx::query_as(
"select count(*) as total_images, sum(file_size) as total_file_size from images",
)
.fetch_one(pool)
.await?)
}
pub async fn add_image(pool: &PgPool, meta: ImageMeta) -> anyhow::Result<bool> {
let kind_str = match meta.kind {
ImageKind::Avatar => "avatar",
ImageKind::Banner => "banner",
};
let res = sqlx::query("insert into images (id, url, content_type, original_url, file_size, width, height, original_file_size, original_type, original_attachment_id, kind, uploaded_by_account, uploaded_by_system, uploaded_at) values ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, (now() at time zone 'utc')) on conflict (id) do nothing")
.bind(meta.id)
.bind(meta.url)
.bind(meta.content_type)
.bind(meta.original_url)
.bind(meta.file_size)
.bind(meta.width)
.bind(meta.height)
.bind(meta.original_file_size)
.bind(meta.original_type)
.bind(meta.original_attachment_id)
.bind(kind_str)
.bind(meta.uploaded_by_account)
.bind(meta.uploaded_by_system)
.execute(pool).await?;
Ok(res.rows_affected() > 0)
}
pub async fn push_queue(
conn: &mut sqlx::PgConnection,
url: &str,
kind: ImageKind,
) -> anyhow::Result<()> {
sqlx::query("insert into image_queue (url, kind) values ($1, $2)")
.bind(url)
.bind(kind)
.execute(conn)
.await?;
Ok(())
}

View file

@ -0,0 +1,7 @@
mod stats;
pub use stats::*;
pub mod avatars;
mod auth;
pub use auth::*;

View file

@ -0,0 +1,26 @@
pub async fn get_stats(pool: &sqlx::postgres::PgPool) -> anyhow::Result<Counts> {
let counts: Counts = sqlx::query_as("select * from info").fetch_one(pool).await?;
Ok(counts)
}
pub async fn insert_stats(
pool: &sqlx::postgres::PgPool,
table: &str,
value: i64,
) -> anyhow::Result<()> {
// danger sql injection
sqlx::query(format!("insert into {table} values (now(), $1)").as_str())
.bind(value)
.execute(pool)
.await?;
Ok(())
}
#[derive(serde::Serialize, sqlx::FromRow)]
pub struct Counts {
pub system_count: i64,
pub member_count: i64,
pub group_count: i64,
pub switch_count: i64,
pub message_count: i64,
}

View file

@ -0,0 +1,53 @@
use serde::{Deserialize, Serialize};
use sqlx::FromRow;
use time::OffsetDateTime;
use uuid::Uuid;
#[derive(FromRow)]
pub struct ImageMeta {
pub id: String,
pub kind: ImageKind,
pub content_type: String,
pub url: String,
pub file_size: i32,
pub width: i32,
pub height: i32,
pub uploaded_at: Option<OffsetDateTime>,
pub original_url: Option<String>,
pub original_attachment_id: Option<i64>,
pub original_file_size: Option<i32>,
pub original_type: Option<String>,
pub uploaded_by_account: Option<i64>,
pub uploaded_by_system: Option<Uuid>,
}
#[derive(FromRow, Serialize)]
pub struct Stats {
pub total_images: i64,
pub total_file_size: i64,
}
#[derive(Serialize, Deserialize, Clone, Copy, Debug, sqlx::Type, PartialEq)]
#[serde(rename_all = "snake_case")]
#[sqlx(rename_all = "snake_case", type_name = "text")]
pub enum ImageKind {
Avatar,
Banner,
}
impl ImageKind {
pub fn size(&self) -> (u32, u32) {
match self {
Self::Avatar => (512, 512),
Self::Banner => (1024, 1024),
}
}
}
#[derive(FromRow)]
pub struct ImageQueueEntry {
pub itemid: i32,
pub url: String,
pub kind: ImageKind,
}

View file

@ -0,0 +1 @@
pub mod avatars;

81
crates/libpk/src/lib.rs Normal file
View file

@ -0,0 +1,81 @@
#![feature(let_chains)]
use std::net::SocketAddr;
use metrics_exporter_prometheus::PrometheusBuilder;
use sentry::IntoDsn;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
pub mod db;
pub mod state;
pub mod util;
pub mod _config;
pub use crate::_config::CONFIG as config;
// functions in this file are only used by the main function below
pub fn init_logging(component: &str) -> anyhow::Result<()> {
if config.json_log {
let mut layer = json_subscriber::layer();
layer.inner_layer_mut().add_static_field(
"component",
serde_json::Value::String(component.to_string()),
);
tracing_subscriber::registry()
.with(layer)
.with(EnvFilter::from_default_env())
.init();
} else {
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.init();
}
Ok(())
}
pub fn init_metrics() -> anyhow::Result<()> {
if config.run_metrics_server {
PrometheusBuilder::new()
.with_http_listener("[::]:9000".parse::<SocketAddr>().unwrap())
.install()?;
}
Ok(())
}
pub fn init_sentry() -> sentry::ClientInitGuard {
sentry::init(sentry::ClientOptions {
dsn: config
.sentry_url
.clone()
.map(|u| u.into_dsn().unwrap())
.flatten(),
release: sentry::release_name!(),
..Default::default()
})
}
#[macro_export]
macro_rules! main {
($component:expr) => {
fn main() -> anyhow::Result<()> {
let _sentry_guard = libpk::init_sentry();
// we might also be able to use env!("CARGO_CRATE_NAME") here
libpk::init_logging($component)?;
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
if let Err(err) = libpk::init_metrics() {
tracing::error!("failed to init metrics collector: {err}");
};
tracing::info!("hello world");
if let Err(err) = real_main().await {
tracing::error!("failed to run service: {err}");
};
});
Ok(())
}
};
}

12
crates/libpk/src/state.rs Normal file
View file

@ -0,0 +1,12 @@
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
pub struct ShardState {
pub shard_id: i32,
pub up: bool,
pub disconnection_count: i32,
/// milliseconds
pub latency: i32,
/// unix timestamp
pub last_heartbeat: i32,
pub last_connection: i32,
pub cluster_id: Option<i32>,
}

View file

@ -0,0 +1 @@
pub mod redis;

View file

@ -0,0 +1,15 @@
use fred::error::RedisError;
pub trait RedisErrorExt<T> {
fn to_option_or_error(self) -> Result<Option<T>, RedisError>;
}
impl<T> RedisErrorExt<T> for Result<T, RedisError> {
fn to_option_or_error(self) -> Result<Option<T>, RedisError> {
match self {
Ok(v) => Ok(Some(v)),
Err(error) if error.is_not_found() => Ok(None),
Err(error) => Err(error),
}
}
}

View file

@ -0,0 +1,13 @@
[package]
name = "model_macros"
version = "0.1.0"
edition = "2021"
[lib]
proc-macro = true
[dependencies]
quote = "1.0"
proc-macro2 = "1.0"
syn = "2.0"

View file

@ -0,0 +1,259 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Expr, Ident, Meta, Type};
#[derive(Clone, Debug)]
enum ElemPatchability {
None,
Private,
Public,
}
#[derive(Clone, Debug)]
struct ModelField {
name: Ident,
ty: Type,
patch: ElemPatchability,
json: Option<Expr>,
is_privacy: bool,
default: Option<Expr>,
}
fn parse_field(field: syn::Field) -> ModelField {
let mut f = ModelField {
name: field.ident.expect("field missing ident"),
ty: field.ty,
patch: ElemPatchability::None,
json: None,
is_privacy: false,
default: None,
};
for attr in field.attrs.iter() {
match &attr.meta {
Meta::Path(path) => {
let ident = path.segments[0].ident.to_string();
match ident.as_str() {
"private_patchable" => match f.patch {
ElemPatchability::None => {
f.patch = ElemPatchability::Private;
}
_ => {
panic!("cannot have multiple patch tags on same field");
}
},
"patchable" => match f.patch {
ElemPatchability::None => {
f.patch = ElemPatchability::Public;
}
_ => {
panic!("cannot have multiple patch tags on same field");
}
},
"privacy" => f.is_privacy = true,
_ => panic!("unknown attribute"),
}
}
Meta::NameValue(nv) => match nv.path.segments[0].ident.to_string().as_str() {
"json" => {
if f.json.is_some() {
panic!("cannot set json multiple times for same field");
}
f.json = Some(nv.value.clone());
}
"default" => {
if f.default.is_some() {
panic!("cannot set default multiple times for same field");
}
f.default = Some(nv.value.clone());
}
_ => panic!("unknown attribute"),
},
Meta::List(_) => panic!("unknown attribute"),
}
}
if matches!(f.patch, ElemPatchability::Public) && f.json.is_none() {
panic!("must have json name to be publicly patchable");
}
if f.json.is_some() && f.is_privacy {
panic!("cannot set custom json name for privacy field");
}
f
}
#[proc_macro_attribute]
pub fn pk_model(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let ast = parse_macro_input!(input as DeriveInput);
let model_type = match ast.data {
syn::Data::Struct(struct_data) => struct_data,
_ => panic!("pk_model can only be used on a struct"),
};
let tname = Ident::new(&format!("PK{}", ast.ident), Span::call_site());
let patchable_name = Ident::new(&format!("PK{}Patch", ast.ident), Span::call_site());
let fields = if let syn::Fields::Named(fields) = model_type.fields {
fields
.named
.iter()
.map(|f| parse_field(f.clone()))
.collect::<Vec<ModelField>>()
} else {
panic!("fields of a struct must be named");
};
// println!("{}: {:#?}", tname, fields);
let tfields = mk_tfields(fields.clone());
let from_json = mk_tfrom_json(fields.clone());
let from_sql = mk_tfrom_sql(fields.clone());
let to_json = mk_tto_json(fields.clone());
let fields: Vec<ModelField> = fields
.iter()
.filter(|f| !matches!(f.patch, ElemPatchability::None))
.cloned()
.collect();
let patch_fields = mk_patch_fields(fields.clone());
let patch_from_json = mk_patch_from_json(fields.clone());
let patch_validate = mk_patch_validate(fields.clone());
let patch_to_json = mk_patch_to_json(fields.clone());
let patch_to_sql = mk_patch_to_sql(fields.clone());
return quote! {
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct #tname {
#tfields
}
impl #tname {
pub fn from_json(input: String) -> Self {
#from_json
}
pub fn to_json(self) -> serde_json::Value {
#to_json
}
}
#[derive(Debug, Clone)]
pub struct #patchable_name {
#patch_fields
}
impl #patchable_name {
pub fn from_json(input: String) -> Self {
#patch_from_json
}
pub fn validate(self) -> bool {
#patch_validate
}
pub fn to_sql(self) -> sea_query::UpdateStatement {
// sea_query::Query::update()
#patch_to_sql
}
pub fn to_json(self) -> serde_json::Value {
#patch_to_json
}
}
}
.into();
}
fn mk_tfields(fields: Vec<ModelField>) -> TokenStream {
fields
.iter()
.map(|f| {
let name = f.name.clone();
let ty = f.ty.clone();
quote! {
pub #name: #ty,
}
})
.collect()
}
fn mk_tfrom_json(fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_tfrom_sql(fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
// todo: check privacy access
let fielddefs: TokenStream = fields
.iter()
.filter_map(|f| {
f.json.as_ref().map(|v| {
let tname = f.name.clone();
if let Some(default) = f.default.as_ref() {
quote! {
#v: self.#tname.unwrap_or(#default),
}
} else {
quote! {
#v: self.#tname,
}
}
})
})
.collect();
let privacyfielddefs: TokenStream = fields
.iter()
.filter_map(|f| {
if f.is_privacy {
let tname = f.name.clone();
let tnamestr = f.name.clone().to_string();
Some(quote! {
#tnamestr: self.#tname,
})
} else {
None
}
})
.collect();
quote! {
serde_json::json!({
#fielddefs
"privacy": {
#privacyfielddefs
}
})
}
}
fn mk_patch_fields(fields: Vec<ModelField>) -> TokenStream {
fields
.iter()
.map(|f| {
let name = f.name.clone();
let ty = f.ty.clone();
quote! {
pub #name: Option<#ty>,
}
})
.collect()
}
fn mk_patch_validate(_fields: Vec<ModelField>) -> TokenStream {
quote! { true }
}
fn mk_patch_from_json(fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_patch_to_sql(fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_patch_to_json(fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}

13
crates/models/Cargo.toml Normal file
View file

@ -0,0 +1,13 @@
[package]
name = "pluralkit_models"
version = "0.1.0"
edition = "2021"
[dependencies]
chrono = { workspace = true, features = ["serde"] }
model_macros = { path = "../model_macros" }
sea-query = "0.32.1"
serde = { workspace = true }
serde_json = { workspace = true, features = ["preserve_order"] }
sqlx = { workspace = true, default-features = false, features = ["chrono"] }
uuid = { workspace = true }

View file

@ -0,0 +1,35 @@
// postgres enums created in c# pluralkit implementations are "fake", i.e. they
// are actually ints in the database rather than postgres enums, because dapper
// does not support postgres enums
// here, we add some impls to support this kind of enum in sqlx
// there is probably a better way to do this, but works for now.
// note: caller needs to implement From<i32> for their type
macro_rules! fake_enum_impls {
($n:ident) => {
impl Type<Postgres> for $n {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("INT4")
}
}
impl From<$n> for i32 {
fn from(enum_value: $n) -> Self {
enum_value as i32
}
}
impl<'r, DB: Database> Decode<'r, DB> for $n
where
i32: Decode<'r, DB>,
{
fn decode(
value: <DB as Database>::ValueRef<'r>,
) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
let value = <i32 as Decode<DB>>::decode(value)?;
Ok(Self::from(value))
}
}
};
}
pub(crate) use fake_enum_impls;

11
crates/models/src/lib.rs Normal file
View file

@ -0,0 +1,11 @@
mod _util;
macro_rules! model {
($n:ident) => {
mod $n;
pub use $n::*;
};
}
model!(system);
model!(system_config);

View file

@ -0,0 +1,80 @@
use std::error::Error;
use model_macros::pk_model;
use chrono::NaiveDateTime;
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
use uuid::Uuid;
use crate::_util::fake_enum_impls;
// todo: fix this
pub type SystemId = i32;
// todo: move this
#[derive(serde::Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
pub enum PrivacyLevel {
Public,
Private,
}
fake_enum_impls!(PrivacyLevel);
impl From<i32> for PrivacyLevel {
fn from(value: i32) -> Self {
match value {
1 => PrivacyLevel::Public,
2 => PrivacyLevel::Private,
_ => unreachable!(),
}
}
}
#[pk_model]
struct System {
id: SystemId,
#[json = "id"]
#[private_patchable]
hid: String,
#[json = "uuid"]
uuid: Uuid,
#[json = "name"]
name: Option<String>,
#[json = "description"]
description: Option<String>,
#[json = "tag"]
tag: Option<String>,
#[json = "pronouns"]
pronouns: Option<String>,
#[json = "avatar_url"]
avatar_url: Option<String>,
#[json = "banner_image"]
banner_image: Option<String>,
#[json = "color"]
color: Option<String>,
token: Option<String>,
#[json = "webhook_url"]
webhook_url: Option<String>,
webhook_token: Option<String>,
#[json = "created"]
created: NaiveDateTime,
#[privacy]
name_privacy: PrivacyLevel,
#[privacy]
avatar_privacy: PrivacyLevel,
#[privacy]
description_privacy: PrivacyLevel,
#[privacy]
banner_privacy: PrivacyLevel,
#[privacy]
member_list_privacy: PrivacyLevel,
#[privacy]
front_privacy: PrivacyLevel,
#[privacy]
front_history_privacy: PrivacyLevel,
#[privacy]
group_list_privacy: PrivacyLevel,
#[privacy]
pronoun_privacy: PrivacyLevel,
}

View file

@ -0,0 +1,89 @@
use model_macros::pk_model;
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
use std::error::Error;
use crate::{SystemId, _util::fake_enum_impls};
pub const DEFAULT_MEMBER_LIMIT: i32 = 1000;
pub const DEFAULT_GROUP_LIMIT: i32 = 250;
#[derive(serde::Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
enum HidPadFormat {
#[serde(rename = "off")]
None,
Left,
Right,
}
fake_enum_impls!(HidPadFormat);
impl From<i32> for HidPadFormat {
fn from(value: i32) -> Self {
match value {
0 => HidPadFormat::None,
1 => HidPadFormat::Left,
2 => HidPadFormat::Right,
_ => unreachable!(),
}
}
}
#[derive(serde::Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
enum ProxySwitchAction {
Off,
New,
Add,
}
fake_enum_impls!(ProxySwitchAction);
impl From<i32> for ProxySwitchAction {
fn from(value: i32) -> Self {
match value {
0 => ProxySwitchAction::Off,
1 => ProxySwitchAction::New,
2 => ProxySwitchAction::Add,
_ => unreachable!(),
}
}
}
#[pk_model]
struct SystemConfig {
system: SystemId,
#[json = "timezone"]
ui_tz: String,
#[json = "pings_enabled"]
pings_enabled: bool,
#[json = "latch_timeout"]
latch_timeout: Option<i32>,
#[json = "member_default_private"]
member_default_private: bool,
#[json = "group_default_private"]
group_default_private: bool,
#[json = "show_private_info"]
show_private_info: bool,
#[json = "member_limit"]
#[default = DEFAULT_MEMBER_LIMIT]
member_limit_override: Option<i32>,
#[json = "group_limit"]
#[default = DEFAULT_GROUP_LIMIT]
group_limit_override: Option<i32>,
#[json = "case_sensitive_proxy_tags"]
case_sensitive_proxy_tags: bool,
#[json = "proxy_error_message_enabled"]
proxy_error_message_enabled: bool,
#[json = "hid_display_split"]
hid_display_split: bool,
#[json = "hid_display_caps"]
hid_display_caps: bool,
#[json = "hid_list_padding"]
hid_list_padding: HidPadFormat,
#[json = "proxy_switch"]
proxy_switch: ProxySwitchAction,
#[json = "name_format"]
name_format: String,
#[json = "description_templates"]
description_templates: Vec<String>,
}

View file

@ -0,0 +1,20 @@
[package]
name = "scheduled_tasks"
version = "0.1.0"
edition = "2021"
[dependencies]
libpk = { path = "../libpk" }
anyhow = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
croner = "2.1.0"
num-format = "0.4.4"

View file

@ -0,0 +1,91 @@
use chrono::Utc;
use croner::Cron;
use fred::prelude::RedisPool;
use sqlx::PgPool;
use tokio::task::JoinSet;
use tracing::{debug, error, info};
mod tasks;
use tasks::*;
#[derive(Clone)]
pub struct AppCtx {
pub data: PgPool,
pub messages: PgPool,
pub stats: PgPool,
pub redis: RedisPool,
}
libpk::main!("scheduled_tasks");
async fn real_main() -> anyhow::Result<()> {
let ctx = AppCtx {
data: libpk::db::init_data_db().await?,
messages: libpk::db::init_messages_db().await?,
stats: libpk::db::init_stats_db().await?,
redis: libpk::db::init_redis().await?,
};
info!("starting scheduled tasks runner");
let mut set = JoinSet::new();
// i couldn't be bothered to figure out the types of passing in an async
// function to another function... so macro it is
macro_rules! doforever {
($cron:expr, $desc:expr, $fn:ident) => {
let ctx = ctx.clone();
let cron = Cron::new($cron)
.with_seconds_optional()
.parse()
.expect("invalid cron");
set.spawn(tokio::spawn(async move {
loop {
let ctx = ctx.clone();
let next_iter_time = cron.find_next_occurrence(&Utc::now(), false).unwrap();
debug!("next execution of {} at {:?}", $desc, next_iter_time);
let dur = next_iter_time - Utc::now();
tokio::time::sleep(dur.to_std().unwrap()).await;
info!("running {}", $desc);
let before = std::time::Instant::now();
if let Err(error) = $fn(ctx).await {
error!("failed to run {}: {}", $desc, error);
// sentry
}
let duration = before.elapsed();
info!("ran {} in {duration:?}", $desc);
// add prometheus log
}
}))
};
}
// every 10 seconds
doforever!(
"0,10,20,30,40,50 * * * * *",
"prometheus updater",
update_prometheus
);
// every minute
doforever!("* * * * *", "database stats updater", update_db_meta);
// every 10 minutes
doforever!(
"0,10,20,30,40,50 * * * *",
"message stats updater",
update_db_message_meta
);
// every minute
doforever!("* * * * *", "discord stats updater", update_discord_stats);
// on :00 and :30
doforever!(
"0,30 * * * *",
"queue deleted image cleanup job",
queue_deleted_image_cleanup
);
set.join_next()
.await
.ok_or(anyhow::anyhow!("could not join_next"))???;
Ok(())
}

View file

@ -0,0 +1,151 @@
use std::time::Duration;
use anyhow::anyhow;
use fred::prelude::KeysInterface;
use libpk::{
config,
db::repository::{get_stats, insert_stats},
};
use metrics::gauge;
use num_format::{Locale, ToFormattedString};
use reqwest::ClientBuilder;
use sqlx::Executor;
use crate::AppCtx;
pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct Count {
count: i64,
}
let count: Count = sqlx::query_as("select count(*) from image_cleanup_jobs")
.fetch_one(&ctx.data)
.await?;
gauge!("pluralkit_image_cleanup_queue_length").set(count.count as f64);
// todo: remaining shard session_start_limit
Ok(())
}
pub async fn update_db_meta(ctx: AppCtx) -> anyhow::Result<()> {
ctx.data
.execute(
r#"
update info set
system_count = (select count(*) from systems),
member_count = (select count(*) from systems),
group_count = (select count(*) from systems),
switch_count = (select count(*) from systems)
"#,
)
.await?;
let new_stats = get_stats(&ctx.data).await?;
insert_stats(&ctx.stats, "systems", new_stats.system_count).await?;
insert_stats(&ctx.stats, "members", new_stats.member_count).await?;
insert_stats(&ctx.stats, "groups", new_stats.group_count).await?;
insert_stats(&ctx.stats, "switches", new_stats.switch_count).await?;
Ok(())
}
pub async fn update_db_message_meta(ctx: AppCtx) -> anyhow::Result<()> {
#[derive(sqlx::FromRow)]
struct MessageCount {
count: i64,
}
let message_count: MessageCount = sqlx::query_as("select count(*) from messages")
.fetch_one(&ctx.messages)
.await?;
sqlx::query("update info set message_count = $1")
.bind(message_count.count)
.execute(&ctx.data)
.await?;
insert_stats(&ctx.stats, "messages", message_count.count).await?;
Ok(())
}
pub async fn update_discord_stats(ctx: AppCtx) -> anyhow::Result<()> {
let client = ClientBuilder::new()
.connect_timeout(Duration::from_secs(3))
.timeout(Duration::from_secs(3))
.build()
.expect("error making client");
let cfg = config
.scheduled_tasks
.as_ref()
.expect("missing scheduled_tasks config");
#[derive(serde::Deserialize)]
struct GatewayStatus {
up: bool,
guild_count: i64,
channel_count: i64,
}
let mut guild_count = 0;
let mut channel_count = 0;
for idx in 0..=cfg.expected_gateway_count {
let res = client
.get(format!("http://cluster{idx}.{}/stats", cfg.gateway_url))
.send()
.await?;
let stat: GatewayStatus = res.json().await?;
if !stat.up {
return Err(anyhow!("cluster {idx} is not up"));
}
guild_count += stat.guild_count;
channel_count += stat.channel_count;
}
insert_stats(&ctx.stats, "guilds", guild_count).await?;
insert_stats(&ctx.stats, "channels", channel_count).await?;
if cfg.set_guild_count {
ctx.redis
.set::<(), &str, String>(
"pluralkit:botstatus",
format!(
"in {} servers",
guild_count.to_formatted_string(&Locale::en)
),
None,
None,
false,
)
.await?;
}
Ok(())
}
pub async fn queue_deleted_image_cleanup(ctx: AppCtx) -> anyhow::Result<()> {
// todo: we want to delete immediately when system is deleted, but after a
// delay if member is deleted
ctx.data
.execute(
r#"
insert into image_cleanup_jobs
select id, now() from images where
not exists (select from image_cleanup_jobs j where j.id = images.id)
and not exists (select from systems where avatar_url = images.url)
and not exists (select from systems where banner_image = images.url)
and not exists (select from system_guild where avatar_url = images.url)
and not exists (select from members where avatar_url = images.url)
and not exists (select from members where banner_image = images.url)
and not exists (select from members where webhook_avatar_url = images.url)
and not exists (select from member_guild where avatar_url = images.url)
and not exists (select from groups where icon = images.url)
and not exists (select from groups where banner_image = images.url);
"#,
)
.await?;
Ok(())
}