mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-13 01:00:12 +00:00
Merge remote-tracking branch 'upstream/main' into rust-command-parser
This commit is contained in:
commit
f721b850d4
183 changed files with 5121 additions and 1909 deletions
48
crates/api/src/auth.rs
Normal file
48
crates/api/src/auth.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
|
||||
|
||||
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthState {
|
||||
system_id: Option<i32>,
|
||||
app_id: Option<i32>,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
|
||||
Self { system_id, app_id }
|
||||
}
|
||||
|
||||
pub fn system_id(&self) -> Option<i32> {
|
||||
self.system_id
|
||||
}
|
||||
|
||||
pub fn app_id(&self) -> Option<i32> {
|
||||
self.app_id
|
||||
}
|
||||
|
||||
pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel {
|
||||
if self
|
||||
.system_id
|
||||
.map(|id| id == a.authable_system_id())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
PrivacyLevel::Private
|
||||
} else {
|
||||
PrivacyLevel::Public
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// authable trait/impls
|
||||
|
||||
pub trait Authable {
|
||||
fn authable_system_id(&self) -> SystemId;
|
||||
}
|
||||
|
||||
impl Authable for PKSystem {
|
||||
fn authable_system_id(&self) -> SystemId {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
|
@ -1 +1,2 @@
|
|||
pub mod private;
|
||||
pub mod system;
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ use axum::{
|
|||
};
|
||||
use hyper::StatusCode;
|
||||
use libpk::config;
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig};
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
|
||||
use reqwest::ClientBuilder;
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
|
|
@ -151,14 +151,12 @@ pub async fn discord_callback(
|
|||
.await
|
||||
.expect("failed to query");
|
||||
|
||||
if system.is_none() {
|
||||
let Some(system) = system else {
|
||||
return json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"user does not have a system registered".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let system = system.unwrap();
|
||||
};
|
||||
|
||||
let system_config: Option<PKSystemConfig> = sqlx::query_as(
|
||||
r#"
|
||||
|
|
@ -179,7 +177,7 @@ pub async fn discord_callback(
|
|||
(
|
||||
StatusCode::OK,
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"system": system.to_json(),
|
||||
"system": system.to_json(PrivacyLevel::Private),
|
||||
"config": system_config.to_json(),
|
||||
"user": user,
|
||||
"token": token,
|
||||
|
|
|
|||
69
crates/api/src/endpoints/system.rs
Normal file
69
crates/api/src/endpoints/system.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Extension, Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
use sqlx::Postgres;
|
||||
use tracing::error;
|
||||
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
|
||||
|
||||
use crate::{auth::AuthState, util::json_err, ApiContext};
|
||||
|
||||
pub async fn get_system_settings(
|
||||
Extension(auth): Extension<AuthState>,
|
||||
Extension(system): Extension<PKSystem>,
|
||||
State(ctx): State<ApiContext>,
|
||||
) -> Response {
|
||||
let access_level = auth.access_level_for(&system);
|
||||
|
||||
let mut config = match sqlx::query_as::<Postgres, PKSystemConfig>(
|
||||
"select * from system_config where system = $1",
|
||||
)
|
||||
.bind(system.id)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(config)) => config,
|
||||
Ok(None) => {
|
||||
error!(
|
||||
system = system.id,
|
||||
"failed to find system config for existing system"
|
||||
);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query system config");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// fix this
|
||||
if config.name_format.is_none() {
|
||||
config.name_format = Some("{name} {tag}".to_string());
|
||||
}
|
||||
|
||||
Json(&match access_level {
|
||||
PrivacyLevel::Private => config.to_json(),
|
||||
PrivacyLevel::Public => json!({
|
||||
"pings_enabled": config.pings_enabled,
|
||||
"latch_timeout": config.latch_timeout,
|
||||
"case_sensitive_proxy_tags": config.case_sensitive_proxy_tags,
|
||||
"proxy_error_message_enabled": config.proxy_error_message_enabled,
|
||||
"hid_display_split": config.hid_display_split,
|
||||
"hid_display_caps": config.hid_display_caps,
|
||||
"hid_list_padding": config.hid_list_padding,
|
||||
"proxy_switch": config.proxy_switch,
|
||||
"name_format": config.name_format,
|
||||
}),
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
http::{Response, StatusCode, Uri},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
Extension, Router,
|
||||
};
|
||||
use hyper_util::{
|
||||
client::legacy::{connect::HttpConnector, Client},
|
||||
|
|
@ -12,6 +15,7 @@ use hyper_util::{
|
|||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
mod auth;
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod middleware;
|
||||
|
|
@ -27,6 +31,7 @@ pub struct ApiContext {
|
|||
}
|
||||
|
||||
async fn rproxy(
|
||||
Extension(auth): Extension<AuthState>,
|
||||
State(ctx): State<ApiContext>,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Result<Response<Body>, StatusCode> {
|
||||
|
|
@ -41,12 +46,25 @@ async fn rproxy(
|
|||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
let headers = req.headers_mut();
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
if let Some(sid) = auth.system_id() {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
|
||||
if let Some(aid) = auth.app_id() {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
|
||||
Ok(ctx
|
||||
.rproxy_client
|
||||
.request(req)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
|
||||
.map_err(|error| {
|
||||
error!(?error, "failed to serve reverse proxy to dotnet-api");
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?
|
||||
.into_response())
|
||||
|
|
@ -57,52 +75,52 @@ async fn rproxy(
|
|||
fn router(ctx: ApiContext) -> Router {
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
Router::new()
|
||||
.route("/v2/systems/:system_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", get(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
|
||||
.route("/v2/systems/{system_id}/settings", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/members", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/members", get(rproxy))
|
||||
.route("/v2/members", post(rproxy))
|
||||
.route("/v2/members/:member_id", get(rproxy))
|
||||
.route("/v2/members/:member_id", patch(rproxy))
|
||||
.route("/v2/members/:member_id", delete(rproxy))
|
||||
.route("/v2/members/{member_id}", get(rproxy))
|
||||
.route("/v2/members/{member_id}", patch(rproxy))
|
||||
.route("/v2/members/{member_id}", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/groups", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/groups", get(rproxy))
|
||||
.route("/v2/groups", post(rproxy))
|
||||
.route("/v2/groups/:group_id", get(rproxy))
|
||||
.route("/v2/groups/:group_id", patch(rproxy))
|
||||
.route("/v2/groups/:group_id", delete(rproxy))
|
||||
.route("/v2/groups/{group_id}", get(rproxy))
|
||||
.route("/v2/groups/{group_id}", patch(rproxy))
|
||||
.route("/v2/groups/{group_id}", delete(rproxy))
|
||||
|
||||
.route("/v2/groups/:group_id/members", get(rproxy))
|
||||
.route("/v2/groups/:group_id/members/add", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/remove", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/add", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/groups", get(rproxy))
|
||||
.route("/v2/members/:member_id/groups/add", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/remove", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups", get(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/add", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches", post(rproxy))
|
||||
.route("/v2/systems/:system_id/fronters", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches", post(rproxy))
|
||||
.route("/v2/systems/{system_id}/fronters", get(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
|
||||
|
||||
.route("/v2/messages/:message_id", get(rproxy))
|
||||
.route("/v2/messages/{message_id}", get(rproxy))
|
||||
|
||||
.route("/private/bulk_privacy/member", post(rproxy))
|
||||
.route("/private/bulk_privacy/group", post(rproxy))
|
||||
|
|
@ -111,16 +129,19 @@ fn router(ctx: ApiContext) -> Router {
|
|||
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
|
||||
.route("/private/stats", get(endpoints::private::meta))
|
||||
|
||||
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::cors))
|
||||
.layer(axum::middleware::from_fn(middleware::logger))
|
||||
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::logger::logger))
|
||||
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::params::params))
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth))
|
||||
|
||||
.layer(axum::middleware::from_fn(middleware::cors::cors))
|
||||
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
|
||||
|
||||
.with_state(ctx)
|
||||
|
|
@ -128,8 +149,8 @@ fn router(ctx: ApiContext) -> Router {
|
|||
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
|
||||
}
|
||||
|
||||
libpk::main!("api");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_data_db().await?;
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
|
|
|
|||
62
crates/api/src/middleware/auth.rs
Normal file
62
crates/api/src/middleware/auth.rs
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
use tracing::error;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
use crate::{util::json_err, ApiContext};
|
||||
|
||||
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
|
||||
let mut authed_system_id: Option<i32> = None;
|
||||
let mut authed_app_id: Option<i32> = None;
|
||||
|
||||
// fetch user authorization
|
||||
if let Some(system_auth_header) = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
{
|
||||
authed_system_id = Some(system_id);
|
||||
}
|
||||
|
||||
// fetch app authorization
|
||||
// todo: actually fetch it from db
|
||||
if let Some(app_auth_header) = req
|
||||
.headers()
|
||||
.get("x-pluralkit-app")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(config_token2) = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
.as_ref()
|
||||
// this is NOT how you validate tokens
|
||||
// but this is low abuse risk so we're keeping it for now
|
||||
&& app_auth_header == config_token2
|
||||
{
|
||||
authed_app_id = Some(1);
|
||||
}
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(AuthState::new(authed_system_id, authed_app_id));
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::HeaderValue,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::ApiContext;
|
||||
|
||||
use super::logger::DID_AUTHENTICATE_HEADER;
|
||||
|
||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||
let headers = request.headers_mut();
|
||||
headers.remove("x-pluralkit-systemid");
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten();
|
||||
let mut authenticated = false;
|
||||
if let Some(auth_header) = auth_header {
|
||||
if let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
None
|
||||
}
|
||||
}
|
||||
{
|
||||
headers.append(
|
||||
"x-pluralkit-systemid",
|
||||
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
|
||||
);
|
||||
authenticated = true;
|
||||
}
|
||||
}
|
||||
let mut response = next.run(request).await;
|
||||
if authenticated {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
|
||||
}
|
||||
response
|
||||
}
|
||||
|
|
@ -4,27 +4,30 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
|
|||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
use crate::{auth::AuthState, util::header_or_unknown};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let auth = extensions
|
||||
.get::<AuthState>()
|
||||
.expect("should always have AuthState");
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_span = span!(
|
||||
|
|
@ -37,25 +40,26 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut response = next.run(request).instrument(request_span).await;
|
||||
let response = next.run(request).instrument(request_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
let authenticated = {
|
||||
let headers = response.headers_mut();
|
||||
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
|
||||
headers.remove(DID_AUTHENTICATE_HEADER);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
let system_id = auth
|
||||
.system_id()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or("none".to_string());
|
||||
|
||||
let app_id = auth
|
||||
.app_id()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or("none".to_string());
|
||||
|
||||
counter!(
|
||||
"pluralkit_api_requests",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
|
|
@ -63,7 +67,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
|
|
@ -81,7 +86,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,6 @@
|
|||
mod cors;
|
||||
pub use cors::cors;
|
||||
|
||||
mod logger;
|
||||
pub use logger::logger;
|
||||
|
||||
mod ignore_invalid_routes;
|
||||
pub use ignore_invalid_routes::ignore_invalid_routes;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cors;
|
||||
pub mod ignore_invalid_routes;
|
||||
pub mod logger;
|
||||
pub mod params;
|
||||
pub mod ratelimit;
|
||||
|
||||
mod authnz;
|
||||
pub use authnz::authnz;
|
||||
|
|
|
|||
139
crates/api/src/middleware/params.rs
Normal file
139
crates/api/src/middleware/params.rs
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
routing::url_params::UrlParams,
|
||||
};
|
||||
|
||||
use sqlx::{types::Uuid, Postgres};
|
||||
use tracing::error;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
use crate::{util::json_err, ApiContext};
|
||||
use pluralkit_models::PKSystem;
|
||||
|
||||
// move this somewhere else
|
||||
fn parse_hid(hid: &str) -> String {
|
||||
if hid.len() > 7 || hid.len() < 5 {
|
||||
hid.to_string()
|
||||
} else {
|
||||
hid.to_lowercase().replace("-", "")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn params(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
|
||||
let pms = match req.extensions().get::<UrlParams>() {
|
||||
None => Vec::new(),
|
||||
Some(UrlParams::Params(pms)) => pms.clone(),
|
||||
_ => {
|
||||
return json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"400: Bad Request","code": 0}"#.to_string(),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
};
|
||||
|
||||
for (key, value) in pms {
|
||||
match key.as_ref() {
|
||||
"system_id" => match value.as_str() {
|
||||
"@me" => {
|
||||
let Some(system_id) = req
|
||||
.extensions()
|
||||
.get::<AuthState>()
|
||||
.expect("missing auth state")
|
||||
.system_id()
|
||||
else {
|
||||
return json_err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
r#"{"message":"401: Missing or invalid Authorization header","code": 0}"#.to_string(),
|
||||
)
|
||||
.into();
|
||||
};
|
||||
|
||||
match sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where id = $1",
|
||||
)
|
||||
.bind(system_id)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(system)) => {
|
||||
req.extensions_mut().insert(system);
|
||||
}
|
||||
Ok(None) => {
|
||||
error!(
|
||||
?system_id,
|
||||
"could not find previously authenticated system in db"
|
||||
);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
?err,
|
||||
"failed to query previously authenticated system in db"
|
||||
);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
id => {
|
||||
println!("a {id}");
|
||||
match match Uuid::parse_str(id) {
|
||||
Ok(uuid) => sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where uuid = $1",
|
||||
)
|
||||
.bind(uuid),
|
||||
Err(_) => match id.parse::<i64>() {
|
||||
Ok(parsed) => sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where id = (select system from accounts where uid = $1)"
|
||||
)
|
||||
.bind(parsed),
|
||||
Err(_) => sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where hid = $1",
|
||||
)
|
||||
.bind(parse_hid(id))
|
||||
},
|
||||
}
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(system)) => {
|
||||
req.extensions_mut().insert(system);
|
||||
}
|
||||
Ok(None) => {
|
||||
return json_err(
|
||||
StatusCode::NOT_FOUND,
|
||||
r#"{"message":"System not found.","code":20001}"#.to_string(),
|
||||
)
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?err, ?id, "failed to query system from path in db");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"member_id" => {}
|
||||
"group_id" => {}
|
||||
"switch_id" => {}
|
||||
"guild_id" => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
|
@ -10,7 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
|
|||
use metrics::counter;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
use crate::{
|
||||
auth::AuthState,
|
||||
util::{header_or_unknown, json_err},
|
||||
};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
|
|
@ -50,7 +53,7 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
|||
.await
|
||||
{
|
||||
Ok(_) => info!("connected to redis for request rate limiting"),
|
||||
Err(err) => error!("could not load redis script: {}", err),
|
||||
Err(error) => error!(?error, "could not load redis script"),
|
||||
}
|
||||
} else {
|
||||
error!("could not wait for connection to load redis script!");
|
||||
|
|
@ -103,37 +106,28 @@ pub async fn do_request_ratelimited(
|
|||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
|
||||
|
||||
// https://github.com/rust-lang/rust/issues/53667
|
||||
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
||||
{
|
||||
if let Some(token2) = &libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
{
|
||||
if header.to_str().unwrap_or("invalid") == token2 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let rlimit = if is_temp_token2 {
|
||||
let auth = extensions
|
||||
.get::<AuthState>()
|
||||
.expect("should always have AuthState");
|
||||
|
||||
// looks like this chooses the tokens/sec by app_id or endpoint
|
||||
// then chooses the key by system_id or source_ip
|
||||
// todo: key should probably be chosen by app_id when it's present
|
||||
// todo: make x-ratelimit-scope actually meaningful
|
||||
|
||||
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||
let rlimit = if let Some(app_id) = auth.app_id()
|
||||
&& app_id == 1
|
||||
{
|
||||
RatelimitType::TempCustom
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
|
|
@ -145,12 +139,12 @@ pub async fn do_request_ratelimited(
|
|||
|
||||
let rl_key = format!(
|
||||
"{}:{}",
|
||||
if authenticated_system_id != "unknown"
|
||||
if let Some(system_id) = auth.system_id()
|
||||
&& matches!(rlimit, RatelimitType::GenericUpdate)
|
||||
{
|
||||
authenticated_system_id
|
||||
system_id.to_string()
|
||||
} else {
|
||||
source_ip
|
||||
source_ip.to_string()
|
||||
},
|
||||
rlimit.key()
|
||||
);
|
||||
|
|
@ -224,8 +218,8 @@ pub async fn do_request_ratelimited(
|
|||
|
||||
return response;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("error getting ratelimit info: {}", err);
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "error getting ratelimit info");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
|
|||
match value.to_str() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
error!("failed to parse header value {:#?}: {:#?}", value, err);
|
||||
error!(?err, ?value, "failed to parse header value");
|
||||
"failed to parse"
|
||||
}
|
||||
}
|
||||
|
|
@ -34,11 +34,7 @@ where
|
|||
.unwrap(),
|
||||
),
|
||||
None => {
|
||||
error!(
|
||||
"error in handler {}: {:#?}",
|
||||
std::any::type_name::<F>(),
|
||||
error
|
||||
);
|
||||
error!(?error, "error in handler {}", std::any::type_name::<F>(),);
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
|
|
@ -48,14 +44,15 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
error!("caught panic from handler: {:#?}", err);
|
||||
pub fn handle_panic(error: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
error!(?error, "caught panic from handler");
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
// todo: make 500 not duplicated
|
||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||
let mut response = (code, text).into_response();
|
||||
let headers = response.headers_mut();
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ use sqlx::prelude::FromRow;
|
|||
use std::{sync::Arc, time::Duration};
|
||||
use tracing::{error, info};
|
||||
|
||||
libpk::main!("avatar_cleanup");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
|
|
@ -13,7 +13,7 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
|
||||
let bucket = {
|
||||
let region = s3::Region::Custom {
|
||||
region: "s3".to_string(),
|
||||
region: "auto".to_string(),
|
||||
endpoint: config.s3.endpoint.to_string(),
|
||||
};
|
||||
|
||||
|
|
@ -38,8 +38,8 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
match cleanup_job(pool.clone(), bucket.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
error!("failed to run avatar cleanup job: {}", err);
|
||||
Err(error) => {
|
||||
error!(?error, "failed to run avatar cleanup job");
|
||||
// sentry
|
||||
}
|
||||
}
|
||||
|
|
@ -55,9 +55,10 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
|
|||
let mut tx = pool.begin().await?;
|
||||
|
||||
let image_id: Option<CleanupJobEntry> = sqlx::query_as(
|
||||
// no timestamp checking here
|
||||
// images are only added to the table after 24h
|
||||
r#"
|
||||
select id from image_cleanup_jobs
|
||||
where ts < now() - interval '1 day'
|
||||
for update skip locked limit 1;"#,
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
|
|
@ -72,6 +73,7 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
|
|||
|
||||
let image_data = libpk::db::repository::avatars::get_by_id(&pool, image_id.clone()).await?;
|
||||
if image_data.is_none() {
|
||||
// unsure how this can happen? there is a FK reference
|
||||
info!("image {image_id} was already deleted, skipping");
|
||||
sqlx::query("delete from image_cleanup_jobs where id = $1")
|
||||
.bind(image_id)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ async fn pull(
|
|||
) -> Result<Json<PullResponse>, PKAvatarError> {
|
||||
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
||||
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||
if !req.force {
|
||||
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
|
||||
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||
// remove any pending image cleanup
|
||||
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
|
||||
|
|
@ -170,8 +170,8 @@ pub struct AppState {
|
|||
config: Arc<AvatarsConfig>,
|
||||
}
|
||||
|
||||
libpk::main!("avatars");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
|
|
@ -179,7 +179,7 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
|
||||
let bucket = {
|
||||
let region = s3::Region::Custom {
|
||||
region: "s3".to_string(),
|
||||
region: "auto".to_string(),
|
||||
endpoint: config.s3.endpoint.to_string(),
|
||||
};
|
||||
|
||||
|
|
@ -232,26 +232,11 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct AppError(anyhow::Error);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
error!("error handling request: {}", self.0);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: self.0.to_string(),
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for PKAvatarError {
|
||||
fn into_response(self) -> Response {
|
||||
let status_code = match self {
|
||||
|
|
@ -278,12 +263,3 @@ impl IntoResponse for PKAvatarError {
|
|||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> From<E> for AppError
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,9 +129,9 @@ pub async fn worker(worker_id: u32, state: Arc<AppState>) {
|
|||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"error in migrate worker {}: {}",
|
||||
worker_id,
|
||||
e.source().unwrap_or(&e)
|
||||
error = e.source().unwrap_or(&e)
|
||||
?worker_id,
|
||||
"error in migrate worker",
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ pub fn process(data: &[u8], kind: ImageKind) -> Result<ProcessOutput, PKAvatarEr
|
|||
} else {
|
||||
reader.decode().map_err(|e| {
|
||||
// print the ugly error, return the nice error
|
||||
error!("error decoding image: {}", e);
|
||||
error!(error = format!("{e:#?}"), "error decoding image");
|
||||
PKAvatarError::ImageFormatError(e)
|
||||
})?
|
||||
};
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ pub async fn pull(
|
|||
}
|
||||
}
|
||||
|
||||
error!("network error for {}: {}", parsed_url.full_url, s);
|
||||
error!(
|
||||
url = parsed_url.full_url,
|
||||
error = s,
|
||||
"network error pulling image"
|
||||
);
|
||||
PKAvatarError::NetworkErrorString(s)
|
||||
})?;
|
||||
let time_after_headers = Instant::now();
|
||||
|
|
@ -82,7 +86,22 @@ pub async fn pull(
|
|||
.map(|x| x.to_string());
|
||||
|
||||
let body = response.bytes().await.map_err(|e| {
|
||||
error!("network error for {}: {}", parsed_url.full_url, e);
|
||||
// terrible
|
||||
let mut s = format!("{}", e);
|
||||
if let Some(src) = e.source() {
|
||||
let _ = write!(s, ": {}", src);
|
||||
let mut err = src;
|
||||
while let Some(src) = err.source() {
|
||||
let _ = write!(s, ": {}", src);
|
||||
err = src;
|
||||
}
|
||||
}
|
||||
|
||||
error!(
|
||||
url = parsed_url.full_url,
|
||||
error = s,
|
||||
"network error pulling image"
|
||||
);
|
||||
PKAvatarError::NetworkError(e)
|
||||
})?;
|
||||
if body.len() != size as usize {
|
||||
|
|
@ -137,6 +156,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
|
|||
|
||||
match (url.scheme(), url.domain()) {
|
||||
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
|
||||
("https", Some("serve.apparyllis.com")) => {
|
||||
return Ok(ParsedUrl {
|
||||
channel_id: 0,
|
||||
attachment_id: 0,
|
||||
filename: "".to_string(),
|
||||
full_url: url.to_string(),
|
||||
})
|
||||
}
|
||||
_ => anyhow::bail!("not a discord cdn url"),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
libpk = { path = "../libpk" }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -19,17 +19,8 @@ use axum::{extract::State, http::Uri, routing::post, Json, Router};
|
|||
|
||||
mod logger;
|
||||
|
||||
// this package does not currently use libpk
|
||||
|
||||
#[tokio::main]
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.json()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
info!("hello world");
|
||||
|
||||
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
|
||||
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
|
||||
let (client, bg) = AsyncClient::connect(stream).await?;
|
||||
|
|
@ -86,11 +77,11 @@ async fn dispatch(
|
|||
let uri = match req.url.parse::<Uri>() {
|
||||
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
|
||||
Err(error) => {
|
||||
error!(?error, "failed to parse uri {}", req.url);
|
||||
error!(?error, uri = req.url, "failed to parse uri");
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
_ => {
|
||||
error!("uri {} is invalid", req.url);
|
||||
error!(uri = req.url, "uri is invalid");
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ futures = { workspace = true }
|
|||
lazy_static = { workspace = true }
|
||||
libpk = { path = "../libpk" }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
signal-hook = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,24 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
extract::{ConnectInfo, Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use libpk::runtime_config::RuntimeConfig;
|
||||
use serde_json::{json, to_string};
|
||||
use tracing::{error, info};
|
||||
use twilight_model::id::Id;
|
||||
use twilight_model::id::{marker::ChannelMarker, Id};
|
||||
|
||||
use crate::discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
use crate::{
|
||||
discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
shard_state::ShardStateManager,
|
||||
},
|
||||
event_awaiter::{AwaitEventRequest, EventAwaiter},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
fn status_code(code: StatusCode, body: String) -> Response {
|
||||
(code, body).into_response()
|
||||
|
|
@ -21,10 +26,15 @@ fn status_code(code: StatusCode, body: String) -> Response {
|
|||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
||||
pub async fn run_server(cache: Arc<DiscordCache>, shard_state: Arc<ShardStateManager>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
|
||||
// hacky fix for `move`
|
||||
let runtime_config_for_post = runtime_config.clone();
|
||||
let runtime_config_for_delete = runtime_config.clone();
|
||||
let awaiter_for_clear = awaiter.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/guilds/:guild_id",
|
||||
"/guilds/{guild_id}",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild(Id::new(guild_id)) {
|
||||
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
|
||||
|
|
@ -33,7 +43,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/members/@me",
|
||||
"/guilds/{guild_id}/members/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
|
||||
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
|
||||
|
|
@ -42,7 +52,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/@me",
|
||||
"/guilds/{guild_id}/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
|
||||
Ok(val) => {
|
||||
|
|
@ -56,7 +66,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/:user_id",
|
||||
"/guilds/{guild_id}/permissions/{user_id}",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
|
||||
|
|
@ -69,7 +79,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/channels",
|
||||
"/guilds/{guild_id}/channels",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
|
||||
Some(channels) => channels.to_owned(),
|
||||
|
|
@ -95,7 +105,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id",
|
||||
"/guilds/{guild_id}/channels/{channel_id}",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
|
||||
|
|
@ -107,7 +117,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
|
||||
"/guilds/{guild_id}/channels/{channel_id}/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
|
||||
|
|
@ -122,16 +132,19 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
|
||||
"/guilds/{guild_id}/channels/{channel_id}/permissions/{user_id}",
|
||||
get(|| async { "todo" }),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/last_message",
|
||||
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
|
||||
"/guilds/{guild_id}/channels/{channel_id}/last_message",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((_guild_id, channel_id)): Path<(u64, Id<ChannelMarker>)>| async move {
|
||||
let lm = cache.get_last_message(channel_id).await;
|
||||
status_code(StatusCode::FOUND, to_string(&lm).unwrap())
|
||||
}),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/roles",
|
||||
"/guilds/{guild_id}/roles",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
|
||||
Some(roles) => roles.to_owned(),
|
||||
|
|
@ -171,13 +184,45 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
|
||||
}))
|
||||
|
||||
.route("/runtime_config", get(|| async move {
|
||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
.route("/runtime_config/{key}", post(|Path(key): Path<String>, body: String| async move {
|
||||
let runtime_config = runtime_config_for_post;
|
||||
runtime_config.set(key, body).await.expect("failed to update runtime config");
|
||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
.route("/runtime_config/{key}", delete(|Path(key): Path<String>| async move {
|
||||
let runtime_config = runtime_config_for_delete;
|
||||
runtime_config.delete(key).await.expect("failed to update runtime config");
|
||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
|
||||
.route("/await_event", post(|ConnectInfo(addr): ConnectInfo<SocketAddr>, body: String| async move {
|
||||
info!("got request: {body} from: {addr}");
|
||||
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
|
||||
return status_code(StatusCode::BAD_REQUEST, "".to_string());
|
||||
};
|
||||
|
||||
awaiter.handle_request(req, addr).await;
|
||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||
}))
|
||||
.route("/clear_awaiter", post(|| async move {
|
||||
awaiter_for_clear.clear().await;
|
||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||
}))
|
||||
|
||||
.route("/shard_status", get(|| async move {
|
||||
status_code(StatusCode::FOUND, to_string(&shard_state.get().await).unwrap())
|
||||
}))
|
||||
|
||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||
.with_state(cache);
|
||||
|
||||
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
use anyhow::format_err;
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::RwLock;
|
||||
use twilight_cache_inmemory::{
|
||||
model::CachedMember,
|
||||
|
|
@ -8,11 +9,12 @@ use twilight_cache_inmemory::{
|
|||
traits::CacheableChannel,
|
||||
InMemoryCache, ResourceType,
|
||||
};
|
||||
use twilight_gateway::Event;
|
||||
use twilight_model::{
|
||||
channel::{Channel, ChannelType},
|
||||
guild::{Guild, Member, Permissions},
|
||||
id::{
|
||||
marker::{ChannelMarker, GuildMarker, UserMarker},
|
||||
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
|
|
@ -123,16 +125,134 @@ pub fn new() -> DiscordCache {
|
|||
.build(),
|
||||
);
|
||||
|
||||
DiscordCache(cache, client, RwLock::new(Vec::new()))
|
||||
DiscordCache(
|
||||
cache,
|
||||
client,
|
||||
RwLock::new(Vec::new()),
|
||||
RwLock::new(HashMap::new()),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct CachedMessage {
|
||||
id: Id<MessageMarker>,
|
||||
referenced_message: Option<Id<MessageMarker>>,
|
||||
author_username: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct LastMessageCacheEntry {
|
||||
pub current: CachedMessage,
|
||||
pub previous: Option<CachedMessage>,
|
||||
}
|
||||
|
||||
pub struct DiscordCache(
|
||||
pub Arc<InMemoryCache>,
|
||||
pub Arc<twilight_http::Client>,
|
||||
pub RwLock<Vec<u32>>,
|
||||
pub RwLock<HashMap<Id<ChannelMarker>, LastMessageCacheEntry>>,
|
||||
);
|
||||
|
||||
impl DiscordCache {
|
||||
pub async fn get_last_message(
|
||||
&self,
|
||||
channel: Id<ChannelMarker>,
|
||||
) -> Option<LastMessageCacheEntry> {
|
||||
self.3.read().await.get(&channel).cloned()
|
||||
}
|
||||
|
||||
pub async fn update(&self, event: &twilight_gateway::Event) {
|
||||
self.0.update(event);
|
||||
|
||||
match event {
|
||||
Event::MessageCreate(m) => match self.3.write().await.entry(m.channel_id) {
|
||||
std::collections::hash_map::Entry::Occupied(mut e) => {
|
||||
let cur = e.get();
|
||||
e.insert(LastMessageCacheEntry {
|
||||
current: CachedMessage {
|
||||
id: m.id,
|
||||
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
|
||||
author_username: m.author.name.clone(),
|
||||
},
|
||||
previous: Some(cur.current.clone()),
|
||||
});
|
||||
}
|
||||
std::collections::hash_map::Entry::Vacant(e) => {
|
||||
e.insert(LastMessageCacheEntry {
|
||||
current: CachedMessage {
|
||||
id: m.id,
|
||||
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
|
||||
author_username: m.author.name.clone(),
|
||||
},
|
||||
previous: None,
|
||||
});
|
||||
}
|
||||
},
|
||||
Event::MessageDelete(m) => {
|
||||
self.handle_message_deletion(m.channel_id, vec![m.id]).await;
|
||||
}
|
||||
Event::MessageDeleteBulk(m) => {
|
||||
self.handle_message_deletion(m.channel_id, m.ids.clone())
|
||||
.await;
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
|
||||
async fn handle_message_deletion(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
mids: Vec<Id<MessageMarker>>,
|
||||
) {
|
||||
let mut lm = self.3.write().await;
|
||||
|
||||
let Some(entry) = lm.get(&channel_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut entry = entry.clone();
|
||||
|
||||
// if none of the deleted messages are relevant, just return
|
||||
if !mids.contains(&entry.current.id)
|
||||
&& entry
|
||||
.previous
|
||||
.clone()
|
||||
.map(|v| !mids.contains(&v.id))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// remove "previous" entry if it was deleted
|
||||
if let Some(prev) = entry.previous.clone()
|
||||
&& mids.contains(&prev.id)
|
||||
{
|
||||
entry.previous = None;
|
||||
}
|
||||
|
||||
// set "current" entry to "previous" if current entry was deleted
|
||||
// (if the "previous" entry still exists, it was not deleted)
|
||||
if let Some(prev) = entry.previous.clone()
|
||||
&& mids.contains(&entry.current.id)
|
||||
{
|
||||
entry.current = prev;
|
||||
entry.previous = None;
|
||||
}
|
||||
|
||||
// if the current entry was already deleted, but previous wasn't,
|
||||
// we would've set current to previous
|
||||
// so if current is deleted this means both current and previous have
|
||||
// been deleted
|
||||
// so just drop the cache entry here
|
||||
if mids.contains(&entry.current.id) && entry.previous.is_none() {
|
||||
lm.remove(&channel_id);
|
||||
return;
|
||||
}
|
||||
|
||||
// ok, update the entry
|
||||
lm.insert(channel_id, entry.clone());
|
||||
}
|
||||
|
||||
pub async fn guild_permissions(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
|
|
@ -356,12 +476,14 @@ impl DiscordCache {
|
|||
system_channel_flags: guild.system_channel_flags(),
|
||||
system_channel_id: guild.system_channel_id(),
|
||||
threads: vec![],
|
||||
unavailable: false,
|
||||
unavailable: Some(false),
|
||||
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
|
||||
verification_level: guild.verification_level(),
|
||||
voice_states: vec![],
|
||||
widget_channel_id: guild.widget_channel_id(),
|
||||
widget_enabled: guild.widget_enabled(),
|
||||
guild_scheduled_events: guild.guild_scheduled_events().to_vec(),
|
||||
max_stage_video_channel_users: guild.max_stage_video_channel_users(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
use anyhow::anyhow;
|
||||
use futures::StreamExt;
|
||||
use libpk::_config::ClusterSettings;
|
||||
use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig, state::ShardStateEvent};
|
||||
use metrics::counter;
|
||||
use std::sync::{mpsc::Sender, Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_gateway::{
|
||||
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
|
||||
|
|
@ -12,9 +14,12 @@ use twilight_model::gateway::{
|
|||
Intents,
|
||||
};
|
||||
|
||||
use crate::discord::identify_queue::{self, RedisQueue};
|
||||
use crate::{
|
||||
discord::identify_queue::{self, RedisQueue},
|
||||
RUNTIME_CONFIG_KEY_EVENT_TARGET,
|
||||
};
|
||||
|
||||
use super::{cache::DiscordCache, shard_state::ShardStateManager};
|
||||
use super::cache::DiscordCache;
|
||||
|
||||
pub fn cluster_config() -> ClusterSettings {
|
||||
libpk::config
|
||||
|
|
@ -44,6 +49,12 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
|
||||
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
|
||||
warn!("we have less than 16 shards, assuming single gateway process");
|
||||
if cluster_settings.node_id != 0 {
|
||||
return Err(anyhow!(
|
||||
"expecting to be node 0 in single-process mode, but we are node {}",
|
||||
cluster_settings.node_id
|
||||
));
|
||||
}
|
||||
(0, (cluster_settings.total_shards - 1).into())
|
||||
} else {
|
||||
(
|
||||
|
|
@ -52,6 +63,13 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
)
|
||||
};
|
||||
|
||||
let prefix = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_prefix_for_gateway
|
||||
.clone();
|
||||
|
||||
let shards = create_iterator(
|
||||
start_shard..end_shard + 1,
|
||||
cluster_settings.total_shards,
|
||||
|
|
@ -64,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
.to_owned(),
|
||||
intents,
|
||||
)
|
||||
.presence(presence("pk;help", false))
|
||||
.presence(presence(format!("{prefix}help").as_str(), false))
|
||||
.queue(queue.clone())
|
||||
.build(),
|
||||
|_, builder| builder.build(),
|
||||
|
|
@ -76,15 +94,23 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
Ok(shards_vec)
|
||||
}
|
||||
|
||||
#[tracing::instrument(fields(shard = %shard.id()), skip_all)]
|
||||
pub async fn runner(
|
||||
mut shard: Shard<RedisQueue>,
|
||||
_tx: Sender<(ShardId, String)>,
|
||||
shard_state: ShardStateManager,
|
||||
tx: Sender<(ShardId, Event, String)>,
|
||||
tx_state: Sender<(ShardId, ShardStateEvent, Option<Event>, Option<i32>)>,
|
||||
cache: Arc<DiscordCache>,
|
||||
runtime_config: Arc<RuntimeConfig>,
|
||||
) {
|
||||
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
|
||||
let shard_id = shard.id().number();
|
||||
|
||||
let our_user_id = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.client_id;
|
||||
|
||||
info!("waiting for events");
|
||||
while let Some(item) = shard.next().await {
|
||||
let raw_event = match item {
|
||||
|
|
@ -105,7 +131,9 @@ pub async fn runner(
|
|||
)
|
||||
.increment(1);
|
||||
|
||||
if let Err(error) = shard_state.socket_closed(shard_id).await {
|
||||
if let Err(error) =
|
||||
tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None))
|
||||
{
|
||||
error!("failed to update shard state for socket closure: {error}");
|
||||
}
|
||||
|
||||
|
|
@ -127,7 +155,7 @@ pub async fn runner(
|
|||
continue;
|
||||
}
|
||||
Err(error) => {
|
||||
error!("shard {shard_id} failed to parse gateway event: {error}");
|
||||
error!(?error, ?shard_id, "failed to parse gateway event");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
|
@ -147,14 +175,31 @@ pub async fn runner(
|
|||
.increment(1);
|
||||
|
||||
// update shard state and discord cache
|
||||
if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await {
|
||||
tracing::error!(?error, "error updating redis state");
|
||||
if matches!(event, Event::Ready(_)) || matches!(event, Event::Resumed) {
|
||||
if let Err(error) = tx_state.try_send((
|
||||
shard.id(),
|
||||
ShardStateEvent::Other,
|
||||
Some(event.clone()),
|
||||
None,
|
||||
)) {
|
||||
tracing::error!(?error, "error updating shard state");
|
||||
}
|
||||
}
|
||||
// need to do heartbeat separately, to get the latency
|
||||
let latency_num = shard
|
||||
.latency()
|
||||
.recent()
|
||||
.first()
|
||||
.map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
if let Event::GatewayHeartbeatAck = event
|
||||
&& let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await
|
||||
&& let Err(error) = tx_state.try_send((
|
||||
shard.id(),
|
||||
ShardStateEvent::Heartbeat,
|
||||
Some(event.clone()),
|
||||
Some(latency_num),
|
||||
))
|
||||
{
|
||||
tracing::error!(?error, "error updating redis state for latency");
|
||||
tracing::error!(?error, "error updating shard state for latency");
|
||||
}
|
||||
|
||||
if let Event::Ready(_) = event {
|
||||
|
|
@ -162,10 +207,28 @@ pub async fn runner(
|
|||
cache.2.write().await.push(shard_id);
|
||||
}
|
||||
}
|
||||
cache.0.update(&event);
|
||||
cache.update(&event).await;
|
||||
|
||||
// okay, we've handled the event internally, let's send it to consumers
|
||||
// tx.send((shard.id(), raw_event)).unwrap();
|
||||
|
||||
// some basic filtering here is useful
|
||||
// we can't use if matching using the | operator, so anything matched does nothing
|
||||
// and the default match skips the next block (continues to the next event)
|
||||
match event {
|
||||
Event::InteractionCreate(_) => {}
|
||||
Event::MessageCreate(ref m) if m.author.id != our_user_id => {}
|
||||
Event::MessageUpdate(ref m) if m.author.id != our_user_id && !m.author.bot => {}
|
||||
Event::MessageDelete(_) => {}
|
||||
Event::MessageDeleteBulk(_) => {}
|
||||
Event::ReactionAdd(ref r) if r.user_id != our_user_id => {}
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
|
||||
tx.send((shard.id(), event, raw_event)).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,8 +78,8 @@ async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: on
|
|||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
Err(error) => {
|
||||
error!(?error, ?shard_id, ?bucket, "error getting shard allowance")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,49 +1,63 @@
|
|||
use fred::{clients::RedisPool, interfaces::HashesInterface};
|
||||
use metrics::{counter, gauge};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
use twilight_gateway::{Event, Latency};
|
||||
use twilight_gateway::Event;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use libpk::state::ShardState;
|
||||
|
||||
#[derive(Clone)]
|
||||
use super::gateway::cluster_config;
|
||||
|
||||
pub struct ShardStateManager {
|
||||
redis: RedisPool,
|
||||
shards: RwLock<HashMap<u32, ShardState>>,
|
||||
}
|
||||
|
||||
pub fn new(redis: RedisPool) -> ShardStateManager {
|
||||
ShardStateManager { redis }
|
||||
ShardStateManager {
|
||||
redis: redis,
|
||||
shards: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
impl ShardStateManager {
|
||||
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
|
||||
match event {
|
||||
// also update gateway.rs with event types
|
||||
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
|
||||
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
|
||||
let data: Option<String> = self.redis.hget("pluralkit:shardstatus", shard_id).await?;
|
||||
match data {
|
||||
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
|
||||
None => Ok(ShardState::default()),
|
||||
async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> {
|
||||
{
|
||||
let mut shards = self.shards.write().await;
|
||||
shards.insert(id, state.clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset::<(), &str, (String, String)>(
|
||||
"pluralkit:shardstatus",
|
||||
(
|
||||
shard_id.to_string(),
|
||||
serde_json::to_string(&info).expect("could not serialize shard"),
|
||||
id.to_string(),
|
||||
serde_json::to_string(&state).expect("could not serialize shard"),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_shard(&self, id: u32) -> Option<ShardState> {
|
||||
let shards = self.shards.read().await;
|
||||
shards.get(&id).cloned()
|
||||
}
|
||||
|
||||
pub async fn get(&self) -> Vec<ShardState> {
|
||||
self.shards.read().await.values().cloned().collect()
|
||||
}
|
||||
|
||||
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"shard {} {}",
|
||||
|
|
@ -57,32 +71,52 @@ impl ShardStateManager {
|
|||
)
|
||||
.increment(1);
|
||||
gauge!("pluralkit_gateway_shard_up").increment(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
|
||||
let mut info = self
|
||||
.get_shard(shard_id)
|
||||
.await
|
||||
.unwrap_or(ShardState::default());
|
||||
|
||||
info.shard_id = shard_id as i32;
|
||||
info.cluster_id = Some(cluster_config().node_id as i32);
|
||||
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.up = true;
|
||||
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
gauge!("pluralkit_gateway_shard_up").decrement(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
|
||||
let mut info = self
|
||||
.get_shard(shard_id)
|
||||
.await
|
||||
.unwrap_or(ShardState::default());
|
||||
|
||||
info.shard_id = shard_id as i32;
|
||||
info.cluster_id = Some(cluster_config().node_id as i32);
|
||||
info.up = false;
|
||||
info.disconnection_count += 1;
|
||||
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
|
||||
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
|
||||
|
||||
let mut info = self
|
||||
.get_shard(shard_id)
|
||||
.await
|
||||
.unwrap_or(ShardState::default());
|
||||
|
||||
info.shard_id = shard_id as i32;
|
||||
info.cluster_id = Some(cluster_config().node_id as i32);
|
||||
info.up = true;
|
||||
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.latency = latency
|
||||
.recent()
|
||||
.first()
|
||||
.map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string())
|
||||
.set(info.latency);
|
||||
info.latency = latency;
|
||||
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
242
crates/gateway/src/event_awaiter.rs
Normal file
242
crates/gateway/src/event_awaiter.rs
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
// - reaction: (message_id, user_id)
|
||||
// - message: (author_id, channel_id, ?options)
|
||||
// - interaction: (custom_id where not_includes "help-menu")
|
||||
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
net::{IpAddr, SocketAddr},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tokio::{sync::RwLock, time::Instant};
|
||||
use tracing::info;
|
||||
use twilight_gateway::Event;
|
||||
use twilight_model::{
|
||||
application::interaction::InteractionData,
|
||||
id::{
|
||||
marker::{ChannelMarker, MessageMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
|
||||
static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AwaitEventRequest {
|
||||
Reaction {
|
||||
message_id: Id<MessageMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
Message {
|
||||
channel_id: Id<ChannelMarker>,
|
||||
author_id: Id<UserMarker>,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
options: Option<Vec<String>>,
|
||||
},
|
||||
Interaction {
|
||||
id: String,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct EventAwaiter {
|
||||
reactions: RwLock<HashMap<(Id<MessageMarker>, Id<UserMarker>), (Instant, String)>>,
|
||||
messages: RwLock<
|
||||
HashMap<(Id<ChannelMarker>, Id<UserMarker>), (Instant, String, Option<Vec<String>>)>,
|
||||
>,
|
||||
interactions: RwLock<HashMap<String, (Instant, String)>>,
|
||||
}
|
||||
|
||||
impl EventAwaiter {
|
||||
pub fn new() -> Self {
|
||||
let v = Self {
|
||||
reactions: RwLock::new(HashMap::new()),
|
||||
messages: RwLock::new(HashMap::new()),
|
||||
interactions: RwLock::new(HashMap::new()),
|
||||
};
|
||||
|
||||
v
|
||||
}
|
||||
|
||||
pub async fn cleanup_loop(&self) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
info!("running event_awaiter cleanup loop");
|
||||
let mut counts = (0, 0, 0);
|
||||
let now = Instant::now();
|
||||
{
|
||||
let mut reactions = self.reactions.write().await;
|
||||
for key in reactions.clone().keys() {
|
||||
if let Entry::Occupied(entry) = reactions.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.0 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut messages = self.messages.write().await;
|
||||
for key in messages.clone().keys() {
|
||||
if let Entry::Occupied(entry) = messages.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.1 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut interactions = self.interactions.write().await;
|
||||
for key in interactions.clone().keys() {
|
||||
if let Entry::Occupied(entry) = interactions.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.2 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn target_for_event(&self, event: Event) -> Option<String> {
|
||||
match event {
|
||||
Event::MessageCreate(message) => {
|
||||
let mut messages = self.messages.write().await;
|
||||
|
||||
messages
|
||||
.remove(&(message.channel_id, message.author.id))
|
||||
.map(|(timeout, target, options)| {
|
||||
if let Some(options) = options
|
||||
&& !options.contains(&message.content.to_lowercase())
|
||||
{
|
||||
messages.insert(
|
||||
(message.channel_id, message.author.id),
|
||||
(timeout, target, Some(options)),
|
||||
);
|
||||
return None;
|
||||
}
|
||||
Some((*target).to_string())
|
||||
})?
|
||||
}
|
||||
Event::ReactionAdd(reaction)
|
||||
if let Some((_, target)) = self
|
||||
.reactions
|
||||
.write()
|
||||
.await
|
||||
.remove(&(reaction.message_id, reaction.user_id)) =>
|
||||
{
|
||||
Some((*target).to_string())
|
||||
}
|
||||
Event::InteractionCreate(interaction)
|
||||
if let Some(data) = interaction.data.clone()
|
||||
&& let InteractionData::MessageComponent(component) = data
|
||||
&& !component.custom_id.contains("help-menu")
|
||||
&& let Some((_, target)) =
|
||||
self.interactions.write().await.remove(&component.custom_id) =>
|
||||
{
|
||||
Some((*target).to_string())
|
||||
}
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_request(&self, req: AwaitEventRequest, addr: SocketAddr) {
|
||||
match req {
|
||||
AwaitEventRequest::Reaction {
|
||||
message_id,
|
||||
user_id,
|
||||
target,
|
||||
timeout,
|
||||
} => {
|
||||
self.reactions.write().await.insert(
|
||||
(message_id, user_id),
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target_or_addr(target, addr),
|
||||
),
|
||||
);
|
||||
}
|
||||
AwaitEventRequest::Message {
|
||||
channel_id,
|
||||
author_id,
|
||||
target,
|
||||
timeout,
|
||||
options,
|
||||
} => {
|
||||
self.messages.write().await.insert(
|
||||
(channel_id, author_id),
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target_or_addr(target, addr),
|
||||
options,
|
||||
),
|
||||
);
|
||||
}
|
||||
AwaitEventRequest::Interaction {
|
||||
id,
|
||||
target,
|
||||
timeout,
|
||||
} => {
|
||||
self.interactions.write().await.insert(
|
||||
id,
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target_or_addr(target, addr),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear(&self) {
|
||||
self.reactions.write().await.clear();
|
||||
self.messages.write().await.clear();
|
||||
self.interactions.write().await.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn target_or_addr(target: String, addr: SocketAddr) -> String {
|
||||
if target == "source-addr" {
|
||||
let ip_str = match addr.ip() {
|
||||
IpAddr::V4(v4) => v4.to_string(),
|
||||
IpAddr::V6(v6) => {
|
||||
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||
v4.to_string()
|
||||
} else {
|
||||
format!("[{v6}]")
|
||||
}
|
||||
}
|
||||
};
|
||||
format!("http://{ip_str}:5002/events")
|
||||
} else {
|
||||
target
|
||||
}
|
||||
}
|
||||
|
|
@ -1,39 +1,79 @@
|
|||
#![feature(let_chains)]
|
||||
#![feature(if_let_guard)]
|
||||
#![feature(duration_constructors)]
|
||||
|
||||
use chrono::Timelike;
|
||||
use discord::gateway::cluster_config;
|
||||
use event_awaiter::EventAwaiter;
|
||||
use fred::{clients::RedisPool, interfaces::*};
|
||||
use signal_hook::{
|
||||
consts::{SIGINT, SIGTERM},
|
||||
iterator::Signals,
|
||||
use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
|
||||
use reqwest::{ClientBuilder, StatusCode};
|
||||
use std::{sync::Arc, time::Duration, vec::Vec};
|
||||
use tokio::{
|
||||
signal::unix::{signal, SignalKind},
|
||||
sync::mpsc::channel,
|
||||
task::JoinSet,
|
||||
};
|
||||
use std::{
|
||||
sync::{mpsc::channel, Arc},
|
||||
time::Duration,
|
||||
vec::Vec,
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_gateway::{MessageSender, ShardId};
|
||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||
|
||||
mod cache_api;
|
||||
mod api;
|
||||
mod discord;
|
||||
mod event_awaiter;
|
||||
mod logger;
|
||||
|
||||
libpk::main!("gateway");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let (shutdown_tx, shutdown_rx) = channel::<()>();
|
||||
let shutdown_tx = Arc::new(shutdown_tx);
|
||||
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
|
||||
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let shard_state = discord::shard_state::new(redis.clone());
|
||||
let runtime_config = Arc::new(
|
||||
RuntimeConfig::new(
|
||||
redis.clone(),
|
||||
format!(
|
||||
"{}:{}",
|
||||
libpk::config.runtime_config_key.as_ref().unwrap(),
|
||||
cluster_config().node_id
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
// hacky, but needed for selfhost for now
|
||||
if let Some(target) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.gateway_target
|
||||
.clone()
|
||||
{
|
||||
runtime_config
|
||||
.set(RUNTIME_CONFIG_KEY_EVENT_TARGET.to_string(), target)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let cache = Arc::new(discord::cache::new());
|
||||
let awaiter = Arc::new(EventAwaiter::new());
|
||||
tokio::spawn({
|
||||
let awaiter = awaiter.clone();
|
||||
async move { awaiter.cleanup_loop().await }
|
||||
});
|
||||
|
||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||
|
||||
let (event_tx, _event_rx) = channel();
|
||||
// arbitrary
|
||||
// todo: make sure this doesn't fill up
|
||||
let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000);
|
||||
|
||||
// todo: make sure this doesn't fill up
|
||||
let (state_tx, mut state_rx) = channel::<(
|
||||
ShardId,
|
||||
ShardStateEvent,
|
||||
Option<twilight_gateway::Event>,
|
||||
Option<i32>,
|
||||
)>(1000);
|
||||
|
||||
let mut senders = Vec::new();
|
||||
let mut signal_senders = Vec::new();
|
||||
|
|
@ -45,61 +85,160 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
set.spawn(tokio::spawn(discord::gateway::runner(
|
||||
shard,
|
||||
event_tx.clone(),
|
||||
shard_state.clone(),
|
||||
state_tx.clone(),
|
||||
cache.clone(),
|
||||
runtime_config.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
let shard_state = Arc::new(discord::shard_state::new(redis.clone()));
|
||||
|
||||
set.spawn(tokio::spawn({
|
||||
let shard_state = shard_state.clone();
|
||||
|
||||
async move {
|
||||
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
|
||||
match state_event {
|
||||
ShardStateEvent::Heartbeat => {
|
||||
if !latency.is_none()
|
||||
&& let Err(error) = shard_state
|
||||
.heartbeated(shard_id.number(), latency.unwrap())
|
||||
.await
|
||||
{
|
||||
error!("failed to update shard state for heartbeat: {error}")
|
||||
};
|
||||
}
|
||||
ShardStateEvent::Closed => {
|
||||
if let Err(error) = shard_state.socket_closed(shard_id.number()).await {
|
||||
error!("failed to update shard state for heartbeat: {error}")
|
||||
};
|
||||
}
|
||||
ShardStateEvent::Other => {
|
||||
if let Err(error) = shard_state
|
||||
.handle_event(
|
||||
shard_id.number(),
|
||||
parsed_event.expect("shard state event not provided!"),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("failed to update shard state for heartbeat: {error}")
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
set.spawn(tokio::spawn({
|
||||
let runtime_config = runtime_config.clone();
|
||||
let awaiter = awaiter.clone();
|
||||
|
||||
async move {
|
||||
let client = Arc::new(
|
||||
ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.timeout(Duration::from_secs(1))
|
||||
.build()
|
||||
.expect("error making client"),
|
||||
);
|
||||
|
||||
while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await {
|
||||
let target = if let Some(target) = awaiter.target_for_event(parsed_event).await {
|
||||
info!(target = ?target, "sending event to awaiter");
|
||||
Some(target)
|
||||
} else if let Some(target) =
|
||||
runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await
|
||||
{
|
||||
Some(target)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(target) = target {
|
||||
tokio::spawn({
|
||||
let client = client.clone();
|
||||
async move {
|
||||
match client
|
||||
.post(format!("{target}/{}", shard_id.number()))
|
||||
.body(raw_event)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if res.status() != StatusCode::OK {
|
||||
error!(
|
||||
status = ?res.status(),
|
||||
target = ?target,
|
||||
"got non-200 from bot while sending event",
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, "failed to request event target");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
set.spawn(tokio::spawn(
|
||||
async move { scheduled_task(redis, senders).await },
|
||||
));
|
||||
|
||||
// todo: probably don't do it this way
|
||||
let api_shutdown_tx = shutdown_tx.clone();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
match cache_api::run_server(cache).await {
|
||||
match api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await {
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to serve cache api");
|
||||
let _ = api_shutdown_tx.send(());
|
||||
error!(?error, "failed to serve cache api");
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}));
|
||||
|
||||
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
|
||||
|
||||
set.spawn(tokio::spawn(async move {
|
||||
for sig in signals.forever() {
|
||||
info!("received signal {:?}", sig);
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
signal(SignalKind::interrupt()).unwrap().recv().await;
|
||||
info!("got SIGINT");
|
||||
}));
|
||||
|
||||
let _ = shutdown_rx.recv();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
signal(SignalKind::terminate()).unwrap().recv().await;
|
||||
info!("got SIGTERM");
|
||||
}));
|
||||
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
set.join_next().await;
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
set.abort_all();
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
|
||||
let prefix = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_prefix_for_gateway
|
||||
.clone();
|
||||
|
||||
println!("{prefix}");
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
(60 - chrono::offset::Utc::now().second()).into(),
|
||||
|
|
@ -119,9 +258,9 @@ async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>
|
|||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence(
|
||||
if let Some(status) = status {
|
||||
format!("pk;help | {}", status)
|
||||
format!("{prefix}help | {status}")
|
||||
} else {
|
||||
"pk;help".to_string()
|
||||
format!("{prefix}help")
|
||||
}
|
||||
.as_str(),
|
||||
false,
|
||||
|
|
|
|||
15
crates/gdpr_worker/Cargo.toml
Normal file
15
crates/gdpr_worker/Cargo.toml
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
[package]
|
||||
name = "gdpr_worker"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
libpk = { path = "../libpk" }
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
twilight-model = { workspace = true }
|
||||
149
crates/gdpr_worker/src/main.rs
Normal file
149
crates/gdpr_worker/src/main.rs
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use sqlx::prelude::FromRow;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_http::api_error::{ApiError, GeneralApiError};
|
||||
use twilight_model::id::{
|
||||
marker::{ChannelMarker, MessageMarker},
|
||||
Id,
|
||||
};
|
||||
|
||||
// create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null);
|
||||
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_messages_db().await?;
|
||||
|
||||
let mut client_builder = twilight_http::Client::builder()
|
||||
.token(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.clone(),
|
||||
)
|
||||
.timeout(Duration::from_secs(30));
|
||||
|
||||
if let Some(base_url) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.api_base_url
|
||||
.clone()
|
||||
{
|
||||
client_builder = client_builder.proxy(base_url, true).ratelimiter(None);
|
||||
}
|
||||
|
||||
let client = Arc::new(client_builder.build());
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
match run_job(db.clone(), client.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
error!(?error, "failed to run messages gdpr job");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromRow)]
|
||||
struct GdprJobEntry {
|
||||
mid: i64,
|
||||
channel_id: i64,
|
||||
}
|
||||
|
||||
async fn run_job(pool: sqlx::PgPool, discord: Arc<twilight_http::Client>) -> anyhow::Result<()> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
let message: Option<GdprJobEntry> = sqlx::query_as(
|
||||
"select mid, channel_id from messages_gdpr_jobs for update skip locked limit 1;",
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let Some(message) = message else {
|
||||
info!("no job to run, sleeping for 1 minute");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
info!("got mid={}, cleaning up...", message.mid);
|
||||
|
||||
// naively delete message on discord's end
|
||||
let res = discord
|
||||
.delete_message(
|
||||
Id::<ChannelMarker>::new(message.channel_id as u64),
|
||||
Id::<MessageMarker>::new(message.mid as u64),
|
||||
)
|
||||
.await;
|
||||
|
||||
if res.is_ok() {
|
||||
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
|
||||
.bind(message.mid)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Err(err) = res {
|
||||
if let twilight_http::error::ErrorType::Response { error, status, .. } = err.kind()
|
||||
&& let ApiError::General(GeneralApiError { code, .. }) = error
|
||||
{
|
||||
match (status.get(), code) {
|
||||
(403, _) => {
|
||||
warn!(
|
||||
"got 403 while deleting message in channel {}, failing fast",
|
||||
message.channel_id
|
||||
);
|
||||
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
|
||||
.bind(message.channel_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
(_, 10003) => {
|
||||
warn!(
|
||||
"deleting message in channel {}: channel not found, failing fast",
|
||||
message.channel_id
|
||||
);
|
||||
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
|
||||
.bind(message.channel_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
(_, 10008) => {
|
||||
warn!("deleting message {}: message not found", message.mid);
|
||||
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
|
||||
.bind(message.mid)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
(_, 50083) => {
|
||||
warn!(
|
||||
"could not delete message in thread {}: thread is archived, failing fast",
|
||||
message.channel_id
|
||||
);
|
||||
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
|
||||
.bind(message.channel_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
error!(
|
||||
?status,
|
||||
?code,
|
||||
message_id = message.mid,
|
||||
"got unknown error deleting message",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
0
crates/h
Normal file
0
crates/h
Normal file
|
|
@ -8,6 +8,7 @@ anyhow = { workspace = true }
|
|||
fred = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
pk_macros = { path = "../macros" }
|
||||
sentry = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -12,10 +12,16 @@ pub struct ClusterSettings {
|
|||
pub total_nodes: u32,
|
||||
}
|
||||
|
||||
fn _default_bot_prefix() -> String {
|
||||
"pk;".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DiscordConfig {
|
||||
pub client_id: Id<UserMarker>,
|
||||
pub bot_token: String,
|
||||
#[serde(default = "_default_bot_prefix")]
|
||||
pub bot_prefix_for_gateway: String,
|
||||
pub client_secret: String,
|
||||
pub max_concurrency: u32,
|
||||
#[serde(default)]
|
||||
|
|
@ -24,6 +30,9 @@ pub struct DiscordConfig {
|
|||
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub cache_api_addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub gateway_target: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
|
|
@ -85,6 +94,7 @@ pub struct ScheduledTasksConfig {
|
|||
pub set_guild_count: bool,
|
||||
pub expected_gateway_count: usize,
|
||||
pub gateway_url: String,
|
||||
pub prometheus_url: String,
|
||||
}
|
||||
|
||||
fn _metrics_default() -> bool {
|
||||
|
|
@ -113,6 +123,9 @@ pub struct PKConfig {
|
|||
#[serde(default = "_json_log_default")]
|
||||
pub(crate) json_log: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub runtime_config_key: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub sentry_url: Option<String>,
|
||||
}
|
||||
|
|
@ -132,10 +145,15 @@ impl PKConfig {
|
|||
lazy_static! {
|
||||
#[derive(Debug)]
|
||||
pub static ref CONFIG: Arc<PKConfig> = {
|
||||
// hacks
|
||||
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
|
||||
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
|
||||
std::env::set_var("pluralkit__discord__cluster__node_id", var);
|
||||
}
|
||||
if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX")
|
||||
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
|
||||
std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap());
|
||||
}
|
||||
|
||||
Arc::new(Config::builder()
|
||||
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
|
||||
|
|
|
|||
|
|
@ -8,12 +8,15 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
|
|||
use sentry_tracing::event_from_event;
|
||||
|
||||
pub mod db;
|
||||
pub mod runtime_config;
|
||||
pub mod state;
|
||||
|
||||
pub mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
// functions in this file are only used by the main function below
|
||||
// functions in this file are only used by the main function in macros/entrypoint.rs
|
||||
|
||||
pub use pk_macros::main;
|
||||
|
||||
pub fn init_logging(component: &str) {
|
||||
let sentry_layer =
|
||||
|
|
@ -42,6 +45,7 @@ pub fn init_logging(component: &str) {
|
|||
tracing_subscriber::registry()
|
||||
.with(sentry_layer)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(EnvFilter::from_default_env())
|
||||
.init();
|
||||
}
|
||||
}
|
||||
|
|
@ -66,28 +70,3 @@ pub fn init_sentry() -> sentry::ClientInitGuard {
|
|||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! main {
|
||||
($component:expr) => {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
// we might also be able to use env!("CARGO_CRATE_NAME") here
|
||||
libpk::init_logging($component);
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
if let Err(err) = libpk::init_metrics() {
|
||||
tracing::error!("failed to init metrics collector: {err}");
|
||||
};
|
||||
tracing::info!("hello world");
|
||||
if let Err(err) = real_main().await {
|
||||
tracing::error!("failed to run service: {err}");
|
||||
};
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
|||
72
crates/libpk/src/runtime_config.rs
Normal file
72
crates/libpk/src/runtime_config.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use fred::{clients::RedisPool, interfaces::HashesInterface};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
pub struct RuntimeConfig {
|
||||
redis: RedisPool,
|
||||
settings: RwLock<HashMap<String, String>>,
|
||||
redis_key: String,
|
||||
}
|
||||
|
||||
impl RuntimeConfig {
|
||||
pub async fn new(redis: RedisPool, component_key: String) -> anyhow::Result<Self> {
|
||||
let redis_key = format!("remote_config:{component_key}");
|
||||
|
||||
let mut c = RuntimeConfig {
|
||||
redis,
|
||||
settings: RwLock::new(HashMap::new()),
|
||||
redis_key,
|
||||
};
|
||||
|
||||
c.load().await?;
|
||||
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
pub async fn load(&mut self) -> anyhow::Result<()> {
|
||||
let redis_config: HashMap<String, String> = self.redis.hgetall(&self.redis_key).await?;
|
||||
|
||||
let mut settings = self.settings.write().await;
|
||||
|
||||
for (key, value) in redis_config {
|
||||
settings.insert(key, value);
|
||||
}
|
||||
|
||||
info!("starting with runtime config: {:?}", settings);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set(&self, key: String, value: String) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset::<(), &str, (String, String)>(&self.redis_key, (key.clone(), value.clone()))
|
||||
.await?;
|
||||
self.settings
|
||||
.write()
|
||||
.await
|
||||
.insert(key.clone(), value.clone());
|
||||
info!("updated runtime config: {key}={value}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete(&self, key: String) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hdel::<(), &str, String>(&self.redis_key, key.clone())
|
||||
.await?;
|
||||
self.settings.write().await.remove(&key.clone());
|
||||
info!("updated runtime config: {key} removed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get(&self, key: &str) -> Option<String> {
|
||||
self.settings.read().await.get(key).cloned()
|
||||
}
|
||||
|
||||
pub async fn exists(&self, key: &str) -> bool {
|
||||
self.settings.read().await.contains_key(key)
|
||||
}
|
||||
|
||||
pub async fn get_all(&self) -> HashMap<String, String> {
|
||||
self.settings.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)]
|
||||
pub struct ShardState {
|
||||
pub shard_id: i32,
|
||||
pub up: bool,
|
||||
|
|
@ -10,3 +10,9 @@ pub struct ShardState {
|
|||
pub last_connection: i32,
|
||||
pub cluster_id: Option<i32>,
|
||||
}
|
||||
|
||||
pub enum ShardStateEvent {
|
||||
Closed,
|
||||
Heartbeat,
|
||||
Other,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[package]
|
||||
name = "model_macros"
|
||||
name = "pk_macros"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
41
crates/macros/src/entrypoint.rs
Normal file
41
crates/macros/src/entrypoint.rs
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
use proc_macro::{Delimiter, TokenTree};
|
||||
use quote::quote;
|
||||
|
||||
pub fn macro_impl(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
// yes, this ignores everything except the codeblock
|
||||
// it's fine.
|
||||
let body = match input.into_iter().last().expect("empty") {
|
||||
TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => group.stream(),
|
||||
_ => panic!("invalid function"),
|
||||
};
|
||||
|
||||
let body = proc_macro2::TokenStream::from(body);
|
||||
|
||||
return quote! {
|
||||
fn main() {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
libpk::init_logging(env!("CARGO_CRATE_NAME"));
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
if let Err(error) = libpk::init_metrics() {
|
||||
tracing::error!(?error, "failed to init metrics collector");
|
||||
};
|
||||
|
||||
tracing::info!("hello world");
|
||||
|
||||
let result: anyhow::Result<()> = async { #body }.await;
|
||||
|
||||
if let Err(error) = result {
|
||||
tracing::error!(?error, "failed to run service");
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
.into();
|
||||
}
|
||||
14
crates/macros/src/lib.rs
Normal file
14
crates/macros/src/lib.rs
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
use proc_macro::TokenStream;
|
||||
|
||||
mod entrypoint;
|
||||
mod model;
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
entrypoint::macro_impl(args, input)
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn pk_model(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
model::macro_impl(args, input)
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ struct ModelField {
|
|||
patch: ElemPatchability,
|
||||
json: Option<Expr>,
|
||||
is_privacy: bool,
|
||||
privacy: Option<Expr>,
|
||||
default: Option<Expr>,
|
||||
}
|
||||
|
||||
|
|
@ -26,6 +27,7 @@ fn parse_field(field: syn::Field) -> ModelField {
|
|||
patch: ElemPatchability::None,
|
||||
json: None,
|
||||
is_privacy: false,
|
||||
privacy: None,
|
||||
default: None,
|
||||
};
|
||||
|
||||
|
|
@ -61,6 +63,12 @@ fn parse_field(field: syn::Field) -> ModelField {
|
|||
}
|
||||
f.json = Some(nv.value.clone());
|
||||
}
|
||||
"privacy" => {
|
||||
if f.privacy.is_some() {
|
||||
panic!("cannot set privacy multiple times for same field");
|
||||
}
|
||||
f.privacy = Some(nv.value.clone());
|
||||
}
|
||||
"default" => {
|
||||
if f.default.is_some() {
|
||||
panic!("cannot set default multiple times for same field");
|
||||
|
|
@ -84,8 +92,7 @@ fn parse_field(field: syn::Field) -> ModelField {
|
|||
f
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn pk_model(
|
||||
pub fn macro_impl(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
|
|
@ -108,8 +115,6 @@ pub fn pk_model(
|
|||
panic!("fields of a struct must be named");
|
||||
};
|
||||
|
||||
// println!("{}: {:#?}", tname, fields);
|
||||
|
||||
let tfields = mk_tfields(fields.clone());
|
||||
let from_json = mk_tfrom_json(fields.clone());
|
||||
let _from_sql = mk_tfrom_sql(fields.clone());
|
||||
|
|
@ -138,9 +143,7 @@ pub fn pk_model(
|
|||
#from_json
|
||||
}
|
||||
|
||||
pub fn to_json(self) -> serde_json::Value {
|
||||
#to_json
|
||||
}
|
||||
#to_json
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -189,19 +192,28 @@ fn mk_tfrom_sql(_fields: Vec<ModelField>) -> TokenStream {
|
|||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
// todo: check privacy access
|
||||
let has_privacy = fields.iter().any(|f| f.privacy.is_some());
|
||||
let fielddefs: TokenStream = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
f.json.as_ref().map(|v| {
|
||||
let tname = f.name.clone();
|
||||
if let Some(default) = f.default.as_ref() {
|
||||
let maybepriv = if let Some(privacy) = f.privacy.as_ref() {
|
||||
quote! {
|
||||
#v: self.#tname.unwrap_or(#default),
|
||||
#v: crate::_util::privacy_lookup!(self.#tname, self.#privacy, lookup_level)
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#v: self.#tname,
|
||||
#v: self.#tname
|
||||
}
|
||||
};
|
||||
if let Some(default) = f.default.as_ref() {
|
||||
quote! {
|
||||
#maybepriv.unwrap_or(#default),
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#maybepriv,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -223,13 +235,35 @@ fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
|||
})
|
||||
.collect();
|
||||
|
||||
quote! {
|
||||
serde_json::json!({
|
||||
#fielddefs
|
||||
"privacy": {
|
||||
#privacyfielddefs
|
||||
let privdef = if has_privacy {
|
||||
quote! {
|
||||
, lookup_level: crate::PrivacyLevel
|
||||
}
|
||||
} else {
|
||||
quote! {}
|
||||
};
|
||||
|
||||
let privacy_fielddefs = if has_privacy {
|
||||
quote! {
|
||||
"privacy": if matches!(lookup_level, crate::PrivacyLevel::Private) {
|
||||
Some(serde_json::json!({
|
||||
#privacyfielddefs
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
} else {
|
||||
quote! {}
|
||||
};
|
||||
|
||||
quote! {
|
||||
pub fn to_json(self #privdef) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
#fielddefs
|
||||
#privacy_fielddefs
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
12
crates/migrate/Cargo.toml
Normal file
12
crates/migrate/Cargo.toml
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
[package]
|
||||
name = "migrate"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
libpk = { path = "../libpk" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
55
crates/migrate/build.rs
Normal file
55
crates/migrate/build.rs
Normal file
|
|
@ -0,0 +1,55 @@
|
|||
use std::{
|
||||
env,
|
||||
error::Error,
|
||||
fs::{self, File},
|
||||
io::Write,
|
||||
path::Path,
|
||||
};
|
||||
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
let out_dir = env::var("OUT_DIR")?;
|
||||
let dest_path = Path::new(&out_dir).join("data.rs");
|
||||
let mut datafile = File::create(&dest_path)?;
|
||||
|
||||
let prefix = "../../../../../../crates/migrate/data";
|
||||
|
||||
let ct = fs::read_dir("data/migrations")?
|
||||
.filter(|p| {
|
||||
p.as_ref()
|
||||
.unwrap()
|
||||
.file_name()
|
||||
.into_string()
|
||||
.unwrap()
|
||||
.contains(".sql")
|
||||
})
|
||||
.count();
|
||||
|
||||
writeln!(&mut datafile, "const MIGRATIONS: [&'static str; {ct}] = [")?;
|
||||
for idx in 0..ct {
|
||||
writeln!(
|
||||
&mut datafile,
|
||||
"\tinclude_str!(\"{prefix}/migrations/{idx}.sql\"),"
|
||||
)?;
|
||||
}
|
||||
writeln!(&mut datafile, "];\n")?;
|
||||
|
||||
writeln!(
|
||||
&mut datafile,
|
||||
"const CLEAN: &'static str = include_str!(\"{prefix}/clean.sql\");"
|
||||
)?;
|
||||
writeln!(
|
||||
&mut datafile,
|
||||
"const VIEWS: &'static str = include_str!(\"{prefix}/views.sql\");"
|
||||
)?;
|
||||
writeln!(
|
||||
&mut datafile,
|
||||
"const FUNCTIONS: &'static str = include_str!(\"{prefix}/functions.sql\");"
|
||||
)?;
|
||||
|
||||
writeln!(
|
||||
&mut datafile,
|
||||
"const SEED: &'static str = include_str!(\"{prefix}/seed.sql\");"
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
16
crates/migrate/data/clean.sql
Normal file
16
crates/migrate/data/clean.sql
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
-- This gets run on every bot startup and makes sure we're starting from a clean slate
|
||||
-- Then, the views/functions.sql files get run, and they recreate the necessary objects
|
||||
-- This does mean we can't use any functions in row triggers, etc. Still unsure how to handle this.
|
||||
|
||||
drop view if exists system_last_switch;
|
||||
drop view if exists system_fronters;
|
||||
drop view if exists member_list;
|
||||
drop view if exists group_list;
|
||||
|
||||
drop function if exists message_context;
|
||||
drop function if exists proxy_members;
|
||||
drop function if exists has_private_members;
|
||||
drop function if exists generate_hid;
|
||||
drop function if exists find_free_system_hid;
|
||||
drop function if exists find_free_member_hid;
|
||||
drop function if exists find_free_group_hid;
|
||||
185
crates/migrate/data/functions.sql
Normal file
185
crates/migrate/data/functions.sql
Normal file
|
|
@ -0,0 +1,185 @@
|
|||
create function message_context(account_id bigint, guild_id bigint, channel_id bigint, thread_id bigint)
|
||||
returns table (
|
||||
allow_autoproxy bool,
|
||||
|
||||
system_id int,
|
||||
system_tag text,
|
||||
system_avatar text,
|
||||
|
||||
latch_timeout integer,
|
||||
case_sensitive_proxy_tags bool,
|
||||
proxy_error_message_enabled bool,
|
||||
proxy_switch int,
|
||||
name_format text,
|
||||
|
||||
tag_enabled bool,
|
||||
proxy_enabled bool,
|
||||
system_guild_tag text,
|
||||
system_guild_avatar text,
|
||||
guild_name_format text,
|
||||
|
||||
last_switch int,
|
||||
last_switch_members int[],
|
||||
last_switch_timestamp timestamp,
|
||||
|
||||
log_channel bigint,
|
||||
in_blacklist bool,
|
||||
in_log_blacklist bool,
|
||||
log_cleanup_enabled bool,
|
||||
require_system_tag bool,
|
||||
suppress_notifications bool,
|
||||
|
||||
deny_bot_usage bool
|
||||
)
|
||||
as $$
|
||||
select
|
||||
-- accounts table
|
||||
accounts.allow_autoproxy as allow_autoproxy,
|
||||
|
||||
-- systems table
|
||||
systems.id as system_id,
|
||||
systems.tag as system_tag,
|
||||
systems.avatar_url as system_avatar,
|
||||
|
||||
-- system_config table
|
||||
system_config.latch_timeout as latch_timeout,
|
||||
system_config.case_sensitive_proxy_tags as case_sensitive_proxy_tags,
|
||||
system_config.proxy_error_message_enabled as proxy_error_message_enabled,
|
||||
system_config.proxy_switch as proxy_switch,
|
||||
system_config.name_format as name_format,
|
||||
|
||||
-- system_guild table
|
||||
coalesce(system_guild.tag_enabled, true) as tag_enabled,
|
||||
coalesce(system_guild.proxy_enabled, true) as proxy_enabled,
|
||||
system_guild.tag as system_guild_tag,
|
||||
system_guild.avatar_url as system_guild_avatar,
|
||||
system_guild.name_format as guild_name_format,
|
||||
|
||||
-- system_last_switch view
|
||||
system_last_switch.switch as last_switch,
|
||||
system_last_switch.members as last_switch_members,
|
||||
system_last_switch.timestamp as last_switch_timestamp,
|
||||
|
||||
-- servers table
|
||||
servers.log_channel as log_channel,
|
||||
((channel_id = any (servers.blacklist))
|
||||
or (thread_id = any (servers.blacklist))) as in_blacklist,
|
||||
((channel_id = any (servers.log_blacklist))
|
||||
or (thread_id = any (servers.log_blacklist))) as in_log_blacklist,
|
||||
coalesce(servers.log_cleanup_enabled, false) as log_cleanup_enabled,
|
||||
coalesce(servers.require_system_tag, false) as require_system_tag,
|
||||
coalesce(servers.suppress_notifications, false) as suppress_notifications,
|
||||
|
||||
-- abuse_logs table
|
||||
coalesce(abuse_logs.deny_bot_usage, false) as deny_bot_usage
|
||||
|
||||
-- We need a "from" clause, so we just use some bogus data that's always present
|
||||
-- This ensure we always have exactly one row going forward, so we can left join afterwards and still get data
|
||||
from (select 1) as _placeholder
|
||||
left join accounts on accounts.uid = account_id
|
||||
left join servers on servers.id = guild_id
|
||||
left join systems on systems.id = accounts.system
|
||||
left join system_config on system_config.system = accounts.system
|
||||
left join system_guild on system_guild.system = accounts.system
|
||||
and system_guild.guild = guild_id
|
||||
left join system_last_switch on system_last_switch.system = accounts.system
|
||||
left join abuse_logs on abuse_logs.id = accounts.abuse_log
|
||||
$$ language sql stable rows 1;
|
||||
|
||||
-- Fetches info about proxying related to a given account/guild
|
||||
-- Returns one row per member in system, should be used in conjuction with `message_context` too
|
||||
create function proxy_members(account_id bigint, guild_id bigint)
|
||||
returns table (
|
||||
id int,
|
||||
proxy_tags proxy_tag[],
|
||||
keep_proxy bool,
|
||||
tts bool,
|
||||
server_keep_proxy bool,
|
||||
|
||||
server_name text,
|
||||
display_name text,
|
||||
name text,
|
||||
|
||||
server_avatar text,
|
||||
webhook_avatar text,
|
||||
avatar text,
|
||||
|
||||
color char(6),
|
||||
|
||||
allow_autoproxy bool
|
||||
)
|
||||
as $$
|
||||
select
|
||||
-- Basic data
|
||||
members.id as id,
|
||||
members.proxy_tags as proxy_tags,
|
||||
members.keep_proxy as keep_proxy,
|
||||
members.tts as tts,
|
||||
member_guild.keep_proxy as server_keep_proxy,
|
||||
|
||||
-- Name info
|
||||
member_guild.display_name as server_name,
|
||||
members.display_name as display_name,
|
||||
members.name as name,
|
||||
|
||||
-- Avatar info
|
||||
member_guild.avatar_url as server_avatar,
|
||||
members.webhook_avatar_url as webhook_avatar,
|
||||
members.avatar_url as avatar,
|
||||
|
||||
members.color as color,
|
||||
|
||||
members.allow_autoproxy as allow_autoproxy
|
||||
from accounts
|
||||
inner join systems on systems.id = accounts.system
|
||||
inner join members on members.system = systems.id
|
||||
left join member_guild on member_guild.member = members.id and member_guild.guild = guild_id
|
||||
where accounts.uid = account_id
|
||||
$$ language sql stable rows 10;
|
||||
|
||||
create function has_private_members(system_hid int) returns bool as $$
|
||||
declare m int;
|
||||
begin
|
||||
m := count(id) from members where system = system_hid and member_visibility = 2;
|
||||
if m > 0 then return true;
|
||||
else return false;
|
||||
end if;
|
||||
end
|
||||
$$ language plpgsql;
|
||||
|
||||
create function generate_hid() returns char(6) as $$
|
||||
select string_agg(substr('abcefghjknoprstuvwxyz', ceil(random() * 21)::integer, 1), '') from generate_series(1, 6)
|
||||
$$ language sql volatile;
|
||||
|
||||
|
||||
create function find_free_system_hid() returns char(6) as $$
|
||||
declare new_hid char(6);
|
||||
begin
|
||||
loop
|
||||
new_hid := generate_hid();
|
||||
if not exists (select 1 from systems where hid = new_hid) then return new_hid; end if;
|
||||
end loop;
|
||||
end
|
||||
$$ language plpgsql volatile;
|
||||
|
||||
|
||||
create function find_free_member_hid() returns char(6) as $$
|
||||
declare new_hid char(6);
|
||||
begin
|
||||
loop
|
||||
new_hid := generate_hid();
|
||||
if not exists (select 1 from members where hid = new_hid) then return new_hid; end if;
|
||||
end loop;
|
||||
end
|
||||
$$ language plpgsql volatile;
|
||||
|
||||
|
||||
create function find_free_group_hid() returns char(6) as $$
|
||||
declare new_hid char(6);
|
||||
begin
|
||||
loop
|
||||
new_hid := generate_hid();
|
||||
if not exists (select 1 from groups where hid = new_hid) then return new_hid; end if;
|
||||
end loop;
|
||||
end
|
||||
$$ language plpgsql volatile;
|
||||
112
crates/migrate/data/migrations/0.sql
Normal file
112
crates/migrate/data/migrations/0.sql
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
-- SCHEMA VERSION 0, 2019-12-26
|
||||
-- "initial version", considered a "starting point" for the migrations
|
||||
|
||||
-- also the assumed database layout of someone either migrating from an older version of PK or starting a new instance,
|
||||
-- so everything here *should* be idempotent given a schema version older than this or nonexistent.
|
||||
|
||||
-- Create proxy_tag compound type if it doesn't exist
|
||||
do $$ begin
|
||||
create type proxy_tag as (
|
||||
prefix text,
|
||||
suffix text
|
||||
);
|
||||
exception when duplicate_object then null;
|
||||
end $$;
|
||||
|
||||
create table if not exists systems
|
||||
(
|
||||
id serial primary key,
|
||||
hid char(5) unique not null,
|
||||
name text,
|
||||
description text,
|
||||
tag text,
|
||||
avatar_url text,
|
||||
token text,
|
||||
created timestamp not null default (current_timestamp at time zone 'utc'),
|
||||
ui_tz text not null default 'UTC'
|
||||
);
|
||||
|
||||
create table if not exists system_guild
|
||||
(
|
||||
system serial not null references systems (id) on delete cascade,
|
||||
guild bigint not null,
|
||||
|
||||
proxy_enabled bool not null default true,
|
||||
|
||||
primary key (system, guild)
|
||||
);
|
||||
|
||||
create table if not exists members
|
||||
(
|
||||
id serial primary key,
|
||||
hid char(5) unique not null,
|
||||
system serial not null references systems (id) on delete cascade,
|
||||
color char(6),
|
||||
avatar_url text,
|
||||
name text not null,
|
||||
display_name text,
|
||||
birthday date,
|
||||
pronouns text,
|
||||
description text,
|
||||
proxy_tags proxy_tag[] not null default array[]::proxy_tag[], -- Rationale on making this an array rather than a separate table - we never need to query them individually, only access them as part of a selected Member struct
|
||||
keep_proxy bool not null default false,
|
||||
created timestamp not null default (current_timestamp at time zone 'utc')
|
||||
);
|
||||
|
||||
create table if not exists member_guild
|
||||
(
|
||||
member serial not null references members (id) on delete cascade,
|
||||
guild bigint not null,
|
||||
|
||||
display_name text default null,
|
||||
|
||||
primary key (member, guild)
|
||||
);
|
||||
|
||||
create table if not exists accounts
|
||||
(
|
||||
uid bigint primary key,
|
||||
system serial not null references systems (id) on delete cascade
|
||||
);
|
||||
|
||||
create table if not exists messages
|
||||
(
|
||||
mid bigint primary key,
|
||||
channel bigint not null,
|
||||
member serial not null references members (id) on delete cascade,
|
||||
sender bigint not null,
|
||||
original_mid bigint
|
||||
);
|
||||
|
||||
create table if not exists switches
|
||||
(
|
||||
id serial primary key,
|
||||
system serial not null references systems (id) on delete cascade,
|
||||
timestamp timestamp not null default (current_timestamp at time zone 'utc')
|
||||
);
|
||||
|
||||
create table if not exists switch_members
|
||||
(
|
||||
id serial primary key,
|
||||
switch serial not null references switches (id) on delete cascade,
|
||||
member serial not null references members (id) on delete cascade
|
||||
);
|
||||
|
||||
create table if not exists webhooks
|
||||
(
|
||||
channel bigint primary key,
|
||||
webhook bigint not null,
|
||||
token text not null
|
||||
);
|
||||
|
||||
create table if not exists servers
|
||||
(
|
||||
id bigint primary key,
|
||||
log_channel bigint,
|
||||
log_blacklist bigint[] not null default array[]::bigint[],
|
||||
blacklist bigint[] not null default array[]::bigint[]
|
||||
);
|
||||
|
||||
create index if not exists idx_switches_system on switches using btree (system asc nulls last) include ("timestamp");
|
||||
create index if not exists idx_switch_members_switch on switch_members using btree (switch asc nulls last) include (member);
|
||||
create index if not exists idx_message_member on messages (member);
|
||||
15
crates/migrate/data/migrations/1.sql
Normal file
15
crates/migrate/data/migrations/1.sql
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
-- SCHEMA VERSION 1: 2019-12-26
|
||||
-- First version introducing the migration system, therefore we add the info/version table
|
||||
|
||||
create table info
|
||||
(
|
||||
id int primary key not null default 1, -- enforced only equal to 1
|
||||
|
||||
schema_version int,
|
||||
|
||||
constraint singleton check (id = 1) -- enforce singleton table/row
|
||||
);
|
||||
|
||||
-- We do an insert here since we *just* added the table
|
||||
-- Future migrations should do an update at the end
|
||||
insert into info (schema_version) values (1);
|
||||
11
crates/migrate/data/migrations/10.sql
Normal file
11
crates/migrate/data/migrations/10.sql
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
-- SCHEMA VERSION 10: 2020-10-09 --
|
||||
-- Member/group limit override per-system
|
||||
|
||||
alter table systems add column member_limit_override smallint default null;
|
||||
alter table systems add column group_limit_override smallint default null;
|
||||
|
||||
-- Lowering global limit to 1000 in this commit, so increase it for systems already above that
|
||||
update systems s set member_limit_override = 1500
|
||||
where (select count(*) from members m where m.system = s.id) > 1000;
|
||||
|
||||
update info set schema_version = 10;
|
||||
10
crates/migrate/data/migrations/11.sql
Normal file
10
crates/migrate/data/migrations/11.sql
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
-- SCHEMA VERSION 11: 2020-10-23 --
|
||||
-- Create command message table --
|
||||
|
||||
create table command_messages
|
||||
(
|
||||
message_id bigint primary key not null,
|
||||
author_id bigint not null
|
||||
);
|
||||
|
||||
update info set schema_version = 11;
|
||||
10
crates/migrate/data/migrations/12.sql
Normal file
10
crates/migrate/data/migrations/12.sql
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
-- SCHEMA VERSION 12: 2020-12-08 --
|
||||
-- Add disabling front/latch autoproxy per-member --
|
||||
-- Add disabling autoproxy per-account --
|
||||
-- Add configurable latch timeout --
|
||||
|
||||
alter table members add column allow_autoproxy bool not null default true;
|
||||
alter table accounts add column allow_autoproxy bool not null default true;
|
||||
alter table systems add column latch_timeout int; -- in seconds
|
||||
|
||||
update info set schema_version = 12;
|
||||
7
crates/migrate/data/migrations/13.sql
Normal file
7
crates/migrate/data/migrations/13.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- SCHEMA VERSION 13: 2021-03-28 --
|
||||
-- Add system and group colors --
|
||||
|
||||
alter table systems add column color char(6);
|
||||
alter table groups add column color char(6);
|
||||
|
||||
update info set schema_version = 13;
|
||||
15
crates/migrate/data/migrations/14.sql
Normal file
15
crates/migrate/data/migrations/14.sql
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
-- SCHEMA VERSION 14: 2021-06-10 --
|
||||
-- Add shard status table --
|
||||
|
||||
create table shards (
|
||||
id int not null primary key,
|
||||
|
||||
-- 0 = down, 1 = up
|
||||
status smallint not null default 0,
|
||||
|
||||
ping float,
|
||||
last_heartbeat timestamptz,
|
||||
last_connection timestamptz
|
||||
);
|
||||
|
||||
update info set schema_version = 14;
|
||||
8
crates/migrate/data/migrations/15.sql
Normal file
8
crates/migrate/data/migrations/15.sql
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
-- SCHEMA VERSION 15: 2021-08-01
|
||||
-- add banner (large) images to entities with "cards"
|
||||
|
||||
alter table systems add column banner_image text;
|
||||
alter table members add column banner_image text;
|
||||
alter table groups add column banner_image text;
|
||||
|
||||
update info set schema_version = 15;
|
||||
7
crates/migrate/data/migrations/16.sql
Normal file
7
crates/migrate/data/migrations/16.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- SCHEMA VERSION 16: 2021-08-02 --
|
||||
-- Add server-specific system tag --
|
||||
|
||||
alter table system_guild add column tag text default null;
|
||||
alter table system_guild add column tag_enabled bool not null default true;
|
||||
|
||||
update info set schema_version = 16;
|
||||
8
crates/migrate/data/migrations/17.sql
Normal file
8
crates/migrate/data/migrations/17.sql
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
-- schema version 17: 2021-09-26 --
|
||||
-- add channel_id to command message table
|
||||
|
||||
alter table command_messages add column channel_id bigint;
|
||||
update command_messages set channel_id = 0;
|
||||
alter table command_messages alter column channel_id set not null;
|
||||
|
||||
update info set schema_version = 17;
|
||||
18
crates/migrate/data/migrations/18.sql
Normal file
18
crates/migrate/data/migrations/18.sql
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
-- schema version 18: 2021-09-26 --
|
||||
-- Add UUIDs for APIs
|
||||
|
||||
create extension if not exists pgcrypto;
|
||||
|
||||
alter table systems add column uuid uuid default gen_random_uuid();
|
||||
create index systems_uuid_idx on systems(uuid);
|
||||
|
||||
alter table members add column uuid uuid default gen_random_uuid();
|
||||
create index members_uuid_idx on members(uuid);
|
||||
|
||||
alter table switches add column uuid uuid default gen_random_uuid();
|
||||
create index switches_uuid_idx on switches(uuid);
|
||||
|
||||
alter table groups add column uuid uuid default gen_random_uuid();
|
||||
create index groups_uuid_idx on groups(uuid);
|
||||
|
||||
update info set schema_version = 18;
|
||||
10
crates/migrate/data/migrations/19.sql
Normal file
10
crates/migrate/data/migrations/19.sql
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
-- schema version 19: 2021-10-15 --
|
||||
-- add stats to info table
|
||||
|
||||
alter table info add column system_count int;
|
||||
alter table info add column member_count int;
|
||||
alter table info add column group_count int;
|
||||
alter table info add column switch_count int;
|
||||
alter table info add column message_count int;
|
||||
|
||||
update info set schema_version = 19;
|
||||
13
crates/migrate/data/migrations/2.sql
Normal file
13
crates/migrate/data/migrations/2.sql
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
-- We're doing a psuedo-enum here since Dapper is wonky with enums
|
||||
-- Still getting mapped to enums at the CLR level, though.
|
||||
-- https://github.com/StackExchange/Dapper/issues/332 (from 2015, still unsolved!)
|
||||
-- 1 = "public"
|
||||
-- 2 = "private"
|
||||
-- not doing a bool here since I want to open up for the possibliity of other privacy levels (eg. "mutuals only")
|
||||
alter table systems add column description_privacy integer check (description_privacy in (1, 2)) not null default 1;
|
||||
alter table systems add column member_list_privacy integer check (member_list_privacy in (1, 2)) not null default 1;
|
||||
alter table systems add column front_privacy integer check (front_privacy in (1, 2)) not null default 1;
|
||||
alter table systems add column front_history_privacy integer check (front_history_privacy in (1, 2)) not null default 1;
|
||||
alter table members add column member_privacy integer check (member_privacy in (1, 2)) not null default 1;
|
||||
|
||||
update info set schema_version = 2;
|
||||
7
crates/migrate/data/migrations/20.sql
Normal file
7
crates/migrate/data/migrations/20.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- schema version 20: insert date
|
||||
-- add outgoing webhook to systems
|
||||
|
||||
alter table systems add column webhook_url text;
|
||||
alter table systems add column webhook_token text;
|
||||
|
||||
update info set schema_version = 20;
|
||||
29
crates/migrate/data/migrations/21.sql
Normal file
29
crates/migrate/data/migrations/21.sql
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
-- schema version 21
|
||||
-- create `system_config` table
|
||||
|
||||
create table system_config (
|
||||
system int primary key references systems(id) on delete cascade,
|
||||
ui_tz text not null default 'UTC',
|
||||
pings_enabled bool not null default true,
|
||||
latch_timeout int,
|
||||
member_limit_override int,
|
||||
group_limit_override int
|
||||
);
|
||||
|
||||
insert into system_config select
|
||||
id as system,
|
||||
ui_tz,
|
||||
pings_enabled,
|
||||
latch_timeout,
|
||||
member_limit_override,
|
||||
group_limit_override
|
||||
from systems;
|
||||
|
||||
alter table systems
|
||||
drop column ui_tz,
|
||||
drop column pings_enabled,
|
||||
drop column latch_timeout,
|
||||
drop column member_limit_override,
|
||||
drop column group_limit_override;
|
||||
|
||||
update info set schema_version = 21;
|
||||
7
crates/migrate/data/migrations/22.sql
Normal file
7
crates/migrate/data/migrations/22.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- schema version 22
|
||||
-- automatically set members/groups as private when creating
|
||||
|
||||
alter table system_config add column member_default_private bool not null default false;
|
||||
alter table system_config add column group_default_private bool not null default false;
|
||||
|
||||
update info set schema_version = 22;
|
||||
6
crates/migrate/data/migrations/23.sql
Normal file
6
crates/migrate/data/migrations/23.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- schema version 23
|
||||
-- show/hide private information when looked up by linked accounts
|
||||
|
||||
alter table system_config add column show_private_info bool default true;
|
||||
|
||||
update info set schema_version = 23;
|
||||
10
crates/migrate/data/migrations/24.sql
Normal file
10
crates/migrate/data/migrations/24.sql
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
-- schema version 24
|
||||
-- don't drop message rows when system/member are deleted
|
||||
|
||||
alter table messages alter column member drop not null;
|
||||
alter table messages drop constraint messages_member_fkey;
|
||||
alter table messages
|
||||
add constraint messages_member_fkey
|
||||
foreign key (member) references members(id) on delete set null;
|
||||
|
||||
update info set schema_version = 24;
|
||||
7
crates/migrate/data/migrations/25.sql
Normal file
7
crates/migrate/data/migrations/25.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- schema version 25
|
||||
-- group name privacy
|
||||
|
||||
alter table groups add column name_privacy integer check (name_privacy in (1, 2)) not null default 1;
|
||||
alter table groups add column metadata_privacy integer check (metadata_privacy in (1, 2)) not null default 1;
|
||||
|
||||
update info set schema_version = 25;
|
||||
12
crates/migrate/data/migrations/26.sql
Normal file
12
crates/migrate/data/migrations/26.sql
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
-- schema version 26
|
||||
-- cache Discord DM channels in the database
|
||||
|
||||
alter table accounts alter column system drop not null;
|
||||
alter table accounts drop constraint accounts_system_fkey;
|
||||
alter table accounts
|
||||
add constraint accounts_system_fkey
|
||||
foreign key (system) references systems(id) on delete set null;
|
||||
|
||||
alter table accounts add column dm_channel bigint;
|
||||
|
||||
update info set schema_version = 26;
|
||||
36
crates/migrate/data/migrations/27.sql
Normal file
36
crates/migrate/data/migrations/27.sql
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
-- schema version 27
|
||||
-- autoproxy locations
|
||||
|
||||
-- mode pseudo-enum: (copied from 3.sql)
|
||||
-- 1 = autoproxy off
|
||||
-- 2 = front mode (first fronter)
|
||||
-- 3 = latch mode (last proxyer)
|
||||
-- 4 = member mode (specific member)
|
||||
|
||||
create table autoproxy (
|
||||
system int references systems(id) on delete cascade,
|
||||
channel_id bigint,
|
||||
guild_id bigint,
|
||||
autoproxy_mode int check (autoproxy_mode in (1, 2, 3, 4)) not null default 1,
|
||||
autoproxy_member int references members(id) on delete set null,
|
||||
last_latch_timestamp timestamp,
|
||||
check (
|
||||
(channel_id = 0 and guild_id = 0)
|
||||
or (channel_id != 0 and guild_id = 0)
|
||||
or (channel_id = 0 and guild_id != 0)
|
||||
),
|
||||
primary key (system, channel_id, guild_id)
|
||||
);
|
||||
|
||||
insert into autoproxy select
|
||||
system,
|
||||
0 as channel_id,
|
||||
guild as guild_id,
|
||||
autoproxy_mode,
|
||||
autoproxy_member
|
||||
from system_guild;
|
||||
|
||||
alter table system_guild drop column autoproxy_mode;
|
||||
alter table system_guild drop column autoproxy_member;
|
||||
|
||||
update info set schema_version = 27;
|
||||
7
crates/migrate/data/migrations/28.sql
Normal file
7
crates/migrate/data/migrations/28.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- schema version 28
|
||||
-- system pronouns
|
||||
|
||||
alter table systems add column pronouns text;
|
||||
alter table systems add column pronoun_privacy integer check (pronoun_privacy in (1, 2)) not null default 1;
|
||||
|
||||
update info set schema_version = 28;
|
||||
5
crates/migrate/data/migrations/29.sql
Normal file
5
crates/migrate/data/migrations/29.sql
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
-- schema version 29
|
||||
|
||||
alter table systems add column is_deleting bool default false;
|
||||
|
||||
update info set schema_version = 29;
|
||||
15
crates/migrate/data/migrations/3.sql
Normal file
15
crates/migrate/data/migrations/3.sql
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
-- Same sort of psuedo-enum due to Dapper limitations. See 2.sql.
|
||||
-- 1 = autoproxy off
|
||||
-- 2 = front mode (first fronter)
|
||||
-- 3 = latch mode (last proxyer)
|
||||
-- 4 = member mode (specific member)
|
||||
alter table system_guild add column autoproxy_mode int check (autoproxy_mode in (1, 2, 3, 4)) not null default 1;
|
||||
|
||||
-- for member mode
|
||||
alter table system_guild add column autoproxy_member int references members (id) on delete set null;
|
||||
|
||||
-- for latch mode
|
||||
-- not *really* nullable, null just means old (pre-schema-change) data.
|
||||
alter table messages add column guild bigint default null;
|
||||
|
||||
update info set schema_version = 3;
|
||||
5
crates/migrate/data/migrations/30.sql
Normal file
5
crates/migrate/data/migrations/30.sql
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
-- schema version 30
|
||||
|
||||
alter table system_config add column description_templates text[] not null default array[]::text[];
|
||||
|
||||
update info set schema_version = 30;
|
||||
5
crates/migrate/data/migrations/31.sql
Normal file
5
crates/migrate/data/migrations/31.sql
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
-- schema version 31
|
||||
|
||||
alter table system_config add column case_sensitive_proxy_tags boolean not null default true;
|
||||
|
||||
update info set schema_version = 31;
|
||||
6
crates/migrate/data/migrations/32.sql
Normal file
6
crates/migrate/data/migrations/32.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 32
|
||||
-- re-add last message timestamp to members
|
||||
|
||||
alter table members add column last_message_timestamp timestamp;
|
||||
|
||||
update info set schema_version = 32;
|
||||
6
crates/migrate/data/migrations/33.sql
Normal file
6
crates/migrate/data/migrations/33.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 33
|
||||
-- add webhook_avatar_url to system members
|
||||
|
||||
alter table members add column webhook_avatar_url text;
|
||||
|
||||
update info set schema_version = 33;
|
||||
6
crates/migrate/data/migrations/34.sql
Normal file
6
crates/migrate/data/migrations/34.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 34
|
||||
-- add proxy_error_message_enabled to system config
|
||||
|
||||
alter table system_config add column proxy_error_message_enabled bool default true;
|
||||
|
||||
update info set schema_version = 34;
|
||||
7
crates/migrate/data/migrations/35.sql
Normal file
7
crates/migrate/data/migrations/35.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- database version 35
|
||||
-- add guild avatar and guild name to system guild settings
|
||||
|
||||
alter table system_guild add column avatar_url text;
|
||||
alter table system_guild add column display_name text;
|
||||
|
||||
update info set schema_version = 35;
|
||||
7
crates/migrate/data/migrations/36.sql
Normal file
7
crates/migrate/data/migrations/36.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- database version 36
|
||||
-- add system avatar privacy and system name privacy
|
||||
|
||||
alter table systems add column name_privacy integer not null default 1;
|
||||
alter table systems add column avatar_privacy integer not null default 1;
|
||||
|
||||
update info set schema_version = 36;
|
||||
7
crates/migrate/data/migrations/37.sql
Normal file
7
crates/migrate/data/migrations/37.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- database version 37
|
||||
-- add proxy tag privacy
|
||||
|
||||
alter table members add column proxy_privacy integer not null default 1;
|
||||
alter table members add constraint members_proxy_privacy_check check (proxy_privacy = ANY (ARRAY[1,2]));
|
||||
|
||||
update info set schema_version = 37;
|
||||
6
crates/migrate/data/migrations/38.sql
Normal file
6
crates/migrate/data/migrations/38.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 38
|
||||
-- add proxy tag privacy
|
||||
|
||||
alter table members add column tts boolean not null default false;
|
||||
|
||||
update info set schema_version = 38;
|
||||
7
crates/migrate/data/migrations/39.sql
Normal file
7
crates/migrate/data/migrations/39.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- database version 39
|
||||
-- add missing privacy constraints
|
||||
|
||||
alter table systems add constraint systems_name_privacy_check check (name_privacy = ANY (ARRAY[1,2]));
|
||||
alter table systems add constraint systems_avatar_privacy_check check (avatar_privacy = ANY (ARRAY[1,2]));
|
||||
|
||||
update info set schema_version = 39;
|
||||
3
crates/migrate/data/migrations/4.sql
Normal file
3
crates/migrate/data/migrations/4.sql
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
-- SCHEMA VERSION 4: 2020-02-12
|
||||
alter table member_guild add column avatar_url text;
|
||||
update info set schema_version = 4;
|
||||
6
crates/migrate/data/migrations/40.sql
Normal file
6
crates/migrate/data/migrations/40.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 40
|
||||
-- add per-server keepproxy toggle
|
||||
|
||||
alter table member_guild add column keep_proxy bool default null;
|
||||
|
||||
update info set schema_version = 40;
|
||||
10
crates/migrate/data/migrations/41.sql
Normal file
10
crates/migrate/data/migrations/41.sql
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
-- database version 41
|
||||
-- fix statistics counts
|
||||
|
||||
alter table info alter column system_count type bigint using system_count::bigint;
|
||||
alter table info alter column member_count type bigint using member_count::bigint;
|
||||
alter table info alter column group_count type bigint using group_count::bigint;
|
||||
alter table info alter column switch_count type bigint using switch_count::bigint;
|
||||
alter table info alter column message_count type bigint using message_count::bigint;
|
||||
|
||||
update info set schema_version = 41;
|
||||
11
crates/migrate/data/migrations/42.sql
Normal file
11
crates/migrate/data/migrations/42.sql
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
-- database version 42
|
||||
-- move to 6 character HIDs, add HID display config setting
|
||||
|
||||
alter table systems alter column hid type char(6) using rpad(hid, 6, ' ');
|
||||
alter table members alter column hid type char(6) using rpad(hid, 6, ' ');
|
||||
alter table groups alter column hid type char(6) using rpad(hid, 6, ' ');
|
||||
|
||||
alter table system_config add column hid_display_split bool default false;
|
||||
alter table system_config add column hid_display_caps bool default false;
|
||||
|
||||
update info set schema_version = 42;
|
||||
6
crates/migrate/data/migrations/43.sql
Normal file
6
crates/migrate/data/migrations/43.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 43
|
||||
-- add config setting for padding 5-character IDs in lists
|
||||
|
||||
alter table system_config add column hid_list_padding int not null default 0;
|
||||
|
||||
update info set schema_version = 43;
|
||||
23
crates/migrate/data/migrations/44.sql
Normal file
23
crates/migrate/data/migrations/44.sql
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
-- database version 44
|
||||
-- add abuse handling measures
|
||||
|
||||
create table abuse_logs (
|
||||
id serial primary key,
|
||||
uuid uuid default gen_random_uuid(),
|
||||
description text,
|
||||
deny_bot_usage bool not null default false,
|
||||
created timestamp not null default (current_timestamp at time zone 'utc')
|
||||
);
|
||||
|
||||
alter table accounts add column abuse_log integer default null references abuse_logs (id) on delete set null;
|
||||
create index abuse_logs_uuid_idx on abuse_logs (uuid);
|
||||
|
||||
-- we now need to handle a row in "accounts" table being created with no
|
||||
-- system (rather than just system being set to null after insert)
|
||||
--
|
||||
-- set default null and drop the sequence (from the column being created
|
||||
-- as type SERIAL)
|
||||
alter table accounts alter column system set default null;
|
||||
drop sequence accounts_system_seq;
|
||||
|
||||
update info set schema_version = 44;
|
||||
6
crates/migrate/data/migrations/45.sql
Normal file
6
crates/migrate/data/migrations/45.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 45
|
||||
-- add new config setting "proxy_switch"
|
||||
|
||||
alter table system_config add column proxy_switch bool default false;
|
||||
|
||||
update info set schema_version = 45;
|
||||
12
crates/migrate/data/migrations/46.sql
Normal file
12
crates/migrate/data/migrations/46.sql
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
-- database version 46
|
||||
-- adds banner privacy
|
||||
|
||||
alter table members add column banner_privacy int not null default 1 check (banner_privacy = ANY (ARRAY[1,2]));
|
||||
alter table groups add column banner_privacy int not null default 1 check (banner_privacy = ANY (ARRAY[1,2]));
|
||||
alter table systems add column banner_privacy int not null default 1 check (banner_privacy = ANY (ARRAY[1,2]));
|
||||
|
||||
update members set banner_privacy = 2 where description_privacy = 2;
|
||||
update groups set banner_privacy = 2 where description_privacy = 2;
|
||||
update systems set banner_privacy = 2 where description_privacy = 2;
|
||||
|
||||
update info set schema_version = 46;
|
||||
6
crates/migrate/data/migrations/47.sql
Normal file
6
crates/migrate/data/migrations/47.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 47
|
||||
-- add config setting for supplying a custom tag format in names
|
||||
|
||||
alter table system_config add column name_format text;
|
||||
|
||||
update info set schema_version = 47;
|
||||
9
crates/migrate/data/migrations/48.sql
Normal file
9
crates/migrate/data/migrations/48.sql
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
-- database version 48
|
||||
--
|
||||
-- add guild settings for disabling "invalid command" responses &
|
||||
-- enforcing the presence of system tags
|
||||
|
||||
alter table servers add column invalid_command_response_enabled bool not null default true;
|
||||
alter table servers add column require_system_tag bool not null default false;
|
||||
|
||||
update info set schema_version = 48;
|
||||
6
crates/migrate/data/migrations/49.sql
Normal file
6
crates/migrate/data/migrations/49.sql
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
-- database version 49
|
||||
-- add guild name format
|
||||
|
||||
alter table system_guild add column name_format text;
|
||||
|
||||
update info set schema_version = 49;
|
||||
3
crates/migrate/data/migrations/5.sql
Normal file
3
crates/migrate/data/migrations/5.sql
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
-- SCHEMA VERSION 5: 2020-02-14
|
||||
alter table servers add column log_cleanup_enabled bool not null default false;
|
||||
update info set schema_version = 5;
|
||||
11
crates/migrate/data/migrations/50.sql
Normal file
11
crates/migrate/data/migrations/50.sql
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
-- database version 50
|
||||
-- change proxy switch config to an enum
|
||||
|
||||
alter table system_config
|
||||
alter column proxy_switch drop default,
|
||||
alter column proxy_switch type int
|
||||
using case when proxy_switch then 1 else 0 end,
|
||||
alter column proxy_switch set default 0,
|
||||
add constraint proxy_switch_check check (proxy_switch = ANY (ARRAY[0,1,2]));
|
||||
|
||||
update info set schema_version = 50;
|
||||
7
crates/migrate/data/migrations/51.sql
Normal file
7
crates/migrate/data/migrations/51.sql
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
-- database version 51
|
||||
--
|
||||
-- add guild setting for SUPPRESS_NOTIFICATIONS message flag on proxied messages
|
||||
|
||||
alter table servers add column suppress_notifications bool not null default false;
|
||||
|
||||
update info set schema_version = 51;
|
||||
21
crates/migrate/data/migrations/52.sql
Normal file
21
crates/migrate/data/migrations/52.sql
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
-- database version 52
|
||||
-- messages db updates
|
||||
|
||||
create index messages_by_original on messages(original_mid);
|
||||
create index messages_by_sender on messages(sender);
|
||||
|
||||
-- remove old table from database version 11
|
||||
alter table command_messages rename to command_messages_old;
|
||||
|
||||
create table command_messages (
|
||||
mid bigint primary key,
|
||||
channel bigint not null,
|
||||
guild bigint not null,
|
||||
sender bigint not null,
|
||||
original_mid bigint not null
|
||||
);
|
||||
|
||||
create index command_messages_by_original on command_messages(original_mid);
|
||||
create index command_messages_by_sender on command_messages(sender);
|
||||
|
||||
update info set schema_version = 52;
|
||||
3
crates/migrate/data/migrations/6.sql
Normal file
3
crates/migrate/data/migrations/6.sql
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
-- SCHEMA VERSION 6: 2020-03-21
|
||||
alter table systems add column pings_enabled bool not null default true;
|
||||
update info set schema_version = 6;
|
||||
35
crates/migrate/data/migrations/7.sql
Normal file
35
crates/migrate/data/migrations/7.sql
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
-- SCHEMA VERSION 7: 2020-06-12
|
||||
-- (in-db message count row)
|
||||
|
||||
-- Add message count row to members table, initialize it with the correct data
|
||||
alter table members add column message_count int not null default 0;
|
||||
|
||||
update members set message_count = counts.count
|
||||
from (select member, count(*) as count from messages group by messages.member) as counts
|
||||
where counts.member = members.id;
|
||||
|
||||
-- Create a trigger function to increment the message count on inserting to the messages table
|
||||
create function trg_msgcount_increment() returns trigger as $$
|
||||
begin
|
||||
update members set message_count = message_count + 1 where id = NEW.member;
|
||||
return NEW;
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger increment_member_message_count before insert on messages for each row execute procedure trg_msgcount_increment();
|
||||
|
||||
|
||||
-- Create a trigger function to decrement the message count on deleting from the messages table
|
||||
create function trg_msgcount_decrement() returns trigger as $$
|
||||
begin
|
||||
-- Don't decrement if count <= zero (shouldn't happen, but we don't want negative message counts)
|
||||
update members set message_count = message_count - 1 where id = OLD.member and message_count > 0;
|
||||
return OLD;
|
||||
end;
|
||||
$$ language plpgsql;
|
||||
|
||||
create trigger decrement_member_message_count before delete on messages for each row execute procedure trg_msgcount_decrement();
|
||||
|
||||
|
||||
-- (update schema ver)
|
||||
update info set schema_version = 7;
|
||||
24
crates/migrate/data/migrations/8.sql
Normal file
24
crates/migrate/data/migrations/8.sql
Normal file
|
|
@ -0,0 +1,24 @@
|
|||
-- SCHEMA VERSION 8: 2020-05-13 --
|
||||
-- Create new columns --
|
||||
alter table members add column description_privacy integer check (description_privacy in (1, 2)) not null default 1;
|
||||
alter table members add column name_privacy integer check (name_privacy in (1, 2)) not null default 1;
|
||||
alter table members add column avatar_privacy integer check (avatar_privacy in (1, 2)) not null default 1;
|
||||
alter table members add column birthday_privacy integer check (birthday_privacy in (1, 2)) not null default 1;
|
||||
alter table members add column pronoun_privacy integer check (pronoun_privacy in (1, 2)) not null default 1;
|
||||
alter table members add column metadata_privacy integer check (metadata_privacy in (1, 2)) not null default 1;
|
||||
-- alter table members add column color_privacy integer check (color_privacy in (1, 2)) not null default 1;
|
||||
|
||||
-- Transfer existing settings --
|
||||
update members set description_privacy = member_privacy;
|
||||
update members set name_privacy = member_privacy;
|
||||
update members set avatar_privacy = member_privacy;
|
||||
update members set birthday_privacy = member_privacy;
|
||||
update members set pronoun_privacy = member_privacy;
|
||||
update members set metadata_privacy = member_privacy;
|
||||
-- update members set color_privacy = member_privacy;
|
||||
|
||||
-- Rename member_privacy to member_visibility --
|
||||
alter table members rename column member_privacy to member_visibility;
|
||||
|
||||
-- Update Schema Info --
|
||||
update info set schema_version = 8;
|
||||
32
crates/migrate/data/migrations/9.sql
Normal file
32
crates/migrate/data/migrations/9.sql
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
-- SCHEMA VERSION 9: 2020-08-25 --
|
||||
-- Adds support for member groups.
|
||||
|
||||
create table groups (
|
||||
id int primary key generated always as identity,
|
||||
hid char(5) unique not null,
|
||||
system int not null references systems(id) on delete cascade,
|
||||
|
||||
name text not null,
|
||||
display_name text,
|
||||
description text,
|
||||
icon text,
|
||||
|
||||
-- Description columns follow the same pattern as usual: 1 = public, 2 = private
|
||||
description_privacy integer check (description_privacy in (1, 2)) not null default 1,
|
||||
icon_privacy integer check (icon_privacy in (1, 2)) not null default 1,
|
||||
list_privacy integer check (list_privacy in (1, 2)) not null default 1,
|
||||
visibility integer check (visibility in (1, 2)) not null default 1,
|
||||
|
||||
created timestamp with time zone not null default (current_timestamp at time zone 'utc')
|
||||
);
|
||||
|
||||
create table group_members (
|
||||
group_id int not null references groups(id) on delete cascade,
|
||||
member_id int not null references members(id) on delete cascade,
|
||||
primary key (group_id, member_id)
|
||||
);
|
||||
|
||||
alter table systems add column group_list_privacy integer check (group_list_privacy in (1, 2)) not null default 1;
|
||||
update systems set group_list_privacy = member_list_privacy;
|
||||
|
||||
update info set schema_version = 9;
|
||||
8
crates/migrate/data/seed.sql
Normal file
8
crates/migrate/data/seed.sql
Normal file
|
|
@ -0,0 +1,8 @@
|
|||
-- example data (for integration tests or such)
|
||||
|
||||
insert into systems (hid, token) values (
|
||||
'exmpl',
|
||||
'vlPitT0tEgT++a450w1/afODy5NXdALcHDwryX6dOIZdGUGbZg+5IH3nrUsQihsw'
|
||||
);
|
||||
insert into system_config (system) values (1);
|
||||
insert into system_guild (system, guild) values (1, 466707357099884544);
|
||||
75
crates/migrate/data/views.sql
Normal file
75
crates/migrate/data/views.sql
Normal file
|
|
@ -0,0 +1,75 @@
|
|||
-- Returns one row per system, containing info about latest switch + array of member IDs (for future joins)
|
||||
create view system_last_switch as
|
||||
select systems.id as system,
|
||||
last_switch.id as switch,
|
||||
last_switch.timestamp as timestamp,
|
||||
array(select member from switch_members where switch_members.switch = last_switch.id order by switch_members.id) as members
|
||||
from systems
|
||||
inner join lateral (select * from switches where switches.system = systems.id order by timestamp desc limit 1) as last_switch on true;
|
||||
|
||||
create view member_list as
|
||||
select members.*,
|
||||
-- Find last switch timestamp
|
||||
(
|
||||
select max(switches.timestamp)
|
||||
from switch_members
|
||||
inner join switches on switches.id = switch_members.switch
|
||||
where switch_members.member = members.id
|
||||
) as last_switch_time,
|
||||
|
||||
-- Extract month/day from birthday and "force" the year identical (just using 4) -> month/day only sorting!
|
||||
case when members.birthday is not null then
|
||||
make_date(
|
||||
4,
|
||||
extract(month from members.birthday)::integer,
|
||||
extract(day from members.birthday)::integer
|
||||
) end as birthday_md,
|
||||
|
||||
-- Extract member description as seen by "the public"
|
||||
case
|
||||
-- Privacy '1' = public; just return description as normal
|
||||
when members.description_privacy = 1 then members.description
|
||||
-- Any other privacy (rn just '2'), return null description (missing case = null in SQL)
|
||||
end as public_description,
|
||||
|
||||
-- Extract member name as seen by "the public"
|
||||
case
|
||||
-- Privacy '1' = public; just return name as normal
|
||||
when members.name_privacy = 1 then members.name
|
||||
-- Any other privacy (rn just '2'), return display name
|
||||
else coalesce(members.display_name, members.name)
|
||||
end as public_name
|
||||
from members;
|
||||
|
||||
create view group_list as
|
||||
select groups.*,
|
||||
-- Find public group member count
|
||||
(
|
||||
select count(*) from group_members
|
||||
inner join members on group_members.member_id = members.id
|
||||
where
|
||||
group_members.group_id = groups.id and members.member_visibility = 1
|
||||
) as public_member_count,
|
||||
-- Find private group member count
|
||||
(
|
||||
select count(*) from group_members
|
||||
inner join members on group_members.member_id = members.id
|
||||
where
|
||||
group_members.group_id = groups.id
|
||||
) as total_member_count,
|
||||
|
||||
-- Extract group description as seen by "the public"
|
||||
case
|
||||
-- Privacy '1' = public; just return description as normal
|
||||
when groups.description_privacy = 1 then groups.description
|
||||
-- Any other privacy (rn just '2'), return null description (missing case = null in SQL)
|
||||
end as public_description,
|
||||
|
||||
-- Extract member name as seen by "the public"
|
||||
case
|
||||
-- Privacy '1' = public; just return name as normal
|
||||
when groups.name_privacy = 1 then groups.name
|
||||
-- Any other privacy (rn just '2'), return display name
|
||||
else coalesce(groups.display_name, groups.name)
|
||||
end as public_name
|
||||
from groups;
|
||||
70
crates/migrate/src/main.rs
Normal file
70
crates/migrate/src/main.rs
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use tracing::info;
|
||||
|
||||
include!(concat!(env!("OUT_DIR"), "/data.rs"));
|
||||
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_data_db().await?;
|
||||
|
||||
// clean
|
||||
// get current migration
|
||||
// migrate to latest
|
||||
// run views
|
||||
// run functions
|
||||
|
||||
#[derive(sqlx::FromRow)]
|
||||
struct CurrentMigration {
|
||||
schema_version: i32,
|
||||
}
|
||||
|
||||
let info = match sqlx::query_as("select schema_version from info")
|
||||
.fetch_optional(&db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(result)) => result,
|
||||
Ok(None) => CurrentMigration { schema_version: -1 },
|
||||
Err(e) if format!("{e}").contains("relation \"info\" does not exist") => {
|
||||
CurrentMigration { schema_version: -1 }
|
||||
}
|
||||
Err(e) => return Err(e.into()),
|
||||
};
|
||||
|
||||
info!("current migration: {}", info.schema_version);
|
||||
|
||||
info!("running clean.sql");
|
||||
sqlx::raw_sql(fix_feff(CLEAN)).execute(&db).await?;
|
||||
|
||||
for idx in (info.schema_version + 1) as usize..MIGRATIONS.len() {
|
||||
info!("running migration {idx}");
|
||||
sqlx::raw_sql(fix_feff(MIGRATIONS[idx as usize]))
|
||||
.execute(&db)
|
||||
.await?;
|
||||
}
|
||||
|
||||
info!("running views.sql");
|
||||
sqlx::raw_sql(fix_feff(VIEWS)).execute(&db).await?;
|
||||
|
||||
info!("running functions.sql");
|
||||
sqlx::raw_sql(fix_feff(FUNCTIONS)).execute(&db).await?;
|
||||
|
||||
if let Ok(var) = std::env::var("SEED")
|
||||
&& var == "true"
|
||||
{
|
||||
info!("running seed.sql");
|
||||
sqlx::raw_sql(fix_feff(SEED)).execute(&db).await?;
|
||||
info!(
|
||||
"example system created with hid 'exmpl', token 'vlPitT0tEgT++a450w1/afODy5NXdALcHDwryX6dOIZdGUGbZg+5IH3nrUsQihsw', guild_id 466707357099884544"
|
||||
);
|
||||
}
|
||||
|
||||
info!("all done!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// some migration scripts have \u{feff} at the start
|
||||
fn fix_feff(sql: &str) -> &str {
|
||||
sql.trim_start_matches("\u{feff}")
|
||||
}
|
||||
|
|
@ -5,7 +5,7 @@ edition = "2021"
|
|||
|
||||
[dependencies]
|
||||
chrono = { workspace = true, features = ["serde"] }
|
||||
model_macros = { path = "../model_macros" }
|
||||
pk_macros = { path = "../macros" }
|
||||
sea-query = "0.32.1"
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true, features = ["preserve_order"] }
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue