chore: reorganize rust crates

This commit is contained in:
alyssa 2025-01-02 00:50:36 +00:00
parent 357122a892
commit 16ce67e02c
58 changed files with 6 additions and 13 deletions

27
crates/gateway/Cargo.toml Normal file
View file

@ -0,0 +1,27 @@
[package]
name = "gateway"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
futures = { workspace = true }
lazy_static = { workspace = true }
libpk = { path = "../libpk" }
metrics = { workspace = true }
serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-gateway = { workspace = true }
twilight-cache-inmemory = { workspace = true }
twilight-util = { workspace = true }
twilight-model = { workspace = true }
twilight-http = { workspace = true }
serde_variant = "0.1.3"

View file

@ -0,0 +1,183 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Router,
};
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::Id;
use crate::discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
};
use std::sync::Arc;
fn status_code(code: StatusCode, body: String) -> Response {
(code, body).into_response()
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
let app = Router::new()
.route(
"/guilds/:guild_id",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild(Id::new(guild_id)) {
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/members/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
Ok(val) => {
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
},
Err(err) => {
error!(?err, ?guild_id, "failed to get own guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/permissions/:user_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
Err(err) => {
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
Some(channels) => channels.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut channels = Vec::new();
for id in channel_ids {
match cache.0.channel(id) {
Some(channel) => channels.push(channel.to_owned()),
None => {
tracing::error!(
channel_id = id.get(),
"referenced channel {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
})
)
.route(
"/guilds/:guild_id/channels/:channel_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
}
match cache.0.channel(Id::new(channel_id)) {
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string())
}
})
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
}
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
Err(err) => {
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
get(|| async { "todo" }),
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
)
.route(
"/guilds/:guild_id/roles",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
Some(roles) => roles.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut roles = Vec::new();
for id in role_ids {
match cache.0.role(id) {
Some(role) => roles.push(role.value().resource().to_owned()),
None => {
tracing::error!(
role_id = id.get(),
"referenced role {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
})
)
.route("/stats", get(|State(cache): State<Arc<DiscordCache>>| async move {
let cluster = cluster_config();
let has_been_up = cache.2.read().await.len() as u32 == if cluster.total_shards > 16 {16} else {cluster.total_shards};
let stats = cache.0.stats();
let stats = json!({
"guild_count": stats.guilds(),
"channel_count": stats.channels(),
// just put this here until prom stats
"unavailable_guild_count": stats.unavailable_guilds(),
"up": has_been_up,
});
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
}))
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,368 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
model::CachedMember,
permission::{MemberRoles, RootError},
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
Id,
},
};
use twilight_util::permission_calculator::PermissionCalculator;
lazy_static! {
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
| Permissions::SEND_MESSAGES
| Permissions::READ_MESSAGE_HISTORY
| Permissions::ADD_REACTIONS
| Permissions::ATTACH_FILES
| Permissions::EMBED_LINKS
| Permissions::USE_EXTERNAL_EMOJIS
| Permissions::CONNECT
| Permissions::SPEAK
| Permissions::USE_VAD;
}
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
Channel {
id,
kind: ChannelType::Private,
application_id: None,
applied_tags: None,
available_tags: None,
bitrate: None,
default_auto_archive_duration: None,
default_forum_layout: None,
default_reaction_emoji: None,
default_sort_order: None,
default_thread_rate_limit_per_user: None,
flags: None,
guild_id: None,
icon: None,
invitable: None,
last_message_id: None,
last_pin_timestamp: None,
managed: None,
member: None,
member_count: None,
message_count: None,
name: None,
newly_created: None,
nsfw: None,
owner_id: None,
parent_id: None,
permission_overwrites: None,
position: None,
rate_limit_per_user: None,
recipients: None,
rtc_region: None,
thread_metadata: None,
topic: None,
user_limit: None,
video_quality_mode: None,
}
}
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
CachedMember {
avatar: item.avatar,
communication_disabled_until: item.communication_disabled_until,
deaf: Some(item.deaf),
flags: item.flags,
joined_at: item.joined_at,
mute: Some(item.mute),
nick: item.nick,
premium_since: item.premium_since,
roles: item.roles,
pending: false,
user_id: id,
}
}
pub fn new() -> DiscordCache {
let mut client_builder = twilight_http::Client::builder().token(
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_token
.clone(),
);
if let Some(base_url) = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.api_base_url
.clone()
{
client_builder = client_builder.proxy(base_url, true);
}
let client = Arc::new(client_builder.build());
let cache = Arc::new(
InMemoryCache::builder()
.resource_types(
ResourceType::GUILD
| ResourceType::CHANNEL
| ResourceType::ROLE
| ResourceType::USER_CURRENT
| ResourceType::MEMBER_CURRENT,
)
.message_cache_size(0)
.build(),
);
DiscordCache(cache, client, RwLock::new(Vec::new()))
}
pub struct DiscordCache(
pub Arc<InMemoryCache>,
pub Arc<twilight_http::Client>,
pub RwLock<Vec<u32>>,
);
impl DiscordCache {
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
if self
.0
.guild(guild_id)
.ok_or_else(|| format_err!("guild not found"))?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id
== libpk::config
.discord
.as_ref()
.expect("missing discord config")
.client_id
{
self.0
.member(guild_id, user_id)
.ok_or(format_err!("self member not found"))?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.root();
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
pub async fn channel_permissions(
&self,
channel_id: Id<ChannelMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
let channel = self
.0
.channel(channel_id)
.ok_or(format_err!("channel not found"))?;
if channel.value().guild_id.is_none() {
return Ok(*DM_PERMISSIONS);
}
let guild_id = channel.value().guild_id.unwrap();
if self
.0
.guild(guild_id)
.ok_or_else(|| {
tracing::error!(
channel_id = channel_id.get(),
guild_id = guild_id.get(),
"referenced guild from cached channel {channel_id} not found in cache"
);
format_err!("internal cache error")
})?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id
== libpk::config
.discord
.as_ref()
.expect("missing discord config")
.client_id
{
self.0
.member(guild_id, user_id)
.ok_or_else(|| {
tracing::error!(
guild_id = guild_id.get(),
"self member for cached guild {guild_id} not found in cache"
);
format_err!("internal cache error")
})?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let overwrites = match channel.kind {
ChannelType::AnnouncementThread
| ChannelType::PrivateThread
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
_ => channel
.value()
.permission_overwrites()
.unwrap_or_default()
.to_vec(),
};
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
self.0.guild(id).map(|guild| {
let channels = self
.0
.guild_channels(id)
.map(|reference| {
reference
.iter()
.filter_map(|channel_id| {
let channel = self.0.channel(*channel_id)?;
if channel.kind.is_thread() {
None
} else {
Some(channel.value().clone())
}
})
.collect()
})
.unwrap_or_default();
let roles = self
.0
.guild_roles(id)
.map(|reference| {
reference
.iter()
.filter_map(|role_id| {
Some(self.0.role(*role_id)?.value().resource().clone())
})
.collect()
})
.unwrap_or_default();
Guild {
afk_channel_id: guild.afk_channel_id(),
afk_timeout: guild.afk_timeout(),
application_id: guild.application_id(),
approximate_member_count: None, // Only present in with_counts HTTP endpoint
banner: guild.banner().map(ToOwned::to_owned),
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
channels,
default_message_notifications: guild.default_message_notifications(),
description: guild.description().map(ToString::to_string),
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
emojis: vec![],
explicit_content_filter: guild.explicit_content_filter(),
features: guild.features().cloned().collect(),
icon: guild.icon().map(ToOwned::to_owned),
id: guild.id(),
joined_at: guild.joined_at(),
large: guild.large(),
max_members: guild.max_members(),
max_presences: guild.max_presences(),
max_video_channel_users: guild.max_video_channel_users(),
member_count: guild.member_count(),
members: vec![],
mfa_level: guild.mfa_level(),
name: guild.name().to_string(),
nsfw_level: guild.nsfw_level(),
owner_id: guild.owner_id(),
owner: guild.owner(),
permissions: guild.permissions(),
public_updates_channel_id: guild.public_updates_channel_id(),
preferred_locale: guild.preferred_locale().to_string(),
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
premium_subscription_count: guild.premium_subscription_count(),
premium_tier: guild.premium_tier(),
presences: vec![],
roles,
rules_channel_id: guild.rules_channel_id(),
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
splash: guild.splash().map(ToOwned::to_owned),
stage_instances: vec![],
stickers: vec![],
system_channel_flags: guild.system_channel_flags(),
system_channel_id: guild.system_channel_id(),
threads: vec![],
unavailable: false,
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
verification_level: guild.verification_level(),
voice_states: vec![],
widget_channel_id: guild.widget_channel_id(),
widget_enabled: guild.widget_enabled(),
}
})
}
}

View file

@ -0,0 +1,200 @@
use futures::StreamExt;
use libpk::_config::ClusterSettings;
use metrics::counter;
use std::sync::{mpsc::Sender, Arc};
use tracing::{error, info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
};
use twilight_model::gateway::{
payload::outgoing::update_presence::UpdatePresencePayload,
presence::{Activity, ActivityType, Status},
Intents,
};
use crate::discord::identify_queue::{self, RedisQueue};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
pub fn cluster_config() -> ClusterSettings {
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.cluster
.clone()
.unwrap_or(libpk::_config::ClusterSettings {
node_id: 0,
total_shards: 1,
total_nodes: 1,
})
}
pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
let intents = Intents::GUILDS
| Intents::DIRECT_MESSAGES
| Intents::DIRECT_MESSAGE_REACTIONS
| Intents::GUILD_MESSAGES
| Intents::GUILD_MESSAGE_REACTIONS
| Intents::MESSAGE_CONTENT;
let queue = identify_queue::new(redis);
let cluster_settings = cluster_config();
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
warn!("we have less than 16 shards, assuming single gateway process");
(0, (cluster_settings.total_shards - 1).into())
} else {
(
(cluster_settings.node_id * 16).into(),
(((cluster_settings.node_id + 1) * 16) - 1).into(),
)
};
let shards = create_iterator(
start_shard..end_shard + 1,
cluster_settings.total_shards,
ConfigBuilder::new(
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_token
.to_owned(),
intents,
)
.presence(presence("pk;help", false))
.queue(queue.clone())
.build(),
|_, builder| builder.build(),
);
let mut shards_vec = Vec::new();
shards_vec.extend(shards);
Ok(shards_vec)
}
pub async fn runner(
mut shard: Shard<RedisQueue>,
_tx: Sender<(ShardId, String)>,
shard_state: ShardStateManager,
cache: Arc<DiscordCache>,
) {
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
info!("waiting for events");
while let Some(item) = shard.next().await {
let raw_event = match item {
Ok(evt) => match evt {
Message::Close(frame) => {
info!(
"shard {} closed: {}",
shard.id().number(),
if let Some(close) = frame {
format!("{} ({})", close.code, close.reason)
} else {
"unknown".to_string()
}
);
if let Err(error) = shard_state.socket_closed(shard.id().number()).await {
error!("failed to update shard state for socket closure: {error}");
}
continue;
}
Message::Text(text) => text,
},
Err(error) => {
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
continue;
}
};
let event = match twilight_gateway::parse(raw_event.clone(), EventTypeFlags::all()) {
Ok(Some(parsed)) => Event::from(parsed),
Ok(None) => {
// we received an event type unknown to twilight
// that's fine, we probably don't need it anyway
continue;
}
Err(error) => {
error!(
"shard {} failed to parse gateway event: {}",
shard.id().number(),
error
);
continue;
}
};
// log the event in metrics
// event_type * shard_id is too many labels and prometheus fails to query it
// so we split it into two metrics
counter!(
"pluralkit_gateway_events_type",
"event_type" => serde_variant::to_variant_name(&event.kind()).unwrap(),
)
.increment(1);
counter!(
"pluralkit_gateway_events_shard",
"shard_id" => shard.id().number().to_string(),
)
.increment(1);
// update shard state and discord cache
if let Err(error) = shard_state
.handle_event(shard.id().number(), event.clone())
.await
{
tracing::warn!(?error, "error updating redis state");
}
// need to do heartbeat separately, to get the latency
if let Event::GatewayHeartbeatAck = event
&& let Err(error) = shard_state
.heartbeated(shard.id().number(), shard.latency())
.await
{
tracing::warn!(?error, "error updating redis state for latency");
}
if let Event::Ready(_) = event {
if !cache.2.read().await.contains(&shard.id().number()) {
cache.2.write().await.push(shard.id().number());
}
}
cache.0.update(&event);
// okay, we've handled the event internally, let's send it to consumers
// tx.send((shard.id(), raw_event)).unwrap();
}
}
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
UpdatePresencePayload {
activities: vec![Activity {
application_id: None,
assets: None,
buttons: vec![],
created_at: None,
details: None,
id: None,
state: None,
url: None,
emoji: None,
flags: None,
instance: None,
kind: ActivityType::Playing,
name: status.to_string(),
party: None,
secrets: None,
timestamps: None,
}],
afk: false,
since: None,
status: if going_away {
Status::Idle
} else {
Status::Online
},
}
}

View file

@ -0,0 +1,88 @@
use fred::{
clients::RedisPool,
error::RedisError,
interfaces::KeysInterface,
types::{Expiration, SetOptions},
};
use std::fmt::Debug;
use std::time::Duration;
use tokio::sync::oneshot;
use tracing::{error, info};
use twilight_gateway::queue::Queue;
pub fn new(redis: RedisPool) -> RedisQueue {
RedisQueue {
redis,
concurrency: libpk::config
.discord
.as_ref()
.expect("missing discord config")
.max_concurrency,
}
}
#[derive(Clone)]
pub struct RedisQueue {
pub redis: RedisPool,
pub concurrency: u32,
}
impl Debug for RedisQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisQueue")
.field("concurrency", &self.concurrency)
.finish()
}
}
impl Queue for RedisQueue {
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
tokio::spawn(request_inner(
self.redis.clone(),
self.concurrency,
shard_id,
tx,
));
rx
}
}
const EXPIRY: i64 = 6;
const RETRY_INTERVAL: u64 = 500;
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
let bucket = shard_id % concurrency;
let key = format!("pluralkit:identify:{}", bucket);
info!(shard_id, bucket, "waiting for allowance...");
loop {
let done: Result<Option<String>, RedisError> = redis
.set(
key.to_string(),
"1",
Some(Expiration::EX(EXPIRY)),
Some(SetOptions::NX),
false,
)
.await;
match done {
Ok(Some(_)) => {
info!(shard_id, bucket, "got allowance!");
// if this fails, it's probably already doing something else
let _ = tx.send(());
return;
}
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
}
}
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
}
}

View file

@ -0,0 +1,4 @@
pub mod cache;
pub mod gateway;
pub mod identify_queue;
pub mod shard_state;

View file

@ -0,0 +1,91 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge};
use tracing::info;
use twilight_gateway::{Event, Latency};
use libpk::{state::*, util::redis::*};
#[derive(Clone)]
pub struct ShardStateManager {
redis: RedisPool,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
_ => Ok(()),
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<String> = self
.redis
.hget("pluralkit:shardstatus", shard_id)
.await
.to_option_or_error()?;
match data {
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
None => Ok(ShardState::default()),
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
self.redis
.hset::<(), &str, (String, String)>(
"pluralkit:shardstatus",
(
shard_id.to_string(),
serde_json::to_string(&info).expect("could not serialize shard"),
),
)
.await?;
Ok(())
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!(
"shard {} {}",
shard_id,
if resumed { "resumed" } else { "ready" }
);
counter!(
"pluralkit_gateway_shard_reconnect",
"shard_id" => shard_id.to_string(),
"resumed" => resumed.to_string(),
)
.increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1);
let mut info = self.get_shard(shard_id).await?;
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let mut info = self.get_shard(shard_id).await?;
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
self.save_shard(shard_id, info).await?;
Ok(())
}
}

View file

@ -0,0 +1,72 @@
use std::time::Instant;
use axum::{
extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response,
};
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
counter!(
"pluralkit_gateway_cache_api_requests",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
)
.increment(1);
histogram!(
"pluralkit_gateway_cache_api_requests_bucket",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
)
.record(elapsed as f64 / 1_000_f64);
if response.status() != StatusCode::FOUND {
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
}
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

140
crates/gateway/src/main.rs Normal file
View file

@ -0,0 +1,140 @@
#![feature(let_chains)]
#![feature(if_let_guard)]
use chrono::Timelike;
use fred::{clients::RedisPool, interfaces::*};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
};
use std::{
sync::{mpsc::channel, Arc},
time::Duration,
vec::Vec,
};
use tokio::task::JoinSet;
use tracing::{info, warn};
use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence;
mod cache_api;
mod discord;
mod logger;
libpk::main!("gateway");
async fn real_main() -> anyhow::Result<()> {
let (shutdown_tx, shutdown_rx) = channel::<()>();
let shutdown_tx = Arc::new(shutdown_tx);
let redis = libpk::db::init_redis().await?;
let shard_state = discord::shard_state::new(redis.clone());
let cache = Arc::new(discord::cache::new());
let shards = discord::gateway::create_shards(redis.clone())?;
let (event_tx, _event_rx) = channel();
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
let mut set = JoinSet::new();
for shard in shards {
senders.push((shard.id(), shard.sender()));
signal_senders.push(shard.sender());
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
cache.clone(),
)));
}
set.spawn(tokio::spawn(
async move { scheduled_task(redis, senders).await },
));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache).await {
Err(error) => {
tracing::error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
}
_ => unreachable!(),
}
}));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
set.spawn(tokio::spawn(async move {
for sig in signals.forever() {
info!("received signal {:?}", sig);
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(());
break;
}
}));
let _ = shutdown_rx.recv();
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
set.abort_all();
info!("gateway exiting, have a nice day!");
Ok(())
}
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
loop {
tokio::time::sleep(Duration::from_secs(
(60 - chrono::offset::Utc::now().second()).into(),
))
.await;
info!("running per-minute scheduled tasks");
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
Ok(val) => Some(val),
Err(error) => {
tracing::warn!(?error, "failed to fetch bot status from redis");
None
}
};
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence(
if let Some(status) = status {
format!("pk;help | {}", status)
} else {
"pk;help".to_string()
}
.as_str(),
false,
),
};
for sender in senders.iter() {
match sender.1.command(&presence) {
Err(error) => {
warn!(?error, "could not update presence on shard {}", sender.0)
}
_ => {}
};
}
}
}