fix(gateway): improve shutdown flow

This commit is contained in:
alyssa 2025-04-26 17:00:48 +00:00
parent 6c0c7a5c99
commit 4a098e4533
4 changed files with 26 additions and 45 deletions

11
Cargo.lock generated
View file

@ -1222,7 +1222,6 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"serde_variant", "serde_variant",
"signal-hook",
"tokio", "tokio",
"tracing", "tracing",
"twilight-cache-inmemory", "twilight-cache-inmemory",
@ -3634,16 +3633,6 @@ version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64"
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]] [[package]]
name = "signal-hook-registry" name = "signal-hook-registry"
version = "1.4.2" version = "1.4.2"

View file

@ -18,7 +18,6 @@ reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-t
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
serde = { version = "1.0.196", features = ["derive"] } serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.117" serde_json = "1.0.117"
signal-hook = "0.3.17"
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "macros", "uuid"] } sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "macros", "uuid"] }
tokio = { version = "1.36.0", features = ["full"] } tokio = { version = "1.36.0", features = ["full"] }
tracing = "0.1" tracing = "0.1"

View file

@ -16,7 +16,6 @@ metrics = { workspace = true }
reqwest = { workspace = true } reqwest = { workspace = true }
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true } serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true } tokio = { workspace = true }
tracing = { workspace = true } tracing = { workspace = true }

View file

@ -8,12 +8,12 @@ use event_awaiter::EventAwaiter;
use fred::{clients::RedisPool, interfaces::*}; use fred::{clients::RedisPool, interfaces::*};
use libpk::runtime_config::RuntimeConfig; use libpk::runtime_config::RuntimeConfig;
use reqwest::{ClientBuilder, StatusCode}; use reqwest::{ClientBuilder, StatusCode};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
};
use std::{sync::Arc, time::Duration, vec::Vec}; use std::{sync::Arc, time::Duration, vec::Vec};
use tokio::{sync::mpsc::channel, task::JoinSet}; use tokio::{
signal::unix::{signal, SignalKind},
sync::mpsc::channel,
task::JoinSet,
};
use tracing::{error, info, warn}; use tracing::{error, info, warn};
use twilight_gateway::{MessageSender, ShardId}; use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence; use twilight_model::gateway::payload::outgoing::UpdatePresence;
@ -27,9 +27,6 @@ const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
libpk::main!("gateway"); libpk::main!("gateway");
async fn real_main() -> anyhow::Result<()> { async fn real_main() -> anyhow::Result<()> {
let (shutdown_tx, mut shutdown_rx) = channel::<()>(1);
let shutdown_tx = Arc::new(shutdown_tx);
let redis = libpk::db::init_redis().await?; let redis = libpk::db::init_redis().await?;
let runtime_config = Arc::new( let runtime_config = Arc::new(
@ -68,7 +65,8 @@ async fn real_main() -> anyhow::Result<()> {
let shards = discord::gateway::create_shards(redis.clone())?; let shards = discord::gateway::create_shards(redis.clone())?;
// arbitrary // arbitrary
let (event_tx, mut event_rx) = channel(1000); // todo: make sure this doesn't fill up
let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000);
let mut senders = Vec::new(); let mut senders = Vec::new();
let mut signal_senders = Vec::new(); let mut signal_senders = Vec::new();
@ -145,43 +143,39 @@ async fn real_main() -> anyhow::Result<()> {
async move { scheduled_task(redis, senders).await }, async move { scheduled_task(redis, senders).await },
)); ));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move { set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache, runtime_config, awaiter.clone()).await { match cache_api::run_server(cache, runtime_config, awaiter.clone()).await {
Err(error) => { Err(error) => {
tracing::error!(?error, "failed to serve cache api"); error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
} }
_ => unreachable!(), _ => unreachable!(),
} }
})); }));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
set.spawn(tokio::spawn(async move { set.spawn(tokio::spawn(async move {
for sig in signals.forever() { signal(SignalKind::interrupt()).unwrap().recv().await;
info!("received signal {:?}", sig); info!("got SIGINT");
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(()).await;
break;
}
})); }));
let _ = shutdown_rx.recv().await; set.spawn(tokio::spawn(async move {
signal(SignalKind::terminate()).unwrap().recv().await;
info!("got SIGTERM");
}));
set.join_next().await;
info!("gateway exiting, have a nice day!"); info!("gateway exiting, have a nice day!");
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
set.abort_all(); set.abort_all();
// sleep 500ms to allow everything to clean up properly // sleep 500ms to allow everything to clean up properly