mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-11 16:20:13 +00:00
Merge remote-tracking branch 'upstream/main' into rust-command-parser
This commit is contained in:
commit
e8f8e5f0a3
37 changed files with 316 additions and 201 deletions
|
|
@ -16,6 +16,7 @@ use axum::{
|
|||
use libpk::_config::AvatarsConfig;
|
||||
use libpk::db::repository::avatars as db;
|
||||
use libpk::db::types::avatars::*;
|
||||
use pull::ParsedUrl;
|
||||
use reqwest::{Client, ClientBuilder};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
|
|
@ -23,7 +24,7 @@ use std::error::Error;
|
|||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use thiserror::Error;
|
||||
use tracing::{error, info};
|
||||
use tracing::{error, info, warn};
|
||||
use uuid::Uuid;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
|
|
@ -35,9 +36,15 @@ pub enum PKAvatarError {
|
|||
#[error("discord cdn responded with status code: {0}")]
|
||||
BadCdnResponse(reqwest::StatusCode),
|
||||
|
||||
#[error("server responded with status code: {0}")]
|
||||
BadServerResponse(reqwest::StatusCode),
|
||||
|
||||
#[error("network error: {0}")]
|
||||
NetworkError(reqwest::Error),
|
||||
|
||||
#[error("network error: {0}")]
|
||||
NetworkErrorString(String),
|
||||
|
||||
#[error("response is missing header: {0}")]
|
||||
MissingHeader(&'static str),
|
||||
|
||||
|
|
@ -86,7 +93,6 @@ async fn pull(
|
|||
) -> Result<Json<PullResponse>, PKAvatarError> {
|
||||
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
||||
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||
|
||||
if !req.force {
|
||||
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||
// remove any pending image cleanup
|
||||
|
|
@ -132,6 +138,26 @@ async fn pull(
|
|||
}))
|
||||
}
|
||||
|
||||
async fn verify(
|
||||
State(state): State<AppState>,
|
||||
Json(req): Json<PullRequest>,
|
||||
) -> Result<(), PKAvatarError> {
|
||||
let result = crate::pull::pull(
|
||||
state.pull_client,
|
||||
&ParsedUrl {
|
||||
full_url: req.url.clone(),
|
||||
channel_id: 0,
|
||||
attachment_id: 0,
|
||||
filename: "".to_string(),
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
let encoded = process::process_async(result.data, req.kind).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn stats(State(state): State<AppState>) -> Result<Json<Stats>, PKAvatarError> {
|
||||
Ok(Json(db::get_stats(&state.pool).await?))
|
||||
}
|
||||
|
|
@ -193,6 +219,7 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
// migrate::spawn_migrate_workers(Arc::new(state.clone()), state.config.migrate_worker_count);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/verify", post(verify))
|
||||
.route("/pull", post(pull))
|
||||
.route("/stats", get(stats))
|
||||
.with_state(state);
|
||||
|
|
@ -235,7 +262,12 @@ impl IntoResponse for PKAvatarError {
|
|||
};
|
||||
|
||||
// print inner error if otherwise hidden
|
||||
error!("error: {}", self.source().unwrap_or(&self));
|
||||
// `error!` calls go to sentry, so only use that if it's our error
|
||||
if matches!(self, PKAvatarError::InternalError(_)) {
|
||||
error!("error: {}", self.source().unwrap_or(&self));
|
||||
} else {
|
||||
warn!("error: {}", self.source().unwrap_or(&self));
|
||||
}
|
||||
|
||||
(
|
||||
status_code,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,8 @@ use std::{str::FromStr, sync::Arc};
|
|||
use crate::PKAvatarError;
|
||||
use anyhow::Context;
|
||||
use reqwest::{Client, StatusCode, Url};
|
||||
use std::error::Error;
|
||||
use std::fmt::Write;
|
||||
use std::time::Instant;
|
||||
use tracing::{error, instrument};
|
||||
|
||||
|
|
@ -28,14 +30,29 @@ pub async fn pull(
|
|||
.expect("set_host should not fail");
|
||||
}
|
||||
let response = client.get(trimmed_url.clone()).send().await.map_err(|e| {
|
||||
error!("network error for {}: {}", parsed_url.full_url, e);
|
||||
PKAvatarError::NetworkError(e)
|
||||
// terrible
|
||||
let mut s = format!("{}", e);
|
||||
if let Some(src) = e.source() {
|
||||
let _ = write!(s, ": {}", src);
|
||||
let mut err = src;
|
||||
while let Some(src) = err.source() {
|
||||
let _ = write!(s, ": {}", src);
|
||||
err = src;
|
||||
}
|
||||
}
|
||||
|
||||
error!("network error for {}: {}", parsed_url.full_url, s);
|
||||
PKAvatarError::NetworkErrorString(s)
|
||||
})?;
|
||||
let time_after_headers = Instant::now();
|
||||
let status = response.status();
|
||||
|
||||
if status != StatusCode::OK {
|
||||
return Err(PKAvatarError::BadCdnResponse(status));
|
||||
if trimmed_url.host_str() == Some("cdn.discordapp.com") {
|
||||
return Err(PKAvatarError::BadCdnResponse(status));
|
||||
} else {
|
||||
return Err(PKAvatarError::BadServerResponse(status));
|
||||
}
|
||||
}
|
||||
|
||||
let size = match response.content_length() {
|
||||
|
|
|
|||
|
|
@ -105,7 +105,7 @@ pub fn new() -> DiscordCache {
|
|||
.api_base_url
|
||||
.clone()
|
||||
{
|
||||
client_builder = client_builder.proxy(base_url, true);
|
||||
client_builder = client_builder.proxy(base_url, true).ratelimiter(None);
|
||||
}
|
||||
|
||||
let client = Arc::new(client_builder.build());
|
||||
|
|
|
|||
|
|
@ -83,29 +83,38 @@ pub async fn runner(
|
|||
cache: Arc<DiscordCache>,
|
||||
) {
|
||||
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
|
||||
let shard_id = shard.id().number();
|
||||
|
||||
info!("waiting for events");
|
||||
while let Some(item) = shard.next().await {
|
||||
let raw_event = match item {
|
||||
Ok(evt) => match evt {
|
||||
Message::Close(frame) => {
|
||||
info!(
|
||||
"shard {} closed: {}",
|
||||
shard.id().number(),
|
||||
if let Some(close) = frame {
|
||||
format!("{} ({})", close.code, close.reason)
|
||||
} else {
|
||||
"unknown".to_string()
|
||||
}
|
||||
);
|
||||
if let Err(error) = shard_state.socket_closed(shard.id().number()).await {
|
||||
let close_code = if let Some(close) = frame {
|
||||
close.code.to_string()
|
||||
} else {
|
||||
"unknown".to_string()
|
||||
};
|
||||
|
||||
info!("shard {shard_id} closed: {close_code}");
|
||||
|
||||
counter!(
|
||||
"pluralkit_gateway_shard_closed",
|
||||
"shard_id" => shard_id.to_string(),
|
||||
"close_code" => close_code,
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
if let Err(error) = shard_state.socket_closed(shard_id).await {
|
||||
error!("failed to update shard state for socket closure: {error}");
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
Message::Text(text) => text,
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
|
||||
tracing::warn!(?error, "error receiving event from shard {shard_id}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
|
@ -118,11 +127,7 @@ pub async fn runner(
|
|||
continue;
|
||||
}
|
||||
Err(error) => {
|
||||
error!(
|
||||
"shard {} failed to parse gateway event: {}",
|
||||
shard.id().number(),
|
||||
error
|
||||
);
|
||||
error!("shard {shard_id} failed to parse gateway event: {error}");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
|
@ -137,29 +142,24 @@ pub async fn runner(
|
|||
.increment(1);
|
||||
counter!(
|
||||
"pluralkit_gateway_events_shard",
|
||||
"shard_id" => shard.id().number().to_string(),
|
||||
"shard_id" => shard_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
// update shard state and discord cache
|
||||
if let Err(error) = shard_state
|
||||
.handle_event(shard.id().number(), event.clone())
|
||||
.await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state");
|
||||
if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await {
|
||||
tracing::error!(?error, "error updating redis state");
|
||||
}
|
||||
// need to do heartbeat separately, to get the latency
|
||||
if let Event::GatewayHeartbeatAck = event
|
||||
&& let Err(error) = shard_state
|
||||
.heartbeated(shard.id().number(), shard.latency())
|
||||
.await
|
||||
&& let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state for latency");
|
||||
tracing::error!(?error, "error updating redis state for latency");
|
||||
}
|
||||
|
||||
if let Event::Ready(_) = event {
|
||||
if !cache.2.read().await.contains(&shard.id().number()) {
|
||||
cache.2.write().await.push(shard.id().number());
|
||||
if !cache.2.read().await.contains(&shard_id) {
|
||||
cache.2.write().await.push(shard_id);
|
||||
}
|
||||
}
|
||||
cache.0.update(&event);
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ use metrics::{counter, gauge};
|
|||
use tracing::info;
|
||||
use twilight_gateway::{Event, Latency};
|
||||
|
||||
use libpk::{state::*, util::redis::*};
|
||||
use libpk::state::ShardState;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShardStateManager {
|
||||
|
|
@ -24,11 +24,7 @@ impl ShardStateManager {
|
|||
}
|
||||
|
||||
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
|
||||
let data: Option<String> = self
|
||||
.redis
|
||||
.hget("pluralkit:shardstatus", shard_id)
|
||||
.await
|
||||
.to_option_or_error()?;
|
||||
let data: Option<String> = self.redis.hget("pluralkit:shardstatus", shard_id).await?;
|
||||
match data {
|
||||
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
|
||||
None => Ok(ShardState::default()),
|
||||
|
|
|
|||
|
|
@ -21,3 +21,4 @@ uuid = { workspace = true }
|
|||
config = "0.14.0"
|
||||
json-subscriber = { version = "0.2.2", features = ["env-filter"] }
|
||||
metrics-exporter-prometheus = { version = "0.15.3", default-features = false, features = ["tokio", "http-listener", "tracing"] }
|
||||
sentry-tracing = "0.36.0"
|
||||
|
|
|
|||
|
|
@ -5,16 +5,28 @@ use metrics_exporter_prometheus::PrometheusBuilder;
|
|||
use sentry::IntoDsn;
|
||||
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilter};
|
||||
|
||||
use sentry_tracing::event_from_event;
|
||||
|
||||
pub mod db;
|
||||
pub mod state;
|
||||
pub mod util;
|
||||
|
||||
pub mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
// functions in this file are only used by the main function below
|
||||
|
||||
pub fn init_logging(component: &str) -> anyhow::Result<()> {
|
||||
pub fn init_logging(component: &str) {
|
||||
let sentry_layer =
|
||||
sentry_tracing::layer().event_mapper(|md, ctx| match md.metadata().level() {
|
||||
&tracing::Level::ERROR => {
|
||||
// for some reason this works, but letting the library handle it doesn't
|
||||
let event = event_from_event(md, ctx);
|
||||
sentry::capture_event(event);
|
||||
sentry_tracing::EventMapping::Ignore
|
||||
}
|
||||
_ => sentry_tracing::EventMapping::Ignore,
|
||||
});
|
||||
|
||||
if config.json_log {
|
||||
let mut layer = json_subscriber::layer();
|
||||
layer.inner_layer_mut().add_static_field(
|
||||
|
|
@ -22,16 +34,16 @@ pub fn init_logging(component: &str) -> anyhow::Result<()> {
|
|||
serde_json::Value::String(component.to_string()),
|
||||
);
|
||||
tracing_subscriber::registry()
|
||||
.with(sentry_layer)
|
||||
.with(layer)
|
||||
.with(EnvFilter::from_default_env())
|
||||
.init();
|
||||
} else {
|
||||
tracing_subscriber::fmt()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
tracing_subscriber::registry()
|
||||
.with(sentry_layer)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.init();
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn init_metrics() -> anyhow::Result<()> {
|
||||
|
|
@ -61,7 +73,7 @@ macro_rules! main {
|
|||
fn main() -> anyhow::Result<()> {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
// we might also be able to use env!("CARGO_CRATE_NAME") here
|
||||
libpk::init_logging($component)?;
|
||||
libpk::init_logging($component);
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
|
|
|
|||
|
|
@ -1 +0,0 @@
|
|||
pub mod redis;
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
use fred::error::RedisError;
|
||||
|
||||
pub trait RedisErrorExt<T> {
|
||||
fn to_option_or_error(self) -> Result<Option<T>, RedisError>;
|
||||
}
|
||||
|
||||
impl<T> RedisErrorExt<T> for Result<T, RedisError> {
|
||||
fn to_option_or_error(self) -> Result<Option<T>, RedisError> {
|
||||
match self {
|
||||
Ok(v) => Ok(Some(v)),
|
||||
Err(error) if error.is_not_found() => Ok(None),
|
||||
Err(error) => Err(error),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ serde_json = { workspace = true }
|
|||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
|
||||
croner = "2.1.0"
|
||||
num-format = "0.4.4"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use chrono::Utc;
|
||||
use croner::Cron;
|
||||
use fred::prelude::RedisPool;
|
||||
|
|
@ -14,15 +16,38 @@ pub struct AppCtx {
|
|||
pub messages: PgPool,
|
||||
pub stats: PgPool,
|
||||
pub redis: RedisPool,
|
||||
|
||||
pub discord: Arc<twilight_http::Client>,
|
||||
}
|
||||
|
||||
libpk::main!("scheduled_tasks");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let mut client_builder = twilight_http::Client::builder().token(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.clone(),
|
||||
);
|
||||
|
||||
if let Some(base_url) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.api_base_url
|
||||
.clone()
|
||||
{
|
||||
client_builder = client_builder.proxy(base_url, true).ratelimiter(None);
|
||||
}
|
||||
|
||||
let ctx = AppCtx {
|
||||
data: libpk::db::init_data_db().await?,
|
||||
messages: libpk::db::init_messages_db().await?,
|
||||
stats: libpk::db::init_stats_db().await?,
|
||||
redis: libpk::db::init_redis().await?,
|
||||
|
||||
discord: Arc::new(client_builder.build()),
|
||||
};
|
||||
|
||||
info!("starting scheduled tasks runner");
|
||||
|
|
|
|||
|
|
@ -25,7 +25,13 @@ pub async fn update_prometheus(ctx: AppCtx) -> anyhow::Result<()> {
|
|||
|
||||
gauge!("pluralkit_image_cleanup_queue_length").set(count.count as f64);
|
||||
|
||||
// todo: remaining shard session_start_limit
|
||||
let gateway = ctx.discord.gateway().authed().await?.model().await?;
|
||||
|
||||
gauge!("pluralkit_gateway_sessions_remaining")
|
||||
.set(gateway.session_start_limit.remaining as f64);
|
||||
gauge!("pluralkit_gateway_sessions_reset_after")
|
||||
.set(gateway.session_start_limit.reset_after as f64);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue