Merge remote-tracking branch 'upstream/main' into rust-command-parser

This commit is contained in:
dusk 2025-08-09 17:38:44 +03:00
commit f721b850d4
No known key found for this signature in database
183 changed files with 5121 additions and 1909 deletions

48
crates/api/src/auth.rs Normal file
View file

@ -0,0 +1,48 @@
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
#[derive(Clone)]
pub struct AuthState {
system_id: Option<i32>,
app_id: Option<i32>,
}
impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
Self { system_id, app_id }
}
pub fn system_id(&self) -> Option<i32> {
self.system_id
}
pub fn app_id(&self) -> Option<i32> {
self.app_id
}
pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel {
if self
.system_id
.map(|id| id == a.authable_system_id())
.unwrap_or(false)
{
PrivacyLevel::Private
} else {
PrivacyLevel::Public
}
}
}
// authable trait/impls
pub trait Authable {
fn authable_system_id(&self) -> SystemId;
}
impl Authable for PKSystem {
fn authable_system_id(&self) -> SystemId {
self.id
}
}

View file

@ -1 +1,2 @@
pub mod private;
pub mod system;

View file

@ -52,7 +52,7 @@ use axum::{
};
use hyper::StatusCode;
use libpk::config;
use pluralkit_models::{PKSystem, PKSystemConfig};
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use reqwest::ClientBuilder;
#[derive(serde::Deserialize, Debug)]
@ -151,14 +151,12 @@ pub async fn discord_callback(
.await
.expect("failed to query");
if system.is_none() {
let Some(system) = system else {
return json_err(
StatusCode::BAD_REQUEST,
"user does not have a system registered".to_string(),
);
}
let system = system.unwrap();
};
let system_config: Option<PKSystemConfig> = sqlx::query_as(
r#"
@ -179,7 +177,7 @@ pub async fn discord_callback(
(
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"system": system.to_json(),
"system": system.to_json(PrivacyLevel::Private),
"config": system_config.to_json(),
"user": user,
"token": token,

View file

@ -0,0 +1,69 @@
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
Extension, Json,
};
use serde_json::json;
use sqlx::Postgres;
use tracing::error;
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use crate::{auth::AuthState, util::json_err, ApiContext};
pub async fn get_system_settings(
Extension(auth): Extension<AuthState>,
Extension(system): Extension<PKSystem>,
State(ctx): State<ApiContext>,
) -> Response {
let access_level = auth.access_level_for(&system);
let mut config = match sqlx::query_as::<Postgres, PKSystemConfig>(
"select * from system_config where system = $1",
)
.bind(system.id)
.fetch_optional(&ctx.db)
.await
{
Ok(Some(config)) => config,
Ok(None) => {
error!(
system = system.id,
"failed to find system config for existing system"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
Err(err) => {
error!(?err, "failed to query system config");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
};
// fix this
if config.name_format.is_none() {
config.name_format = Some("{name} {tag}".to_string());
}
Json(&match access_level {
PrivacyLevel::Private => config.to_json(),
PrivacyLevel::Public => json!({
"pings_enabled": config.pings_enabled,
"latch_timeout": config.latch_timeout,
"case_sensitive_proxy_tags": config.case_sensitive_proxy_tags,
"proxy_error_message_enabled": config.proxy_error_message_enabled,
"hid_display_split": config.hid_display_split,
"hid_display_caps": config.hid_display_caps,
"hid_list_padding": config.hid_list_padding,
"proxy_switch": config.proxy_switch,
"name_format": config.name_format,
}),
})
.into_response()
}

View file

@ -1,10 +1,13 @@
#![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
routing::{delete, get, patch, post},
Router,
Extension, Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
@ -12,6 +15,7 @@ use hyper_util::{
};
use tracing::{error, info};
mod auth;
mod endpoints;
mod error;
mod middleware;
@ -27,6 +31,7 @@ pub struct ApiContext {
}
async fn rproxy(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
@ -41,12 +46,25 @@ async fn rproxy(
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(ctx
.rproxy_client
.request(req)
.await
.map_err(|err| {
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
.map_err(|error| {
error!(?error, "failed to serve reverse proxy to dotnet-api");
StatusCode::BAD_GATEWAY
})?
.into_response())
@ -57,52 +75,52 @@ async fn rproxy(
fn router(ctx: ApiContext) -> Router {
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/v2/systems/:system_id", get(rproxy))
.route("/v2/systems/:system_id", patch(rproxy))
.route("/v2/systems/:system_id/settings", get(rproxy))
.route("/v2/systems/:system_id/settings", patch(rproxy))
.route("/v2/systems/{system_id}", get(rproxy))
.route("/v2/systems/{system_id}", patch(rproxy))
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
.route("/v2/systems/{system_id}/settings", patch(rproxy))
.route("/v2/systems/:system_id/members", get(rproxy))
.route("/v2/systems/{system_id}/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/:member_id", get(rproxy))
.route("/v2/members/:member_id", patch(rproxy))
.route("/v2/members/:member_id", delete(rproxy))
.route("/v2/members/{member_id}", get(rproxy))
.route("/v2/members/{member_id}", patch(rproxy))
.route("/v2/members/{member_id}", delete(rproxy))
.route("/v2/systems/:system_id/groups", get(rproxy))
.route("/v2/systems/{system_id}/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/:group_id", get(rproxy))
.route("/v2/groups/:group_id", patch(rproxy))
.route("/v2/groups/:group_id", delete(rproxy))
.route("/v2/groups/{group_id}", get(rproxy))
.route("/v2/groups/{group_id}", patch(rproxy))
.route("/v2/groups/{group_id}", delete(rproxy))
.route("/v2/groups/:group_id/members", get(rproxy))
.route("/v2/groups/:group_id/members/add", post(rproxy))
.route("/v2/groups/:group_id/members/remove", post(rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
.route("/v2/groups/{group_id}/members", get(rproxy))
.route("/v2/groups/{group_id}/members/add", post(rproxy))
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
.route("/v2/members/:member_id/groups", get(rproxy))
.route("/v2/members/:member_id/groups/add", post(rproxy))
.route("/v2/members/:member_id/groups/remove", post(rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
.route("/v2/members/{member_id}/groups", get(rproxy))
.route("/v2/members/{member_id}/groups/add", post(rproxy))
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
.route("/v2/systems/:system_id/switches", get(rproxy))
.route("/v2/systems/:system_id/switches", post(rproxy))
.route("/v2/systems/:system_id/fronters", get(rproxy))
.route("/v2/systems/{system_id}/switches", get(rproxy))
.route("/v2/systems/{system_id}/switches", post(rproxy))
.route("/v2/systems/{system_id}/fronters", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
.route("/v2/messages/:message_id", get(rproxy))
.route("/v2/messages/{message_id}", get(rproxy))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
@ -111,16 +129,19 @@ fn router(ctx: ApiContext) -> Router {
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.layer(axum::middleware::from_fn(middleware::logger))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::logger::logger))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::params::params))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth))
.layer(axum::middleware::from_fn(middleware::cors::cors))
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
.with_state(ctx)
@ -128,8 +149,8 @@ fn router(ctx: ApiContext) -> Router {
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
}
libpk::main!("api");
async fn real_main() -> anyhow::Result<()> {
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;

View file

@ -0,0 +1,62 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<i32> = None;
// fetch user authorization
if let Some(system_auth_header) = req
.headers()
.get("authorization")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
}
{
authed_system_id = Some(system_id);
}
// fetch app authorization
// todo: actually fetch it from db
if let Some(app_auth_header) = req
.headers()
.get("x-pluralkit-app")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(config_token2) = libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
.as_ref()
// this is NOT how you validate tokens
// but this is low abuse risk so we're keeping it for now
&& app_auth_header == config_token2
{
authed_app_id = Some(1);
}
req.extensions_mut()
.insert(AuthState::new(authed_system_id, authed_app_id));
next.run(req).await
}

View file

@ -1,45 +0,0 @@
use axum::{
extract::{Request, State},
http::HeaderValue,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::ApiContext;
use super::logger::DID_AUTHENTICATE_HEADER;
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
let headers = request.headers_mut();
headers.remove("x-pluralkit-systemid");
let auth_header = headers
.get("authorization")
.map(|h| h.to_str().ok())
.flatten();
let mut authenticated = false;
if let Some(auth_header) = auth_header {
if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
None
}
}
{
headers.append(
"x-pluralkit-systemid",
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
);
authenticated = true;
}
}
let mut response = next.run(request).await;
if authenticated {
response
.headers_mut()
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
}
response
}

View file

@ -4,27 +4,30 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use crate::util::header_or_unknown;
use crate::{auth::AuthState, util::header_or_unknown};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let endpoint = request
.extensions()
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
let uri = request.uri().clone();
let request_span = span!(
@ -37,25 +40,26 @@ pub async fn logger(request: Request, next: Next) -> Response {
);
let start = Instant::now();
let mut response = next.run(request).instrument(request_span).await;
let response = next.run(request).instrument(request_span).await;
let elapsed = start.elapsed().as_millis();
let authenticated = {
let headers = response.headers_mut();
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
headers.remove(DID_AUTHENTICATE_HEADER);
true
} else {
false
}
};
let system_id = auth
.system_id()
.map(|v| v.to_string())
.unwrap_or("none".to_string());
let app_id = auth
.app_id()
.map(|v| v.to_string())
.unwrap_or("none".to_string());
counter!(
"pluralkit_api_requests",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.increment(1);
histogram!(
@ -63,7 +67,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.record(elapsed as f64 / 1_000_f64);
@ -81,7 +86,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.increment(1);

View file

@ -1,13 +1,6 @@
mod cors;
pub use cors::cors;
mod logger;
pub use logger::logger;
mod ignore_invalid_routes;
pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod auth;
pub mod cors;
pub mod ignore_invalid_routes;
pub mod logger;
pub mod params;
pub mod ratelimit;
mod authnz;
pub use authnz::authnz;

View file

@ -0,0 +1,139 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
routing::url_params::UrlParams,
};
use sqlx::{types::Uuid, Postgres};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
use pluralkit_models::PKSystem;
// move this somewhere else
fn parse_hid(hid: &str) -> String {
if hid.len() > 7 || hid.len() < 5 {
hid.to_string()
} else {
hid.to_lowercase().replace("-", "")
}
}
pub async fn params(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let pms = match req.extensions().get::<UrlParams>() {
None => Vec::new(),
Some(UrlParams::Params(pms)) => pms.clone(),
_ => {
return json_err(
StatusCode::BAD_REQUEST,
r#"{"message":"400: Bad Request","code": 0}"#.to_string(),
)
.into()
}
};
for (key, value) in pms {
match key.as_ref() {
"system_id" => match value.as_str() {
"@me" => {
let Some(system_id) = req
.extensions()
.get::<AuthState>()
.expect("missing auth state")
.system_id()
else {
return json_err(
StatusCode::UNAUTHORIZED,
r#"{"message":"401: Missing or invalid Authorization header","code": 0}"#.to_string(),
)
.into();
};
match sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where id = $1",
)
.bind(system_id)
.fetch_optional(&ctx.db)
.await
{
Ok(Some(system)) => {
req.extensions_mut().insert(system);
}
Ok(None) => {
error!(
?system_id,
"could not find previously authenticated system in db"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#
.to_string(),
);
}
Err(err) => {
error!(
?err,
"failed to query previously authenticated system in db"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#
.to_string(),
);
}
}
}
id => {
println!("a {id}");
match match Uuid::parse_str(id) {
Ok(uuid) => sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where uuid = $1",
)
.bind(uuid),
Err(_) => match id.parse::<i64>() {
Ok(parsed) => sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where id = (select system from accounts where uid = $1)"
)
.bind(parsed),
Err(_) => sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where hid = $1",
)
.bind(parse_hid(id))
},
}
.fetch_optional(&ctx.db)
.await
{
Ok(Some(system)) => {
req.extensions_mut().insert(system);
}
Ok(None) => {
return json_err(
StatusCode::NOT_FOUND,
r#"{"message":"System not found.","code":20001}"#.to_string(),
)
}
Err(err) => {
error!(?err, ?id, "failed to query system from path in db");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#
.to_string(),
);
}
}
}
},
"member_id" => {}
"group_id" => {}
"switch_id" => {}
"guild_id" => {}
_ => {}
}
}
next.run(req).await
}

View file

@ -10,7 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
use metrics::counter;
use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
use crate::{
auth::AuthState,
util::{header_or_unknown, json_err},
};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
@ -50,7 +53,7 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
.await
{
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
Err(error) => error!(?error, "could not load redis script"),
}
} else {
error!("could not wait for connection to load redis script!");
@ -103,37 +106,28 @@ pub async fn do_request_ratelimited(
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
// https://github.com/rust-lang/rust/issues/53667
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if let Some(token2) = &libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
{
if header.to_str().unwrap_or("invalid") == token2 {
true
} else {
false
}
} else {
false
}
} else {
false
};
let extensions = request.extensions().clone();
let endpoint = request
.extensions()
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
// looks like this chooses the tokens/sec by app_id or endpoint
// then chooses the key by system_id or source_ip
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if let Some(app_id) = auth.app_id()
&& app_id == 1
{
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
@ -145,12 +139,12 @@ pub async fn do_request_ratelimited(
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
if let Some(system_id) = auth.system_id()
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
system_id.to_string()
} else {
source_ip
source_ip.to_string()
},
rlimit.key()
);
@ -224,8 +218,8 @@ pub async fn do_request_ratelimited(
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
Err(error) => {
tracing::error!(?error, "error getting ratelimit info");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),

View file

@ -11,7 +11,7 @@ pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
match value.to_str() {
Ok(v) => v,
Err(err) => {
error!("failed to parse header value {:#?}: {:#?}", value, err);
error!(?err, ?value, "failed to parse header value");
"failed to parse"
}
}
@ -34,11 +34,7 @@ where
.unwrap(),
),
None => {
error!(
"error in handler {}: {:#?}",
std::any::type_name::<F>(),
error
);
error!(?error, "error in handler {}", std::any::type_name::<F>(),);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
@ -48,14 +44,15 @@ where
}
}
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!("caught panic from handler: {:#?}", err);
pub fn handle_panic(error: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!(?error, "caught panic from handler");
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
// todo: make 500 not duplicated
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();

View file

@ -4,8 +4,8 @@ use sqlx::prelude::FromRow;
use std::{sync::Arc, time::Duration};
use tracing::{error, info};
libpk::main!("avatar_cleanup");
async fn real_main() -> anyhow::Result<()> {
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let config = libpk::config
.avatars
.as_ref()
@ -13,7 +13,7 @@ async fn real_main() -> anyhow::Result<()> {
let bucket = {
let region = s3::Region::Custom {
region: "s3".to_string(),
region: "auto".to_string(),
endpoint: config.s3.endpoint.to_string(),
};
@ -38,8 +38,8 @@ async fn real_main() -> anyhow::Result<()> {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
match cleanup_job(pool.clone(), bucket.clone()).await {
Ok(()) => {}
Err(err) => {
error!("failed to run avatar cleanup job: {}", err);
Err(error) => {
error!(?error, "failed to run avatar cleanup job");
// sentry
}
}
@ -55,9 +55,10 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
let mut tx = pool.begin().await?;
let image_id: Option<CleanupJobEntry> = sqlx::query_as(
// no timestamp checking here
// images are only added to the table after 24h
r#"
select id from image_cleanup_jobs
where ts < now() - interval '1 day'
for update skip locked limit 1;"#,
)
.fetch_optional(&mut *tx)
@ -72,6 +73,7 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
let image_data = libpk::db::repository::avatars::get_by_id(&pool, image_id.clone()).await?;
if image_data.is_none() {
// unsure how this can happen? there is a FK reference
info!("image {image_id} was already deleted, skipping");
sqlx::query("delete from image_cleanup_jobs where id = $1")
.bind(image_id)

View file

@ -93,7 +93,7 @@ async fn pull(
) -> Result<Json<PullResponse>, PKAvatarError> {
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
if !req.force {
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
// remove any pending image cleanup
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
@ -170,8 +170,8 @@ pub struct AppState {
config: Arc<AvatarsConfig>,
}
libpk::main!("avatars");
async fn real_main() -> anyhow::Result<()> {
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let config = libpk::config
.avatars
.as_ref()
@ -179,7 +179,7 @@ async fn real_main() -> anyhow::Result<()> {
let bucket = {
let region = s3::Region::Custom {
region: "s3".to_string(),
region: "auto".to_string(),
endpoint: config.s3.endpoint.to_string(),
};
@ -232,26 +232,11 @@ async fn real_main() -> anyhow::Result<()> {
Ok(())
}
struct AppError(anyhow::Error);
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
error!("error handling request: {}", self.0);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: self.0.to_string(),
}),
)
.into_response()
}
}
impl IntoResponse for PKAvatarError {
fn into_response(self) -> Response {
let status_code = match self {
@ -278,12 +263,3 @@ impl IntoResponse for PKAvatarError {
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

View file

@ -129,9 +129,9 @@ pub async fn worker(worker_id: u32, state: Arc<AppState>) {
Ok(()) => {}
Err(e) => {
error!(
"error in migrate worker {}: {}",
worker_id,
e.source().unwrap_or(&e)
error = e.source().unwrap_or(&e)
?worker_id,
"error in migrate worker",
);
tokio::time::sleep(Duration::from_secs(5)).await;
}

View file

@ -84,7 +84,7 @@ pub fn process(data: &[u8], kind: ImageKind) -> Result<ProcessOutput, PKAvatarEr
} else {
reader.decode().map_err(|e| {
// print the ugly error, return the nice error
error!("error decoding image: {}", e);
error!(error = format!("{e:#?}"), "error decoding image");
PKAvatarError::ImageFormatError(e)
})?
};

View file

@ -41,7 +41,11 @@ pub async fn pull(
}
}
error!("network error for {}: {}", parsed_url.full_url, s);
error!(
url = parsed_url.full_url,
error = s,
"network error pulling image"
);
PKAvatarError::NetworkErrorString(s)
})?;
let time_after_headers = Instant::now();
@ -82,7 +86,22 @@ pub async fn pull(
.map(|x| x.to_string());
let body = response.bytes().await.map_err(|e| {
error!("network error for {}: {}", parsed_url.full_url, e);
// terrible
let mut s = format!("{}", e);
if let Some(src) = e.source() {
let _ = write!(s, ": {}", src);
let mut err = src;
while let Some(src) = err.source() {
let _ = write!(s, ": {}", src);
err = src;
}
}
error!(
url = parsed_url.full_url,
error = s,
"network error pulling image"
);
PKAvatarError::NetworkError(e)
})?;
if body.len() != size as usize {
@ -137,6 +156,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
match (url.scheme(), url.domain()) {
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
("https", Some("serve.apparyllis.com")) => {
return Ok(ParsedUrl {
channel_id: 0,
attachment_id: 0,
filename: "".to_string(),
full_url: url.to_string(),
})
}
_ => anyhow::bail!("not a discord cdn url"),
}

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
libpk = { path = "../libpk" }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View file

@ -19,17 +19,8 @@ use axum::{extract::State, http::Uri, routing::post, Json, Router};
mod logger;
// this package does not currently use libpk
#[tokio::main]
#[libpk::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.json()
.with_env_filter(EnvFilter::from_default_env())
.init();
info!("hello world");
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
let (client, bg) = AsyncClient::connect(stream).await?;
@ -86,11 +77,11 @@ async fn dispatch(
let uri = match req.url.parse::<Uri>() {
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
Err(error) => {
error!(?error, "failed to parse uri {}", req.url);
error!(?error, uri = req.url, "failed to parse uri");
return DispatchResponse::BadData.to_string();
}
_ => {
error!("uri {} is invalid", req.url);
error!(uri = req.url, "uri is invalid");
return DispatchResponse::BadData.to_string();
}
};

View file

@ -13,8 +13,9 @@ futures = { workspace = true }
lazy_static = { workspace = true }
libpk = { path = "../libpk" }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }

View file

@ -1,19 +1,24 @@
use axum::{
extract::{Path, State},
extract::{ConnectInfo, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
routing::{delete, get, post},
Router,
};
use libpk::runtime_config::RuntimeConfig;
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::Id;
use twilight_model::id::{marker::ChannelMarker, Id};
use crate::discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
use crate::{
discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
shard_state::ShardStateManager,
},
event_awaiter::{AwaitEventRequest, EventAwaiter},
};
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};
fn status_code(code: StatusCode, body: String) -> Response {
(code, body).into_response()
@ -21,10 +26,15 @@ fn status_code(code: StatusCode, body: String) -> Response {
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
pub async fn run_server(cache: Arc<DiscordCache>, shard_state: Arc<ShardStateManager>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
// hacky fix for `move`
let runtime_config_for_post = runtime_config.clone();
let runtime_config_for_delete = runtime_config.clone();
let awaiter_for_clear = awaiter.clone();
let app = Router::new()
.route(
"/guilds/:guild_id",
"/guilds/{guild_id}",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild(Id::new(guild_id)) {
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
@ -33,7 +43,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/members/@me",
"/guilds/{guild_id}/members/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
@ -42,7 +52,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/permissions/@me",
"/guilds/{guild_id}/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
Ok(val) => {
@ -56,7 +66,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/permissions/:user_id",
"/guilds/{guild_id}/permissions/{user_id}",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
@ -69,7 +79,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
)
.route(
"/guilds/:guild_id/channels",
"/guilds/{guild_id}/channels",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
Some(channels) => channels.to_owned(),
@ -95,7 +105,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
})
)
.route(
"/guilds/:guild_id/channels/:channel_id",
"/guilds/{guild_id}/channels/{channel_id}",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
@ -107,7 +117,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
})
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
"/guilds/{guild_id}/channels/{channel_id}/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
@ -122,16 +132,19 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
"/guilds/{guild_id}/channels/{channel_id}/permissions/{user_id}",
get(|| async { "todo" }),
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
"/guilds/{guild_id}/channels/{channel_id}/last_message",
get(|State(cache): State<Arc<DiscordCache>>, Path((_guild_id, channel_id)): Path<(u64, Id<ChannelMarker>)>| async move {
let lm = cache.get_last_message(channel_id).await;
status_code(StatusCode::FOUND, to_string(&lm).unwrap())
}),
)
.route(
"/guilds/:guild_id/roles",
"/guilds/{guild_id}/roles",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
Some(roles) => roles.to_owned(),
@ -171,13 +184,45 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
}))
.route("/runtime_config", get(|| async move {
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
}))
.route("/runtime_config/{key}", post(|Path(key): Path<String>, body: String| async move {
let runtime_config = runtime_config_for_post;
runtime_config.set(key, body).await.expect("failed to update runtime config");
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
}))
.route("/runtime_config/{key}", delete(|Path(key): Path<String>| async move {
let runtime_config = runtime_config_for_delete;
runtime_config.delete(key).await.expect("failed to update runtime config");
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
}))
.route("/await_event", post(|ConnectInfo(addr): ConnectInfo<SocketAddr>, body: String| async move {
info!("got request: {body} from: {addr}");
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
return status_code(StatusCode::BAD_REQUEST, "".to_string());
};
awaiter.handle_request(req, addr).await;
status_code(StatusCode::NO_CONTENT, "".to_string())
}))
.route("/clear_awaiter", post(|| async move {
awaiter_for_clear.clear().await;
status_code(StatusCode::NO_CONTENT, "".to_string())
}))
.route("/shard_status", get(|| async move {
status_code(StatusCode::FOUND, to_string(&shard_state.get().await).unwrap())
}))
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
Ok(())
}

View file

@ -1,6 +1,7 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
model::CachedMember,
@ -8,11 +9,12 @@ use twilight_cache_inmemory::{
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_gateway::Event;
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
Id,
},
};
@ -123,16 +125,134 @@ pub fn new() -> DiscordCache {
.build(),
);
DiscordCache(cache, client, RwLock::new(Vec::new()))
DiscordCache(
cache,
client,
RwLock::new(Vec::new()),
RwLock::new(HashMap::new()),
)
}
#[derive(Clone, Serialize)]
pub struct CachedMessage {
id: Id<MessageMarker>,
referenced_message: Option<Id<MessageMarker>>,
author_username: String,
}
#[derive(Clone, Serialize)]
pub struct LastMessageCacheEntry {
pub current: CachedMessage,
pub previous: Option<CachedMessage>,
}
pub struct DiscordCache(
pub Arc<InMemoryCache>,
pub Arc<twilight_http::Client>,
pub RwLock<Vec<u32>>,
pub RwLock<HashMap<Id<ChannelMarker>, LastMessageCacheEntry>>,
);
impl DiscordCache {
pub async fn get_last_message(
&self,
channel: Id<ChannelMarker>,
) -> Option<LastMessageCacheEntry> {
self.3.read().await.get(&channel).cloned()
}
pub async fn update(&self, event: &twilight_gateway::Event) {
self.0.update(event);
match event {
Event::MessageCreate(m) => match self.3.write().await.entry(m.channel_id) {
std::collections::hash_map::Entry::Occupied(mut e) => {
let cur = e.get();
e.insert(LastMessageCacheEntry {
current: CachedMessage {
id: m.id,
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
author_username: m.author.name.clone(),
},
previous: Some(cur.current.clone()),
});
}
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(LastMessageCacheEntry {
current: CachedMessage {
id: m.id,
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
author_username: m.author.name.clone(),
},
previous: None,
});
}
},
Event::MessageDelete(m) => {
self.handle_message_deletion(m.channel_id, vec![m.id]).await;
}
Event::MessageDeleteBulk(m) => {
self.handle_message_deletion(m.channel_id, m.ids.clone())
.await;
}
_ => {}
};
}
async fn handle_message_deletion(
&self,
channel_id: Id<ChannelMarker>,
mids: Vec<Id<MessageMarker>>,
) {
let mut lm = self.3.write().await;
let Some(entry) = lm.get(&channel_id) else {
return;
};
let mut entry = entry.clone();
// if none of the deleted messages are relevant, just return
if !mids.contains(&entry.current.id)
&& entry
.previous
.clone()
.map(|v| !mids.contains(&v.id))
.unwrap_or(false)
{
return;
}
// remove "previous" entry if it was deleted
if let Some(prev) = entry.previous.clone()
&& mids.contains(&prev.id)
{
entry.previous = None;
}
// set "current" entry to "previous" if current entry was deleted
// (if the "previous" entry still exists, it was not deleted)
if let Some(prev) = entry.previous.clone()
&& mids.contains(&entry.current.id)
{
entry.current = prev;
entry.previous = None;
}
// if the current entry was already deleted, but previous wasn't,
// we would've set current to previous
// so if current is deleted this means both current and previous have
// been deleted
// so just drop the cache entry here
if mids.contains(&entry.current.id) && entry.previous.is_none() {
lm.remove(&channel_id);
return;
}
// ok, update the entry
lm.insert(channel_id, entry.clone());
}
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,
@ -356,12 +476,14 @@ impl DiscordCache {
system_channel_flags: guild.system_channel_flags(),
system_channel_id: guild.system_channel_id(),
threads: vec![],
unavailable: false,
unavailable: Some(false),
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
verification_level: guild.verification_level(),
voice_states: vec![],
widget_channel_id: guild.widget_channel_id(),
widget_enabled: guild.widget_enabled(),
guild_scheduled_events: guild.guild_scheduled_events().to_vec(),
max_stage_video_channel_users: guild.max_stage_video_channel_users(),
}
})
}

View file

@ -1,7 +1,9 @@
use anyhow::anyhow;
use futures::StreamExt;
use libpk::_config::ClusterSettings;
use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig, state::ShardStateEvent};
use metrics::counter;
use std::sync::{mpsc::Sender, Arc};
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tracing::{error, info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
@ -12,9 +14,12 @@ use twilight_model::gateway::{
Intents,
};
use crate::discord::identify_queue::{self, RedisQueue};
use crate::{
discord::identify_queue::{self, RedisQueue},
RUNTIME_CONFIG_KEY_EVENT_TARGET,
};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
use super::cache::DiscordCache;
pub fn cluster_config() -> ClusterSettings {
libpk::config
@ -44,6 +49,12 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
warn!("we have less than 16 shards, assuming single gateway process");
if cluster_settings.node_id != 0 {
return Err(anyhow!(
"expecting to be node 0 in single-process mode, but we are node {}",
cluster_settings.node_id
));
}
(0, (cluster_settings.total_shards - 1).into())
} else {
(
@ -52,6 +63,13 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
)
};
let prefix = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_prefix_for_gateway
.clone();
let shards = create_iterator(
start_shard..end_shard + 1,
cluster_settings.total_shards,
@ -64,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
.to_owned(),
intents,
)
.presence(presence("pk;help", false))
.presence(presence(format!("{prefix}help").as_str(), false))
.queue(queue.clone())
.build(),
|_, builder| builder.build(),
@ -76,15 +94,23 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
Ok(shards_vec)
}
#[tracing::instrument(fields(shard = %shard.id()), skip_all)]
pub async fn runner(
mut shard: Shard<RedisQueue>,
_tx: Sender<(ShardId, String)>,
shard_state: ShardStateManager,
tx: Sender<(ShardId, Event, String)>,
tx_state: Sender<(ShardId, ShardStateEvent, Option<Event>, Option<i32>)>,
cache: Arc<DiscordCache>,
runtime_config: Arc<RuntimeConfig>,
) {
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
let shard_id = shard.id().number();
let our_user_id = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.client_id;
info!("waiting for events");
while let Some(item) = shard.next().await {
let raw_event = match item {
@ -105,7 +131,9 @@ pub async fn runner(
)
.increment(1);
if let Err(error) = shard_state.socket_closed(shard_id).await {
if let Err(error) =
tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None))
{
error!("failed to update shard state for socket closure: {error}");
}
@ -127,7 +155,7 @@ pub async fn runner(
continue;
}
Err(error) => {
error!("shard {shard_id} failed to parse gateway event: {error}");
error!(?error, ?shard_id, "failed to parse gateway event");
continue;
}
};
@ -147,14 +175,31 @@ pub async fn runner(
.increment(1);
// update shard state and discord cache
if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await {
tracing::error!(?error, "error updating redis state");
if matches!(event, Event::Ready(_)) || matches!(event, Event::Resumed) {
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Other,
Some(event.clone()),
None,
)) {
tracing::error!(?error, "error updating shard state");
}
}
// need to do heartbeat separately, to get the latency
let latency_num = shard
.latency()
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
if let Event::GatewayHeartbeatAck = event
&& let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await
&& let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Heartbeat,
Some(event.clone()),
Some(latency_num),
))
{
tracing::error!(?error, "error updating redis state for latency");
tracing::error!(?error, "error updating shard state for latency");
}
if let Event::Ready(_) = event {
@ -162,10 +207,28 @@ pub async fn runner(
cache.2.write().await.push(shard_id);
}
}
cache.0.update(&event);
cache.update(&event).await;
// okay, we've handled the event internally, let's send it to consumers
// tx.send((shard.id(), raw_event)).unwrap();
// some basic filtering here is useful
// we can't use if matching using the | operator, so anything matched does nothing
// and the default match skips the next block (continues to the next event)
match event {
Event::InteractionCreate(_) => {}
Event::MessageCreate(ref m) if m.author.id != our_user_id => {}
Event::MessageUpdate(ref m) if m.author.id != our_user_id && !m.author.bot => {}
Event::MessageDelete(_) => {}
Event::MessageDeleteBulk(_) => {}
Event::ReactionAdd(ref r) if r.user_id != our_user_id => {}
_ => {
continue;
}
}
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
tx.send((shard.id(), event, raw_event)).await.unwrap();
}
}
}

View file

@ -78,8 +78,8 @@ async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: on
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
Err(error) => {
error!(?error, ?shard_id, ?bucket, "error getting shard allowance")
}
}

View file

@ -1,49 +1,63 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge};
use tokio::sync::RwLock;
use tracing::info;
use twilight_gateway::{Event, Latency};
use twilight_gateway::Event;
use std::collections::HashMap;
use libpk::state::ShardState;
#[derive(Clone)]
use super::gateway::cluster_config;
pub struct ShardStateManager {
redis: RedisPool,
shards: RwLock<HashMap<u32, ShardState>>,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
ShardStateManager {
redis: redis,
shards: RwLock::new(HashMap::new()),
}
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
// also update gateway.rs with event types
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
_ => Ok(()),
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<String> = self.redis.hget("pluralkit:shardstatus", shard_id).await?;
match data {
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
None => Ok(ShardState::default()),
async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> {
{
let mut shards = self.shards.write().await;
shards.insert(id, state.clone());
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
self.redis
.hset::<(), &str, (String, String)>(
"pluralkit:shardstatus",
(
shard_id.to_string(),
serde_json::to_string(&info).expect("could not serialize shard"),
id.to_string(),
serde_json::to_string(&state).expect("could not serialize shard"),
),
)
.await?;
Ok(())
}
async fn get_shard(&self, id: u32) -> Option<ShardState> {
let shards = self.shards.read().await;
shards.get(&id).cloned()
}
pub async fn get(&self) -> Vec<ShardState> {
self.shards.read().await.values().cloned().collect()
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!(
"shard {} {}",
@ -57,32 +71,52 @@ impl ShardStateManager {
)
.increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1);
let mut info = self.get_shard(shard_id).await?;
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let mut info = self.get_shard(shard_id).await?;
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string())
.set(info.latency);
info.latency = latency;
self.save_shard(shard_id, info).await?;
Ok(())
}

View file

@ -0,0 +1,242 @@
// - reaction: (message_id, user_id)
// - message: (author_id, channel_id, ?options)
// - interaction: (custom_id where not_includes "help-menu")
use std::{
collections::{hash_map::Entry, HashMap},
net::{IpAddr, SocketAddr},
time::Duration,
};
use serde::Deserialize;
use tokio::{sync::RwLock, time::Instant};
use tracing::info;
use twilight_gateway::Event;
use twilight_model::{
application::interaction::InteractionData,
id::{
marker::{ChannelMarker, MessageMarker, UserMarker},
Id,
},
};
static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15);
#[derive(Deserialize)]
#[serde(untagged)]
pub enum AwaitEventRequest {
Reaction {
message_id: Id<MessageMarker>,
user_id: Id<UserMarker>,
target: String,
timeout: Option<u64>,
},
Message {
channel_id: Id<ChannelMarker>,
author_id: Id<UserMarker>,
target: String,
timeout: Option<u64>,
options: Option<Vec<String>>,
},
Interaction {
id: String,
target: String,
timeout: Option<u64>,
},
}
pub struct EventAwaiter {
reactions: RwLock<HashMap<(Id<MessageMarker>, Id<UserMarker>), (Instant, String)>>,
messages: RwLock<
HashMap<(Id<ChannelMarker>, Id<UserMarker>), (Instant, String, Option<Vec<String>>)>,
>,
interactions: RwLock<HashMap<String, (Instant, String)>>,
}
impl EventAwaiter {
pub fn new() -> Self {
let v = Self {
reactions: RwLock::new(HashMap::new()),
messages: RwLock::new(HashMap::new()),
interactions: RwLock::new(HashMap::new()),
};
v
}
pub async fn cleanup_loop(&self) {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
info!("running event_awaiter cleanup loop");
let mut counts = (0, 0, 0);
let now = Instant::now();
{
let mut reactions = self.reactions.write().await;
for key in reactions.clone().keys() {
if let Entry::Occupied(entry) = reactions.entry(key.clone())
&& entry.get().0 < now
{
counts.0 += 1;
entry.remove();
}
}
}
{
let mut messages = self.messages.write().await;
for key in messages.clone().keys() {
if let Entry::Occupied(entry) = messages.entry(key.clone())
&& entry.get().0 < now
{
counts.1 += 1;
entry.remove();
}
}
}
{
let mut interactions = self.interactions.write().await;
for key in interactions.clone().keys() {
if let Entry::Occupied(entry) = interactions.entry(key.clone())
&& entry.get().0 < now
{
counts.2 += 1;
entry.remove();
}
}
}
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
}
}
pub async fn target_for_event(&self, event: Event) -> Option<String> {
match event {
Event::MessageCreate(message) => {
let mut messages = self.messages.write().await;
messages
.remove(&(message.channel_id, message.author.id))
.map(|(timeout, target, options)| {
if let Some(options) = options
&& !options.contains(&message.content.to_lowercase())
{
messages.insert(
(message.channel_id, message.author.id),
(timeout, target, Some(options)),
);
return None;
}
Some((*target).to_string())
})?
}
Event::ReactionAdd(reaction)
if let Some((_, target)) = self
.reactions
.write()
.await
.remove(&(reaction.message_id, reaction.user_id)) =>
{
Some((*target).to_string())
}
Event::InteractionCreate(interaction)
if let Some(data) = interaction.data.clone()
&& let InteractionData::MessageComponent(component) = data
&& !component.custom_id.contains("help-menu")
&& let Some((_, target)) =
self.interactions.write().await.remove(&component.custom_id) =>
{
Some((*target).to_string())
}
_ => None,
}
}
pub async fn handle_request(&self, req: AwaitEventRequest, addr: SocketAddr) {
match req {
AwaitEventRequest::Reaction {
message_id,
user_id,
target,
timeout,
} => {
self.reactions.write().await.insert(
(message_id, user_id),
(
Instant::now()
.checked_add(
timeout
.map(|i| Duration::from_secs(i))
.unwrap_or(DEFAULT_TIMEOUT),
)
.expect("invalid time"),
target_or_addr(target, addr),
),
);
}
AwaitEventRequest::Message {
channel_id,
author_id,
target,
timeout,
options,
} => {
self.messages.write().await.insert(
(channel_id, author_id),
(
Instant::now()
.checked_add(
timeout
.map(|i| Duration::from_secs(i))
.unwrap_or(DEFAULT_TIMEOUT),
)
.expect("invalid time"),
target_or_addr(target, addr),
options,
),
);
}
AwaitEventRequest::Interaction {
id,
target,
timeout,
} => {
self.interactions.write().await.insert(
id,
(
Instant::now()
.checked_add(
timeout
.map(|i| Duration::from_secs(i))
.unwrap_or(DEFAULT_TIMEOUT),
)
.expect("invalid time"),
target_or_addr(target, addr),
),
);
}
}
}
pub async fn clear(&self) {
self.reactions.write().await.clear();
self.messages.write().await.clear();
self.interactions.write().await.clear();
}
}
fn target_or_addr(target: String, addr: SocketAddr) -> String {
if target == "source-addr" {
let ip_str = match addr.ip() {
IpAddr::V4(v4) => v4.to_string(),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
v4.to_string()
} else {
format!("[{v6}]")
}
}
};
format!("http://{ip_str}:5002/events")
} else {
target
}
}

View file

@ -1,39 +1,79 @@
#![feature(let_chains)]
#![feature(if_let_guard)]
#![feature(duration_constructors)]
use chrono::Timelike;
use discord::gateway::cluster_config;
use event_awaiter::EventAwaiter;
use fred::{clients::RedisPool, interfaces::*};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
use reqwest::{ClientBuilder, StatusCode};
use std::{sync::Arc, time::Duration, vec::Vec};
use tokio::{
signal::unix::{signal, SignalKind},
sync::mpsc::channel,
task::JoinSet,
};
use std::{
sync::{mpsc::channel, Arc},
time::Duration,
vec::Vec,
};
use tokio::task::JoinSet;
use tracing::{info, warn};
use tracing::{error, info, warn};
use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence;
mod cache_api;
mod api;
mod discord;
mod event_awaiter;
mod logger;
libpk::main!("gateway");
async fn real_main() -> anyhow::Result<()> {
let (shutdown_tx, shutdown_rx) = channel::<()>();
let shutdown_tx = Arc::new(shutdown_tx);
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let redis = libpk::db::init_redis().await?;
let shard_state = discord::shard_state::new(redis.clone());
let runtime_config = Arc::new(
RuntimeConfig::new(
redis.clone(),
format!(
"{}:{}",
libpk::config.runtime_config_key.as_ref().unwrap(),
cluster_config().node_id
),
)
.await?,
);
// hacky, but needed for selfhost for now
if let Some(target) = libpk::config
.discord
.as_ref()
.unwrap()
.gateway_target
.clone()
{
runtime_config
.set(RUNTIME_CONFIG_KEY_EVENT_TARGET.to_string(), target)
.await?;
}
let cache = Arc::new(discord::cache::new());
let awaiter = Arc::new(EventAwaiter::new());
tokio::spawn({
let awaiter = awaiter.clone();
async move { awaiter.cleanup_loop().await }
});
let shards = discord::gateway::create_shards(redis.clone())?;
let (event_tx, _event_rx) = channel();
// arbitrary
// todo: make sure this doesn't fill up
let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000);
// todo: make sure this doesn't fill up
let (state_tx, mut state_rx) = channel::<(
ShardId,
ShardStateEvent,
Option<twilight_gateway::Event>,
Option<i32>,
)>(1000);
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
@ -45,61 +85,160 @@ async fn real_main() -> anyhow::Result<()> {
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
state_tx.clone(),
cache.clone(),
runtime_config.clone(),
)));
}
let shard_state = Arc::new(discord::shard_state::new(redis.clone()));
set.spawn(tokio::spawn({
let shard_state = shard_state.clone();
async move {
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
match state_event {
ShardStateEvent::Heartbeat => {
if !latency.is_none()
&& let Err(error) = shard_state
.heartbeated(shard_id.number(), latency.unwrap())
.await
{
error!("failed to update shard state for heartbeat: {error}")
};
}
ShardStateEvent::Closed => {
if let Err(error) = shard_state.socket_closed(shard_id.number()).await {
error!("failed to update shard state for heartbeat: {error}")
};
}
ShardStateEvent::Other => {
if let Err(error) = shard_state
.handle_event(
shard_id.number(),
parsed_event.expect("shard state event not provided!"),
)
.await
{
error!("failed to update shard state for heartbeat: {error}")
};
}
}
}
}
}));
set.spawn(tokio::spawn({
let runtime_config = runtime_config.clone();
let awaiter = awaiter.clone();
async move {
let client = Arc::new(
ClientBuilder::new()
.connect_timeout(Duration::from_secs(1))
.timeout(Duration::from_secs(1))
.build()
.expect("error making client"),
);
while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await {
let target = if let Some(target) = awaiter.target_for_event(parsed_event).await {
info!(target = ?target, "sending event to awaiter");
Some(target)
} else if let Some(target) =
runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await
{
Some(target)
} else {
None
};
if let Some(target) = target {
tokio::spawn({
let client = client.clone();
async move {
match client
.post(format!("{target}/{}", shard_id.number()))
.body(raw_event)
.send()
.await
{
Ok(res) => {
if res.status() != StatusCode::OK {
error!(
status = ?res.status(),
target = ?target,
"got non-200 from bot while sending event",
);
}
}
Err(error) => {
error!(?error, "failed to request event target");
}
}
}
});
}
}
}
}));
set.spawn(tokio::spawn(
async move { scheduled_task(redis, senders).await },
));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache).await {
match api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await {
Err(error) => {
tracing::error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
error!(?error, "failed to serve cache api");
}
_ => unreachable!(),
}
}));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
set.spawn(tokio::spawn(async move {
for sig in signals.forever() {
info!("received signal {:?}", sig);
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(());
break;
}
signal(SignalKind::interrupt()).unwrap().recv().await;
info!("got SIGINT");
}));
let _ = shutdown_rx.recv();
set.spawn(tokio::spawn(async move {
signal(SignalKind::terminate()).unwrap().recv().await;
info!("got SIGTERM");
}));
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
set.join_next().await;
info!("gateway exiting, have a nice day!");
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
set.abort_all();
info!("gateway exiting, have a nice day!");
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
Ok(())
}
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
let prefix = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_prefix_for_gateway
.clone();
println!("{prefix}");
loop {
tokio::time::sleep(Duration::from_secs(
(60 - chrono::offset::Utc::now().second()).into(),
@ -119,9 +258,9 @@ async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence(
if let Some(status) = status {
format!("pk;help | {}", status)
format!("{prefix}help | {status}")
} else {
"pk;help".to_string()
format!("{prefix}help")
}
.as_str(),
false,

View file

@ -0,0 +1,15 @@
[package]
name = "gdpr_worker"
version = "0.1.0"
edition = "2021"
[dependencies]
libpk = { path = "../libpk" }
anyhow = { workspace = true }
axum = { workspace = true }
futures = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
twilight-model = { workspace = true }

View file

@ -0,0 +1,149 @@
#![feature(let_chains)]
use sqlx::prelude::FromRow;
use std::{sync::Arc, time::Duration};
use tracing::{error, info, warn};
use twilight_http::api_error::{ApiError, GeneralApiError};
use twilight_model::id::{
marker::{ChannelMarker, MessageMarker},
Id,
};
// create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null);
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_messages_db().await?;
let mut client_builder = twilight_http::Client::builder()
.token(
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_token
.clone(),
)
.timeout(Duration::from_secs(30));
if let Some(base_url) = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.api_base_url
.clone()
{
client_builder = client_builder.proxy(base_url, true).ratelimiter(None);
}
let client = Arc::new(client_builder.build());
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
match run_job(db.clone(), client.clone()).await {
Ok(()) => {}
Err(error) => {
error!(?error, "failed to run messages gdpr job");
}
}
}
}
#[derive(FromRow)]
struct GdprJobEntry {
mid: i64,
channel_id: i64,
}
async fn run_job(pool: sqlx::PgPool, discord: Arc<twilight_http::Client>) -> anyhow::Result<()> {
let mut tx = pool.begin().await?;
let message: Option<GdprJobEntry> = sqlx::query_as(
"select mid, channel_id from messages_gdpr_jobs for update skip locked limit 1;",
)
.fetch_optional(&mut *tx)
.await?;
let Some(message) = message else {
info!("no job to run, sleeping for 1 minute");
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
return Ok(());
};
info!("got mid={}, cleaning up...", message.mid);
// naively delete message on discord's end
let res = discord
.delete_message(
Id::<ChannelMarker>::new(message.channel_id as u64),
Id::<MessageMarker>::new(message.mid as u64),
)
.await;
if res.is_ok() {
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
.bind(message.mid)
.execute(&mut *tx)
.await?;
}
if let Err(err) = res {
if let twilight_http::error::ErrorType::Response { error, status, .. } = err.kind()
&& let ApiError::General(GeneralApiError { code, .. }) = error
{
match (status.get(), code) {
(403, _) => {
warn!(
"got 403 while deleting message in channel {}, failing fast",
message.channel_id
);
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
.bind(message.channel_id)
.execute(&mut *tx)
.await?;
}
(_, 10003) => {
warn!(
"deleting message in channel {}: channel not found, failing fast",
message.channel_id
);
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
.bind(message.channel_id)
.execute(&mut *tx)
.await?;
}
(_, 10008) => {
warn!("deleting message {}: message not found", message.mid);
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
.bind(message.mid)
.execute(&mut *tx)
.await?;
}
(_, 50083) => {
warn!(
"could not delete message in thread {}: thread is archived, failing fast",
message.channel_id
);
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
.bind(message.channel_id)
.execute(&mut *tx)
.await?;
}
_ => {
error!(
?status,
?code,
message_id = message.mid,
"got unknown error deleting message",
);
}
}
} else {
return Err(err.into());
}
}
tx.commit().await?;
return Ok(());
}

0
crates/h Normal file
View file

View file

@ -8,6 +8,7 @@ anyhow = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
pk_macros = { path = "../macros" }
sentry = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View file

@ -12,10 +12,16 @@ pub struct ClusterSettings {
pub total_nodes: u32,
}
fn _default_bot_prefix() -> String {
"pk;".to_string()
}
#[derive(Deserialize, Debug)]
pub struct DiscordConfig {
pub client_id: Id<UserMarker>,
pub bot_token: String,
#[serde(default = "_default_bot_prefix")]
pub bot_prefix_for_gateway: String,
pub client_secret: String,
pub max_concurrency: u32,
#[serde(default)]
@ -24,6 +30,9 @@ pub struct DiscordConfig {
#[serde(default = "_default_api_addr")]
pub cache_api_addr: String,
#[serde(default)]
pub gateway_target: Option<String>,
}
#[derive(Deserialize, Debug)]
@ -85,6 +94,7 @@ pub struct ScheduledTasksConfig {
pub set_guild_count: bool,
pub expected_gateway_count: usize,
pub gateway_url: String,
pub prometheus_url: String,
}
fn _metrics_default() -> bool {
@ -113,6 +123,9 @@ pub struct PKConfig {
#[serde(default = "_json_log_default")]
pub(crate) json_log: bool,
#[serde(default)]
pub runtime_config_key: Option<String>,
#[serde(default)]
pub sentry_url: Option<String>,
}
@ -132,10 +145,15 @@ impl PKConfig {
lazy_static! {
#[derive(Debug)]
pub static ref CONFIG: Arc<PKConfig> = {
// hacks
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
}
if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap());
}
Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))

View file

@ -8,12 +8,15 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
use sentry_tracing::event_from_event;
pub mod db;
pub mod runtime_config;
pub mod state;
pub mod _config;
pub use crate::_config::CONFIG as config;
// functions in this file are only used by the main function below
// functions in this file are only used by the main function in macros/entrypoint.rs
pub use pk_macros::main;
pub fn init_logging(component: &str) {
let sentry_layer =
@ -42,6 +45,7 @@ pub fn init_logging(component: &str) {
tracing_subscriber::registry()
.with(sentry_layer)
.with(tracing_subscriber::fmt::layer())
.with(EnvFilter::from_default_env())
.init();
}
}
@ -66,28 +70,3 @@ pub fn init_sentry() -> sentry::ClientInitGuard {
..Default::default()
})
}
#[macro_export]
macro_rules! main {
($component:expr) => {
fn main() -> anyhow::Result<()> {
let _sentry_guard = libpk::init_sentry();
// we might also be able to use env!("CARGO_CRATE_NAME") here
libpk::init_logging($component);
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
if let Err(err) = libpk::init_metrics() {
tracing::error!("failed to init metrics collector: {err}");
};
tracing::info!("hello world");
if let Err(err) = real_main().await {
tracing::error!("failed to run service: {err}");
};
});
Ok(())
}
};
}

View file

@ -0,0 +1,72 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use std::collections::HashMap;
use tokio::sync::RwLock;
use tracing::info;
pub struct RuntimeConfig {
redis: RedisPool,
settings: RwLock<HashMap<String, String>>,
redis_key: String,
}
impl RuntimeConfig {
pub async fn new(redis: RedisPool, component_key: String) -> anyhow::Result<Self> {
let redis_key = format!("remote_config:{component_key}");
let mut c = RuntimeConfig {
redis,
settings: RwLock::new(HashMap::new()),
redis_key,
};
c.load().await?;
Ok(c)
}
pub async fn load(&mut self) -> anyhow::Result<()> {
let redis_config: HashMap<String, String> = self.redis.hgetall(&self.redis_key).await?;
let mut settings = self.settings.write().await;
for (key, value) in redis_config {
settings.insert(key, value);
}
info!("starting with runtime config: {:?}", settings);
Ok(())
}
pub async fn set(&self, key: String, value: String) -> anyhow::Result<()> {
self.redis
.hset::<(), &str, (String, String)>(&self.redis_key, (key.clone(), value.clone()))
.await?;
self.settings
.write()
.await
.insert(key.clone(), value.clone());
info!("updated runtime config: {key}={value}");
Ok(())
}
pub async fn delete(&self, key: String) -> anyhow::Result<()> {
self.redis
.hdel::<(), &str, String>(&self.redis_key, key.clone())
.await?;
self.settings.write().await.remove(&key.clone());
info!("updated runtime config: {key} removed");
Ok(())
}
pub async fn get(&self, key: &str) -> Option<String> {
self.settings.read().await.get(key).cloned()
}
pub async fn exists(&self, key: &str) -> bool {
self.settings.read().await.contains_key(key)
}
pub async fn get_all(&self) -> HashMap<String, String> {
self.settings.read().await.clone()
}
}

View file

@ -1,4 +1,4 @@
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)]
pub struct ShardState {
pub shard_id: i32,
pub up: bool,
@ -10,3 +10,9 @@ pub struct ShardState {
pub last_connection: i32,
pub cluster_id: Option<i32>,
}
pub enum ShardStateEvent {
Closed,
Heartbeat,
Other,
}

View file

@ -1,5 +1,5 @@
[package]
name = "model_macros"
name = "pk_macros"
version = "0.1.0"
edition = "2021"

View file

@ -0,0 +1,41 @@
use proc_macro::{Delimiter, TokenTree};
use quote::quote;
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
// yes, this ignores everything except the codeblock
// it's fine.
let body = match input.into_iter().last().expect("empty") {
TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => group.stream(),
_ => panic!("invalid function"),
};
let body = proc_macro2::TokenStream::from(body);
return quote! {
fn main() {
let _sentry_guard = libpk::init_sentry();
libpk::init_logging(env!("CARGO_CRATE_NAME"));
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
if let Err(error) = libpk::init_metrics() {
tracing::error!(?error, "failed to init metrics collector");
};
tracing::info!("hello world");
let result: anyhow::Result<()> = async { #body }.await;
if let Err(error) = result {
tracing::error!(?error, "failed to run service");
};
});
}
}
.into();
}

14
crates/macros/src/lib.rs Normal file
View file

@ -0,0 +1,14 @@
use proc_macro::TokenStream;
mod entrypoint;
mod model;
#[proc_macro_attribute]
pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
entrypoint::macro_impl(args, input)
}
#[proc_macro_attribute]
pub fn pk_model(args: TokenStream, input: TokenStream) -> TokenStream {
model::macro_impl(args, input)
}

View file

@ -16,6 +16,7 @@ struct ModelField {
patch: ElemPatchability,
json: Option<Expr>,
is_privacy: bool,
privacy: Option<Expr>,
default: Option<Expr>,
}
@ -26,6 +27,7 @@ fn parse_field(field: syn::Field) -> ModelField {
patch: ElemPatchability::None,
json: None,
is_privacy: false,
privacy: None,
default: None,
};
@ -61,6 +63,12 @@ fn parse_field(field: syn::Field) -> ModelField {
}
f.json = Some(nv.value.clone());
}
"privacy" => {
if f.privacy.is_some() {
panic!("cannot set privacy multiple times for same field");
}
f.privacy = Some(nv.value.clone());
}
"default" => {
if f.default.is_some() {
panic!("cannot set default multiple times for same field");
@ -84,8 +92,7 @@ fn parse_field(field: syn::Field) -> ModelField {
f
}
#[proc_macro_attribute]
pub fn pk_model(
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
@ -108,8 +115,6 @@ pub fn pk_model(
panic!("fields of a struct must be named");
};
// println!("{}: {:#?}", tname, fields);
let tfields = mk_tfields(fields.clone());
let from_json = mk_tfrom_json(fields.clone());
let _from_sql = mk_tfrom_sql(fields.clone());
@ -138,9 +143,7 @@ pub fn pk_model(
#from_json
}
pub fn to_json(self) -> serde_json::Value {
#to_json
}
#to_json
}
#[derive(Debug, Clone)]
@ -189,19 +192,28 @@ fn mk_tfrom_sql(_fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
// todo: check privacy access
let has_privacy = fields.iter().any(|f| f.privacy.is_some());
let fielddefs: TokenStream = fields
.iter()
.filter_map(|f| {
f.json.as_ref().map(|v| {
let tname = f.name.clone();
if let Some(default) = f.default.as_ref() {
let maybepriv = if let Some(privacy) = f.privacy.as_ref() {
quote! {
#v: self.#tname.unwrap_or(#default),
#v: crate::_util::privacy_lookup!(self.#tname, self.#privacy, lookup_level)
}
} else {
quote! {
#v: self.#tname,
#v: self.#tname
}
};
if let Some(default) = f.default.as_ref() {
quote! {
#maybepriv.unwrap_or(#default),
}
} else {
quote! {
#maybepriv,
}
}
})
@ -223,13 +235,35 @@ fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
})
.collect();
quote! {
serde_json::json!({
#fielddefs
"privacy": {
#privacyfielddefs
let privdef = if has_privacy {
quote! {
, lookup_level: crate::PrivacyLevel
}
} else {
quote! {}
};
let privacy_fielddefs = if has_privacy {
quote! {
"privacy": if matches!(lookup_level, crate::PrivacyLevel::Private) {
Some(serde_json::json!({
#privacyfielddefs
}))
} else {
None
}
})
}
} else {
quote! {}
};
quote! {
pub fn to_json(self #privdef) -> serde_json::Value {
serde_json::json!({
#fielddefs
#privacy_fielddefs
})
}
}
}

12
crates/migrate/Cargo.toml Normal file
View file

@ -0,0 +1,12 @@
[package]
name = "migrate"
version = "0.1.0"
edition = "2021"
[dependencies]
libpk = { path = "../libpk" }
anyhow = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }

55
crates/migrate/build.rs Normal file
View file

@ -0,0 +1,55 @@
use std::{
env,
error::Error,
fs::{self, File},
io::Write,
path::Path,
};
fn main() -> Result<(), Box<dyn Error>> {
let out_dir = env::var("OUT_DIR")?;
let dest_path = Path::new(&out_dir).join("data.rs");
let mut datafile = File::create(&dest_path)?;
let prefix = "../../../../../../crates/migrate/data";
let ct = fs::read_dir("data/migrations")?
.filter(|p| {
p.as_ref()
.unwrap()
.file_name()
.into_string()
.unwrap()
.contains(".sql")
})
.count();
writeln!(&mut datafile, "const MIGRATIONS: [&'static str; {ct}] = [")?;
for idx in 0..ct {
writeln!(
&mut datafile,
"\tinclude_str!(\"{prefix}/migrations/{idx}.sql\"),"
)?;
}
writeln!(&mut datafile, "];\n")?;
writeln!(
&mut datafile,
"const CLEAN: &'static str = include_str!(\"{prefix}/clean.sql\");"
)?;
writeln!(
&mut datafile,
"const VIEWS: &'static str = include_str!(\"{prefix}/views.sql\");"
)?;
writeln!(
&mut datafile,
"const FUNCTIONS: &'static str = include_str!(\"{prefix}/functions.sql\");"
)?;
writeln!(
&mut datafile,
"const SEED: &'static str = include_str!(\"{prefix}/seed.sql\");"
)?;
Ok(())
}

View file

@ -0,0 +1,16 @@
-- This gets run on every bot startup and makes sure we're starting from a clean slate
-- Then, the views/functions.sql files get run, and they recreate the necessary objects
-- This does mean we can't use any functions in row triggers, etc. Still unsure how to handle this.
drop view if exists system_last_switch;
drop view if exists system_fronters;
drop view if exists member_list;
drop view if exists group_list;
drop function if exists message_context;
drop function if exists proxy_members;
drop function if exists has_private_members;
drop function if exists generate_hid;
drop function if exists find_free_system_hid;
drop function if exists find_free_member_hid;
drop function if exists find_free_group_hid;

View file

@ -0,0 +1,185 @@
create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
returns table (
allow_autoproxy bool,
system_id int,
system_tag text,
system_avatar text,
latch_timeout integer,
case_sensitive_proxy_tags bool,
proxy_error_message_enabled bool,
proxy_switch int,
name_format text,
tag_enabled bool,
proxy_enabled bool,
system_guild_tag text,
system_guild_avatar text,
guild_name_format text,
last_switch int,
last_switch_members int[],
last_switch_timestamp timestamp,
log_channel bigint,
in_blacklist bool,
in_log_blacklist bool,
log_cleanup_enabled bool,
require_system_tag bool,
suppress_notifications bool,
deny_bot_usage bool
)
as $$
select
-- accounts table
accounts.allow_autoproxy as allow_autoproxy,
-- systems table
systems.id as system_id,
systems.tag as system_tag,
systems.avatar_url as system_avatar,
-- system_config table
system_config.latch_timeout as latch_timeout,
system_config.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
system_config.proxy_error_message_enabled as proxy_error_message_enabled,
system_config.proxy_switch as proxy_switch,
system_config.name_format as name_format,
-- system_guild table
coalesce(system_guild.tag_enabled, true) as tag_enabled,
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
system_guild.tag as system_guild_tag,
system_guild.avatar_url as system_guild_avatar,
system_guild.name_format as guild_name_format,
-- system_last_switch view
system_last_switch.switch as last_switch,
system_last_switch.members as last_switch_members,
system_last_switch.timestamp as last_switch_timestamp,
-- servers table
servers.log_channel as log_channel,
((channel_id = any (servers.blacklist))
or (thread_id = any (servers.blacklist))) as in_blacklist,
((channel_id = any (servers.log_blacklist))
or (thread_id = any (servers.log_blacklist))) as in_log_blacklist,
coalesce(servers.log_cleanup_enabled, false) as log_cleanup_enabled,
coalesce(servers.require_system_tag, false) as require_system_tag,
coalesce(servers.suppress_notifications, false) as suppress_notifications,
-- abuse_logs table
coalesce(abuse_logs.deny_bot_usage, false) as deny_bot_usage
-- We need a "from" clause, so we just use some bogus data that's always present
-- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
from (select 1) as _placeholder
left join accounts on accounts.uid = account_id
left join servers on servers.id = guild_id
left join systems on systems.id = accounts.system
left join system_config on system_config.system = accounts.system
left join system_guild on system_guild.system = accounts.system
and system_guild.guild = guild_id
left join system_last_switch on system_last_switch.system = accounts.system
left join abuse_logs on abuse_logs.id = accounts.abuse_log
$$ language sql stable rows 1;
-- Fetches info about proxying related to a given account/guild
-- Returns one row per member in system, should be used in conjuction with `message_context` too
create function proxy_members(account_id bigint, guild_id bigint)
returns table (
id int,
proxy_tags proxy_tag[],
keep_proxy bool,
tts bool,
server_keep_proxy bool,
server_name text,
display_name text,
name text,
server_avatar text,
webhook_avatar text,
avatar text,
color char(6),
allow_autoproxy bool
)
as $$
select
-- Basic data
members.id as id,
members.proxy_tags as proxy_tags,
members.keep_proxy as keep_proxy,
members.tts as tts,
member_guild.keep_proxy as server_keep_proxy,
-- Name info
member_guild.display_name as server_name,
members.display_name as display_name,
members.name as name,
-- Avatar info
member_guild.avatar_url as server_avatar,
members.webhook_avatar_url as webhook_avatar,
members.avatar_url as avatar,
members.color as color,
members.allow_autoproxy as allow_autoproxy
from accounts
inner join systems on systems.id = accounts.system
inner join members on members.system = systems.id
left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id
where accounts.uid = account_id
$$ language sql stable rows 10;
create function has_private_members(system_hid int) returns bool as $$
declare m int;
begin
m := count(id) from members where system = system_hid and member_visibility = 2;
if m > 0 then return true;
else return false;
end if;
end
$$ language plpgsql;
create function generate_hid() returns char(6) as $$
select string_agg(substr('abcefghjknoprstuvwxyz', ceil(random() * 21)::integer, 1), '') from generate_series(1, 6)
$$ language sql volatile;
create function find_free_system_hid() returns char(6) as $$
declare new_hid char(6);
begin
loop
new_hid := generate_hid();
if not exists (select 1 from systems where hid = new_hid) then return new_hid; end if;
end loop;
end
$$ language plpgsql volatile;
create function find_free_member_hid() returns char(6) as $$
declare new_hid char(6);
begin
loop
new_hid := generate_hid();
if not exists (select 1 from members where hid = new_hid) then return new_hid; end if;
end loop;
end
$$ language plpgsql volatile;
create function find_free_group_hid() returns char(6) as $$
declare new_hid char(6);
begin
loop
new_hid := generate_hid();
if not exists (select 1 from groups where hid = new_hid) then return new_hid; end if;
end loop;
end
$$ language plpgsql volatile;

View file

@ -0,0 +1,112 @@
-- SCHEMA VERSION 0, 2019-12-26
-- "initial version", considered a "starting point" for the migrations
-- also the assumed database layout of someone either migrating from an older version of PK or starting a new instance,
-- so everything here *should* be idempotent given a schema version older than this or nonexistent.
-- Create proxy_tag compound type if it doesn't exist
do $$ begin
create type proxy_tag as (
prefix text,
suffix text
);
exception when duplicate_object then null;
end $$;
create table if not exists systems
(
id serial primary key,
hid char(5) unique not null,
name text,
description text,
tag text,
avatar_url text,
token text,
created timestamp not null default (current_timestamp at time zone 'utc'),
ui_tz text not null default 'UTC'
);
create table if not exists system_guild
(
system serial not null references systems (id) on delete cascade,
guild bigint not null,
proxy_enabled bool not null default true,
primary key (system, guild)
);
create table if not exists members
(
id serial primary key,
hid char(5) unique not null,
system serial not null references systems (id) on delete cascade,
color char(6),
avatar_url text,
name text not null,
display_name text,
birthday date,
pronouns text,
description text,
proxy_tags proxy_tag[] not null default array[]::proxy_tag[], -- Rationale on making this an array rather than a separate table - we never need to query them individually, only access them as part of a selected Member struct
keep_proxy bool not null default false,
created timestamp not null default (current_timestamp at time zone 'utc')
);
create table if not exists member_guild
(
member serial not null references members (id) on delete cascade,
guild bigint not null,
display_name text default null,
primary key (member, guild)
);
create table if not exists accounts
(
uid bigint primary key,
system serial not null references systems (id) on delete cascade
);
create table if not exists messages
(
mid bigint primary key,
channel bigint not null,
member serial not null references members (id) on delete cascade,
sender bigint not null,
original_mid bigint
);
create table if not exists switches
(
id serial primary key,
system serial not null references systems (id) on delete cascade,
timestamp timestamp not null default (current_timestamp at time zone 'utc')
);
create table if not exists switch_members
(
id serial primary key,
switch serial not null references switches (id) on delete cascade,
member serial not null references members (id) on delete cascade
);
create table if not exists webhooks
(
channel bigint primary key,
webhook bigint not null,
token text not null
);
create table if not exists servers
(
id bigint primary key,
log_channel bigint,
log_blacklist bigint[] not null default array[]::bigint[],
blacklist bigint[] not null default array[]::bigint[]
);
create index if not exists idx_switches_system on switches using btree (system asc nulls last) include ("timestamp");
create index if not exists idx_switch_members_switch on switch_members using btree (switch asc nulls last) include (member);
create index if not exists idx_message_member on messages (member);

View file

@ -0,0 +1,15 @@
-- SCHEMA VERSION 1: 2019-12-26
-- First version introducing the migration system, therefore we add the info/version table
create table info
(
id int primary key not null default 1, -- enforced only equal to 1
schema_version int,
constraint singleton check (id = 1) -- enforce singleton table/row
);
-- We do an insert here since we *just* added the table
-- Future migrations should do an update at the end
insert into info (schema_version) values (1);

View file

@ -0,0 +1,11 @@
-- SCHEMA VERSION 10: 2020-10-09 --
-- Member/group limit override per-system
alter table systems add column member_limit_override smallint default null;
alter table systems add column group_limit_override smallint default null;
-- Lowering global limit to 1000 in this commit, so increase it for systems already above that
update systems s set member_limit_override = 1500
where (select count(*) from members m where m.system = s.id) > 1000;
update info set schema_version = 10;

View file

@ -0,0 +1,10 @@
-- SCHEMA VERSION 11: 2020-10-23 --
-- Create command message table --
create table command_messages
(
message_id bigint primary key not null,
author_id bigint not null
);
update info set schema_version = 11;

View file

@ -0,0 +1,10 @@
-- SCHEMA VERSION 12: 2020-12-08 --
-- Add disabling front/latch autoproxy per-member --
-- Add disabling autoproxy per-account --
-- Add configurable latch timeout --
alter table members add column allow_autoproxy bool not null default true;
alter table accounts add column allow_autoproxy bool not null default true;
alter table systems add column latch_timeout int; -- in seconds
update info set schema_version = 12;

View file

@ -0,0 +1,7 @@
-- SCHEMA VERSION 13: 2021-03-28 --
-- Add system and group colors --
alter table systems add column color char(6);
alter table groups add column color char(6);
update info set schema_version = 13;

View file

@ -0,0 +1,15 @@
-- SCHEMA VERSION 14: 2021-06-10 --
-- Add shard status table --
create table shards (
id int not null primary key,
-- 0 = down, 1 = up
status smallint not null default 0,
ping float,
last_heartbeat timestamptz,
last_connection timestamptz
);
update info set schema_version = 14;

View file

@ -0,0 +1,8 @@
-- SCHEMA VERSION 15: 2021-08-01
-- add banner (large) images to entities with "cards"
alter table systems add column banner_image text;
alter table members add column banner_image text;
alter table groups add column banner_image text;
update info set schema_version = 15;

View file

@ -0,0 +1,7 @@
-- SCHEMA VERSION 16: 2021-08-02 --
-- Add server-specific system tag --
alter table system_guild add column tag text default null;
alter table system_guild add column tag_enabled bool not null default true;
update info set schema_version = 16;

View file

@ -0,0 +1,8 @@
-- schema version 17: 2021-09-26 --
-- add channel_id to command message table
alter table command_messages add column channel_id bigint;
update command_messages set channel_id = 0;
alter table command_messages alter column channel_id set not null;
update info set schema_version = 17;

View file

@ -0,0 +1,18 @@
-- schema version 18: 2021-09-26 --
-- Add UUIDs for APIs
create extension if not exists pgcrypto;
alter table systems add column uuid uuid default gen_random_uuid();
create index systems_uuid_idx on systems(uuid);
alter table members add column uuid uuid default gen_random_uuid();
create index members_uuid_idx on members(uuid);
alter table switches add column uuid uuid default gen_random_uuid();
create index switches_uuid_idx on switches(uuid);
alter table groups add column uuid uuid default gen_random_uuid();
create index groups_uuid_idx on groups(uuid);
update info set schema_version = 18;

View file

@ -0,0 +1,10 @@
-- schema version 19: 2021-10-15 --
-- add stats to info table
alter table info add column system_count int;
alter table info add column member_count int;
alter table info add column group_count int;
alter table info add column switch_count int;
alter table info add column message_count int;
update info set schema_version = 19;

View file

@ -0,0 +1,13 @@
-- We're doing a psuedo-enum here since Dapper is wonky with enums
-- Still getting mapped to enums at the CLR level, though.
-- https://github.com/StackExchange/Dapper/issues/332 (from 2015, still unsolved!)
-- 1 = "public"
-- 2 = "private"
-- not doing a bool here since I want to open up for the possibliity of other privacy levels (eg. "mutuals only")
alter table systems add column description_privacy integer check (description_privacy in (1, 2)) not null default 1;
alter table systems add column member_list_privacy integer check (member_list_privacy in (1, 2)) not null default 1;
alter table systems add column front_privacy integer check (front_privacy in (1, 2)) not null default 1;
alter table systems add column front_history_privacy integer check (front_history_privacy in (1, 2)) not null default 1;
alter table members add column member_privacy integer check (member_privacy in (1, 2)) not null default 1;
update info set schema_version = 2;

View file

@ -0,0 +1,7 @@
-- schema version 20: insert date
-- add outgoing webhook to systems
alter table systems add column webhook_url text;
alter table systems add column webhook_token text;
update info set schema_version = 20;

View file

@ -0,0 +1,29 @@
-- schema version 21
-- create `system_config` table
create table system_config (
system int primary key references systems(id) on delete cascade,
ui_tz text not null default 'UTC',
pings_enabled bool not null default true,
latch_timeout int,
member_limit_override int,
group_limit_override int
);
insert into system_config select
id as system,
ui_tz,
pings_enabled,
latch_timeout,
member_limit_override,
group_limit_override
from systems;
alter table systems
drop column ui_tz,
drop column pings_enabled,
drop column latch_timeout,
drop column member_limit_override,
drop column group_limit_override;
update info set schema_version = 21;

View file

@ -0,0 +1,7 @@
-- schema version 22
-- automatically set members/groups as private when creating
alter table system_config add column member_default_private bool not null default false;
alter table system_config add column group_default_private bool not null default false;
update info set schema_version = 22;

View file

@ -0,0 +1,6 @@
-- schema version 23
-- show/hide private information when looked up by linked accounts
alter table system_config add column show_private_info bool default true;
update info set schema_version = 23;

View file

@ -0,0 +1,10 @@
-- schema version 24
-- don't drop message rows when system/member are deleted
alter table messages alter column member drop not null;
alter table messages drop constraint messages_member_fkey;
alter table messages
add constraint messages_member_fkey
foreign key (member) references members(id) on delete set null;
update info set schema_version = 24;

View file

@ -0,0 +1,7 @@
-- schema version 25
-- group name privacy
alter table groups add column name_privacy integer check (name_privacy in (1, 2)) not null default 1;
alter table groups add column metadata_privacy integer check (metadata_privacy in (1, 2)) not null default 1;
update info set schema_version = 25;

View file

@ -0,0 +1,12 @@
-- schema version 26
-- cache Discord DM channels in the database
alter table accounts alter column system drop not null;
alter table accounts drop constraint accounts_system_fkey;
alter table accounts
add constraint accounts_system_fkey
foreign key (system) references systems(id) on delete set null;
alter table accounts add column dm_channel bigint;
update info set schema_version = 26;

View file

@ -0,0 +1,36 @@
-- schema version 27
-- autoproxy locations
-- mode pseudo-enum: (copied from 3.sql)
-- 1 = autoproxy off
-- 2 = front mode (first fronter)
-- 3 = latch mode (last proxyer)
-- 4 = member mode (specific member)
create table autoproxy (
system int references systems(id) on delete cascade,
channel_id bigint,
guild_id bigint,
autoproxy_mode int check (autoproxy_mode in (1, 2, 3, 4)) not null default 1,
autoproxy_member int references members(id) on delete set null,
last_latch_timestamp timestamp,
check (
(channel_id = 0 and guild_id = 0)
or (channel_id != 0 and guild_id = 0)
or (channel_id = 0 and guild_id != 0)
),
primary key (system, channel_id, guild_id)
);
insert into autoproxy select
system,
0 as channel_id,
guild as guild_id,
autoproxy_mode,
autoproxy_member
from system_guild;
alter table system_guild drop column autoproxy_mode;
alter table system_guild drop column autoproxy_member;
update info set schema_version = 27;

View file

@ -0,0 +1,7 @@
-- schema version 28
-- system pronouns
alter table systems add column pronouns text;
alter table systems add column pronoun_privacy integer check (pronoun_privacy in (1, 2)) not null default 1;
update info set schema_version = 28;

View file

@ -0,0 +1,5 @@
-- schema version 29
alter table systems add column is_deleting bool default false;
update info set schema_version = 29;

View file

@ -0,0 +1,15 @@
-- Same sort of psuedo-enum due to Dapper limitations. See 2.sql.
-- 1 = autoproxy off
-- 2 = front mode (first fronter)
-- 3 = latch mode (last proxyer)
-- 4 = member mode (specific member)
alter table system_guild add column autoproxy_mode int check (autoproxy_mode in (1, 2, 3, 4)) not null default 1;
-- for member mode
alter table system_guild add column autoproxy_member int references members (id) on delete set null;
-- for latch mode
-- not *really* nullable, null just means old (pre-schema-change) data.
alter table messages add column guild bigint default null;
update info set schema_version = 3;

View file

@ -0,0 +1,5 @@
-- schema version 30
alter table system_config add column description_templates text[] not null default array[]::text[];
update info set schema_version = 30;

View file

@ -0,0 +1,5 @@
-- schema version 31
alter table system_config add column case_sensitive_proxy_tags boolean not null default true;
update info set schema_version = 31;

View file

@ -0,0 +1,6 @@
-- database version 32
-- re-add last message timestamp to members
alter table members add column last_message_timestamp timestamp;
update info set schema_version = 32;

View file

@ -0,0 +1,6 @@
-- database version 33
-- add webhook_avatar_url to system members
alter table members add column webhook_avatar_url text;
update info set schema_version = 33;

View file

@ -0,0 +1,6 @@
-- database version 34
-- add proxy_error_message_enabled to system config
alter table system_config add column proxy_error_message_enabled bool default true;
update info set schema_version = 34;

View file

@ -0,0 +1,7 @@
-- database version 35
-- add guild avatar and guild name to system guild settings
alter table system_guild add column avatar_url text;
alter table system_guild add column display_name text;
update info set schema_version = 35;

View file

@ -0,0 +1,7 @@
-- database version 36
-- add system avatar privacy and system name privacy
alter table systems add column name_privacy integer not null default 1;
alter table systems add column avatar_privacy integer not null default 1;
update info set schema_version = 36;

View file

@ -0,0 +1,7 @@
-- database version 37
-- add proxy tag privacy
alter table members add column proxy_privacy integer not null default 1;
alter table members add constraint members_proxy_privacy_check check (proxy_privacy = ANY (ARRAY[1,2]));
update info set schema_version = 37;

View file

@ -0,0 +1,6 @@
-- database version 38
-- add proxy tag privacy
alter table members add column tts boolean not null default false;
update info set schema_version = 38;

View file

@ -0,0 +1,7 @@
-- database version 39
-- add missing privacy constraints
alter table systems add constraint systems_name_privacy_check check (name_privacy = ANY (ARRAY[1,2]));
alter table systems add constraint systems_avatar_privacy_check check (avatar_privacy = ANY (ARRAY[1,2]));
update info set schema_version = 39;

View file

@ -0,0 +1,3 @@
-- SCHEMA VERSION 4: 2020-02-12
alter table member_guild add column avatar_url text;
update info set schema_version = 4;

View file

@ -0,0 +1,6 @@
-- database version 40
-- add per-server keepproxy toggle
alter table member_guild add column keep_proxy bool default null;
update info set schema_version = 40;

View file

@ -0,0 +1,10 @@
-- database version 41
-- fix statistics counts
alter table info alter column system_count type bigint using system_count::bigint;
alter table info alter column member_count type bigint using member_count::bigint;
alter table info alter column group_count type bigint using group_count::bigint;
alter table info alter column switch_count type bigint using switch_count::bigint;
alter table info alter column message_count type bigint using message_count::bigint;
update info set schema_version = 41;

View file

@ -0,0 +1,11 @@
-- database version 42
-- move to 6 character HIDs, add HID display config setting
alter table systems alter column hid type char(6) using rpad(hid, 6, ' ');
alter table members alter column hid type char(6) using rpad(hid, 6, ' ');
alter table groups alter column hid type char(6) using rpad(hid, 6, ' ');
alter table system_config add column hid_display_split bool default false;
alter table system_config add column hid_display_caps bool default false;
update info set schema_version = 42;

View file

@ -0,0 +1,6 @@
-- database version 43
-- add config setting for padding 5-character IDs in lists
alter table system_config add column hid_list_padding int not null default 0;
update info set schema_version = 43;

View file

@ -0,0 +1,23 @@
-- database version 44
-- add abuse handling measures
create table abuse_logs (
id serial primary key,
uuid uuid default gen_random_uuid(),
description text,
deny_bot_usage bool not null default false,
created timestamp not null default (current_timestamp at time zone 'utc')
);
alter table accounts add column abuse_log integer default null references abuse_logs (id) on delete set null;
create index abuse_logs_uuid_idx on abuse_logs (uuid);
-- we now need to handle a row in "accounts" table being created with no
-- system (rather than just system being set to null after insert)
--
-- set default null and drop the sequence (from the column being created
-- as type SERIAL)
alter table accounts alter column system set default null;
drop sequence accounts_system_seq;
update info set schema_version = 44;

View file

@ -0,0 +1,6 @@
-- database version 45
-- add new config setting "proxy_switch"
alter table system_config add column proxy_switch bool default false;
update info set schema_version = 45;

View file

@ -0,0 +1,12 @@
-- database version 46
-- adds banner privacy
alter table members add column banner_privacy int not null default 1 check (banner_privacy = ANY (ARRAY[1,2]));
alter table groups add column banner_privacy int not null default 1 check (banner_privacy = ANY (ARRAY[1,2]));
alter table systems add column banner_privacy int not null default 1 check (banner_privacy = ANY (ARRAY[1,2]));
update members set banner_privacy = 2 where description_privacy = 2;
update groups set banner_privacy = 2 where description_privacy = 2;
update systems set banner_privacy = 2 where description_privacy = 2;
update info set schema_version = 46;

View file

@ -0,0 +1,6 @@
-- database version 47
-- add config setting for supplying a custom tag format in names
alter table system_config add column name_format text;
update info set schema_version = 47;

View file

@ -0,0 +1,9 @@
-- database version 48
--
-- add guild settings for disabling "invalid command" responses &
-- enforcing the presence of system tags
alter table servers add column invalid_command_response_enabled bool not null default true;
alter table servers add column require_system_tag bool not null default false;
update info set schema_version = 48;

View file

@ -0,0 +1,6 @@
-- database version 49
-- add guild name format
alter table system_guild add column name_format text;
update info set schema_version = 49;

View file

@ -0,0 +1,3 @@
-- SCHEMA VERSION 5: 2020-02-14
alter table servers add column log_cleanup_enabled bool not null default false;
update info set schema_version = 5;

View file

@ -0,0 +1,11 @@
-- database version 50
-- change proxy switch config to an enum
alter table system_config
alter column proxy_switch drop default,
alter column proxy_switch type int
using case when proxy_switch then 1 else 0 end,
alter column proxy_switch set default 0,
add constraint proxy_switch_check check (proxy_switch = ANY (ARRAY[0,1,2]));
update info set schema_version = 50;

View file

@ -0,0 +1,7 @@
-- database version 51
--
-- add guild setting for SUPPRESS_NOTIFICATIONS message flag on proxied messages
alter table servers add column suppress_notifications bool not null default false;
update info set schema_version = 51;

View file

@ -0,0 +1,21 @@
-- database version 52
-- messages db updates
create index messages_by_original on messages(original_mid);
create index messages_by_sender on messages(sender);
-- remove old table from database version 11
alter table command_messages rename to command_messages_old;
create table command_messages (
mid bigint primary key,
channel bigint not null,
guild bigint not null,
sender bigint not null,
original_mid bigint not null
);
create index command_messages_by_original on command_messages(original_mid);
create index command_messages_by_sender on command_messages(sender);
update info set schema_version = 52;

View file

@ -0,0 +1,3 @@
-- SCHEMA VERSION 6: 2020-03-21
alter table systems add column pings_enabled bool not null default true;
update info set schema_version = 6;

View file

@ -0,0 +1,35 @@
-- SCHEMA VERSION 7: 2020-06-12
-- (in-db message count row)
-- Add message count row to members table, initialize it with the correct data
alter table members add column message_count int not null default 0;
update members set message_count = counts.count
from (select member, count(*) as count from messages group by messages.member) as counts
where counts.member = members.id;
-- Create a trigger function to increment the message count on inserting to the messages table
create function trg_msgcount_increment() returns trigger as $$
begin
update members set message_count = message_count + 1 where id = NEW.member;
return NEW;
end;
$$ language plpgsql;
create trigger increment_member_message_count before insert on messages for each row execute procedure trg_msgcount_increment();
-- Create a trigger function to decrement the message count on deleting from the messages table
create function trg_msgcount_decrement() returns trigger as $$
begin
-- Don't decrement if count <= zero (shouldn't happen, but we don't want negative message counts)
update members set message_count = message_count - 1 where id = OLD.member and message_count > 0;
return OLD;
end;
$$ language plpgsql;
create trigger decrement_member_message_count before delete on messages for each row execute procedure trg_msgcount_decrement();
-- (update schema ver)
update info set schema_version = 7;

View file

@ -0,0 +1,24 @@
-- SCHEMA VERSION 8: 2020-05-13 --
-- Create new columns --
alter table members add column description_privacy integer check (description_privacy in (1, 2)) not null default 1;
alter table members add column name_privacy integer check (name_privacy in (1, 2)) not null default 1;
alter table members add column avatar_privacy integer check (avatar_privacy in (1, 2)) not null default 1;
alter table members add column birthday_privacy integer check (birthday_privacy in (1, 2)) not null default 1;
alter table members add column pronoun_privacy integer check (pronoun_privacy in (1, 2)) not null default 1;
alter table members add column metadata_privacy integer check (metadata_privacy in (1, 2)) not null default 1;
-- alter table members add column color_privacy integer check (color_privacy in (1, 2)) not null default 1;
-- Transfer existing settings --
update members set description_privacy = member_privacy;
update members set name_privacy = member_privacy;
update members set avatar_privacy = member_privacy;
update members set birthday_privacy = member_privacy;
update members set pronoun_privacy = member_privacy;
update members set metadata_privacy = member_privacy;
-- update members set color_privacy = member_privacy;
-- Rename member_privacy to member_visibility --
alter table members rename column member_privacy to member_visibility;
-- Update Schema Info --
update info set schema_version = 8;

View file

@ -0,0 +1,32 @@
-- SCHEMA VERSION 9: 2020-08-25 --
-- Adds support for member groups.
create table groups (
id int primary key generated always as identity,
hid char(5) unique not null,
system int not null references systems(id) on delete cascade,
name text not null,
display_name text,
description text,
icon text,
-- Description columns follow the same pattern as usual: 1 = public, 2 = private
description_privacy integer check (description_privacy in (1, 2)) not null default 1,
icon_privacy integer check (icon_privacy in (1, 2)) not null default 1,
list_privacy integer check (list_privacy in (1, 2)) not null default 1,
visibility integer check (visibility in (1, 2)) not null default 1,
created timestamp with time zone not null default (current_timestamp at time zone 'utc')
);
create table group_members (
group_id int not null references groups(id) on delete cascade,
member_id int not null references members(id) on delete cascade,
primary key (group_id, member_id)
);
alter table systems add column group_list_privacy integer check (group_list_privacy in (1, 2)) not null default 1;
update systems set group_list_privacy = member_list_privacy;
update info set schema_version = 9;

View file

@ -0,0 +1,8 @@
-- example data (for integration tests or such)
insert into systems (hid, token) values (
'exmpl',
'vlPitT0tEgT++a450w1/afODy5NXdALcHDwryX6dOIZdGUGbZg+5IH3nrUsQihsw'
);
insert into system_config (system) values (1);
insert into system_guild (system, guild) values (1, 466707357099884544);

View file

@ -0,0 +1,75 @@
-- Returns one row per system, containing info about latest switch + array of member IDs (for future joins)
create view system_last_switch as
select systems.id as system,
last_switch.id as switch,
last_switch.timestamp as timestamp,
array(select member from switch_members where switch_members.switch = last_switch.id order by switch_members.id) as members
from systems
inner join lateral (select * from switches where switches.system = systems.id order by timestamp desc limit 1) as last_switch on true;
create view member_list as
select members.*,
-- Find last switch timestamp
(
select max(switches.timestamp)
from switch_members
inner join switches on switches.id = switch_members.switch
where switch_members.member = members.id
) as last_switch_time,
-- Extract month/day from birthday and "force" the year identical (just using 4) -> month/day only sorting!
case when members.birthday is not null then
make_date(
4,
extract(month from members.birthday)::integer,
extract(day from members.birthday)::integer
) end as birthday_md,
-- Extract member description as seen by "the public"
case
-- Privacy '1' = public; just return description as normal
when members.description_privacy = 1 then members.description
-- Any other privacy (rn just '2'), return null description (missing case = null in SQL)
end as public_description,
-- Extract member name as seen by "the public"
case
-- Privacy '1' = public; just return name as normal
when members.name_privacy = 1 then members.name
-- Any other privacy (rn just '2'), return display name
else coalesce(members.display_name, members.name)
end as public_name
from members;
create view group_list as
select groups.*,
-- Find public group member count
(
select count(*) from group_members
inner join members on group_members.member_id = members.id
where
group_members.group_id = groups.id and members.member_visibility = 1
) as public_member_count,
-- Find private group member count
(
select count(*) from group_members
inner join members on group_members.member_id = members.id
where
group_members.group_id = groups.id
) as total_member_count,
-- Extract group description as seen by "the public"
case
-- Privacy '1' = public; just return description as normal
when groups.description_privacy = 1 then groups.description
-- Any other privacy (rn just '2'), return null description (missing case = null in SQL)
end as public_description,
-- Extract member name as seen by "the public"
case
-- Privacy '1' = public; just return name as normal
when groups.name_privacy = 1 then groups.name
-- Any other privacy (rn just '2'), return display name
else coalesce(groups.display_name, groups.name)
end as public_name
from groups;

View file

@ -0,0 +1,70 @@
#![feature(let_chains)]
use tracing::info;
include!(concat!(env!("OUT_DIR"), "/data.rs"));
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
// clean
// get current migration
// migrate to latest
// run views
// run functions
#[derive(sqlx::FromRow)]
struct CurrentMigration {
schema_version: i32,
}
let info = match sqlx::query_as("select schema_version from info")
.fetch_optional(&db)
.await
{
Ok(Some(result)) => result,
Ok(None) => CurrentMigration { schema_version: -1 },
Err(e) if format!("{e}").contains("relation \"info\" does not exist") => {
CurrentMigration { schema_version: -1 }
}
Err(e) => return Err(e.into()),
};
info!("current migration: {}", info.schema_version);
info!("running clean.sql");
sqlx::raw_sql(fix_feff(CLEAN)).execute(&db).await?;
for idx in (info.schema_version + 1) as usize..MIGRATIONS.len() {
info!("running migration {idx}");
sqlx::raw_sql(fix_feff(MIGRATIONS[idx as usize]))
.execute(&db)
.await?;
}
info!("running views.sql");
sqlx::raw_sql(fix_feff(VIEWS)).execute(&db).await?;
info!("running functions.sql");
sqlx::raw_sql(fix_feff(FUNCTIONS)).execute(&db).await?;
if let Ok(var) = std::env::var("SEED")
&& var == "true"
{
info!("running seed.sql");
sqlx::raw_sql(fix_feff(SEED)).execute(&db).await?;
info!(
"example system created with hid 'exmpl', token 'vlPitT0tEgT++a450w1/afODy5NXdALcHDwryX6dOIZdGUGbZg+5IH3nrUsQihsw', guild_id 466707357099884544"
);
}
info!("all done!");
Ok(())
}
// some migration scripts have \u{feff} at the start
fn fix_feff(sql: &str) -> &str {
sql.trim_start_matches("\u{feff}")
}

View file

@ -5,7 +5,7 @@ edition = "2021"
[dependencies]
chrono = { workspace = true, features = ["serde"] }
model_macros = { path = "../model_macros" }
pk_macros = { path = "../macros" }
sea-query = "0.32.1"
serde = { workspace = true }
serde_json = { workspace = true, features = ["preserve_order"] }

Some files were not shown because too many files have changed in this diff Show more