mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 13:06:50 +00:00
feat: gateway service
This commit is contained in:
parent
1118d8bdf8
commit
e4ed354536
50 changed files with 1737 additions and 545 deletions
|
|
@ -147,6 +147,7 @@ async fn main() -> anyhow::Result<()> {
|
|||
|
||||
let addr: &str = libpk::config.api.addr.as_ref();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
25
services/gateway/Cargo.toml
Normal file
25
services/gateway/Cargo.toml
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
[package]
|
||||
name = "gateway"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
libpk = { path = "../../lib/libpk" }
|
||||
prost = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
signal-hook = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
twilight-gateway = { workspace = true }
|
||||
twilight-cache-inmemory = { workspace = true }
|
||||
twilight-util = { workspace = true }
|
||||
twilight-model = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
168
services/gateway/src/cache_api.rs
Normal file
168
services/gateway/src/cache_api.rs
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
Router,
|
||||
};
|
||||
use serde_json::to_string;
|
||||
use tracing::{error, info};
|
||||
use twilight_model::guild::Permissions;
|
||||
use twilight_model::id::Id;
|
||||
|
||||
use crate::discord::cache::{dm_channel, DiscordCache, DM_PERMISSIONS};
|
||||
use std::sync::Arc;
|
||||
|
||||
fn status_code(code: StatusCode, body: String) -> Response {
|
||||
(code, body).into_response()
|
||||
}
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/guilds/:guild_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild(Id::new(guild_id)) {
|
||||
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/members/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.0.member(Id::new(guild_id), libpk::config.discord.client_id) {
|
||||
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.client_id).await {
|
||||
Ok(val) => {
|
||||
println!("hh {}", Permissions::all().bits());
|
||||
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
|
||||
},
|
||||
Err(err) => {
|
||||
error!(?err, ?guild_id, "failed to get own guild member permissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/:user_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
|
||||
Err(err) => {
|
||||
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/channels",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
|
||||
Some(channels) => channels.to_owned(),
|
||||
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
};
|
||||
|
||||
let mut channels = Vec::new();
|
||||
for id in channel_ids {
|
||||
match cache.0.channel(id) {
|
||||
Some(channel) => channels.push(channel.to_owned()),
|
||||
None => {
|
||||
tracing::error!(
|
||||
channel_id = id.get(),
|
||||
"referenced channel {} from guild {} not found in cache",
|
||||
id.get(), guild_id,
|
||||
);
|
||||
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
|
||||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
|
||||
}
|
||||
match cache.0.channel(Id::new(channel_id)) {
|
||||
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
|
||||
None => status_code(StatusCode::NOT_FOUND, "".to_string())
|
||||
}
|
||||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
|
||||
}
|
||||
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.client_id).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
|
||||
Err(err) => {
|
||||
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
|
||||
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
|
||||
},
|
||||
}
|
||||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
|
||||
get(|| async { "todo" }),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/last_message",
|
||||
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/roles",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
|
||||
Some(roles) => roles.to_owned(),
|
||||
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
|
||||
};
|
||||
|
||||
let mut roles = Vec::new();
|
||||
for id in role_ids {
|
||||
match cache.0.role(id) {
|
||||
Some(role) => roles.push(role.value().resource().to_owned()),
|
||||
None => {
|
||||
tracing::error!(
|
||||
role_id = id.get(),
|
||||
"referenced role {} from guild {} not found in cache",
|
||||
id.get(), guild_id,
|
||||
);
|
||||
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
|
||||
})
|
||||
)
|
||||
|
||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||
.with_state(cache);
|
||||
|
||||
let addr: &str = libpk::config.api.addr.as_ref();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
339
services/gateway/src/discord/cache.rs
Normal file
339
services/gateway/src/discord/cache.rs
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
use anyhow::format_err;
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use twilight_cache_inmemory::{
|
||||
model::CachedMember,
|
||||
permission::{MemberRoles, RootError},
|
||||
traits::CacheableChannel,
|
||||
InMemoryCache, ResourceType,
|
||||
};
|
||||
use twilight_model::{
|
||||
channel::{Channel, ChannelType},
|
||||
guild::{Guild, Member, Permissions},
|
||||
id::{
|
||||
marker::{ChannelMarker, GuildMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
use twilight_util::permission_calculator::PermissionCalculator;
|
||||
|
||||
lazy_static! {
|
||||
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
|
||||
| Permissions::SEND_MESSAGES
|
||||
| Permissions::READ_MESSAGE_HISTORY
|
||||
| Permissions::ADD_REACTIONS
|
||||
| Permissions::ATTACH_FILES
|
||||
| Permissions::EMBED_LINKS
|
||||
| Permissions::USE_EXTERNAL_EMOJIS
|
||||
| Permissions::CONNECT
|
||||
| Permissions::SPEAK
|
||||
| Permissions::USE_VAD;
|
||||
}
|
||||
|
||||
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
|
||||
Channel {
|
||||
id,
|
||||
kind: ChannelType::Private,
|
||||
|
||||
application_id: None,
|
||||
applied_tags: None,
|
||||
available_tags: None,
|
||||
bitrate: None,
|
||||
default_auto_archive_duration: None,
|
||||
default_forum_layout: None,
|
||||
default_reaction_emoji: None,
|
||||
default_sort_order: None,
|
||||
default_thread_rate_limit_per_user: None,
|
||||
flags: None,
|
||||
guild_id: None,
|
||||
icon: None,
|
||||
invitable: None,
|
||||
last_message_id: None,
|
||||
last_pin_timestamp: None,
|
||||
managed: None,
|
||||
member: None,
|
||||
member_count: None,
|
||||
message_count: None,
|
||||
name: None,
|
||||
newly_created: None,
|
||||
nsfw: None,
|
||||
owner_id: None,
|
||||
parent_id: None,
|
||||
permission_overwrites: None,
|
||||
position: None,
|
||||
rate_limit_per_user: None,
|
||||
recipients: None,
|
||||
rtc_region: None,
|
||||
thread_metadata: None,
|
||||
topic: None,
|
||||
user_limit: None,
|
||||
video_quality_mode: None,
|
||||
}
|
||||
}
|
||||
|
||||
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
|
||||
CachedMember {
|
||||
avatar: item.avatar,
|
||||
communication_disabled_until: item.communication_disabled_until,
|
||||
deaf: Some(item.deaf),
|
||||
flags: item.flags,
|
||||
joined_at: item.joined_at,
|
||||
mute: Some(item.mute),
|
||||
nick: item.nick,
|
||||
premium_since: item.premium_since,
|
||||
roles: item.roles,
|
||||
pending: false,
|
||||
user_id: id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new() -> DiscordCache {
|
||||
let mut client_builder =
|
||||
twilight_http::Client::builder().token(libpk::config.discord.bot_token.clone());
|
||||
|
||||
if let Some(base_url) = libpk::config.discord.api_base_url.clone() {
|
||||
client_builder = client_builder.proxy(base_url, true);
|
||||
}
|
||||
|
||||
let client = Arc::new(client_builder.build());
|
||||
|
||||
let cache = Arc::new(
|
||||
InMemoryCache::builder()
|
||||
.resource_types(
|
||||
ResourceType::GUILD
|
||||
| ResourceType::CHANNEL
|
||||
| ResourceType::ROLE
|
||||
| ResourceType::USER_CURRENT
|
||||
| ResourceType::MEMBER_CURRENT,
|
||||
)
|
||||
.message_cache_size(0)
|
||||
.build(),
|
||||
);
|
||||
|
||||
DiscordCache(cache, client)
|
||||
}
|
||||
|
||||
pub struct DiscordCache(pub Arc<InMemoryCache>, pub Arc<twilight_http::Client>);
|
||||
|
||||
impl DiscordCache {
|
||||
pub async fn guild_permissions(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
if self
|
||||
.0
|
||||
.guild(guild_id)
|
||||
.ok_or_else(|| format_err!("guild not found"))?
|
||||
.owner_id()
|
||||
== user_id
|
||||
{
|
||||
return Ok(Permissions::all());
|
||||
}
|
||||
|
||||
let member = if user_id == libpk::config.discord.client_id {
|
||||
self.0
|
||||
.member(guild_id, user_id)
|
||||
.ok_or(format_err!("self member not found"))?
|
||||
.value()
|
||||
.to_owned()
|
||||
} else {
|
||||
member_to_cached_member(
|
||||
self.1
|
||||
.guild_member(guild_id, user_id)
|
||||
.await?
|
||||
.model()
|
||||
.await?,
|
||||
user_id,
|
||||
)
|
||||
};
|
||||
|
||||
let MemberRoles { assigned, everyone } = self
|
||||
.0
|
||||
.permissions()
|
||||
.member_roles(guild_id, &member)
|
||||
.map_err(RootError::from_member_roles)?;
|
||||
let calculator =
|
||||
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
|
||||
|
||||
let permissions = calculator.root();
|
||||
|
||||
Ok(self
|
||||
.0
|
||||
.permissions()
|
||||
.disable_member_communication(&member, permissions))
|
||||
}
|
||||
|
||||
pub async fn channel_permissions(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
let channel = self
|
||||
.0
|
||||
.channel(channel_id)
|
||||
.ok_or(format_err!("channel not found"))?;
|
||||
|
||||
if channel.value().guild_id.is_none() {
|
||||
return Ok(*DM_PERMISSIONS);
|
||||
}
|
||||
|
||||
let guild_id = channel.value().guild_id.unwrap();
|
||||
|
||||
if self
|
||||
.0
|
||||
.guild(guild_id)
|
||||
.ok_or_else(|| {
|
||||
tracing::error!(
|
||||
channel_id = channel_id.get(),
|
||||
guild_id = guild_id.get(),
|
||||
"referenced guild from cached channel {channel_id} not found in cache"
|
||||
);
|
||||
format_err!("internal cache error")
|
||||
})?
|
||||
.owner_id()
|
||||
== user_id
|
||||
{
|
||||
return Ok(Permissions::all());
|
||||
}
|
||||
|
||||
let member = if user_id == libpk::config.discord.client_id {
|
||||
self.0
|
||||
.member(guild_id, user_id)
|
||||
.ok_or_else(|| {
|
||||
tracing::error!(
|
||||
guild_id = guild_id.get(),
|
||||
"self member for cached guild {guild_id} not found in cache"
|
||||
);
|
||||
format_err!("internal cache error")
|
||||
})?
|
||||
.value()
|
||||
.to_owned()
|
||||
} else {
|
||||
member_to_cached_member(
|
||||
self.1
|
||||
.guild_member(guild_id, user_id)
|
||||
.await?
|
||||
.model()
|
||||
.await?,
|
||||
user_id,
|
||||
)
|
||||
};
|
||||
|
||||
let MemberRoles { assigned, everyone } = self
|
||||
.0
|
||||
.permissions()
|
||||
.member_roles(guild_id, &member)
|
||||
.map_err(RootError::from_member_roles)?;
|
||||
|
||||
let overwrites = match channel.kind {
|
||||
ChannelType::AnnouncementThread
|
||||
| ChannelType::PrivateThread
|
||||
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
|
||||
_ => channel
|
||||
.value()
|
||||
.permission_overwrites()
|
||||
.unwrap_or_default()
|
||||
.to_vec(),
|
||||
};
|
||||
|
||||
let calculator =
|
||||
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
|
||||
|
||||
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
|
||||
|
||||
Ok(self
|
||||
.0
|
||||
.permissions()
|
||||
.disable_member_communication(&member, permissions))
|
||||
}
|
||||
|
||||
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
|
||||
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
|
||||
self.0.guild(id).map(|guild| {
|
||||
let channels = self
|
||||
.0
|
||||
.guild_channels(id)
|
||||
.map(|reference| {
|
||||
reference
|
||||
.iter()
|
||||
.filter_map(|channel_id| {
|
||||
let channel = self.0.channel(*channel_id)?;
|
||||
|
||||
if channel.kind.is_thread() {
|
||||
None
|
||||
} else {
|
||||
Some(channel.value().clone())
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
let roles = self
|
||||
.0
|
||||
.guild_roles(id)
|
||||
.map(|reference| {
|
||||
reference
|
||||
.iter()
|
||||
.filter_map(|role_id| {
|
||||
Some(self.0.role(*role_id)?.value().resource().clone())
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
.unwrap_or_default();
|
||||
|
||||
Guild {
|
||||
afk_channel_id: guild.afk_channel_id(),
|
||||
afk_timeout: guild.afk_timeout(),
|
||||
application_id: guild.application_id(),
|
||||
approximate_member_count: None, // Only present in with_counts HTTP endpoint
|
||||
banner: guild.banner().map(ToOwned::to_owned),
|
||||
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
|
||||
channels,
|
||||
default_message_notifications: guild.default_message_notifications(),
|
||||
description: guild.description().map(ToString::to_string),
|
||||
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
|
||||
emojis: vec![],
|
||||
explicit_content_filter: guild.explicit_content_filter(),
|
||||
features: guild.features().cloned().collect(),
|
||||
icon: guild.icon().map(ToOwned::to_owned),
|
||||
id: guild.id(),
|
||||
joined_at: guild.joined_at(),
|
||||
large: guild.large(),
|
||||
max_members: guild.max_members(),
|
||||
max_presences: guild.max_presences(),
|
||||
max_video_channel_users: guild.max_video_channel_users(),
|
||||
member_count: guild.member_count(),
|
||||
members: vec![],
|
||||
mfa_level: guild.mfa_level(),
|
||||
name: guild.name().to_string(),
|
||||
nsfw_level: guild.nsfw_level(),
|
||||
owner_id: guild.owner_id(),
|
||||
owner: guild.owner(),
|
||||
permissions: guild.permissions(),
|
||||
public_updates_channel_id: guild.public_updates_channel_id(),
|
||||
preferred_locale: guild.preferred_locale().to_string(),
|
||||
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
|
||||
premium_subscription_count: guild.premium_subscription_count(),
|
||||
premium_tier: guild.premium_tier(),
|
||||
presences: vec![],
|
||||
roles,
|
||||
rules_channel_id: guild.rules_channel_id(),
|
||||
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
|
||||
splash: guild.splash().map(ToOwned::to_owned),
|
||||
stage_instances: vec![],
|
||||
stickers: vec![],
|
||||
system_channel_flags: guild.system_channel_flags(),
|
||||
system_channel_id: guild.system_channel_id(),
|
||||
threads: vec![],
|
||||
unavailable: false,
|
||||
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
|
||||
verification_level: guild.verification_level(),
|
||||
voice_states: vec![],
|
||||
widget_channel_id: guild.widget_channel_id(),
|
||||
widget_enabled: guild.widget_enabled(),
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
121
services/gateway/src/discord/gateway.rs
Normal file
121
services/gateway/src/discord/gateway.rs
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
use std::sync::{mpsc::Sender, Arc};
|
||||
use tracing::{info, warn};
|
||||
use twilight_gateway::{
|
||||
create_iterator, ConfigBuilder, Event, EventTypeFlags, Shard, ShardId, StreamExt,
|
||||
};
|
||||
use twilight_model::gateway::{
|
||||
payload::outgoing::update_presence::UpdatePresencePayload,
|
||||
presence::{Activity, ActivityType, Status},
|
||||
Intents,
|
||||
};
|
||||
|
||||
use crate::discord::identify_queue::{self, RedisQueue};
|
||||
|
||||
use super::{cache::DiscordCache, shard_state::ShardStateManager};
|
||||
|
||||
pub fn create_shards(redis: fred::pool::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
|
||||
let intents = Intents::GUILDS
|
||||
| Intents::DIRECT_MESSAGES
|
||||
| Intents::DIRECT_MESSAGE_REACTIONS
|
||||
| Intents::GUILD_MESSAGES
|
||||
| Intents::GUILD_MESSAGE_REACTIONS
|
||||
| Intents::MESSAGE_CONTENT;
|
||||
|
||||
let queue = identify_queue::new(redis);
|
||||
|
||||
let cluster_settings =
|
||||
libpk::config
|
||||
.discord
|
||||
.cluster
|
||||
.clone()
|
||||
.unwrap_or(libpk::_config::ClusterSettings {
|
||||
node_id: 0,
|
||||
total_shards: 1,
|
||||
total_nodes: 1,
|
||||
});
|
||||
|
||||
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
|
||||
warn!("we have less than 16 shards, assuming single gateway process");
|
||||
(0, (cluster_settings.total_shards - 1).into())
|
||||
} else {
|
||||
(
|
||||
(cluster_settings.node_id * 16).into(),
|
||||
(((cluster_settings.node_id + 1) * 16) - 1).into(),
|
||||
)
|
||||
};
|
||||
|
||||
let shards = create_iterator(
|
||||
start_shard..end_shard + 1,
|
||||
cluster_settings.total_shards,
|
||||
ConfigBuilder::new(libpk::config.discord.bot_token.to_owned(), intents)
|
||||
.presence(presence("pk;help", false))
|
||||
.queue(queue.clone())
|
||||
.build(),
|
||||
|_, builder| builder.build(),
|
||||
);
|
||||
|
||||
let mut shards_vec = Vec::new();
|
||||
shards_vec.extend(shards);
|
||||
|
||||
Ok(shards_vec)
|
||||
}
|
||||
|
||||
pub async fn runner(
|
||||
mut shard: Shard<RedisQueue>,
|
||||
tx: Sender<(ShardId, Event)>,
|
||||
shard_state: ShardStateManager,
|
||||
cache: Arc<DiscordCache>,
|
||||
) {
|
||||
//let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
|
||||
info!("waiting for events");
|
||||
while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
|
||||
match item {
|
||||
Ok(event) => {
|
||||
if let Err(error) = shard_state
|
||||
.handle_event(shard.id().number(), event.clone())
|
||||
.await
|
||||
{
|
||||
tracing::warn!(?error, "error updating redis state")
|
||||
}
|
||||
cache.0.update(&event);
|
||||
//if let Err(error) = tx.send((shard.id(), event)) {
|
||||
// tracing::warn!(?error, "error sending event to global handler: {error}",);
|
||||
//}
|
||||
}
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
|
||||
UpdatePresencePayload {
|
||||
activities: vec![Activity {
|
||||
application_id: None,
|
||||
assets: None,
|
||||
buttons: vec![],
|
||||
created_at: None,
|
||||
details: None,
|
||||
id: None,
|
||||
state: None,
|
||||
url: None,
|
||||
emoji: None,
|
||||
flags: None,
|
||||
instance: None,
|
||||
kind: ActivityType::Playing,
|
||||
name: status.to_string(),
|
||||
party: None,
|
||||
secrets: None,
|
||||
timestamps: None,
|
||||
}],
|
||||
afk: false,
|
||||
since: None,
|
||||
status: if going_away {
|
||||
Status::Idle
|
||||
} else {
|
||||
Status::Online
|
||||
},
|
||||
}
|
||||
}
|
||||
87
services/gateway/src/discord/identify_queue.rs
Normal file
87
services/gateway/src/discord/identify_queue.rs
Normal file
|
|
@ -0,0 +1,87 @@
|
|||
use fred::{
|
||||
error::RedisError,
|
||||
interfaces::KeysInterface,
|
||||
pool::RedisPool,
|
||||
types::{Expiration, SetOptions},
|
||||
};
|
||||
use std::fmt::Debug;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::oneshot;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::queue::Queue;
|
||||
|
||||
use libpk::util::redis::RedisErrorExt;
|
||||
|
||||
pub fn new(redis: RedisPool) -> RedisQueue {
|
||||
RedisQueue {
|
||||
redis,
|
||||
concurrency: libpk::config.discord.max_concurrency,
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisQueue {
|
||||
pub redis: RedisPool,
|
||||
pub concurrency: u32,
|
||||
}
|
||||
|
||||
impl Debug for RedisQueue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RedisQueue")
|
||||
.field("concurrency", &self.concurrency)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Queue for RedisQueue {
|
||||
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
|
||||
let (tx, rx) = oneshot::channel();
|
||||
|
||||
tokio::spawn(request_inner(
|
||||
self.redis.clone(),
|
||||
self.concurrency,
|
||||
shard_id,
|
||||
tx,
|
||||
));
|
||||
|
||||
rx
|
||||
}
|
||||
}
|
||||
|
||||
const EXPIRY: i64 = 6;
|
||||
const RETRY_INTERVAL: u64 = 500;
|
||||
|
||||
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
|
||||
let bucket = shard_id % concurrency;
|
||||
let key = format!("pluralkit:identify:{}", bucket);
|
||||
|
||||
info!(shard_id, bucket, "waiting for allowance...");
|
||||
loop {
|
||||
let done: Result<Option<String>, RedisError> = redis
|
||||
.set(
|
||||
key.to_string(),
|
||||
"1",
|
||||
Some(Expiration::EX(EXPIRY)),
|
||||
Some(SetOptions::NX),
|
||||
false,
|
||||
)
|
||||
.await
|
||||
.to_option_or_error();
|
||||
match done {
|
||||
Ok(Some(_)) => {
|
||||
info!(shard_id, bucket, "got allowance!");
|
||||
// if this fails, it's probably already doing something else
|
||||
let _ = tx.send(());
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
|
||||
}
|
||||
}
|
||||
4
services/gateway/src/discord/mod.rs
Normal file
4
services/gateway/src/discord/mod.rs
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
pub mod cache;
|
||||
pub mod gateway;
|
||||
pub mod identify_queue;
|
||||
pub mod shard_state;
|
||||
84
services/gateway/src/discord/shard_state.rs
Normal file
84
services/gateway/src/discord/shard_state.rs
Normal file
|
|
@ -0,0 +1,84 @@
|
|||
use bytes::Bytes;
|
||||
use fred::{interfaces::HashesInterface, pool::RedisPool};
|
||||
use prost::Message;
|
||||
use tracing::info;
|
||||
use twilight_gateway::Event;
|
||||
|
||||
use libpk::{proto::*, util::redis::*};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ShardStateManager {
|
||||
redis: RedisPool,
|
||||
}
|
||||
|
||||
pub fn new(redis: RedisPool) -> ShardStateManager {
|
||||
ShardStateManager { redis }
|
||||
}
|
||||
|
||||
impl ShardStateManager {
|
||||
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
|
||||
match event {
|
||||
Event::Ready(_) => self.ready_or_resumed(shard_id).await,
|
||||
Event::Resumed => self.ready_or_resumed(shard_id).await,
|
||||
Event::GatewayClose(_) => self.socket_closed(shard_id).await,
|
||||
Event::GatewayHeartbeat(_) => self.heartbeated(shard_id).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
|
||||
let data: Option<Vec<u8>> = self
|
||||
.redis
|
||||
.hget("pluralkit:shardstatus", shard_id)
|
||||
.await
|
||||
.to_option_or_error()?;
|
||||
match data {
|
||||
Some(buf) => {
|
||||
Ok(ShardState::decode(buf.as_slice()).expect("could not decode shard data!"))
|
||||
}
|
||||
None => Ok(ShardState::default()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset(
|
||||
"pluralkit:shardstatus",
|
||||
(
|
||||
shard_id.to_string(),
|
||||
Bytes::copy_from_slice(&info.encode_to_vec()),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn ready_or_resumed(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
info!("shard {} ready", shard_id);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.up = true;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
info!("shard {} closed", shard_id);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.up = false;
|
||||
info.disconnection_count += 1;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn heartbeated(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
info.up = true;
|
||||
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
|
||||
// todo
|
||||
// info.latency = latency.recent().front().map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
info.latency = 1;
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
52
services/gateway/src/logger.rs
Normal file
52
services/gateway/src/logger.rs
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_id_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).instrument(request_id_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
uri.path(),
|
||||
elapsed
|
||||
);
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
141
services/gateway/src/main.rs
Normal file
141
services/gateway/src/main.rs
Normal file
|
|
@ -0,0 +1,141 @@
|
|||
use chrono::Timelike;
|
||||
use fred::{interfaces::*, pool::RedisPool};
|
||||
use signal_hook::{
|
||||
consts::{SIGINT, SIGTERM},
|
||||
iterator::Signals,
|
||||
};
|
||||
use std::{
|
||||
sync::{mpsc::channel, Arc},
|
||||
time::Duration,
|
||||
vec::Vec,
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{info, warn};
|
||||
use twilight_gateway::{MessageSender, ShardId};
|
||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||
|
||||
mod cache_api;
|
||||
mod discord;
|
||||
mod logger;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
libpk::init_logging("gateway")?;
|
||||
libpk::init_metrics()?;
|
||||
info!("hello world");
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = channel::<()>();
|
||||
let shutdown_tx = Arc::new(shutdown_tx);
|
||||
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let shard_state = discord::shard_state::new(redis.clone());
|
||||
let cache = Arc::new(discord::cache::new());
|
||||
|
||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||
|
||||
let (event_tx, _event_rx) = channel();
|
||||
|
||||
let mut senders = Vec::new();
|
||||
let mut signal_senders = Vec::new();
|
||||
|
||||
let mut set = JoinSet::new();
|
||||
for shard in shards {
|
||||
senders.push((shard.id(), shard.sender()));
|
||||
signal_senders.push(shard.sender());
|
||||
set.spawn(tokio::spawn(discord::gateway::runner(
|
||||
shard,
|
||||
event_tx.clone(),
|
||||
shard_state.clone(),
|
||||
cache.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
set.spawn(tokio::spawn(
|
||||
async move { scheduled_task(redis, senders).await },
|
||||
));
|
||||
|
||||
// todo: probably don't do it this way
|
||||
let api_shutdown_tx = shutdown_tx.clone();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
match cache_api::run_server(cache).await {
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to serve cache api");
|
||||
let _ = api_shutdown_tx.send(());
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}));
|
||||
|
||||
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
|
||||
|
||||
tokio::spawn(async move {
|
||||
for sig in signals.forever() {
|
||||
info!("received signal {:?}", sig);
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
});
|
||||
|
||||
let _ = shutdown_rx.recv();
|
||||
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
set.abort_all();
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
(60 - chrono::offset::Utc::now().second()).into(),
|
||||
))
|
||||
.await;
|
||||
info!("running per-minute scheduled tasks");
|
||||
|
||||
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
|
||||
Ok(val) => Some(val),
|
||||
Err(error) => {
|
||||
tracing::warn!(?error, "failed to fetch bot status from redis");
|
||||
None
|
||||
}
|
||||
};
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence(
|
||||
if let Some(status) = status {
|
||||
format!("pk;help | {}", status)
|
||||
} else {
|
||||
"pk;help".to_string()
|
||||
}
|
||||
.as_str(),
|
||||
false,
|
||||
),
|
||||
};
|
||||
|
||||
for sender in senders.iter() {
|
||||
match sender.1.command(&presence) {
|
||||
Err(error) => {
|
||||
warn!(?error, "could not update presence on shard {}", sender.0)
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue