Merge remote-tracking branch 'upstream/main' into rust-command-parser

This commit is contained in:
dusk 2025-09-26 15:16:54 +00:00
commit b353dcbda2
No known key found for this signature in database
94 changed files with 2575 additions and 738 deletions

View file

@ -5,6 +5,7 @@ edition = "2024"
[dependencies]
pluralkit_models = { path = "../models" }
pk_macros = { path = "../macros" }
libpk = { path = "../libpk" }
anyhow = { workspace = true }
@ -26,3 +27,4 @@ reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
serde_urlencoded = "0.7.1"
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["catch-panic"] }
subtle = "2.6.1"

View file

@ -7,11 +7,16 @@ pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
pub struct AuthState {
system_id: Option<i32>,
app_id: Option<i32>,
internal: bool,
}
impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
Self { system_id, app_id }
pub fn new(system_id: Option<i32>, app_id: Option<i32>, internal: bool) -> Self {
Self {
system_id,
app_id,
internal,
}
}
pub fn system_id(&self) -> Option<i32> {
@ -22,6 +27,10 @@ impl AuthState {
self.app_id
}
pub fn internal(&self) -> bool {
self.internal
}
pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel {
if self
.system_id

View file

@ -2,10 +2,12 @@ use crate::ApiContext;
use axum::{extract::State, response::Json};
use fred::interfaces::*;
use libpk::state::ShardState;
use pk_macros::api_endpoint;
use serde::Deserialize;
use serde_json::{json, Value};
use serde_json::{Value, json};
use std::collections::HashMap;
#[allow(dead_code)]
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {
@ -13,34 +15,33 @@ struct ClusterStats {
pub channel_count: i32,
}
#[api_endpoint]
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
let mut shard_status = ctx
.redis
.hgetall::<HashMap<String, String>, &str>("pluralkit:shardstatus")
.await
.unwrap()
.await?
.values()
.map(|v| serde_json::from_str(v).expect("could not deserialize shard"))
.collect::<Vec<ShardState>>();
shard_status.sort_by(|a, b| b.shard_id.cmp(&a.shard_id));
Json(json!({
Ok(Json(json!({
"shards": shard_status,
}))
})))
}
#[api_endpoint]
pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
let stats = serde_json::from_str::<Value>(
ctx.redis
.get::<String, &'static str>("statsapi")
.await
.unwrap()
.await?
.as_str(),
)
.unwrap();
)?;
Json(stats)
Ok(Json(stats))
}
use std::time::Duration;

View file

@ -1,22 +1,18 @@
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
Extension, Json,
};
use serde_json::json;
use axum::{Extension, Json, extract::State, response::IntoResponse};
use pk_macros::api_endpoint;
use serde_json::{Value, json};
use sqlx::Postgres;
use tracing::error;
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use crate::{auth::AuthState, util::json_err, ApiContext};
use crate::{ApiContext, auth::AuthState, error::fail};
#[api_endpoint]
pub async fn get_system_settings(
Extension(auth): Extension<AuthState>,
Extension(system): Extension<PKSystem>,
State(ctx): State<ApiContext>,
) -> Response {
) -> Json<Value> {
let access_level = auth.access_level_for(&system);
let mut config = match sqlx::query_as::<Postgres, PKSystemConfig>(
@ -27,23 +23,11 @@ pub async fn get_system_settings(
.await
{
Ok(Some(config)) => config,
Ok(None) => {
error!(
system = system.id,
"failed to find system config for existing system"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
Err(err) => {
error!(?err, "failed to query system config");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
Ok(None) => fail!(
system = system.id,
"failed to find system config for existing system"
),
Err(err) => fail!(?err, "failed to query system config"),
};
// fix this
@ -51,7 +35,7 @@ pub async fn get_system_settings(
config.name_format = Some("{name} {tag}".to_string());
}
Json(&match access_level {
Ok(Json(match access_level {
PrivacyLevel::Private => config.to_json(),
PrivacyLevel::Public => json!({
"pings_enabled": config.pings_enabled,
@ -64,6 +48,5 @@ pub async fn get_system_settings(
"proxy_switch": config.proxy_switch,
"name_format": config.name_format,
}),
})
.into_response()
}))
}

View file

@ -1,13 +1,17 @@
use axum::http::StatusCode;
use axum::{
http::StatusCode,
response::{IntoResponse, Response},
};
use std::fmt;
// todo
#[allow(dead_code)]
// todo: model parse errors
#[derive(Debug)]
pub struct PKError {
pub response_code: StatusCode,
pub json_code: i32,
pub message: &'static str,
pub inner: Option<anyhow::Error>,
}
impl fmt::Display for PKError {
@ -16,17 +20,67 @@ impl fmt::Display for PKError {
}
}
impl std::error::Error for PKError {}
impl Clone for PKError {
fn clone(&self) -> PKError {
if self.inner.is_some() {
panic!("cannot clone PKError with inner error");
}
PKError {
response_code: self.response_code,
json_code: self.json_code,
message: self.message,
inner: None,
}
}
}
impl<E> From<E> for PKError
where
E: std::fmt::Display + Into<anyhow::Error>,
{
fn from(err: E) -> Self {
let mut res = GENERIC_SERVER_ERROR.clone();
res.inner = Some(err.into());
res
}
}
impl IntoResponse for PKError {
fn into_response(self) -> Response {
if let Some(inner) = self.inner {
tracing::error!(?inner, "error returned from handler");
}
crate::util::json_err(
self.response_code,
serde_json::to_string(&serde_json::json!({
"message": self.message,
"code": self.json_code,
}))
.unwrap(),
)
}
}
macro_rules! fail {
($($stuff:tt)+) => {{
tracing::error!($($stuff)+);
return Err(crate::error::GENERIC_SERVER_ERROR);
}};
}
pub(crate) use fail;
#[allow(unused_macros)]
macro_rules! define_error {
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
const $name: PKError = PKError {
#[allow(dead_code)]
pub const $name: PKError = PKError {
response_code: $response_code,
json_code: $json_code,
message: $message,
inner: None,
};
};
}
// define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }

View file

@ -1,17 +1,19 @@
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use axum::{
Extension, Router,
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
http::Uri,
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
Extension, Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
client::legacy::{Client, connect::HttpConnector},
rt::TokioExecutor,
};
use tracing::{error, info};
use tracing::info;
use pk_macros::api_endpoint;
mod auth;
mod endpoints;
@ -28,11 +30,12 @@ pub struct ApiContext {
rproxy_client: Client<HttpConnector, Body>,
}
#[api_endpoint]
async fn rproxy(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
) -> Response {
let path = req.uri().path();
let path_query = req
.uri()
@ -57,15 +60,7 @@ async fn rproxy(
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(ctx
.rproxy_client
.request(req)
.await
.map_err(|error| {
error!(?error, "failed to serve reverse proxy to dotnet-api");
StatusCode::BAD_GATEWAY
})?
.into_response())
Ok(ctx.rproxy_client.request(req).await?.into_response())
}
// this function is manually formatted for easier legibility of route_services

View file

@ -5,10 +5,12 @@ use axum::{
response::Response,
};
use subtle::ConstantTimeEq;
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
use crate::{ApiContext, util::json_err};
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let mut authed_system_id: Option<i32> = None;
@ -48,15 +50,31 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
.expect("missing api config")
.temp_token2
.as_ref()
// this is NOT how you validate tokens
// but this is low abuse risk so we're keeping it for now
&& app_auth_header == config_token2
&& app_auth_header
.as_bytes()
.ct_eq(config_token2.as_bytes())
.into()
{
authed_app_id = Some(1);
}
// todo: fix syntax
let internal = if req.headers().get("x-pluralkit-client-ip").is_none()
&& let Some(auth_header) = req
.headers()
.get("x-pluralkit-internalauth")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(real_token) = libpk::config.internal_auth.clone()
&& auth_header.as_bytes().ct_eq(real_token.as_bytes()).into()
{
true
} else {
false
};
req.extensions_mut()
.insert(AuthState::new(authed_system_id, authed_app_id));
.insert(AuthState::new(authed_system_id, authed_app_id, internal));
next.run(req).await
}

View file

@ -2,7 +2,7 @@ use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use tracing::{Instrument, Level, info, span, warn};
use crate::{auth::AuthState, util::header_or_unknown};

View file

@ -6,11 +6,11 @@ use axum::{
routing::url_params::UrlParams,
};
use sqlx::{types::Uuid, Postgres};
use sqlx::{Postgres, types::Uuid};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
use crate::{ApiContext, util::json_err};
use pluralkit_models::PKSystem;
// move this somewhere else
@ -31,7 +31,7 @@ pub async fn params(State(ctx): State<ApiContext>, mut req: Request, next: Next)
StatusCode::BAD_REQUEST,
r#"{"message":"400: Bad Request","code": 0}"#.to_string(),
)
.into()
.into();
}
};

View file

@ -45,21 +45,6 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
tokio::spawn(async move { handle });
let rscript = r.clone();
tokio::spawn(async move {
if let Ok(()) = rscript.wait_for_connect().await {
match rscript
.script_load::<String, String>(LUA_SCRIPT.to_string())
.await
{
Ok(_) => info!("connected to redis for request rate limiting"),
Err(error) => error!(?error, "could not load redis script"),
}
} else {
error!("could not wait for connection to load redis script!");
}
});
r
});
@ -152,12 +137,34 @@ pub async fn do_request_ratelimited(
let period = 1; // seconds
let cost = 1; // todo: update this for group member endpoints
let script_exists: Vec<usize> =
match redis.script_exists(vec![LUA_SCRIPT_SHA.to_string()]).await {
Ok(exists) => exists,
Err(error) => {
error!(?error, "failed to check ratelimit script");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
);
}
};
if script_exists[0] != 1 {
match redis
.script_load::<String, String>(LUA_SCRIPT.to_string())
.await
{
Ok(_) => info!("successfully loaded ratelimit script to redis"),
Err(error) => {
error!(?error, "could not load redis script")
}
}
}
// local rate_limit_key = KEYS[1]
// local rate = ARGV[1]
// local period = ARGV[2]
// return {remaining, tostring(retry_after), reset_after}
// todo: check if error is script not found and reload script
let resp = redis
.evalsha::<(i32, String, u64), String, Vec<String>, Vec<i32>>(
LUA_SCRIPT_SHA.to_string(),
@ -219,7 +226,7 @@ pub async fn do_request_ratelimited(
return response;
}
Err(error) => {
tracing::error!(?error, "error getting ratelimit info");
error!(?error, "error getting ratelimit info");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),

View file

@ -3,7 +3,7 @@ use axum::{
http::{HeaderValue, StatusCode},
response::IntoResponse,
};
use serde_json::{json, to_string, Value};
use serde_json::{Value, json, to_string};
use tracing::error;
pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {

View file

@ -1,7 +1,7 @@
[package]
name = "avatars"
version = "0.1.0"
edition = "2021"
edition = "2024"
[[bin]]
name = "avatar_cleanup"
@ -21,9 +21,9 @@ uuid = { workspace = true }
data-encoding = "2.5.0"
gif = "0.13.1"
image = { version = "0.24.8", default-features = false, features = ["gif", "jpeg", "png", "webp", "tiff"] }
image = { version = "0.25.6", default-features = false, features = ["gif", "jpeg", "png", "webp", "tiff"] }
form_urlencoded = "1.2.1"
rust-s3 = { version = "0.33.0", default-features = false, features = ["tokio-rustls-tls"] }
sha2 = "0.10.8"
thiserror = "1.0.56"
webp = "0.2.6"
webp = "0.3.1"

View file

@ -8,10 +8,10 @@ use anyhow::Context;
use axum::extract::State;
use axum::routing::get;
use axum::{
Json, Router,
http::StatusCode,
response::{IntoResponse, Response},
routing::post,
Json, Router,
};
use libpk::_config::AvatarsConfig;
use libpk::db::repository::avatars as db;

View file

@ -4,7 +4,7 @@ use std::io::Cursor;
use std::time::Instant;
use tracing::{debug, error, info, instrument};
use crate::{hash::Hash, ImageKind, PKAvatarError};
use crate::{ImageKind, PKAvatarError, hash::Hash};
const MAX_DIMENSION: u32 = 4000;
@ -211,8 +211,8 @@ fn process_gif_inner(
}))
}
fn reader_for(data: &[u8]) -> image::io::Reader<Cursor<&[u8]>> {
image::io::Reader::new(Cursor::new(data))
fn reader_for(data: &[u8]) -> image::ImageReader<Cursor<&[u8]>> {
image::ImageReader::new(Cursor::new(data))
.with_guessed_format()
.expect("cursor i/o is infallible")
}

View file

@ -62,7 +62,7 @@ pub async fn pull(
let size = match response.content_length() {
None => return Err(PKAvatarError::MissingHeader("Content-Length")),
Some(size) if size > MAX_SIZE => {
return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE))
return Err(PKAvatarError::ImageFileSizeTooLarge(size, MAX_SIZE));
}
Some(size) => size,
};
@ -162,7 +162,7 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
attachment_id: 0,
filename: "".to_string(),
full_url: url.to_string(),
})
});
}
_ => anyhow::bail!("not a discord cdn url"),
}

View file

@ -1,7 +1,7 @@
[package]
name = "dispatch"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
anyhow = { workspace = true }

View file

@ -1,7 +1,7 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
use tracing::{Instrument, Level, info, span, warn};
// log any requests that take longer than 2 seconds
// todo: change as necessary

View file

@ -5,7 +5,7 @@ use hickory_client::{
rr::{DNSClass, Name, RData, RecordType},
udp::UdpClientStream,
};
use reqwest::{redirect::Policy, StatusCode};
use reqwest::{StatusCode, redirect::Policy};
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
@ -14,7 +14,7 @@ use std::{
use tokio::{net::UdpSocket, sync::RwLock};
use tracing::{debug, error};
use axum::{extract::State, http::Uri, routing::post, Json, Router};
use axum::{Json, Router, extract::State, http::Uri, routing::post};
mod logger;
@ -127,7 +127,7 @@ async fn dispatch(
match res {
Ok(res) if res.status() != 200 => {
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
return DispatchResponse::InvalidResponseCode(res.status()).to_string();
}
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");

View file

@ -1,18 +1,18 @@
use axum::{
Router,
extract::{ConnectInfo, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::{delete, get, post},
Router,
};
use libpk::runtime_config::RuntimeConfig;
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::{marker::ChannelMarker, Id};
use twilight_model::id::{Id, marker::ChannelMarker};
use crate::{
discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
cache::{DM_PERMISSIONS, DiscordCache, dm_channel},
gateway::cluster_config,
shard_state::ShardStateManager,
},

View file

@ -4,18 +4,18 @@ use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
InMemoryCache, ResourceType,
model::CachedMember,
permission::{MemberRoles, RootError},
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_gateway::Event;
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
Id,
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
},
};
use twilight_util::permission_calculator::PermissionCalculator;

View file

@ -6,17 +6,17 @@ use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tracing::{error, info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId, create_iterator,
};
use twilight_model::gateway::{
Intents,
payload::outgoing::update_presence::UpdatePresencePayload,
presence::{Activity, ActivityType, Status},
Intents,
};
use crate::{
discord::identify_queue::{self, RedisQueue},
RUNTIME_CONFIG_KEY_EVENT_TARGET,
discord::identify_queue::{self, RedisQueue},
};
use super::cache::DiscordCache;
@ -116,7 +116,14 @@ pub async fn runner(
let raw_event = match item {
Ok(evt) => match evt {
Message::Close(frame) => {
let mut state_event = ShardStateEvent::Closed;
let close_code = if let Some(close) = frame {
match close.code {
4000..=4003 | 4005..=4009 => {
state_event = ShardStateEvent::Reconnect;
}
_ => {}
}
close.code.to_string()
} else {
"unknown".to_string()
@ -131,9 +138,7 @@ pub async fn runner(
)
.increment(1);
if let Err(error) =
tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None))
{
if let Err(error) = tx_state.try_send((shard.id(), state_event, None, None)) {
error!("failed to update shard state for socket closure: {error}");
}
@ -174,32 +179,45 @@ pub async fn runner(
)
.increment(1);
// update shard state and discord cache
if matches!(event, Event::Ready(_)) || matches!(event, Event::Resumed) {
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Other,
Some(event.clone()),
None,
)) {
tracing::error!(?error, "error updating shard state");
// check for shard status events
match event {
Event::Ready(_) | Event::Resumed => {
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Other,
Some(event.clone()),
None,
)) {
tracing::error!(?error, "error updating shard state");
}
}
}
// need to do heartbeat separately, to get the latency
let latency_num = shard
.latency()
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
if let Event::GatewayHeartbeatAck = event
&& let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Heartbeat,
Some(event.clone()),
Some(latency_num),
))
{
tracing::error!(?error, "error updating shard state for latency");
Event::GatewayReconnect => {
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Reconnect,
Some(event.clone()),
None,
)) {
tracing::error!(?error, "error updating shard state for reconnect");
}
}
Event::GatewayHeartbeatAck => {
// need to do heartbeat separately, to get the latency
let latency_num = shard
.latency()
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Heartbeat,
Some(event.clone()),
Some(latency_num),
)) {
tracing::error!(?error, "error updating shard state for latency");
}
}
_ => {}
}
if let Event::Ready(_) = event {
@ -227,7 +245,9 @@ pub async fn runner(
}
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
tx.send((shard.id(), event, raw_event)).await.unwrap();
if let Err(error) = tx.try_send((shard.id(), event, raw_event)) {
tracing::error!(?error, "error sending shard event");
}
}
}
}

View file

@ -86,7 +86,7 @@ impl ShardStateManager {
Ok(())
}
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
pub async fn socket_closed(&self, shard_id: u32, reconnect: bool) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let mut info = self
@ -97,6 +97,9 @@ impl ShardStateManager {
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = false;
if reconnect {
info.last_reconnect = chrono::offset::Utc::now().timestamp() as i32
}
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;

View file

@ -3,7 +3,7 @@
// - interaction: (custom_id where not_includes "help-menu")
use std::{
collections::{hash_map::Entry, HashMap},
collections::{HashMap, hash_map::Entry},
net::{IpAddr, SocketAddr},
time::Duration,
};
@ -15,8 +15,8 @@ use twilight_gateway::Event;
use twilight_model::{
application::interaction::InteractionData,
id::{
marker::{ChannelMarker, MessageMarker, UserMarker},
Id,
marker::{ChannelMarker, MessageMarker, UserMarker},
},
};
@ -103,7 +103,13 @@ impl EventAwaiter {
}
}
}
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
info!(
"ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions",
Instant::now().duration_since(now).as_micros(),
counts.0,
counts.1,
counts.2
);
}
}

View file

@ -4,7 +4,7 @@ use axum::{
extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response,
};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use tracing::{Instrument, Level, info, span, warn};
// log any requests that take longer than 2 seconds
// todo: change as necessary

View file

@ -1,4 +1,3 @@
#![feature(duration_constructors_lite)]
#![feature(if_let_guard)]
use chrono::Timelike;
@ -9,7 +8,7 @@ use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
use reqwest::{ClientBuilder, StatusCode};
use std::{sync::Arc, time::Duration, vec::Vec};
use tokio::{
signal::unix::{signal, SignalKind},
signal::unix::{SignalKind, signal},
sync::mpsc::channel,
task::JoinSet,
};
@ -108,8 +107,16 @@ async fn main() -> anyhow::Result<()> {
};
}
ShardStateEvent::Closed => {
if let Err(error) = shard_state.socket_closed(shard_id.number()).await {
error!("failed to update shard state for heartbeat: {error}")
if let Err(error) =
shard_state.socket_closed(shard_id.number(), false).await
{
error!("failed to update shard state for closed: {error}")
};
}
ShardStateEvent::Reconnect => {
if let Err(error) = shard_state.socket_closed(shard_id.number(), true).await
{
error!("failed to update shard state for reconnect: {error}")
};
}
ShardStateEvent::Other => {
@ -120,7 +127,7 @@ async fn main() -> anyhow::Result<()> {
)
.await
{
error!("failed to update shard state for heartbeat: {error}")
error!("failed to update shard state for other evt: {error}")
};
}
}

