mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
feat: remote await events from gateway
This commit is contained in:
parent
64ff69723c
commit
15c992c572
17 changed files with 439 additions and 30 deletions
|
|
@ -14,6 +14,7 @@ lazy_static = { workspace = true }
|
|||
libpk = { path = "../libpk" }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
signal-hook = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -10,9 +10,12 @@ use serde_json::{json, to_string};
|
|||
use tracing::{error, info};
|
||||
use twilight_model::id::Id;
|
||||
|
||||
use crate::discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
use crate::{
|
||||
discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
},
|
||||
event_awaiter::{AwaitEventRequest, EventAwaiter},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
|
|
@ -22,10 +25,11 @@ fn status_code(code: StatusCode, body: String) -> Response {
|
|||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeConfig>) -> anyhow::Result<()> {
|
||||
pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
|
||||
// hacky fix for `move`
|
||||
let runtime_config_for_post = runtime_config.clone();
|
||||
let runtime_config_for_delete = runtime_config.clone();
|
||||
let awaiter_for_clear = awaiter.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
|
|
@ -190,6 +194,19 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
|
|||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
|
||||
.route("/await_event", post(|body: String| async move {
|
||||
info!("got request: {body}");
|
||||
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
|
||||
return status_code(StatusCode::BAD_REQUEST, "".to_string());
|
||||
};
|
||||
awaiter.handle_request(req).await;
|
||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||
}))
|
||||
.route("/clear_awaiter", post(|| async move {
|
||||
awaiter_for_clear.clear().await;
|
||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||
}))
|
||||
|
||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||
.with_state(cache);
|
||||
|
||||
|
|
|
|||
|
|
@ -82,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
|
||||
pub async fn runner(
|
||||
mut shard: Shard<RedisQueue>,
|
||||
tx: Sender<(ShardId, String)>,
|
||||
tx: Sender<(ShardId, Event, String)>,
|
||||
shard_state: ShardStateManager,
|
||||
cache: Arc<DiscordCache>,
|
||||
runtime_config: Arc<RuntimeConfig>,
|
||||
|
|
@ -182,21 +182,21 @@ pub async fn runner(
|
|||
// and the default match skips the next block (continues to the next event)
|
||||
match event {
|
||||
Event::InteractionCreate(_) => {}
|
||||
Event::MessageCreate(m) if m.author.id != our_user_id => {}
|
||||
Event::MessageUpdate(m)
|
||||
Event::MessageCreate(ref m) if m.author.id != our_user_id => {}
|
||||
Event::MessageUpdate(ref m)
|
||||
if let Some(author) = m.author.clone()
|
||||
&& author.id != our_user_id
|
||||
&& !author.bot => {}
|
||||
Event::MessageDelete(_) => {}
|
||||
Event::MessageDeleteBulk(_) => {}
|
||||
Event::ReactionAdd(r) if r.user_id != our_user_id => {}
|
||||
Event::ReactionAdd(ref r) if r.user_id != our_user_id => {}
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
|
||||
tx.send((shard.id(), raw_event)).await.unwrap();
|
||||
tx.send((shard.id(), event, raw_event)).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
223
crates/gateway/src/event_awaiter.rs
Normal file
223
crates/gateway/src/event_awaiter.rs
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
// - reaction: (message_id, user_id)
|
||||
// - message: (author_id, channel_id, ?options)
|
||||
// - interaction: (custom_id where not_includes "help-menu")
|
||||
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tokio::{sync::RwLock, time::Instant};
|
||||
use tracing::info;
|
||||
use twilight_gateway::Event;
|
||||
use twilight_model::{
|
||||
application::interaction::InteractionData,
|
||||
id::{
|
||||
marker::{ChannelMarker, MessageMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
|
||||
static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AwaitEventRequest {
|
||||
Reaction {
|
||||
message_id: Id<MessageMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
Message {
|
||||
channel_id: Id<ChannelMarker>,
|
||||
author_id: Id<UserMarker>,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
options: Option<Vec<String>>,
|
||||
},
|
||||
Interaction {
|
||||
id: String,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct EventAwaiter {
|
||||
reactions: RwLock<HashMap<(Id<MessageMarker>, Id<UserMarker>), (Instant, String)>>,
|
||||
messages: RwLock<
|
||||
HashMap<(Id<ChannelMarker>, Id<UserMarker>), (Instant, String, Option<Vec<String>>)>,
|
||||
>,
|
||||
interactions: RwLock<HashMap<String, (Instant, String)>>,
|
||||
}
|
||||
|
||||
impl EventAwaiter {
|
||||
pub fn new() -> Self {
|
||||
let v = Self {
|
||||
reactions: RwLock::new(HashMap::new()),
|
||||
messages: RwLock::new(HashMap::new()),
|
||||
interactions: RwLock::new(HashMap::new()),
|
||||
};
|
||||
|
||||
v
|
||||
}
|
||||
|
||||
pub async fn cleanup_loop(&self) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
info!("running event_awaiter cleanup loop");
|
||||
let mut counts = (0, 0, 0);
|
||||
let now = Instant::now();
|
||||
{
|
||||
let mut reactions = self.reactions.write().await;
|
||||
for key in reactions.clone().keys() {
|
||||
if let Entry::Occupied(entry) = reactions.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.0 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut messages = self.messages.write().await;
|
||||
for key in messages.clone().keys() {
|
||||
if let Entry::Occupied(entry) = messages.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.1 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut interactions = self.interactions.write().await;
|
||||
for key in interactions.clone().keys() {
|
||||
if let Entry::Occupied(entry) = interactions.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.2 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn target_for_event(&self, event: Event) -> Option<String> {
|
||||
match event {
|
||||
Event::MessageCreate(message) => {
|
||||
let mut messages = self.messages.write().await;
|
||||
|
||||
messages
|
||||
.remove(&(message.channel_id, message.author.id))
|
||||
.map(|(timeout, target, options)| {
|
||||
if let Some(options) = options
|
||||
&& !options.contains(&message.content)
|
||||
{
|
||||
messages.insert(
|
||||
(message.channel_id, message.author.id),
|
||||
(timeout, target, Some(options)),
|
||||
);
|
||||
return None;
|
||||
}
|
||||
Some((*target).to_string())
|
||||
})?
|
||||
}
|
||||
Event::ReactionAdd(reaction)
|
||||
if let Some((_, target)) = self
|
||||
.reactions
|
||||
.write()
|
||||
.await
|
||||
.remove(&(reaction.message_id, reaction.user_id)) =>
|
||||
{
|
||||
Some((*target).to_string())
|
||||
}
|
||||
Event::InteractionCreate(interaction)
|
||||
if let Some(data) = interaction.data.clone()
|
||||
&& let InteractionData::MessageComponent(component) = data
|
||||
&& !component.custom_id.contains("help-menu")
|
||||
&& let Some((_, target)) =
|
||||
self.interactions.write().await.remove(&component.custom_id) =>
|
||||
{
|
||||
Some((*target).to_string())
|
||||
}
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_request(&self, req: AwaitEventRequest) {
|
||||
match req {
|
||||
AwaitEventRequest::Reaction {
|
||||
message_id,
|
||||
user_id,
|
||||
target,
|
||||
timeout,
|
||||
} => {
|
||||
self.reactions.write().await.insert(
|
||||
(message_id, user_id),
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target,
|
||||
),
|
||||
);
|
||||
}
|
||||
AwaitEventRequest::Message {
|
||||
channel_id,
|
||||
author_id,
|
||||
target,
|
||||
timeout,
|
||||
options,
|
||||
} => {
|
||||
self.messages.write().await.insert(
|
||||
(channel_id, author_id),
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target,
|
||||
options,
|
||||
),
|
||||
);
|
||||
}
|
||||
AwaitEventRequest::Interaction {
|
||||
id,
|
||||
target,
|
||||
timeout,
|
||||
} => {
|
||||
self.interactions.write().await.insert(
|
||||
id,
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target,
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear(&self) {
|
||||
self.reactions.write().await.clear();
|
||||
self.messages.write().await.clear();
|
||||
self.interactions.write().await.clear();
|
||||
}
|
||||
}
|
||||
|
|
@ -1,8 +1,10 @@
|
|||
#![feature(let_chains)]
|
||||
#![feature(if_let_guard)]
|
||||
#![feature(duration_constructors)]
|
||||
|
||||
use chrono::Timelike;
|
||||
use discord::gateway::cluster_config;
|
||||
use event_awaiter::EventAwaiter;
|
||||
use fred::{clients::RedisPool, interfaces::*};
|
||||
use libpk::runtime_config::RuntimeConfig;
|
||||
use reqwest::ClientBuilder;
|
||||
|
|
@ -12,12 +14,13 @@ use signal_hook::{
|
|||
};
|
||||
use std::{sync::Arc, time::Duration, vec::Vec};
|
||||
use tokio::{sync::mpsc::channel, task::JoinSet};
|
||||
use tracing::{error, info, warn};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use twilight_gateway::{MessageSender, ShardId};
|
||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||
|
||||
mod cache_api;
|
||||
mod discord;
|
||||
mod event_awaiter;
|
||||
mod logger;
|
||||
|
||||
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
|
||||
|
|
@ -39,6 +42,11 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
|
||||
let shard_state = discord::shard_state::new(redis.clone());
|
||||
let cache = Arc::new(discord::cache::new());
|
||||
let awaiter = Arc::new(EventAwaiter::new());
|
||||
tokio::spawn({
|
||||
let awaiter = awaiter.clone();
|
||||
async move { awaiter.cleanup_loop().await }
|
||||
});
|
||||
|
||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||
|
||||
|
|
@ -63,22 +71,36 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
|
||||
set.spawn(tokio::spawn({
|
||||
let runtime_config = runtime_config.clone();
|
||||
async move {
|
||||
let client = Arc::new(ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.timeout(Duration::from_secs(1))
|
||||
.build()
|
||||
.expect("error making client"));
|
||||
let awaiter = awaiter.clone();
|
||||
|
||||
async move {
|
||||
let client = Arc::new(
|
||||
ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.timeout(Duration::from_secs(1))
|
||||
.build()
|
||||
.expect("error making client"),
|
||||
);
|
||||
|
||||
while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await {
|
||||
let target = if let Some(target) = awaiter.target_for_event(parsed_event).await {
|
||||
debug!("sending event to awaiter");
|
||||
Some(target)
|
||||
} else if let Some(target) =
|
||||
runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await
|
||||
{
|
||||
Some(target)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
while let Some((shard_id, event)) = event_rx.recv().await {
|
||||
let target = runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await;
|
||||
if let Some(target) = target {
|
||||
tokio::spawn({
|
||||
let client = client.clone();
|
||||
async move {
|
||||
if let Err(error) = client
|
||||
.post(format!("{target}/{}", shard_id.number()))
|
||||
.body(event)
|
||||
.body(raw_event)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
|
|
@ -98,7 +120,7 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
// todo: probably don't do it this way
|
||||
let api_shutdown_tx = shutdown_tx.clone();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
match cache_api::run_server(cache, runtime_config).await {
|
||||
match cache_api::run_server(cache, runtime_config, awaiter.clone()).await {
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to serve cache api");
|
||||
let _ = api_shutdown_tx.send(());
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue