fix(gateway): move shard state updates, store in hashmap

This commit is contained in:
asleepyskye 2025-05-22 17:31:29 -04:00 committed by alyssa
parent 0167519804
commit e2acaf93be
4 changed files with 115 additions and 40 deletions

View file

@ -1,6 +1,6 @@
use anyhow::anyhow;
use futures::StreamExt;
use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig};
use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig, state::ShardStateEvent};
use metrics::counter;
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
@ -19,7 +19,7 @@ use crate::{
RUNTIME_CONFIG_KEY_EVENT_TARGET,
};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
use super::cache::DiscordCache;
pub fn cluster_config() -> ClusterSettings {
libpk::config
@ -98,7 +98,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
pub async fn runner(
mut shard: Shard<RedisQueue>,
tx: Sender<(ShardId, Event, String)>,
shard_state: ShardStateManager,
tx_state: Sender<(ShardId, ShardStateEvent, Option<Event>, Option<i32>)>,
cache: Arc<DiscordCache>,
runtime_config: Arc<RuntimeConfig>,
) {
@ -131,8 +131,10 @@ pub async fn runner(
)
.increment(1);
if let Err(error) = shard_state.socket_closed(shard_id).await {
error!(?error, "failed to update shard state for socket closure");
if let Err(error) =
tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None))
{
error!("failed to update shard state for socket closure: {error}");
}
continue;
@ -173,14 +175,29 @@ pub async fn runner(
.increment(1);
// update shard state and discord cache
if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await {
tracing::error!(?error, "error updating redis state");
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Other,
Some(event.clone()),
None,
)) {
tracing::error!(?error, "error updating shard state");
}
// need to do heartbeat separately, to get the latency
let latency_num = shard
.latency()
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
if let Event::GatewayHeartbeatAck = event
&& let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await
&& let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Heartbeat,
Some(event.clone()),
Some(latency_num),
))
{
tracing::error!(?error, "error updating redis state for latency");
tracing::error!(?error, "error updating shard state for latency");
}
if let Event::Ready(_) = event {

View file

@ -1,21 +1,29 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge};
use tracing::info;
use twilight_gateway::{Event, Latency};
use twilight_gateway::Event;
use std::collections::HashMap;
use libpk::state::ShardState;
use super::gateway::cluster_config;
#[derive(Clone)]
pub struct ShardStateManager {
redis: RedisPool,
shards: HashMap<u32, ShardState>,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
ShardStateManager {
redis: redis,
shards: HashMap::new(),
}
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
pub async fn handle_event(&mut self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
@ -23,15 +31,8 @@ impl ShardStateManager {
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<String> = self.redis.hget("pluralkit:shardstatus", shard_id).await?;
match data {
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
None => Ok(ShardState::default()),
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
async fn save_shard(&mut self, shard_id: u32) -> anyhow::Result<()> {
let info = self.shards.get(&shard_id);
self.redis
.hset::<(), &str, (String, String)>(
"pluralkit:shardstatus",
@ -44,7 +45,7 @@ impl ShardStateManager {
Ok(())
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
async fn ready_or_resumed(&mut self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!(
"shard {} {}",
shard_id,
@ -57,33 +58,41 @@ impl ShardStateManager {
)
.increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1);
let mut info = self.get_shard(shard_id).await?;
let info = self.shards.entry(shard_id).or_insert(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
self.save_shard(shard_id).await?;
Ok(())
}
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
pub async fn socket_closed(&mut self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let mut info = self.get_shard(shard_id).await?;
let info = self.shards.entry(shard_id).or_insert(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
self.save_shard(shard_id).await?;
Ok(())
}
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
pub async fn heartbeated(&mut self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
let info = self.shards.entry(shard_id).or_insert(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string())
.set(info.latency);
self.save_shard(shard_id, info).await?;
info.latency = latency;
self.save_shard(shard_id).await?;
Ok(())
}
}

View file

@ -6,7 +6,7 @@ use chrono::Timelike;
use discord::gateway::cluster_config;
use event_awaiter::EventAwaiter;
use fred::{clients::RedisPool, interfaces::*};
use libpk::runtime_config::RuntimeConfig;
use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
use reqwest::{ClientBuilder, StatusCode};
use std::{sync::Arc, time::Duration, vec::Vec};
use tokio::{
@ -54,7 +54,6 @@ async fn main() -> anyhow::Result<()> {
.await?;
}
let shard_state = discord::shard_state::new(redis.clone());
let cache = Arc::new(discord::cache::new());
let awaiter = Arc::new(EventAwaiter::new());
tokio::spawn({
@ -68,6 +67,14 @@ async fn main() -> anyhow::Result<()> {
// todo: make sure this doesn't fill up
let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000);
// todo: make sure this doesn't fill up
let (state_tx, mut state_rx) = channel::<(
ShardId,
ShardStateEvent,
Option<twilight_gateway::Event>,
Option<i32>,
)>(1000);
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
@ -78,12 +85,48 @@ async fn main() -> anyhow::Result<()> {
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
state_tx.clone(),
cache.clone(),
runtime_config.clone(),
)));
}
set.spawn(tokio::spawn({
let mut shard_state = discord::shard_state::new(redis.clone());
async move {
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
match state_event {
ShardStateEvent::Heartbeat => {
if !latency.is_none()
&& let Err(error) = shard_state
.heartbeated(shard_id.number(), latency.unwrap())
.await
{
error!("failed to update shard state for heartbeat: {error}")
};
}
ShardStateEvent::Closed => {
if let Err(error) = shard_state.socket_closed(shard_id.number()).await {
error!("failed to update shard state for heartbeat: {error}")
};
}
ShardStateEvent::Other => {
if let Err(error) = shard_state
.handle_event(
shard_id.number(),
parsed_event.expect("shard state event not provided!"),
)
.await
{
error!("failed to update shard state for heartbeat: {error}")
};
}
}
}
}
}));
set.spawn(tokio::spawn({
let runtime_config = runtime_config.clone();
let awaiter = awaiter.clone();

View file

@ -1,4 +1,4 @@
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)]
pub struct ShardState {
pub shard_id: i32,
pub up: bool,
@ -10,3 +10,9 @@ pub struct ShardState {
pub last_connection: i32,
pub cluster_id: Option<i32>,
}
pub enum ShardStateEvent {
Closed,
Heartbeat,
Other,
}