From b5f18106e172a0dcd7109e416b0498a70b8a4251 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sat, 24 May 2025 12:23:19 +0000 Subject: [PATCH] feat(gateway): add /shard_status endpoint --- crates/gateway/src/cache_api.rs | 7 ++- crates/gateway/src/discord/shard_state.rs | 58 ++++++++++++++++------- crates/gateway/src/main.rs | 6 ++- 3 files changed, 51 insertions(+), 20 deletions(-) diff --git a/crates/gateway/src/cache_api.rs b/crates/gateway/src/cache_api.rs index 375ab86e..f8c3f556 100644 --- a/crates/gateway/src/cache_api.rs +++ b/crates/gateway/src/cache_api.rs @@ -14,6 +14,7 @@ use crate::{ discord::{ cache::{dm_channel, DiscordCache, DM_PERMISSIONS}, gateway::cluster_config, + shard_state::ShardStateManager, }, event_awaiter::{AwaitEventRequest, EventAwaiter}, }; @@ -25,7 +26,7 @@ fn status_code(code: StatusCode, body: String) -> Response { // this function is manually formatted for easier legibility of route_services #[rustfmt::skip] -pub async fn run_server(cache: Arc, runtime_config: Arc, awaiter: Arc) -> anyhow::Result<()> { +pub async fn run_server(cache: Arc, shard_state: Arc, runtime_config: Arc, awaiter: Arc) -> anyhow::Result<()> { // hacky fix for `move` let runtime_config_for_post = runtime_config.clone(); let runtime_config_for_delete = runtime_config.clone(); @@ -211,6 +212,10 @@ pub async fn run_server(cache: Arc, runtime_config: Arc, + shards: RwLock>, } pub fn new(redis: RedisPool) -> ShardStateManager { ShardStateManager { redis: redis, - shards: HashMap::new(), + shards: RwLock::new(HashMap::new()), } } impl ShardStateManager { - pub async fn handle_event(&mut self, shard_id: u32, event: Event) -> anyhow::Result<()> { + pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> { match event { Event::Ready(_) => self.ready_or_resumed(shard_id, false).await, Event::Resumed => self.ready_or_resumed(shard_id, true).await, @@ -31,21 +31,33 @@ impl ShardStateManager { } } - async fn save_shard(&mut self, shard_id: u32) -> anyhow::Result<()> { - let info = self.shards.get(&shard_id); + async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> { + { + let mut shards = self.shards.write().await; + shards.insert(id, state.clone()); + } self.redis .hset::<(), &str, (String, String)>( "pluralkit:shardstatus", ( - shard_id.to_string(), - serde_json::to_string(&info).expect("could not serialize shard"), + id.to_string(), + serde_json::to_string(&state).expect("could not serialize shard"), ), ) .await?; Ok(()) } - async fn ready_or_resumed(&mut self, shard_id: u32, resumed: bool) -> anyhow::Result<()> { + async fn get_shard(&self, id: u32) -> Option { + let shards = self.shards.read().await; + shards.get(&id).cloned() + } + + pub async fn get(&self) -> Vec { + self.shards.read().await.values().cloned().collect() + } + + async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> { info!( "shard {} {}", shard_id, @@ -59,40 +71,52 @@ impl ShardStateManager { .increment(1); gauge!("pluralkit_gateway_shard_up").increment(1); - let info = self.shards.entry(shard_id).or_insert(ShardState::default()); + let mut info = self + .get_shard(shard_id) + .await + .unwrap_or(ShardState::default()); + info.shard_id = shard_id as i32; info.cluster_id = Some(cluster_config().node_id as i32); info.last_connection = chrono::offset::Utc::now().timestamp() as i32; info.up = true; - self.save_shard(shard_id).await?; + self.save_shard(shard_id, info).await?; Ok(()) } - pub async fn socket_closed(&mut self, shard_id: u32) -> anyhow::Result<()> { + pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> { gauge!("pluralkit_gateway_shard_up").decrement(1); - let info = self.shards.entry(shard_id).or_insert(ShardState::default()); + let mut info = self + .get_shard(shard_id) + .await + .unwrap_or(ShardState::default()); + info.shard_id = shard_id as i32; info.cluster_id = Some(cluster_config().node_id as i32); info.up = false; info.disconnection_count += 1; - self.save_shard(shard_id).await?; + self.save_shard(shard_id, info).await?; Ok(()) } - pub async fn heartbeated(&mut self, shard_id: u32, latency: i32) -> anyhow::Result<()> { + pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> { gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency); - let info = self.shards.entry(shard_id).or_insert(ShardState::default()); + let mut info = self + .get_shard(shard_id) + .await + .unwrap_or(ShardState::default()); + info.shard_id = shard_id as i32; info.cluster_id = Some(cluster_config().node_id as i32); info.up = true; info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32; info.latency = latency; - self.save_shard(shard_id).await?; + self.save_shard(shard_id, info).await?; Ok(()) } } diff --git a/crates/gateway/src/main.rs b/crates/gateway/src/main.rs index f7c88cb2..3b45e1b8 100644 --- a/crates/gateway/src/main.rs +++ b/crates/gateway/src/main.rs @@ -91,8 +91,10 @@ async fn main() -> anyhow::Result<()> { ))); } + let shard_state = Arc::new(discord::shard_state::new(redis.clone())); + set.spawn(tokio::spawn({ - let mut shard_state = discord::shard_state::new(redis.clone()); + let shard_state = shard_state.clone(); async move { while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await { @@ -187,7 +189,7 @@ async fn main() -> anyhow::Result<()> { )); set.spawn(tokio::spawn(async move { - match cache_api::run_server(cache, runtime_config, awaiter.clone()).await { + match cache_api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await { Err(error) => { error!(?error, "failed to serve cache api"); }