feat(gateway): add /shard_status endpoint

This commit is contained in:
alyssa 2025-05-24 12:23:19 +00:00
parent e2acaf93be
commit b5f18106e1
3 changed files with 51 additions and 20 deletions

View file

@ -14,6 +14,7 @@ use crate::{
discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
shard_state::ShardStateManager,
},
event_awaiter::{AwaitEventRequest, EventAwaiter},
};
@ -25,7 +26,7 @@ fn status_code(code: StatusCode, body: String) -> Response {
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
pub async fn run_server(cache: Arc<DiscordCache>, shard_state: Arc<ShardStateManager>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
// hacky fix for `move`
let runtime_config_for_post = runtime_config.clone();
let runtime_config_for_delete = runtime_config.clone();
@ -211,6 +212,10 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
status_code(StatusCode::NO_CONTENT, "".to_string())
}))
.route("/shard_status", get(|| async move {
status_code(StatusCode::FOUND, to_string(&shard_state.get().await).unwrap())
}))
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);

View file

@ -1,5 +1,6 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge};
use tokio::sync::RwLock;
use tracing::info;
use twilight_gateway::Event;
@ -9,21 +10,20 @@ use libpk::state::ShardState;
use super::gateway::cluster_config;
#[derive(Clone)]
pub struct ShardStateManager {
redis: RedisPool,
shards: HashMap<u32, ShardState>,
shards: RwLock<HashMap<u32, ShardState>>,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager {
redis: redis,
shards: HashMap::new(),
shards: RwLock::new(HashMap::new()),
}
}
impl ShardStateManager {
pub async fn handle_event(&mut self, shard_id: u32, event: Event) -> anyhow::Result<()> {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
@ -31,21 +31,33 @@ impl ShardStateManager {
}
}
async fn save_shard(&mut self, shard_id: u32) -> anyhow::Result<()> {
let info = self.shards.get(&shard_id);
async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> {
{
let mut shards = self.shards.write().await;
shards.insert(id, state.clone());
}
self.redis
.hset::<(), &str, (String, String)>(
"pluralkit:shardstatus",
(
shard_id.to_string(),
serde_json::to_string(&info).expect("could not serialize shard"),
id.to_string(),
serde_json::to_string(&state).expect("could not serialize shard"),
),
)
.await?;
Ok(())
}
async fn ready_or_resumed(&mut self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
async fn get_shard(&self, id: u32) -> Option<ShardState> {
let shards = self.shards.read().await;
shards.get(&id).cloned()
}
pub async fn get(&self) -> Vec<ShardState> {
self.shards.read().await.values().cloned().collect()
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!(
"shard {} {}",
shard_id,
@ -59,40 +71,52 @@ impl ShardStateManager {
.increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1);
let info = self.shards.entry(shard_id).or_insert(ShardState::default());
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id).await?;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn socket_closed(&mut self, shard_id: u32) -> anyhow::Result<()> {
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let info = self.shards.entry(shard_id).or_insert(ShardState::default());
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id).await?;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn heartbeated(&mut self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
let info = self.shards.entry(shard_id).or_insert(ShardState::default());
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency;
self.save_shard(shard_id).await?;
self.save_shard(shard_id, info).await?;
Ok(())
}
}

View file

@ -91,8 +91,10 @@ async fn main() -> anyhow::Result<()> {
)));
}
let shard_state = Arc::new(discord::shard_state::new(redis.clone()));
set.spawn(tokio::spawn({
let mut shard_state = discord::shard_state::new(redis.clone());
let shard_state = shard_state.clone();
async move {
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
@ -187,7 +189,7 @@ async fn main() -> anyhow::Result<()> {
));
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache, runtime_config, awaiter.clone()).await {
match cache_api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await {
Err(error) => {
error!(?error, "failed to serve cache api");
}