diff --git a/crates/gateway/src/discord/gateway.rs b/crates/gateway/src/discord/gateway.rs index 5ca24185..24864da8 100644 --- a/crates/gateway/src/discord/gateway.rs +++ b/crates/gateway/src/discord/gateway.rs @@ -1,6 +1,6 @@ use anyhow::anyhow; use futures::StreamExt; -use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig}; +use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig, state::ShardStateEvent}; use metrics::counter; use std::sync::Arc; use tokio::sync::mpsc::Sender; @@ -19,7 +19,7 @@ use crate::{ RUNTIME_CONFIG_KEY_EVENT_TARGET, }; -use super::{cache::DiscordCache, shard_state::ShardStateManager}; +use super::cache::DiscordCache; pub fn cluster_config() -> ClusterSettings { libpk::config @@ -90,7 +90,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result, tx: Sender<(ShardId, Event, String)>, - shard_state: ShardStateManager, + tx_state: Sender<(ShardId, ShardStateEvent, Option, Option)>, cache: Arc, runtime_config: Arc, ) { @@ -123,7 +123,9 @@ pub async fn runner( ) .increment(1); - if let Err(error) = shard_state.socket_closed(shard_id).await { + if let Err(error) = + tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None)) + { error!("failed to update shard state for socket closure: {error}"); } @@ -165,14 +167,29 @@ pub async fn runner( .increment(1); // update shard state and discord cache - if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await { - tracing::error!(?error, "error updating redis state"); + if let Err(error) = tx_state.try_send(( + shard.id(), + ShardStateEvent::Other, + Some(event.clone()), + None, + )) { + tracing::error!(?error, "error updating shard state"); } // need to do heartbeat separately, to get the latency + let latency_num = shard + .latency() + .recent() + .first() + .map_or_else(|| 0, |d| d.as_millis()) as i32; if let Event::GatewayHeartbeatAck = event - && let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await + && let Err(error) = tx_state.try_send(( + shard.id(), + ShardStateEvent::Heartbeat, + Some(event.clone()), + Some(latency_num), + )) { - tracing::error!(?error, "error updating redis state for latency"); + tracing::error!(?error, "error updating shard state for latency"); } if let Event::Ready(_) = event { diff --git a/crates/gateway/src/discord/shard_state.rs b/crates/gateway/src/discord/shard_state.rs index a7579583..d063fb5b 100644 --- a/crates/gateway/src/discord/shard_state.rs +++ b/crates/gateway/src/discord/shard_state.rs @@ -1,21 +1,29 @@ use fred::{clients::RedisPool, interfaces::HashesInterface}; use metrics::{counter, gauge}; use tracing::info; -use twilight_gateway::{Event, Latency}; +use twilight_gateway::Event; + +use std::collections::HashMap; use libpk::state::ShardState; +use super::gateway::cluster_config; + #[derive(Clone)] pub struct ShardStateManager { redis: RedisPool, + shards: HashMap, } pub fn new(redis: RedisPool) -> ShardStateManager { - ShardStateManager { redis } + ShardStateManager { + redis: redis, + shards: HashMap::new(), + } } impl ShardStateManager { - pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> { + pub async fn handle_event(&mut self, shard_id: u32, event: Event) -> anyhow::Result<()> { match event { Event::Ready(_) => self.ready_or_resumed(shard_id, false).await, Event::Resumed => self.ready_or_resumed(shard_id, true).await, @@ -23,15 +31,8 @@ impl ShardStateManager { } } - async fn get_shard(&self, shard_id: u32) -> anyhow::Result { - let data: Option = self.redis.hget("pluralkit:shardstatus", shard_id).await?; - match data { - Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")), - None => Ok(ShardState::default()), - } - } - - async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> { + async fn save_shard(&mut self, shard_id: u32) -> anyhow::Result<()> { + let info = self.shards.get(&shard_id); self.redis .hset::<(), &str, (String, String)>( "pluralkit:shardstatus", @@ -44,7 +45,7 @@ impl ShardStateManager { Ok(()) } - async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> { + async fn ready_or_resumed(&mut self, shard_id: u32, resumed: bool) -> anyhow::Result<()> { info!( "shard {} {}", shard_id, @@ -57,33 +58,41 @@ impl ShardStateManager { ) .increment(1); gauge!("pluralkit_gateway_shard_up").increment(1); - let mut info = self.get_shard(shard_id).await?; + + let info = self.shards.entry(shard_id).or_insert(ShardState::default()); + info.shard_id = shard_id as i32; + info.cluster_id = Some(cluster_config().node_id as i32); info.last_connection = chrono::offset::Utc::now().timestamp() as i32; info.up = true; - self.save_shard(shard_id, info).await?; + + self.save_shard(shard_id).await?; Ok(()) } - pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> { + pub async fn socket_closed(&mut self, shard_id: u32) -> anyhow::Result<()> { gauge!("pluralkit_gateway_shard_up").decrement(1); - let mut info = self.get_shard(shard_id).await?; + + let info = self.shards.entry(shard_id).or_insert(ShardState::default()); + info.shard_id = shard_id as i32; + info.cluster_id = Some(cluster_config().node_id as i32); info.up = false; info.disconnection_count += 1; - self.save_shard(shard_id, info).await?; + + self.save_shard(shard_id).await?; Ok(()) } - pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> { - let mut info = self.get_shard(shard_id).await?; + pub async fn heartbeated(&mut self, shard_id: u32, latency: i32) -> anyhow::Result<()> { + gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency); + + let info = self.shards.entry(shard_id).or_insert(ShardState::default()); + info.shard_id = shard_id as i32; + info.cluster_id = Some(cluster_config().node_id as i32); info.up = true; info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32; - info.latency = latency - .recent() - .first() - .map_or_else(|| 0, |d| d.as_millis()) as i32; - gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()) - .set(info.latency); - self.save_shard(shard_id, info).await?; + info.latency = latency; + + self.save_shard(shard_id).await?; Ok(()) } } diff --git a/crates/gateway/src/main.rs b/crates/gateway/src/main.rs index 7cab2e76..0518e3ee 100644 --- a/crates/gateway/src/main.rs +++ b/crates/gateway/src/main.rs @@ -6,7 +6,7 @@ use chrono::Timelike; use discord::gateway::cluster_config; use event_awaiter::EventAwaiter; use fred::{clients::RedisPool, interfaces::*}; -use libpk::runtime_config::RuntimeConfig; +use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent}; use reqwest::{ClientBuilder, StatusCode}; use std::{sync::Arc, time::Duration, vec::Vec}; use tokio::{ @@ -54,7 +54,6 @@ async fn real_main() -> anyhow::Result<()> { .await?; } - let shard_state = discord::shard_state::new(redis.clone()); let cache = Arc::new(discord::cache::new()); let awaiter = Arc::new(EventAwaiter::new()); tokio::spawn({ @@ -68,6 +67,14 @@ async fn real_main() -> anyhow::Result<()> { // todo: make sure this doesn't fill up let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000); + // todo: make sure this doesn't fill up + let (state_tx, mut state_rx) = channel::<( + ShardId, + ShardStateEvent, + Option, + Option, + )>(1000); + let mut senders = Vec::new(); let mut signal_senders = Vec::new(); @@ -78,12 +85,48 @@ async fn real_main() -> anyhow::Result<()> { set.spawn(tokio::spawn(discord::gateway::runner( shard, event_tx.clone(), - shard_state.clone(), + state_tx.clone(), cache.clone(), runtime_config.clone(), ))); } + set.spawn(tokio::spawn({ + let mut shard_state = discord::shard_state::new(redis.clone()); + + async move { + while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await { + match state_event { + ShardStateEvent::Heartbeat => { + if !latency.is_none() + && let Err(error) = shard_state + .heartbeated(shard_id.number(), latency.unwrap()) + .await + { + error!("failed to update shard state for heartbeat: {error}") + }; + } + ShardStateEvent::Closed => { + if let Err(error) = shard_state.socket_closed(shard_id.number()).await { + error!("failed to update shard state for heartbeat: {error}") + }; + } + ShardStateEvent::Other => { + if let Err(error) = shard_state + .handle_event( + shard_id.number(), + parsed_event.expect("shard state event not provided!"), + ) + .await + { + error!("failed to update shard state for heartbeat: {error}") + }; + } + } + } + } + })); + set.spawn(tokio::spawn({ let runtime_config = runtime_config.clone(); let awaiter = awaiter.clone(); diff --git a/crates/libpk/src/state.rs b/crates/libpk/src/state.rs index 90a77c21..df44ea1d 100644 --- a/crates/libpk/src/state.rs +++ b/crates/libpk/src/state.rs @@ -1,4 +1,4 @@ -#[derive(serde::Serialize, serde::Deserialize, Clone, Default)] +#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)] pub struct ShardState { pub shard_id: i32, pub up: bool, @@ -10,3 +10,9 @@ pub struct ShardState { pub last_connection: i32, pub cluster_id: Option, } + +pub enum ShardStateEvent { + Closed, + Heartbeat, + Other, +}