View file

@ -3,8 +3,8 @@ use std::{sync::Arc, time::Duration};
use tracing::{error, info, warn};
use twilight_http::api_error::{ApiError, GeneralApiError};
use twilight_model::id::{
marker::{ChannelMarker, MessageMarker},
Id,
marker::{ChannelMarker, MessageMarker},
};
// create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null);

View file

@ -3,7 +3,7 @@ use lazy_static::lazy_static;
use serde::Deserialize;
use std::sync::Arc;
use twilight_model::id::{marker::UserMarker, Id};
use twilight_model::id::{Id, marker::UserMarker};
#[derive(Clone, Deserialize, Debug)]
pub struct ClusterSettings {
@ -128,6 +128,9 @@ pub struct PKConfig {
#[serde(default)]
pub sentry_url: Option<String>,
#[serde(default)]
pub internal_auth: Option<String>,
}
impl PKConfig {
@ -148,15 +151,11 @@ lazy_static! {
// hacks
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
unsafe {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
}
unsafe { std::env::set_var("pluralkit__discord__cluster__node_id", var); }
}
if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
unsafe {
std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap());
}
unsafe { std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap()); }
}
Arc::new(Config::builder()

View file

@ -51,7 +51,7 @@ pub async fn remove_deletion_queue(pool: &PgPool, attachment_id: u64) -> anyhow:
}
pub async fn pop_queue(
pool: &'_ PgPool,
pool: &PgPool,
) -> anyhow::Result<Option<(Transaction<'_, Postgres>, ImageQueueEntry)>> {
let mut tx = pool.begin().await?;
let res: Option<ImageQueueEntry> = sqlx::query_as("delete from image_queue where itemid = (select itemid from image_queue order by itemid for update skip locked limit 1) returning *")

View file

@ -1,7 +1,7 @@
use serde::{Deserialize, Serialize};
use sqlx::{
types::chrono::{DateTime, Utc},
FromRow,
types::chrono::{DateTime, Utc},
};
use uuid::Uuid;

View file

@ -2,7 +2,7 @@ use std::net::SocketAddr;
use metrics_exporter_prometheus::PrometheusBuilder;
use sentry::IntoDsn;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
use sentry_tracing::event_from_event;

View file

@ -8,11 +8,13 @@ pub struct ShardState {
/// unix timestamp
pub last_heartbeat: i32,
pub last_connection: i32,
pub last_reconnect: i32,
pub cluster_id: Option<i32>,
}
pub enum ShardStateEvent {
Closed,
Heartbeat,
Reconnect,
Other,
}

View file

@ -1,7 +1,7 @@
[package]
name = "pk_macros"
version = "0.1.0"
edition = "2021"
edition = "2024"
[lib]
proc-macro = true
@ -10,4 +10,5 @@ proc-macro = true
quote = "1.0"
proc-macro2 = "1.0"
syn = "2.0"
prettyplease = "0.2.36"

52
crates/macros/src/api.rs Normal file
View file

@ -0,0 +1,52 @@
use quote::quote;
use syn::{FnArg, ItemFn, Pat, parse_macro_input};
fn _pretty_print(ts: &proc_macro2::TokenStream) -> String {
let file = syn::parse_file(&ts.to_string()).unwrap();
prettyplease::unparse(&file)
}
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as ItemFn);
let fn_name = &input.sig.ident;
let fn_params = &input.sig.inputs;
let fn_body = &input.block;
let syn::ReturnType::Type(_, fn_return_type) = &input.sig.output else {
panic!("handler return type must not be nothing");
};
let pms: Vec<proc_macro2::TokenStream> = fn_params
.iter()
.map(|v| {
let FnArg::Typed(pat) = v else {
panic!("must not have self param in handler");
};
let mut pat = pat.pat.clone();
if let Pat::Ident(ident) = *pat {
let mut ident = ident.clone();
ident.mutability = None;
pat = Box::new(Pat::Ident(ident));
}
quote! { #pat }
})
.collect();
let res = quote! {
#[allow(unused_mut)]
pub async fn #fn_name(#fn_params) -> axum::response::Response {
async fn inner(#fn_params) -> Result<#fn_return_type, crate::error::PKError> {
#fn_body
}
match inner(#(#pms),*).await {
Ok(res) => res.into_response(),
Err(err) => err.into_response(),
}
}
};
res.into()
}

View file

@ -1,8 +1,14 @@
use proc_macro::TokenStream;
mod api;
mod entrypoint;
mod model;
#[proc_macro_attribute]
pub fn api_endpoint(args: TokenStream, input: TokenStream) -> TokenStream {
api::macro_impl(args, input)
}
#[proc_macro_attribute]
pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
entrypoint::macro_impl(args, input)

View file

@ -1,6 +1,6 @@
use proc_macro2::{Span, TokenStream};
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Expr, Ident, Meta, Type};
use syn::{DeriveInput, Expr, Ident, Meta, Type, parse_macro_input};
#[derive(Clone, Debug)]
enum ElemPatchability {

View file

@ -8,10 +8,11 @@ use std::{
fn main() -> Result<(), Box<dyn Error>> {
let out_dir = env::var("OUT_DIR")?;
let manifest_dir = env::var("CARGO_MANIFEST_DIR")?;
let dest_path = Path::new(&out_dir).join("data.rs");
let mut datafile = File::create(&dest_path)?;
let prefix = "../../../../../../crates/migrate/data";
let prefix = manifest_dir + "/data";
let ct = fs::read_dir("data/migrations")?
.filter(|p| {

View file

@ -0,0 +1,6 @@
-- database version 53
-- add toggle for showing color codes on cv2 cards
alter table system_config add column card_show_color_hex bool default false;
update info set schema_version = 53;

View file

@ -1,7 +1,7 @@
[package]
name = "pluralkit_models"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
chrono = { workspace = true, features = ["serde"] }

View file

@ -18,7 +18,7 @@ pub enum PrivacyLevel {
}
// this sucks, put it somewhere else
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
use sqlx::{Database, Decode, Postgres, Type, postgres::PgTypeInfo};
use std::error::Error;
_util::fake_enum_impls!(PrivacyLevel);

View file

@ -1,7 +1,7 @@
[package]
name = "scheduled_tasks"
version = "0.1.0"
edition = "2021"
edition = "2024"
[dependencies]
libpk = { path = "../libpk" }
@ -9,6 +9,7 @@ libpk = { path = "../libpk" }
anyhow = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }

View file

@ -99,13 +99,25 @@ async fn main() -> anyhow::Result<()> {
update_db_message_meta
);
doforever!("* * * * *", "discord stats updater", update_discord_stats);
// on :00 and :30
// on hh:00 and hh:30
doforever!(
"0,30 * * * *",
"queue deleted image cleanup job",
queue_deleted_image_cleanup
);
// non-standard cron: at hh:mm:00, hh:mm:30
doforever!("0,30 * * * * *", "stats api updater", update_stats_api);
// every hour (could probably even be less frequent, basebackups are taken rarely)
doforever!(
"* * * * *",
"data basebackup info updater",
update_data_basebackup_prometheus
);
doforever!(
"* * * * *",
"messages basebackup info updater",
update_messages_basebackup_prometheus
);
set.join_next()
.await

View file

@ -1,4 +1,4 @@
use std::time::Duration;
use std::{collections::HashMap, time::Duration};
use anyhow::anyhow;
use fred::prelude::KeysInterface;
@ -10,10 +10,22 @@ use metrics::gauge;
use num_format::{Locale, ToFormattedString};
use reqwest::ClientBuilder;
use sqlx::Executor;
use tokio::{process::Command, sync::Mutex};
use crate::AppCtx;
pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> {
let data_ts = *BASEBACKUP_TS.lock().await.get("data").unwrap_or(&0) as f64;
let messages_ts = *BASEBACKUP_TS.lock().await.get("messages").unwrap_or(&0) as f64;
let now_ts = chrono::Utc::now().timestamp() as f64;
gauge!("pluralkit_latest_backup_ts", "repo" => "data").set(data_ts);
gauge!("pluralkit_latest_backup_ts", "repo" => "messages").set(messages_ts);
gauge!("pluralkit_latest_backup_age", "repo" => "data").set(now_ts - data_ts);
gauge!("pluralkit_latest_backup_age", "repo" => "messages").set(now_ts - messages_ts);
#[derive(sqlx::FromRow)]
struct Count {
count: i64,
@ -41,6 +53,83 @@ pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> {
Ok(())
}
lazy_static::lazy_static! {
static ref BASEBACKUP_TS: Mutex<HashMap<String, i64>> = Mutex::new(HashMap::new());
}
pub async fn update_data_basebackup_prometheus(_: AppCtx) -> anyhow::Result<()> {
update_basebackup_ts("data".to_string()).await
}
pub async fn update_messages_basebackup_prometheus(_: AppCtx) -> anyhow::Result<()> {
update_basebackup_ts("messages".to_string()).await
}
async fn update_basebackup_ts(repo: String) -> anyhow::Result<()> {
let mut env = HashMap::new();
for (key, value) in std::env::vars() {
if key.starts_with("AWS") {
env.insert(key, value);
}
}
env.insert(
"WALG_S3_PREFIX".to_string(),
format!("s3://pluralkit-backups/{repo}/"),
);
let output = Command::new("wal-g")
.arg("backup-list")
.arg("--json")
.envs(env)
.output()
.await?;
if !output.status.success() {
// todo: we should return error here
tracing::error!(
status = output.status.code(),
"failed to execute wal-g command"
);
return Ok(());
}
#[derive(serde::Deserialize)]
struct WalgBackupInfo {
backup_name: String,
time: String,
ts_parsed: Option<i64>,
}
let mut info =
serde_json::from_str::<Vec<WalgBackupInfo>>(&String::from_utf8_lossy(&output.stdout))?
.into_iter()
.filter(|v| v.backup_name.contains("base"))
.filter_map(|mut v| {
chrono::DateTime::parse_from_rfc3339(&v.time)
.ok()
.map(|dt| {
v.ts_parsed = Some(dt.with_timezone(&chrono::Utc).timestamp());
v
})
})
.collect::<Vec<WalgBackupInfo>>();
info.sort_by(|a, b| b.ts_parsed.cmp(&a.ts_parsed));
let Some(info) = info.first() else {
anyhow::bail!("could not find any basebackups in repo {repo}");
};
BASEBACKUP_TS
.lock()
.await
.insert(repo, info.ts_parsed.unwrap());
Ok(())
}
pub async fn update_db_meta(ctx: AppCtx) -> anyhow::Result<()> {
ctx.data
.execute(