feat: gateway service

This commit is contained in:
alyssa 2024-09-14 12:19:47 +09:00
parent 1118d8bdf8
commit e4ed354536
50 changed files with 1737 additions and 545 deletions

View file

@ -147,6 +147,7 @@ async fn main() -> anyhow::Result<()> {
let addr: &str = libpk::config.api.addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())

View file

@ -0,0 +1,25 @@
[package]
name = "gateway"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
futures = { workspace = true }
lazy_static = { workspace = true }
libpk = { path = "../../lib/libpk" }
prost = { workspace = true }
serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-gateway = { workspace = true }
twilight-cache-inmemory = { workspace = true }
twilight-util = { workspace = true }
twilight-model = { workspace = true }
twilight-http = { workspace = true }

View file

@ -0,0 +1,168 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Router,
};
use serde_json::to_string;
use tracing::{error, info};
use twilight_model::guild::Permissions;
use twilight_model::id::Id;
use crate::discord::cache::{dm_channel, DiscordCache, DM_PERMISSIONS};
use std::sync::Arc;
fn status_code(code: StatusCode, body: String) -> Response {
(code, body).into_response()
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
let app = Router::new()
.route(
"/guilds/:guild_id",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild(Id::new(guild_id)) {
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/members/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.0.member(Id::new(guild_id), libpk::config.discord.client_id) {
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.client_id).await {
Ok(val) => {
println!("hh {}", Permissions::all().bits());
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
},
Err(err) => {
error!(?err, ?guild_id, "failed to get own guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/permissions/:user_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
Err(err) => {
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
Some(channels) => channels.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut channels = Vec::new();
for id in channel_ids {
match cache.0.channel(id) {
Some(channel) => channels.push(channel.to_owned()),
None => {
tracing::error!(
channel_id = id.get(),
"referenced channel {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
})
)
.route(
"/guilds/:guild_id/channels/:channel_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
}
match cache.0.channel(Id::new(channel_id)) {
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string())
}
})
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
}
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.client_id).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
Err(err) => {
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
get(|| async { "todo" }),
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
)
.route(
"/guilds/:guild_id/roles",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
Some(roles) => roles.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut roles = Vec::new();
for id in role_ids {
match cache.0.role(id) {
Some(role) => roles.push(role.value().resource().to_owned()),
None => {
tracing::error!(
role_id = id.get(),
"referenced role {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
})
)
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);
let addr: &str = libpk::config.api.addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,339 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use twilight_cache_inmemory::{
model::CachedMember,
permission::{MemberRoles, RootError},
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
Id,
},
};
use twilight_util::permission_calculator::PermissionCalculator;
lazy_static! {
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
| Permissions::SEND_MESSAGES
| Permissions::READ_MESSAGE_HISTORY
| Permissions::ADD_REACTIONS
| Permissions::ATTACH_FILES
| Permissions::EMBED_LINKS
| Permissions::USE_EXTERNAL_EMOJIS
| Permissions::CONNECT
| Permissions::SPEAK
| Permissions::USE_VAD;
}
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
Channel {
id,
kind: ChannelType::Private,
application_id: None,
applied_tags: None,
available_tags: None,
bitrate: None,
default_auto_archive_duration: None,
default_forum_layout: None,
default_reaction_emoji: None,
default_sort_order: None,
default_thread_rate_limit_per_user: None,
flags: None,
guild_id: None,
icon: None,
invitable: None,
last_message_id: None,
last_pin_timestamp: None,
managed: None,
member: None,
member_count: None,
message_count: None,
name: None,
newly_created: None,
nsfw: None,
owner_id: None,
parent_id: None,
permission_overwrites: None,
position: None,
rate_limit_per_user: None,
recipients: None,
rtc_region: None,
thread_metadata: None,
topic: None,
user_limit: None,
video_quality_mode: None,
}
}
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
CachedMember {
avatar: item.avatar,
communication_disabled_until: item.communication_disabled_until,
deaf: Some(item.deaf),
flags: item.flags,
joined_at: item.joined_at,
mute: Some(item.mute),
nick: item.nick,
premium_since: item.premium_since,
roles: item.roles,
pending: false,
user_id: id,
}
}
pub fn new() -> DiscordCache {
let mut client_builder =
twilight_http::Client::builder().token(libpk::config.discord.bot_token.clone());
if let Some(base_url) = libpk::config.discord.api_base_url.clone() {
client_builder = client_builder.proxy(base_url, true);
}
let client = Arc::new(client_builder.build());
let cache = Arc::new(
InMemoryCache::builder()
.resource_types(
ResourceType::GUILD
| ResourceType::CHANNEL
| ResourceType::ROLE
| ResourceType::USER_CURRENT
| ResourceType::MEMBER_CURRENT,
)
.message_cache_size(0)
.build(),
);
DiscordCache(cache, client)
}
pub struct DiscordCache(pub Arc<InMemoryCache>, pub Arc<twilight_http::Client>);
impl DiscordCache {
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
if self
.0
.guild(guild_id)
.ok_or_else(|| format_err!("guild not found"))?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id == libpk::config.discord.client_id {
self.0
.member(guild_id, user_id)
.ok_or(format_err!("self member not found"))?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.root();
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
pub async fn channel_permissions(
&self,
channel_id: Id<ChannelMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
let channel = self
.0
.channel(channel_id)
.ok_or(format_err!("channel not found"))?;
if channel.value().guild_id.is_none() {
return Ok(*DM_PERMISSIONS);
}
let guild_id = channel.value().guild_id.unwrap();
if self
.0
.guild(guild_id)
.ok_or_else(|| {
tracing::error!(
channel_id = channel_id.get(),
guild_id = guild_id.get(),
"referenced guild from cached channel {channel_id} not found in cache"
);
format_err!("internal cache error")
})?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id == libpk::config.discord.client_id {
self.0
.member(guild_id, user_id)
.ok_or_else(|| {
tracing::error!(
guild_id = guild_id.get(),
"self member for cached guild {guild_id} not found in cache"
);
format_err!("internal cache error")
})?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let overwrites = match channel.kind {
ChannelType::AnnouncementThread
| ChannelType::PrivateThread
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
_ => channel
.value()
.permission_overwrites()
.unwrap_or_default()
.to_vec(),
};
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
self.0.guild(id).map(|guild| {
let channels = self
.0
.guild_channels(id)
.map(|reference| {
reference
.iter()
.filter_map(|channel_id| {
let channel = self.0.channel(*channel_id)?;
if channel.kind.is_thread() {
None
} else {
Some(channel.value().clone())
}
})
.collect()
})
.unwrap_or_default();
let roles = self
.0
.guild_roles(id)
.map(|reference| {
reference
.iter()
.filter_map(|role_id| {
Some(self.0.role(*role_id)?.value().resource().clone())
})
.collect()
})
.unwrap_or_default();
Guild {
afk_channel_id: guild.afk_channel_id(),
afk_timeout: guild.afk_timeout(),
application_id: guild.application_id(),
approximate_member_count: None, // Only present in with_counts HTTP endpoint
banner: guild.banner().map(ToOwned::to_owned),
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
channels,
default_message_notifications: guild.default_message_notifications(),
description: guild.description().map(ToString::to_string),
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
emojis: vec![],
explicit_content_filter: guild.explicit_content_filter(),
features: guild.features().cloned().collect(),
icon: guild.icon().map(ToOwned::to_owned),
id: guild.id(),
joined_at: guild.joined_at(),
large: guild.large(),
max_members: guild.max_members(),
max_presences: guild.max_presences(),
max_video_channel_users: guild.max_video_channel_users(),
member_count: guild.member_count(),
members: vec![],
mfa_level: guild.mfa_level(),
name: guild.name().to_string(),
nsfw_level: guild.nsfw_level(),
owner_id: guild.owner_id(),
owner: guild.owner(),
permissions: guild.permissions(),
public_updates_channel_id: guild.public_updates_channel_id(),
preferred_locale: guild.preferred_locale().to_string(),
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
premium_subscription_count: guild.premium_subscription_count(),
premium_tier: guild.premium_tier(),
presences: vec![],
roles,
rules_channel_id: guild.rules_channel_id(),
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
splash: guild.splash().map(ToOwned::to_owned),
stage_instances: vec![],
stickers: vec![],
system_channel_flags: guild.system_channel_flags(),
system_channel_id: guild.system_channel_id(),
threads: vec![],
unavailable: false,
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
verification_level: guild.verification_level(),
voice_states: vec![],
widget_channel_id: guild.widget_channel_id(),
widget_enabled: guild.widget_enabled(),
}
})
}
}

View file

@ -0,0 +1,121 @@
use std::sync::{mpsc::Sender, Arc};
use tracing::{info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Shard, ShardId, StreamExt,
};
use twilight_model::gateway::{
payload::outgoing::update_presence::UpdatePresencePayload,
presence::{Activity, ActivityType, Status},
Intents,
};
use crate::discord::identify_queue::{self, RedisQueue};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
pub fn create_shards(redis: fred::pool::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
let intents = Intents::GUILDS
| Intents::DIRECT_MESSAGES
| Intents::DIRECT_MESSAGE_REACTIONS
| Intents::GUILD_MESSAGES
| Intents::GUILD_MESSAGE_REACTIONS
| Intents::MESSAGE_CONTENT;
let queue = identify_queue::new(redis);
let cluster_settings =
libpk::config
.discord
.cluster
.clone()
.unwrap_or(libpk::_config::ClusterSettings {
node_id: 0,
total_shards: 1,
total_nodes: 1,
});
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
warn!("we have less than 16 shards, assuming single gateway process");
(0, (cluster_settings.total_shards - 1).into())
} else {
(
(cluster_settings.node_id * 16).into(),
(((cluster_settings.node_id + 1) * 16) - 1).into(),
)
};
let shards = create_iterator(
start_shard..end_shard + 1,
cluster_settings.total_shards,
ConfigBuilder::new(libpk::config.discord.bot_token.to_owned(), intents)
.presence(presence("pk;help", false))
.queue(queue.clone())
.build(),
|_, builder| builder.build(),
);
let mut shards_vec = Vec::new();
shards_vec.extend(shards);
Ok(shards_vec)
}
pub async fn runner(
mut shard: Shard<RedisQueue>,
tx: Sender<(ShardId, Event)>,
shard_state: ShardStateManager,
cache: Arc<DiscordCache>,
) {
//let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
info!("waiting for events");
while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
match item {
Ok(event) => {
if let Err(error) = shard_state
.handle_event(shard.id().number(), event.clone())
.await
{
tracing::warn!(?error, "error updating redis state")
}
cache.0.update(&event);
//if let Err(error) = tx.send((shard.id(), event)) {
// tracing::warn!(?error, "error sending event to global handler: {error}",);
//}
}
Err(error) => {
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
continue;
}
}
}
}
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
UpdatePresencePayload {
activities: vec![Activity {
application_id: None,
assets: None,
buttons: vec![],
created_at: None,
details: None,
id: None,
state: None,
url: None,
emoji: None,
flags: None,
instance: None,
kind: ActivityType::Playing,
name: status.to_string(),
party: None,
secrets: None,
timestamps: None,
}],
afk: false,
since: None,
status: if going_away {
Status::Idle
} else {
Status::Online
},
}
}

View file

@ -0,0 +1,87 @@
use fred::{
error::RedisError,
interfaces::KeysInterface,
pool::RedisPool,
types::{Expiration, SetOptions},
};
use std::fmt::Debug;
use std::time::Duration;
use tokio::sync::oneshot;
use tracing::{error, info};
use twilight_gateway::queue::Queue;
use libpk::util::redis::RedisErrorExt;
pub fn new(redis: RedisPool) -> RedisQueue {
RedisQueue {
redis,
concurrency: libpk::config.discord.max_concurrency,
}
}
#[derive(Clone)]
pub struct RedisQueue {
pub redis: RedisPool,
pub concurrency: u32,
}
impl Debug for RedisQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisQueue")
.field("concurrency", &self.concurrency)
.finish()
}
}
impl Queue for RedisQueue {
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
tokio::spawn(request_inner(
self.redis.clone(),
self.concurrency,
shard_id,
tx,
));
rx
}
}
const EXPIRY: i64 = 6;
const RETRY_INTERVAL: u64 = 500;
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
let bucket = shard_id % concurrency;
let key = format!("pluralkit:identify:{}", bucket);
info!(shard_id, bucket, "waiting for allowance...");
loop {
let done: Result<Option<String>, RedisError> = redis
.set(
key.to_string(),
"1",
Some(Expiration::EX(EXPIRY)),
Some(SetOptions::NX),
false,
)
.await
.to_option_or_error();
match done {
Ok(Some(_)) => {
info!(shard_id, bucket, "got allowance!");
// if this fails, it's probably already doing something else
let _ = tx.send(());
return;
}
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
}
}
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
}
}

View file

@ -0,0 +1,4 @@
pub mod cache;
pub mod gateway;
pub mod identify_queue;
pub mod shard_state;

View file

@ -0,0 +1,84 @@
use bytes::Bytes;
use fred::{interfaces::HashesInterface, pool::RedisPool};
use prost::Message;
use tracing::info;
use twilight_gateway::Event;
use libpk::{proto::*, util::redis::*};
#[derive(Clone)]
pub struct ShardStateManager {
redis: RedisPool,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
Event::Ready(_) => self.ready_or_resumed(shard_id).await,
Event::Resumed => self.ready_or_resumed(shard_id).await,
Event::GatewayClose(_) => self.socket_closed(shard_id).await,
Event::GatewayHeartbeat(_) => self.heartbeated(shard_id).await,
_ => Ok(()),
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<Vec<u8>> = self
.redis
.hget("pluralkit:shardstatus", shard_id)
.await
.to_option_or_error()?;
match data {
Some(buf) => {
Ok(ShardState::decode(buf.as_slice()).expect("could not decode shard data!"))
}
None => Ok(ShardState::default()),
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
self.redis
.hset(
"pluralkit:shardstatus",
(
shard_id.to_string(),
Bytes::copy_from_slice(&info.encode_to_vec()),
),
)
.await?;
Ok(())
}
async fn ready_or_resumed(&self, shard_id: u32) -> anyhow::Result<()> {
info!("shard {} ready", shard_id);
let mut info = self.get_shard(shard_id).await?;
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
Ok(())
}
async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
info!("shard {} closed", shard_id);
let mut info = self.get_shard(shard_id).await?;
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
async fn heartbeated(&self, shard_id: u32) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
// todo
// info.latency = latency.recent().front().map_or_else(|| 0, |d| d.as_millis()) as i32;
info.latency = 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
}

View file

@ -0,0 +1,52 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

View file

@ -0,0 +1,141 @@
use chrono::Timelike;
use fred::{interfaces::*, pool::RedisPool};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
};
use std::{
sync::{mpsc::channel, Arc},
time::Duration,
vec::Vec,
};
use tokio::task::JoinSet;
use tracing::{info, warn};
use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence;
mod cache_api;
mod discord;
mod logger;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
libpk::init_logging("gateway")?;
libpk::init_metrics()?;
info!("hello world");
let (shutdown_tx, shutdown_rx) = channel::<()>();
let shutdown_tx = Arc::new(shutdown_tx);
let redis = libpk::db::init_redis().await?;
let shard_state = discord::shard_state::new(redis.clone());
let cache = Arc::new(discord::cache::new());
let shards = discord::gateway::create_shards(redis.clone())?;
let (event_tx, _event_rx) = channel();
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
let mut set = JoinSet::new();
for shard in shards {
senders.push((shard.id(), shard.sender()));
signal_senders.push(shard.sender());
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
cache.clone(),
)));
}
set.spawn(tokio::spawn(
async move { scheduled_task(redis, senders).await },
));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache).await {
Err(error) => {
tracing::error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
}
_ => unreachable!(),
}
}));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
tokio::spawn(async move {
for sig in signals.forever() {
info!("received signal {:?}", sig);
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(());
break;
}
});
let _ = shutdown_rx.recv();
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
set.abort_all();
info!("gateway exiting, have a nice day!");
Ok(())
}
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
loop {
tokio::time::sleep(Duration::from_secs(
(60 - chrono::offset::Utc::now().second()).into(),
))
.await;
info!("running per-minute scheduled tasks");
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
Ok(val) => Some(val),
Err(error) => {
tracing::warn!(?error, "failed to fetch bot status from redis");
None
}
};
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence(
if let Some(status) = status {
format!("pk;help | {}", status)
} else {
"pk;help".to_string()
}
.as_str(),
false,
),
};
for sender in senders.iter() {
match sender.1.command(&presence) {
Err(error) => {
warn!(?error, "could not update presence on shard {}", sender.0)
}
_ => {}
};
}
}
}