mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
chore: reorganize rust crates
This commit is contained in:
parent
357122a892
commit
16ce67e02c
58 changed files with 6 additions and 13 deletions
27
crates/gateway/Cargo.toml
Normal file
27
crates/gateway/Cargo.toml
Normal file
|
|
@ -0,0 +1,27 @@
|
|||
[package]
|
||||
name = "gateway"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
libpk = { path = "../libpk" }
|
||||
metrics = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
signal-hook = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
twilight-gateway = { workspace = true }
|
||||
twilight-cache-inmemory = { workspace = true }
|
||||
twilight-util = { workspace = true }
|
||||
twilight-model = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
|
||||
serde_variant = "0.1.3"
|
||||
183
crates/gateway/src/cache_api.rs
Normal file
183
crates/gateway/src/cache_api.rs
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use serde_json::{json, to_string};
|
||||
use tracing::{error, info};
|
||||
use twilight_model::id::Id;
|
||||
|
||||
use crate::discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn status_code(code: StatusCode, body: String) -> Response {
|
||||
(code, body).into_response()
|
||||
}
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/guilds/:guild_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild(Id::new(guild_id)) {
|
||||
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/members/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
|
||||
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
|
||||
Ok(val) => {
|
||||
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
|
||||
},
|
||||
Err(err) => {
|
||||
error!(?err, ?guild_id, "failed to get own guild member permissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/:user_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
|
||||
Err(err) => {
|
||||
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/channels",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
|
||||
Some(channels) => channels.to_owned(),
|
||||
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
};
|
||||
|
||||
let mut channels = Vec::new();
|
||||
for id in channel_ids {
|
||||
match cache.0.channel(id) {
|
||||
Some(channel) => channels.push(channel.to_owned()),
|
||||
None => {
|
||||
tracing::error!(
|
||||
channel_id = id.get(),
|
||||
"referenced channel {} from guild {} not found in cache",
|
||||
id.get(), guild_id,
|
||||
);
|
||||
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
|
||||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
|
||||
}
|
||||
match cache.0.channel(Id::new(channel_id)) {
|
||||
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string())
|
||||
}
|
||||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
|
||||
}
|
||||
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
|
||||
Err(err) => {
|
||||
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
|
||||
get(|| async { "todo" }),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/last_message",
|
||||
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/roles",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
|
||||
Some(roles) => roles.to_owned(),
|
||||
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
};
|
||||
|
||||
let mut roles = Vec::new();
|
||||
for id in role_ids {
|
||||
match cache.0.role(id) {
|
||||
Some(role) => roles.push(role.value().resource().to_owned()),
|
||||
None => {
|
||||
tracing::error!(
|
||||
role_id = id.get(),
|
||||
"referenced role {} from guild {} not found in cache",
|
||||
id.get(), guild_id,
|
||||
);
|
||||
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
|
||||
})
|
||||
)
|
||||
|
||||
.route("/stats", get(|State(cache): State<Arc<DiscordCache>>| async move {
|
||||
let cluster = cluster_config();
|
||||
let has_been_up = cache.2.read().await.len() as u32 == if cluster.total_shards > 16 {16} else {cluster.total_shards};
|
||||
let stats = cache.0.stats();
|
||||
let stats = json!({
|
||||
"guild_count": stats.guilds(),
|
||||
"channel_count": stats.channels(),
|
||||
// just put this here until prom stats
|
||||
"unavailable_guild_count": stats.unavailable_guilds(),
|
||||
"up": has_been_up,
|
||||
});
|
||||
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
|
||||
}))
|
||||
|
||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||
.with_state(cache);
|
||||
|
||||
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
368
crates/gateway/src/discord/cache.rs
Normal file
368
crates/gateway/src/discord/cache.rs
Normal file
|
|
@ -0,0 +1,368 @@
|
|||
use anyhow::format_err;
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::RwLock;
|
||||
use twilight_cache_inmemory::{
|
||||
model::CachedMember,
|
||||
permission::{MemberRoles, RootError},
|
||||
traits::CacheableChannel,
|
||||
InMemoryCache, ResourceType,
|
||||
};
|
||||
use twilight_model::{
|
||||
channel::{Channel, ChannelType},
|
||||
guild::{Guild, Member, Permissions},
|
||||
id::{
|
||||
marker::{ChannelMarker, GuildMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
use twilight_util::permission_calculator::PermissionCalculator;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
|
||||
| Permissions::SEND_MESSAGES
|
||||
| Permissions::READ_MESSAGE_HISTORY
|
||||
| Permissions::ADD_REACTIONS
|
||||
| Permissions::ATTACH_FILES
|
||||
| Permissions::EMBED_LINKS
|
||||
| Permissions::USE_EXTERNAL_EMOJIS
|
||||
| Permissions::CONNECT
|
||||
| Permissions::SPEAK
|
||||
| Permissions::USE_VAD;
|
||||
}
|
||||
|
||||
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
|
||||
Channel {
|
||||
id,
|
||||
kind: ChannelType::Private,
|
||||
|
||||
application_id: None,
|
||||
applied_tags: None,
|
||||
available_tags: None,
|
||||
bitrate: None,
|
||||
default_auto_archive_duration: None,
|
||||
default_forum_layout: None,
|
||||
default_reaction_emoji: None,
|
||||
default_sort_order: None,
|
||||
default_thread_rate_limit_per_user: None,
|
||||
flags: None,
|
||||
guild_id: None,
|
||||
icon: None,
|
||||
invitable: None,
|
||||
last_message_id: None,
|
||||
last_pin_timestamp: None,
|
||||
managed: None,
|
||||
member: None,
|
||||
member_count: None,
|
||||
message_count: None,
|
||||
name: None,
|
||||
newly_created: None,
|
||||
nsfw: None,
|
||||
owner_id: None,
|
||||
parent_id: None,
|
||||
permission_overwrites: None,
|
||||
position: None,
|
||||
rate_limit_per_user: None,
|
||||
recipients: None,
|
||||
rtc_region: None,
|
||||
thread_metadata: None,
|
||||
topic: None,
|
||||
user_limit: None,
|
||||
video_quality_mode: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
|
||||
CachedMember {
|
||||
avatar: item.avatar,
|
||||
communication_disabled_until: item.communication_disabled_until,
|
||||
deaf: Some(item.deaf),
|
||||
flags: item.flags,
|
||||
joined_at: item.joined_at,
|
||||
mute: Some(item.mute),
|
||||
nick: item.nick,
|
||||
premium_since: item.premium_since,
|
||||
roles: item.roles,
|
||||
pending: false,
|
||||
user_id: id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> DiscordCache {
|
||||
let mut client_builder = twilight_http::Client::builder().token(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.clone(),
|
||||
);
|
||||
|
||||
if let Some(base_url) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.api_base_url
|
||||
.clone()
|
||||
{
|
||||
client_builder = client_builder.proxy(base_url, true);
|
||||
}
|
||||
|
||||
let client = Arc::new(client_builder.build());
|
||||
|
||||
let cache = Arc::new(
|
||||
InMemoryCache::builder()
|
||||
.resource_types(
|
||||
ResourceType::GUILD
|
||||
| ResourceType::CHANNEL
|
||||
| ResourceType::ROLE
|
||||
| ResourceType::USER_CURRENT
|
||||
| ResourceType::MEMBER_CURRENT,
|
||||
)
|
||||
.message_cache_size(0)
|
||||
.build(),
|
||||
);
|
||||
|
||||
DiscordCache(cache, client, RwLock::new(Vec::new()))
|
||||
}
|
||||
|
||||
pub struct DiscordCache(
|
||||
pub Arc<InMemoryCache>,
|
||||
pub Arc<twilight_http::Client>,
|
||||
pub RwLock<Vec<u32>>,
|
||||
);
|
||||
|
||||
impl DiscordCache {
|
||||
pub async fn guild_permissions(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
if self
|
||||
.0
|
||||
.guild(guild_id)
|
||||
.ok_or_else(|| format_err!("guild not found"))?
|
||||
.owner_id()
|
||||
== user_id
|
||||
{
|
||||
return Ok(Permissions::all());
|
||||
}
|
||||
|
||||
let member = if user_id
|
||||
== libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.client_id
|
||||
{
|
||||
self.0
|
||||
.member(guild_id, user_id)
|
||||
.ok_or(format_err!("self member not found"))?
|
||||
.value()
|
||||
.to_owned()
|
||||
} else {
|
||||
member_to_cached_member(
|
||||
self.1
|
||||
.guild_member(guild_id, user_id)
|
||||
.await?
|
||||
.model()
|
||||
.await?,
|
||||
user_id,
|
||||
)
|
||||
};
|
||||
|
||||
let MemberRoles { assigned, everyone } = self
|
||||
.0
|
||||
.permissions()
|
||||
.member_roles(guild_id, &member)
|
||||
.map_err(RootError::from_member_roles)?;
|
||||
let calculator =
|
||||
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
|
||||
|
||||
let permissions = calculator.root();
|
||||
|
||||
Ok(self
|
||||
.0
|
||||
.permissions()
|
||||
.disable_member_communication(&member, permissions))
|
||||
}
|
||||
|
||||
pub async fn channel_permissions(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
let channel = self
|
||||
.0
|
||||
.channel(channel_id)
|
||||
.ok_or(format_err!("channel not found"))?;
|
||||
|
||||
if channel.value().guild_id.is_none() {
|
||||
return Ok(*DM_PERMISSIONS);
|
||||
}
|
||||
|
||||
let guild_id = channel.value().guild_id.unwrap();
|
||||
|
||||
if self
|
||||
.0
|
||||
.guild(guild_id)
|
||||
.ok_or_else(|| {
|
||||
tracing::error!(
|
||||
channel_id = channel_id.get(),
|
||||
guild_id = guild_id.get(),
|
||||
"referenced guild from cached channel {channel_id} not found in cache"
|
||||
);
|
||||
format_err!("internal cache error")
|
||||
})?
|
||||
.owner_id()
|
||||
== user_id
|
||||
{
|
||||
return Ok(Permissions::all());
|
||||
}
|
||||
|
||||
let member = if user_id
|
||||
== libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.client_id
|
||||
{
|
||||
self.0
|
||||
.member(guild_id, user_id)
|
||||
.ok_or_else(|| {
|
||||
tracing::error!(
|
||||
guild_id = guild_id.get(),
|
||||
"self member for cached guild {guild_id} not found in cache"
|
||||
);
|
||||
format_err!("internal cache error")
|
||||
})?
|
||||
.value()
|
||||
.to_owned()
|
||||
} else {
|
||||
member_to_cached_member(
|
||||
self.1
|
||||
.guild_member(guild_id, user_id)
|
||||
.await?
|
||||
.model()
|
||||
.await?,
|
||||
user_id,
|
||||
)
|
||||
};
|
||||
|
||||
let MemberRoles { assigned, everyone } = self
|
||||
.0
|
||||
.permissions()
|
||||
.member_roles(guild_id, &member)
|
||||
.map_err(RootError::from_member_roles)?;
|
||||
|
||||
let overwrites = match channel.kind {
|
||||
ChannelType::AnnouncementThread
|
||||
| ChannelType::PrivateThread
|
||||
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
|
||||
_ => channel
|
||||
.value()
|
||||
.permission_overwrites()
|
||||
.unwrap_or_default()
|
||||
.to_vec(),
|
||||
};
|
||||
|
||||
let calculator =
|
||||
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
|
||||
|
||||
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
|
||||
|
||||
Ok(self
|
||||
.0
|
||||
.permissions()
|
||||
.disable_member_communication(&member, permissions))
|
||||
}
|
||||
|
||||
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
|
||||
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
|
||||
self.0.guild(id).map(|guild| {
|
||||
let channels = self
|
||||
.0
|
||||
.guild_channels(id)
|
||||
.map(|reference| {
|
||||
reference
|
||||
.iter()
|
||||
.filter_map(|channel_id| {
|
||||
let channel = self.0.channel(*channel_id)?;
|
||||
|
||||
if channel.kind.is_thread() {
|
||||
None
|
||||
} else {
|
||||
Some(channel.value().clone())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let roles = self
|
||||
.0
|
||||
.guild_roles(id)
|
||||
.map(|reference| {
|
||||
reference
|
||||
.iter()
|
||||
.filter_map(|role_id| {
|
||||
Some(self.0.role(*role_id)?.value().resource().clone())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Guild {
|
||||
afk_channel_id: guild.afk_channel_id(),
|
||||
afk_timeout: guild.afk_timeout(),
|
||||
application_id: guild.application_id(),
|
||||
approximate_member_count: None, // Only present in with_counts HTTP endpoint
|
||||
banner: guild.banner().map(ToOwned::to_owned),
|
||||
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
|
||||
channels,
|
||||
default_message_notifications: guild.default_message_notifications(),
|
||||
description: guild.description().map(ToString::to_string),
|
||||
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
|
||||
emojis: vec![],
|
||||
explicit_content_filter: guild.explicit_content_filter(),
|
||||
features: guild.features().cloned().collect(),
|
||||
icon: guild.icon().map(ToOwned::to_owned),
|
||||
id: guild.id(),
|
||||
joined_at: guild.joined_at(),
|
||||
large: guild.large(),
|
||||
max_members: guild.max_members(),
|
||||
max_presences: guild.max_presences(),
|
||||
max_video_channel_users: guild.max_video_channel_users(),
|
||||
member_count: guild.member_count(),
|
||||
members: vec![],
|
||||
mfa_level: guild.mfa_level(),
|
||||
name: guild.name().to_string(),
|
||||
nsfw_level: guild.nsfw_level(),
|
||||
owner_id: guild.owner_id(),
|
||||
owner: guild.owner(),
|
||||
permissions: guild.permissions(),
|
||||
public_updates_channel_id: guild.public_updates_channel_id(),
|
||||
preferred_locale: guild.preferred_locale().to_string(),
|
||||
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
|
||||
premium_subscription_count: guild.premium_subscription_count(),
|
||||
premium_tier: guild.premium_tier(),
|
||||
presences: vec![],
|
||||
roles,
|
||||
rules_channel_id: guild.rules_channel_id(),
|
||||
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
|
||||
splash: guild.splash().map(ToOwned::to_owned),
|
||||
stage_instances: vec![],
|
||||
stickers: vec![],
|
||||
system_channel_flags: guild.system_channel_flags(),
|
||||
system_channel_id: guild.system_channel_id(),
|
||||
threads: vec![],
|
||||
unavailable: false,
|
||||
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
|
||||
verification_level: guild.verification_level(),
|
||||
voice_states: vec![],
|
||||
widget_channel_id: guild.widget_channel_id(),
|
||||
widget_enabled: guild.widget_enabled(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
200
crates/gateway/src/discord/gateway.rs
Normal file
200
crates/gateway/src/discord/gateway.rs
Normal file
|
|
@ -0,0 +1,200 @@
|
|||
use futures::StreamExt;
|
||||
use libpk::_config::ClusterSettings;
|
||||
use metrics::counter;
|
||||
use std::sync::{mpsc::Sender, Arc};
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_gateway::{
|
||||
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
|
||||
};
|
||||
use twilight_model::gateway::{
|
||||
payload::outgoing::update_presence::UpdatePresencePayload,
|
||||
presence::{Activity, ActivityType, Status},
|
||||
Intents,
|
||||
};
|
||||
|
||||
use crate::discord::identify_queue::{self, RedisQueue};
|
||||
|
||||
use super::{cache::DiscordCache, shard_state::ShardStateManager};
|
||||
|
||||
pub fn cluster_config() -> ClusterSettings {
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.cluster
|
||||
.clone()
|
||||
.unwrap_or(libpk::_config::ClusterSettings {
|
||||
node_id: 0,
|
||||
total_shards: 1,
|
||||
total_nodes: 1,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
|
||||
let intents = Intents::GUILDS
|
||||
| Intents::DIRECT_MESSAGES
|
||||
| Intents::DIRECT_MESSAGE_REACTIONS
|
||||
| Intents::GUILD_MESSAGES
|
||||
| Intents::GUILD_MESSAGE_REACTIONS
|
||||
| Intents::MESSAGE_CONTENT;
|
||||
|
||||
let queue = identify_queue::new(redis);
|
||||
|
||||
let cluster_settings = cluster_config();
|
||||
|
||||
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
|
||||
warn!("we have less than 16 shards, assuming single gateway process");
|
||||
(0, (cluster_settings.total_shards - 1).into())
|
||||
} else {
|
||||
(
|
||||
(cluster_settings.node_id * 16).into(),
|
||||
(((cluster_settings.node_id + 1) * 16) - 1).into(),
|
||||
)
|
||||
};
|
||||
|
||||
let shards = create_iterator(
|
||||
start_shard..end_shard + 1,
|
||||
cluster_settings.total_shards,
|
||||
ConfigBuilder::new(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.to_owned(),
|
||||
intents,
|
||||
)
|
||||
.presence(presence("pk;help", false))
|
||||
.queue(queue.clone())
|
||||
.build(),
|
||||
|_, builder| builder.build(),
|
||||
);
|
||||
|
||||
let mut shards_vec = Vec::new();
|
||||
shards_vec.extend(shards);
|
||||
|
||||
Ok(shards_vec)
|
||||
}
|
||||
|
||||
pub async fn runner(
|
||||
mut shard: Shard<RedisQueue>,
|
||||
_tx: Sender<(ShardId, String)>,
|
||||
shard_state: ShardStateManager,
|
||||
cache: Arc<DiscordCache>,
|
||||
) {
|
||||
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
|
||||
info!("waiting for events");
|
||||
while let Some(item) = shard.next().await {
|
||||
let raw_event = match item {
|
||||
Ok(evt) => match evt {
|
||||
Message::Close(frame) => {
|
||||
info!(
|
||||
"shard {} closed: {}",
|
||||
shard.id().number(),
|
||||
if let Some(close) = frame {
|
||||
format!("{} ({})", close.code, close.reason)
|
||||
} else {
|
||||
"unknown".to_string()
|
||||
}
|
||||
);
|
||||
if let Err(error) = shard_state.socket_closed(shard.id().number()).await {
|
||||
error!("failed to update shard state for socket closure: {error}");
|
||||
}
|
||||
continue;
|
||||
}
|
||||
Message::Text(text) => text,
|
||||
},
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
let event = match twilight_gateway::parse(raw_event.clone(), EventTypeFlags::all()) {
|
||||
Ok(Some(parsed)) => Event::from(parsed),
|
||||
Ok(None) => {
|
||||
// we received an event type unknown to twilight
|
||||
// that's fine, we probably don't need it anyway
|
||||
continue;
|
||||
}
|
||||
Err(error) => {
|
||||
error!(
|
||||
"shard {} failed to parse gateway event: {}",
|
||||
shard.id().number(),
|
||||
error
|
||||
);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
||||
// log the event in metrics
|
||||
// event_type * shard_id is too many labels and prometheus fails to query it
|
||||
// so we split it into two metrics
|
||||
counter!(
|
||||
"pluralkit_gateway_events_type",
|
||||
"event_type" => serde_variant::to_variant_name(&event.kind()).unwrap(),
|
||||
)
|
||||
.increment(1);
|
||||
counter!(
|
||||
"pluralkit_gateway_events_shard",
|
||||
"shard_id" => shard.id().number().to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
// update shard state and discord cache
|
||||
if let Err(error) = shard_state
|
||||
.handle_event(shard.id().number(), event.clone())
|
||||
.await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state");
|
||||
}
|
||||
// need to do heartbeat separately, to get the latency
|
||||
if let Event::GatewayHeartbeatAck = event
|
||||
&& let Err(error) = shard_state
|
||||
.heartbeated(shard.id().number(), shard.latency())
|
||||
.await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state for latency");
|
||||
}
|
||||
|
||||
if let Event::Ready(_) = event {
|
||||
if !cache.2.read().await.contains(&shard.id().number()) {
|
||||
cache.2.write().await.push(shard.id().number());
|
||||
}
|
||||
}
|
||||
cache.0.update(&event);
|
||||
|
||||
// okay, we've handled the event internally, let's send it to consumers
|
||||
// tx.send((shard.id(), raw_event)).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
|
||||
UpdatePresencePayload {
|
||||
activities: vec![Activity {
|
||||
application_id: None,
|
||||
assets: None,
|
||||
buttons: vec![],
|
||||
created_at: None,
|
||||
details: None,
|
||||
id: None,
|
||||
state: None,
|
||||
url: None,
|
||||
emoji: None,
|
||||
flags: None,
|
||||
instance: None,
|
||||
kind: ActivityType::Playing,
|
||||
name: status.to_string(),
|
||||
party: None,
|
||||
secrets: None,
|
||||
timestamps: None,
|
||||
}],
|
||||
afk: false,
|
||||
since: None,
|
||||
status: if going_away {
|
||||
Status::Idle
|
||||
} else {
|
||||
Status::Online
|
||||
},
|
||||
}
|
||||
}
|
||||
88
crates/gateway/src/discord/identify_queue.rs
Normal file
88
crates/gateway/src/discord/identify_queue.rs
Normal file
|
|
@ -0,0 +1,88 @@
|
|||
use fred::{
|
||||
clients::RedisPool,
|
||||
error::RedisError,
|
||||
interfaces::KeysInterface,
|
||||
types::{Expiration, SetOptions},
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::queue::Queue;
|
||||
|
||||
pub fn new(redis: RedisPool) -> RedisQueue {
|
||||
RedisQueue {
|
||||
redis,
|
||||
concurrency: libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.max_concurrency,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisQueue {
|
||||
pub redis: RedisPool,
|
||||
pub concurrency: u32,
|
||||
}
|
||||
|
||||
impl Debug for RedisQueue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RedisQueue")
|
||||
.field("concurrency", &self.concurrency)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Queue for RedisQueue {
|
||||
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(request_inner(
|
||||
self.redis.clone(),
|
||||
self.concurrency,
|
||||
shard_id,
|
||||
tx,
|
||||
));
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
const EXPIRY: i64 = 6;
|
||||
const RETRY_INTERVAL: u64 = 500;
|
||||
|
||||
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
|
||||
let bucket = shard_id % concurrency;
|
||||
let key = format!("pluralkit:identify:{}", bucket);
|
||||
|
||||
info!(shard_id, bucket, "waiting for allowance...");
|
||||
loop {
|
||||
let done: Result<Option<String>, RedisError> = redis
|
||||
.set(
|
||||
key.to_string(),
|
||||
"1",
|
||||
Some(Expiration::EX(EXPIRY)),
|
||||
Some(SetOptions::NX),
|
||||
false,
|
||||
)
|
||||
.await;
|
||||
match done {
|
||||
Ok(Some(_)) => {
|
||||
info!(shard_id, bucket, "got allowance!");
|
||||
// if this fails, it's probably already doing something else
|
||||
let _ = tx.send(());
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
|
||||
}
|
||||
}
|
||||
4
crates/gateway/src/discord/mod.rs
Normal file
4
crates/gateway/src/discord/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod cache;
|
||||
pub mod gateway;
|
||||
pub mod identify_queue;
|
||||
pub mod shard_state;
|
||||
91
crates/gateway/src/discord/shard_state.rs
Normal file
91
crates/gateway/src/discord/shard_state.rs
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
use fred::{clients::RedisPool, interfaces::HashesInterface};
|
||||
use metrics::{counter, gauge};
|
||||
use tracing::info;
|
||||
use twilight_gateway::{Event, Latency};
|
||||
|
||||
use libpk::{state::*, util::redis::*};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShardStateManager {
|
||||
redis: RedisPool,
|
||||
}
|
||||
|
||||
pub fn new(redis: RedisPool) -> ShardStateManager {
|
||||
ShardStateManager { redis }
|
||||
}
|
||||
|
||||
impl ShardStateManager {
|
||||
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
|
||||
match event {
|
||||
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
|
||||
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
|
||||
let data: Option<String> = self
|
||||
.redis
|
||||
.hget("pluralkit:shardstatus", shard_id)
|
||||
.await
|
||||
.to_option_or_error()?;
|
||||
match data {
|
||||
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
|
||||
None => Ok(ShardState::default()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset::<(), &str, (String, String)>(
|
||||
"pluralkit:shardstatus",
|
||||
(
|
||||
shard_id.to_string(),
|
||||
serde_json::to_string(&info).expect("could not serialize shard"),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"shard {} {}",
|
||||
shard_id,
|
||||
if resumed { "resumed" } else { "ready" }
|
||||
);
|
||||
counter!(
|
||||
"pluralkit_gateway_shard_reconnect",
|
||||
"shard_id" => shard_id.to_string(),
|
||||
"resumed" => resumed.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
gauge!("pluralkit_gateway_shard_up").increment(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.up = true;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
gauge!("pluralkit_gateway_shard_up").decrement(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.up = false;
|
||||
info.disconnection_count += 1;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.up = true;
|
||||
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.latency = latency
|
||||
.recent()
|
||||
.first()
|
||||
.map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
72
crates/gateway/src/logger.rs
Normal file
72
crates/gateway/src/logger.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{
|
||||
extract::MatchedPath, extract::Request, http::StatusCode, middleware::Next, response::Response,
|
||||
};
|
||||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_id_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).instrument(request_id_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
counter!(
|
||||
"pluralkit_gateway_cache_api_requests",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
"pluralkit_gateway_cache_api_requests_bucket",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
if response.status() != StatusCode::FOUND {
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
uri.path(),
|
||||
elapsed
|
||||
);
|
||||
}
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
140
crates/gateway/src/main.rs
Normal file
140
crates/gateway/src/main.rs
Normal file
|
|
@ -0,0 +1,140 @@
|
|||
#![feature(let_chains)]
|
||||
#![feature(if_let_guard)]
|
||||
|
||||
use chrono::Timelike;
|
||||
use fred::{clients::RedisPool, interfaces::*};
|
||||
use signal_hook::{
|
||||
consts::{SIGINT, SIGTERM},
|
||||
iterator::Signals,
|
||||
};
|
||||
use std::{
|
||||
sync::{mpsc::channel, Arc},
|
||||
time::Duration,
|
||||
vec::Vec,
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{info, warn};
|
||||
use twilight_gateway::{MessageSender, ShardId};
|
||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||
|
||||
mod cache_api;
|
||||
mod discord;
|
||||
mod logger;
|
||||
|
||||
libpk::main!("gateway");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let (shutdown_tx, shutdown_rx) = channel::<()>();
|
||||
let shutdown_tx = Arc::new(shutdown_tx);
|
||||
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let shard_state = discord::shard_state::new(redis.clone());
|
||||
let cache = Arc::new(discord::cache::new());
|
||||
|
||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||
|
||||
let (event_tx, _event_rx) = channel();
|
||||
|
||||
let mut senders = Vec::new();
|
||||
let mut signal_senders = Vec::new();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for shard in shards {
|
||||
senders.push((shard.id(), shard.sender()));
|
||||
signal_senders.push(shard.sender());
|
||||
set.spawn(tokio::spawn(discord::gateway::runner(
|
||||
shard,
|
||||
event_tx.clone(),
|
||||
shard_state.clone(),
|
||||
cache.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
set.spawn(tokio::spawn(
|
||||
async move { scheduled_task(redis, senders).await },
|
||||
));
|
||||
|
||||
// todo: probably don't do it this way
|
||||
let api_shutdown_tx = shutdown_tx.clone();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
match cache_api::run_server(cache).await {
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to serve cache api");
|
||||
let _ = api_shutdown_tx.send(());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}));
|
||||
|
||||
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
|
||||
|
||||
set.spawn(tokio::spawn(async move {
|
||||
for sig in signals.forever() {
|
||||
info!("received signal {:?}", sig);
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
}));
|
||||
|
||||
let _ = shutdown_rx.recv();
|
||||
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
set.abort_all();
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
(60 - chrono::offset::Utc::now().second()).into(),
|
||||
))
|
||||
.await;
|
||||
info!("running per-minute scheduled tasks");
|
||||
|
||||
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
|
||||
Ok(val) => Some(val),
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "failed to fetch bot status from redis");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence(
|
||||
if let Some(status) = status {
|
||||
format!("pk;help | {}", status)
|
||||
} else {
|
||||
"pk;help".to_string()
|
||||
}
|
||||
.as_str(),
|
||||
false,
|
||||
),
|
||||
};
|
||||
|
||||
for sender in senders.iter() {
|
||||
match sender.1.command(&presence) {
|
||||
Err(error) => {
|
||||
warn!(?error, "could not update presence on shard {}", sender.0)
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue