feat(gateway): add /shard_status endpoint

This commit is contained in:
alyssa 2025-05-24 12:23:19 +00:00
parent e2acaf93be
commit b5f18106e1
3 changed files with 51 additions and 20 deletions

View file

@ -14,6 +14,7 @@ use crate::{
discord::{ discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS}, cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config, gateway::cluster_config,
shard_state::ShardStateManager,
}, },
event_awaiter::{AwaitEventRequest, EventAwaiter}, event_awaiter::{AwaitEventRequest, EventAwaiter},
}; };
@ -25,7 +26,7 @@ fn status_code(code: StatusCode, body: String) -> Response {
// this function is manually formatted for easier legibility of route_services // this function is manually formatted for easier legibility of route_services
#[rustfmt::skip] #[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> { pub async fn run_server(cache: Arc<DiscordCache>, shard_state: Arc<ShardStateManager>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
// hacky fix for `move` // hacky fix for `move`
let runtime_config_for_post = runtime_config.clone(); let runtime_config_for_post = runtime_config.clone();
let runtime_config_for_delete = runtime_config.clone(); let runtime_config_for_delete = runtime_config.clone();
@ -211,6 +212,10 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
status_code(StatusCode::NO_CONTENT, "".to_string()) status_code(StatusCode::NO_CONTENT, "".to_string())
})) }))
.route("/shard_status", get(|| async move {
status_code(StatusCode::FOUND, to_string(&shard_state.get().await).unwrap())
}))
.layer(axum::middleware::from_fn(crate::logger::logger)) .layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache); .with_state(cache);

View file

@ -1,5 +1,6 @@
use fred::{clients::RedisPool, interfaces::HashesInterface}; use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge}; use metrics::{counter, gauge};
use tokio::sync::RwLock;
use tracing::info; use tracing::info;
use twilight_gateway::Event; use twilight_gateway::Event;
@ -9,21 +10,20 @@ use libpk::state::ShardState;
use super::gateway::cluster_config; use super::gateway::cluster_config;
#[derive(Clone)]
pub struct ShardStateManager { pub struct ShardStateManager {
redis: RedisPool, redis: RedisPool,
shards: HashMap<u32, ShardState>, shards: RwLock<HashMap<u32, ShardState>>,
} }
pub fn new(redis: RedisPool) -> ShardStateManager { pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { ShardStateManager {
redis: redis, redis: redis,
shards: HashMap::new(), shards: RwLock::new(HashMap::new()),
} }
} }
impl ShardStateManager { impl ShardStateManager {
pub async fn handle_event(&mut self, shard_id: u32, event: Event) -> anyhow::Result<()> { pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event { match event {
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await, Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await, Event::Resumed => self.ready_or_resumed(shard_id, true).await,
@ -31,21 +31,33 @@ impl ShardStateManager {
} }
} }
async fn save_shard(&mut self, shard_id: u32) -> anyhow::Result<()> { async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> {
let info = self.shards.get(&shard_id); {
let mut shards = self.shards.write().await;
shards.insert(id, state.clone());
}
self.redis self.redis
.hset::<(), &str, (String, String)>( .hset::<(), &str, (String, String)>(
"pluralkit:shardstatus", "pluralkit:shardstatus",
( (
shard_id.to_string(), id.to_string(),
serde_json::to_string(&info).expect("could not serialize shard"), serde_json::to_string(&state).expect("could not serialize shard"),
), ),
) )
.await?; .await?;
Ok(()) Ok(())
} }
async fn ready_or_resumed(&mut self, shard_id: u32, resumed: bool) -> anyhow::Result<()> { async fn get_shard(&self, id: u32) -> Option<ShardState> {
let shards = self.shards.read().await;
shards.get(&id).cloned()
}
pub async fn get(&self) -> Vec<ShardState> {
self.shards.read().await.values().cloned().collect()
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!( info!(
"shard {} {}", "shard {} {}",
shard_id, shard_id,
@ -59,40 +71,52 @@ impl ShardStateManager {
.increment(1); .increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1); gauge!("pluralkit_gateway_shard_up").increment(1);
let info = self.shards.entry(shard_id).or_insert(ShardState::default()); let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32; info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32); info.cluster_id = Some(cluster_config().node_id as i32);
info.last_connection = chrono::offset::Utc::now().timestamp() as i32; info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true; info.up = true;
self.save_shard(shard_id).await?; self.save_shard(shard_id, info).await?;
Ok(()) Ok(())
} }
pub async fn socket_closed(&mut self, shard_id: u32) -> anyhow::Result<()> { pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1); gauge!("pluralkit_gateway_shard_up").decrement(1);
let info = self.shards.entry(shard_id).or_insert(ShardState::default()); let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32; info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32); info.cluster_id = Some(cluster_config().node_id as i32);
info.up = false; info.up = false;
info.disconnection_count += 1; info.disconnection_count += 1;
self.save_shard(shard_id).await?; self.save_shard(shard_id, info).await?;
Ok(()) Ok(())
} }
pub async fn heartbeated(&mut self, shard_id: u32, latency: i32) -> anyhow::Result<()> { pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency); gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
let info = self.shards.entry(shard_id).or_insert(ShardState::default()); let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32; info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32); info.cluster_id = Some(cluster_config().node_id as i32);
info.up = true; info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32; info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency; info.latency = latency;
self.save_shard(shard_id).await?; self.save_shard(shard_id, info).await?;
Ok(()) Ok(())
} }
} }

View file

@ -91,8 +91,10 @@ async fn main() -> anyhow::Result<()> {
))); )));
} }
let shard_state = Arc::new(discord::shard_state::new(redis.clone()));
set.spawn(tokio::spawn({ set.spawn(tokio::spawn({
let mut shard_state = discord::shard_state::new(redis.clone()); let shard_state = shard_state.clone();
async move { async move {
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await { while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
@ -187,7 +189,7 @@ async fn main() -> anyhow::Result<()> {
)); ));
set.spawn(tokio::spawn(async move { set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache, runtime_config, awaiter.clone()).await { match cache_api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await {
Err(error) => { Err(error) => {
error!(?error, "failed to serve cache api"); error!(?error, "failed to serve cache api");
} }