diff --git a/Cargo.lock b/Cargo.lock index a4a5e80c..75afc365 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1089,6 +1089,7 @@ dependencies = [ "libpk", "metrics", "reqwest 0.12.8", + "serde", "serde_json", "serde_variant", "signal-hook", diff --git a/Myriad/Cache/HTTPDiscordCache.cs b/Myriad/Cache/HTTPDiscordCache.cs index 31035660..311d697d 100644 --- a/Myriad/Cache/HTTPDiscordCache.cs +++ b/Myriad/Cache/HTTPDiscordCache.cs @@ -1,7 +1,10 @@ using Serilog; using System.Net; +using System.Text; using System.Text.Json; +using NodaTime; + using Myriad.Serialization; using Myriad.Types; @@ -12,6 +15,7 @@ public class HttpDiscordCache: IDiscordCache private readonly ILogger _logger; private readonly HttpClient _client; private readonly Uri _cacheEndpoint; + private readonly string? _eventTarget; private readonly int _shardCount; private readonly ulong _ownUserId; @@ -21,11 +25,12 @@ public class HttpDiscordCache: IDiscordCache public EventHandler<(bool?, string)> OnDebug; - public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, int shardCount, ulong ownUserId, bool useInnerCache) + public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, string? eventTarget, int shardCount, ulong ownUserId, bool useInnerCache) { _logger = logger; _client = client; _cacheEndpoint = new Uri(cacheEndpoint); + _eventTarget = eventTarget; _shardCount = shardCount; _ownUserId = ownUserId; _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad(); @@ -65,6 +70,68 @@ public class HttpDiscordCache: IDiscordCache return JsonSerializer.Deserialize(plaintext, _jsonSerializerOptions); } + private Task AwaitEvent(ulong guildId, object data) + => AwaitEventShard((int)((guildId >> 22) % (ulong)_shardCount), data); + + private async Task AwaitEventShard(int shardId, object data) + { + if (_eventTarget == null) + throw new Exception("missing event target for remote await event"); + + var cluster = _cacheEndpoint.Authority; + // todo: there should not be infra-specific code here + if (cluster.Contains(".service.consul") || cluster.Contains("process.pluralkit-gateway.internal")) + // int(((guild_id >> 22) % shard_count) / 16) + cluster = $"cluster{shardId / 16}.{cluster}"; + + var response = await _client.PostAsync( + $"{_cacheEndpoint.Scheme}://{cluster}/await_event", + new StringContent(JsonSerializer.Serialize(data), Encoding.UTF8) + ); + + if (response.StatusCode != HttpStatusCode.NoContent) + throw new Exception($"failed to await event from gateway: {response.StatusCode}"); + } + + public async Task AwaitReaction(ulong guildId, ulong messageId, ulong userId, Duration? timeout) + { + var obj = new + { + message_id = messageId, + user_id = userId, + target = _eventTarget!, + timeout = timeout?.TotalSeconds, + }; + + await AwaitEvent(guildId, obj); + } + + public async Task AwaitMessage(ulong guildId, ulong channelId, ulong authorId, Duration? timeout, string[] options = null) + { + var obj = new + { + channel_id = channelId, + author_id = authorId, + target = _eventTarget!, + timeout = timeout?.TotalSeconds, + options = options, + }; + + await AwaitEvent(guildId, obj); + } + + public async Task AwaitInteraction(int shardId, string id, Duration? timeout) + { + var obj = new + { + id = id, + target = _eventTarget!, + timeout = timeout?.TotalSeconds, + }; + + await AwaitEventShard(shardId, obj); + } + public async Task TryGetGuild(ulong guildId) { var hres = await QueryCache($"/guilds/{guildId}", guildId); diff --git a/Myriad/Myriad.csproj b/Myriad/Myriad.csproj index 3b92ea30..0a6ed588 100644 --- a/Myriad/Myriad.csproj +++ b/Myriad/Myriad.csproj @@ -21,6 +21,8 @@ + + diff --git a/Myriad/packages.lock.json b/Myriad/packages.lock.json index 945321a9..7680637b 100644 --- a/Myriad/packages.lock.json +++ b/Myriad/packages.lock.json @@ -2,6 +2,22 @@ "version": 1, "dependencies": { "net8.0": { + "NodaTime": { + "type": "Direct", + "requested": "[3.2.0, )", + "resolved": "3.2.0", + "contentHash": "yoRA3jEJn8NM0/rQm78zuDNPA3DonNSZdsorMUj+dltc1D+/Lc5h9YXGqbEEZozMGr37lAoYkcSM/KjTVqD0ow==" + }, + "NodaTime.Serialization.JsonNet": { + "type": "Direct", + "requested": "[3.1.0, )", + "resolved": "3.1.0", + "contentHash": "eEr9lXUz50TYr4rpeJG4TDAABkpxjIKr5mDSi/Zav8d6Njy6fH7x4ZtNwWFj0Vd+vIvEZNrHFQ4Gfy8j4BqRGg==", + "dependencies": { + "Newtonsoft.Json": "13.0.3", + "NodaTime": "[3.0.0, 4.0.0)" + } + }, "Polly": { "type": "Direct", "requested": "[8.5.0, )", @@ -52,6 +68,11 @@ "resolved": "6.0.0", "contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA==" }, + "Newtonsoft.Json": { + "type": "Transitive", + "resolved": "13.0.3", + "contentHash": "HrC5BXdl00IP9zeV+0Z848QWPAoCr9P3bDEZguI+gkLcBKAOxix/tLEAAHC+UvDNPv4a2d18lOReHMOagPa+zQ==" + }, "Pipelines.Sockets.Unofficial": { "type": "Transitive", "resolved": "2.2.8", diff --git a/PluralKit.Bot/BotConfig.cs b/PluralKit.Bot/BotConfig.cs index 0be554a1..1e6e0f0b 100644 --- a/PluralKit.Bot/BotConfig.cs +++ b/PluralKit.Bot/BotConfig.cs @@ -26,6 +26,7 @@ public class BotConfig public string? HttpListenerAddr { get; set; } public bool DisableGateway { get; set; } = false; + public string? EventAwaiterTarget { get; set; } public string? DiscordBaseUrl { get; set; } public string? AvatarServiceUrl { get; set; } diff --git a/PluralKit.Bot/Interactive/BaseInteractive.cs b/PluralKit.Bot/Interactive/BaseInteractive.cs index c2ae33ee..779058fe 100644 --- a/PluralKit.Bot/Interactive/BaseInteractive.cs +++ b/PluralKit.Bot/Interactive/BaseInteractive.cs @@ -28,7 +28,7 @@ public abstract class BaseInteractive ButtonStyle style = ButtonStyle.Secondary, bool disabled = false) { var dispatch = _ctx.Services.Resolve(); - var customId = dispatch.Register(handler, Timeout); + var customId = dispatch.Register(_ctx.ShardId, handler, Timeout); var button = new Button { @@ -89,7 +89,7 @@ public abstract class BaseInteractive { var dispatch = ctx.Services.Resolve(); foreach (var button in _buttons) - button.CustomId = dispatch.Register(button.Handler, Timeout); + button.CustomId = dispatch.Register(_ctx.ShardId, button.Handler, Timeout); } public abstract Task Start(); diff --git a/PluralKit.Bot/Interactive/YesNoPrompt.cs b/PluralKit.Bot/Interactive/YesNoPrompt.cs index 110e4bb9..194dd1f1 100644 --- a/PluralKit.Bot/Interactive/YesNoPrompt.cs +++ b/PluralKit.Bot/Interactive/YesNoPrompt.cs @@ -1,5 +1,6 @@ using Autofac; +using Myriad.Cache; using Myriad.Gateway; using Myriad.Rest.Types; using Myriad.Types; @@ -69,6 +70,9 @@ public class YesNoPrompt: BaseInteractive return true; } + // no need to reawait message + // gateway will already have sent us only matching messages + return false; } @@ -88,6 +92,17 @@ public class YesNoPrompt: BaseInteractive { try { + // check if http gateway and set listener + // todo: this one needs to handle options for message + if (_ctx.Cache is HttpDiscordCache) + await (_ctx.Cache as HttpDiscordCache).AwaitMessage( + _ctx.Guild?.Id ?? 0, + _ctx.Channel.Id, + _ctx.Author.Id, + Timeout, + options: new[] { "yes", "y", "no", "n" } + ); + await queue.WaitFor(MessagePredicate, Timeout, cts.Token); } catch (TimeoutException e) diff --git a/PluralKit.Bot/Modules.cs b/PluralKit.Bot/Modules.cs index ae6f85ec..668ea9c3 100644 --- a/PluralKit.Bot/Modules.cs +++ b/PluralKit.Bot/Modules.cs @@ -49,8 +49,15 @@ public class BotModule: Module if (botConfig.HttpCacheUrl != null) { - var cache = new HttpDiscordCache(c.Resolve(), - c.Resolve(), botConfig.HttpCacheUrl, botConfig.Cluster?.TotalShards ?? 1, botConfig.ClientId, botConfig.HttpUseInnerCache); + var cache = new HttpDiscordCache( + c.Resolve(), + c.Resolve(), + botConfig.HttpCacheUrl, + botConfig.EventAwaiterTarget, + botConfig.Cluster?.TotalShards ?? 1, + botConfig.ClientId, + botConfig.HttpUseInnerCache + ); var metrics = c.Resolve(); diff --git a/PluralKit.Bot/Services/InteractionDispatchService.cs b/PluralKit.Bot/Services/InteractionDispatchService.cs index 968ea35a..f900a792 100644 --- a/PluralKit.Bot/Services/InteractionDispatchService.cs +++ b/PluralKit.Bot/Services/InteractionDispatchService.cs @@ -1,5 +1,7 @@ using System.Collections.Concurrent; +using Myriad.Cache; + using NodaTime; using Serilog; @@ -16,9 +18,12 @@ public class InteractionDispatchService: IDisposable private readonly ConcurrentDictionary _handlers = new(); private readonly ILogger _logger; - public InteractionDispatchService(IClock clock, ILogger logger) + private readonly IDiscordCache _cache; + + public InteractionDispatchService(IClock clock, ILogger logger, IDiscordCache cache) { _clock = clock; + _cache = cache; _logger = logger.ForContext(); _cleanupWorker = CleanupLoop(_cts.Token); @@ -50,9 +55,15 @@ public class InteractionDispatchService: IDisposable _handlers.TryRemove(customIdGuid, out _); } - public string Register(Func callback, Duration? expiry = null) + public string Register(int shardId, Func callback, Duration? expiry = null) { var key = Guid.NewGuid(); + + // if http_cache, return RegisterRemote + // not awaited here, it's probably fine + if (_cache is HttpDiscordCache) + (_cache as HttpDiscordCache).AwaitInteraction(shardId, key.ToString(), expiry); + var handler = new RegisteredInteraction { Callback = callback, diff --git a/PluralKit.Bot/Utils/ContextUtils.cs b/PluralKit.Bot/Utils/ContextUtils.cs index ce353472..7cd3de42 100644 --- a/PluralKit.Bot/Utils/ContextUtils.cs +++ b/PluralKit.Bot/Utils/ContextUtils.cs @@ -1,6 +1,7 @@ using Autofac; using Myriad.Builders; +using Myriad.Cache; using Myriad.Gateway; using Myriad.Rest.Exceptions; using Myriad.Rest.Types.Requests; @@ -40,8 +41,12 @@ public static class ContextUtils } public static async Task AwaitReaction(this Context ctx, Message message, - User user = null, Func predicate = null, Duration? timeout = null) + User user, Func predicate = null, Duration? timeout = null) { + // check if http gateway and set listener + if (ctx.Cache is HttpDiscordCache) + await (ctx.Cache as HttpDiscordCache).AwaitReaction(ctx.Guild?.Id ?? 0, message.Id, user!.Id, timeout); + bool ReactionPredicate(MessageReactionAddEvent evt) { if (message.Id != evt.MessageId) return false; // Ignore reactions for different messages @@ -57,11 +62,17 @@ public static class ContextUtils public static async Task ConfirmWithReply(this Context ctx, string expectedReply, bool treatAsHid = false) { + var timeout = Duration.FromMinutes(1); + + // check if http gateway and set listener + if (ctx.Cache is HttpDiscordCache) + await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout); + bool Predicate(MessageCreateEvent e) => e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id; var msg = await ctx.Services.Resolve>() - .WaitFor(Predicate, Duration.FromMinutes(1)); + .WaitFor(Predicate, timeout); var content = msg.Content; if (treatAsHid) @@ -96,11 +107,17 @@ public static class ContextUtils async Task PromptPageNumber() { + var timeout = Duration.FromMinutes(0.5); + + // check if http gateway and set listener + if (ctx.Cache is HttpDiscordCache) + await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout); + bool Predicate(MessageCreateEvent e) => e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id; var msg = await ctx.Services.Resolve>() - .WaitFor(Predicate, Duration.FromMinutes(0.5)); + .WaitFor(Predicate, timeout); int.TryParse(msg.Content, out var num); diff --git a/PluralKit.Bot/packages.lock.json b/PluralKit.Bot/packages.lock.json index fca78f6e..79da3d52 100644 --- a/PluralKit.Bot/packages.lock.json +++ b/PluralKit.Bot/packages.lock.json @@ -764,6 +764,8 @@ "myriad": { "type": "Project", "dependencies": { + "NodaTime": "[3.2.0, )", + "NodaTime.Serialization.JsonNet": "[3.1.0, )", "Polly": "[8.5.0, )", "Polly.Contrib.WaitAndRetry": "[1.1.1, )", "Serilog": "[4.2.0, )", diff --git a/PluralKit.Tests/packages.lock.json b/PluralKit.Tests/packages.lock.json index c11cedb5..b9e49ba9 100644 --- a/PluralKit.Tests/packages.lock.json +++ b/PluralKit.Tests/packages.lock.json @@ -976,6 +976,8 @@ "myriad": { "type": "Project", "dependencies": { + "NodaTime": "[3.2.0, )", + "NodaTime.Serialization.JsonNet": "[3.1.0, )", "Polly": "[8.5.0, )", "Polly.Contrib.WaitAndRetry": "[1.1.1, )", "Serilog": "[4.2.0, )", diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index e9fd444e..420aef1c 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -14,6 +14,7 @@ lazy_static = { workspace = true } libpk = { path = "../libpk" } metrics = { workspace = true } reqwest = { workspace = true } +serde = { workspace = true } serde_json = { workspace = true } signal-hook = { workspace = true } tokio = { workspace = true } diff --git a/crates/gateway/src/cache_api.rs b/crates/gateway/src/cache_api.rs index e3de4696..d22fb113 100644 --- a/crates/gateway/src/cache_api.rs +++ b/crates/gateway/src/cache_api.rs @@ -10,9 +10,12 @@ use serde_json::{json, to_string}; use tracing::{error, info}; use twilight_model::id::Id; -use crate::discord::{ - cache::{dm_channel, DiscordCache, DM_PERMISSIONS}, - gateway::cluster_config, +use crate::{ + discord::{ + cache::{dm_channel, DiscordCache, DM_PERMISSIONS}, + gateway::cluster_config, + }, + event_awaiter::{AwaitEventRequest, EventAwaiter}, }; use std::sync::Arc; @@ -22,10 +25,11 @@ fn status_code(code: StatusCode, body: String) -> Response { // this function is manually formatted for easier legibility of route_services #[rustfmt::skip] -pub async fn run_server(cache: Arc, runtime_config: Arc) -> anyhow::Result<()> { +pub async fn run_server(cache: Arc, runtime_config: Arc, awaiter: Arc) -> anyhow::Result<()> { // hacky fix for `move` let runtime_config_for_post = runtime_config.clone(); let runtime_config_for_delete = runtime_config.clone(); + let awaiter_for_clear = awaiter.clone(); let app = Router::new() .route( @@ -190,6 +194,19 @@ pub async fn run_server(cache: Arc, runtime_config: Arc(&body) else { + return status_code(StatusCode::BAD_REQUEST, "".to_string()); + }; + awaiter.handle_request(req).await; + status_code(StatusCode::NO_CONTENT, "".to_string()) + })) + .route("/clear_awaiter", post(|| async move { + awaiter_for_clear.clear().await; + status_code(StatusCode::NO_CONTENT, "".to_string()) + })) + .layer(axum::middleware::from_fn(crate::logger::logger)) .with_state(cache); diff --git a/crates/gateway/src/discord/gateway.rs b/crates/gateway/src/discord/gateway.rs index c4a13a81..89dfc26f 100644 --- a/crates/gateway/src/discord/gateway.rs +++ b/crates/gateway/src/discord/gateway.rs @@ -82,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result, - tx: Sender<(ShardId, String)>, + tx: Sender<(ShardId, Event, String)>, shard_state: ShardStateManager, cache: Arc, runtime_config: Arc, @@ -182,21 +182,21 @@ pub async fn runner( // and the default match skips the next block (continues to the next event) match event { Event::InteractionCreate(_) => {} - Event::MessageCreate(m) if m.author.id != our_user_id => {} - Event::MessageUpdate(m) + Event::MessageCreate(ref m) if m.author.id != our_user_id => {} + Event::MessageUpdate(ref m) if let Some(author) = m.author.clone() && author.id != our_user_id && !author.bot => {} Event::MessageDelete(_) => {} Event::MessageDeleteBulk(_) => {} - Event::ReactionAdd(r) if r.user_id != our_user_id => {} + Event::ReactionAdd(ref r) if r.user_id != our_user_id => {} _ => { continue; } } if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await { - tx.send((shard.id(), raw_event)).await.unwrap(); + tx.send((shard.id(), event, raw_event)).await.unwrap(); } } } diff --git a/crates/gateway/src/event_awaiter.rs b/crates/gateway/src/event_awaiter.rs new file mode 100644 index 00000000..9196170d --- /dev/null +++ b/crates/gateway/src/event_awaiter.rs @@ -0,0 +1,223 @@ +// - reaction: (message_id, user_id) +// - message: (author_id, channel_id, ?options) +// - interaction: (custom_id where not_includes "help-menu") + +use std::{ + collections::{hash_map::Entry, HashMap}, + time::Duration, +}; + +use serde::Deserialize; +use tokio::{sync::RwLock, time::Instant}; +use tracing::info; +use twilight_gateway::Event; +use twilight_model::{ + application::interaction::InteractionData, + id::{ + marker::{ChannelMarker, MessageMarker, UserMarker}, + Id, + }, +}; + +static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15); + +#[derive(Deserialize)] +#[serde(untagged)] +pub enum AwaitEventRequest { + Reaction { + message_id: Id, + user_id: Id, + target: String, + timeout: Option, + }, + Message { + channel_id: Id, + author_id: Id, + target: String, + timeout: Option, + options: Option>, + }, + Interaction { + id: String, + target: String, + timeout: Option, + }, +} + +pub struct EventAwaiter { + reactions: RwLock, Id), (Instant, String)>>, + messages: RwLock< + HashMap<(Id, Id), (Instant, String, Option>)>, + >, + interactions: RwLock>, +} + +impl EventAwaiter { + pub fn new() -> Self { + let v = Self { + reactions: RwLock::new(HashMap::new()), + messages: RwLock::new(HashMap::new()), + interactions: RwLock::new(HashMap::new()), + }; + + v + } + + pub async fn cleanup_loop(&self) { + loop { + tokio::time::sleep(Duration::from_secs(30)).await; + info!("running event_awaiter cleanup loop"); + let mut counts = (0, 0, 0); + let now = Instant::now(); + { + let mut reactions = self.reactions.write().await; + for key in reactions.clone().keys() { + if let Entry::Occupied(entry) = reactions.entry(key.clone()) + && entry.get().0 < now + { + counts.0 += 1; + entry.remove(); + } + } + } + { + let mut messages = self.messages.write().await; + for key in messages.clone().keys() { + if let Entry::Occupied(entry) = messages.entry(key.clone()) + && entry.get().0 < now + { + counts.1 += 1; + entry.remove(); + } + } + } + { + let mut interactions = self.interactions.write().await; + for key in interactions.clone().keys() { + if let Entry::Occupied(entry) = interactions.entry(key.clone()) + && entry.get().0 < now + { + counts.2 += 1; + entry.remove(); + } + } + } + info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2); + } + } + + pub async fn target_for_event(&self, event: Event) -> Option { + match event { + Event::MessageCreate(message) => { + let mut messages = self.messages.write().await; + + messages + .remove(&(message.channel_id, message.author.id)) + .map(|(timeout, target, options)| { + if let Some(options) = options + && !options.contains(&message.content) + { + messages.insert( + (message.channel_id, message.author.id), + (timeout, target, Some(options)), + ); + return None; + } + Some((*target).to_string()) + })? + } + Event::ReactionAdd(reaction) + if let Some((_, target)) = self + .reactions + .write() + .await + .remove(&(reaction.message_id, reaction.user_id)) => + { + Some((*target).to_string()) + } + Event::InteractionCreate(interaction) + if let Some(data) = interaction.data.clone() + && let InteractionData::MessageComponent(component) = data + && !component.custom_id.contains("help-menu") + && let Some((_, target)) = + self.interactions.write().await.remove(&component.custom_id) => + { + Some((*target).to_string()) + } + + _ => None, + } + } + + pub async fn handle_request(&self, req: AwaitEventRequest) { + match req { + AwaitEventRequest::Reaction { + message_id, + user_id, + target, + timeout, + } => { + self.reactions.write().await.insert( + (message_id, user_id), + ( + Instant::now() + .checked_add( + timeout + .map(|i| Duration::from_secs(i)) + .unwrap_or(DEFAULT_TIMEOUT), + ) + .expect("invalid time"), + target, + ), + ); + } + AwaitEventRequest::Message { + channel_id, + author_id, + target, + timeout, + options, + } => { + self.messages.write().await.insert( + (channel_id, author_id), + ( + Instant::now() + .checked_add( + timeout + .map(|i| Duration::from_secs(i)) + .unwrap_or(DEFAULT_TIMEOUT), + ) + .expect("invalid time"), + target, + options, + ), + ); + } + AwaitEventRequest::Interaction { + id, + target, + timeout, + } => { + self.interactions.write().await.insert( + id, + ( + Instant::now() + .checked_add( + timeout + .map(|i| Duration::from_secs(i)) + .unwrap_or(DEFAULT_TIMEOUT), + ) + .expect("invalid time"), + target, + ), + ); + } + } + } + + pub async fn clear(&self) { + self.reactions.write().await.clear(); + self.messages.write().await.clear(); + self.interactions.write().await.clear(); + } +} diff --git a/crates/gateway/src/main.rs b/crates/gateway/src/main.rs index 4213b235..ab679c06 100644 --- a/crates/gateway/src/main.rs +++ b/crates/gateway/src/main.rs @@ -1,8 +1,10 @@ #![feature(let_chains)] #![feature(if_let_guard)] +#![feature(duration_constructors)] use chrono::Timelike; use discord::gateway::cluster_config; +use event_awaiter::EventAwaiter; use fred::{clients::RedisPool, interfaces::*}; use libpk::runtime_config::RuntimeConfig; use reqwest::ClientBuilder; @@ -12,12 +14,13 @@ use signal_hook::{ }; use std::{sync::Arc, time::Duration, vec::Vec}; use tokio::{sync::mpsc::channel, task::JoinSet}; -use tracing::{error, info, warn}; +use tracing::{debug, error, info, warn}; use twilight_gateway::{MessageSender, ShardId}; use twilight_model::gateway::payload::outgoing::UpdatePresence; mod cache_api; mod discord; +mod event_awaiter; mod logger; const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target"; @@ -39,6 +42,11 @@ async fn real_main() -> anyhow::Result<()> { let shard_state = discord::shard_state::new(redis.clone()); let cache = Arc::new(discord::cache::new()); + let awaiter = Arc::new(EventAwaiter::new()); + tokio::spawn({ + let awaiter = awaiter.clone(); + async move { awaiter.cleanup_loop().await } + }); let shards = discord::gateway::create_shards(redis.clone())?; @@ -63,22 +71,36 @@ async fn real_main() -> anyhow::Result<()> { set.spawn(tokio::spawn({ let runtime_config = runtime_config.clone(); - async move { - let client = Arc::new(ClientBuilder::new() - .connect_timeout(Duration::from_secs(1)) - .timeout(Duration::from_secs(1)) - .build() - .expect("error making client")); + let awaiter = awaiter.clone(); + + async move { + let client = Arc::new( + ClientBuilder::new() + .connect_timeout(Duration::from_secs(1)) + .timeout(Duration::from_secs(1)) + .build() + .expect("error making client"), + ); + + while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await { + let target = if let Some(target) = awaiter.target_for_event(parsed_event).await { + debug!("sending event to awaiter"); + Some(target) + } else if let Some(target) = + runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await + { + Some(target) + } else { + None + }; - while let Some((shard_id, event)) = event_rx.recv().await { - let target = runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await; if let Some(target) = target { tokio::spawn({ let client = client.clone(); async move { if let Err(error) = client .post(format!("{target}/{}", shard_id.number())) - .body(event) + .body(raw_event) .send() .await { @@ -98,7 +120,7 @@ async fn real_main() -> anyhow::Result<()> { // todo: probably don't do it this way let api_shutdown_tx = shutdown_tx.clone(); set.spawn(tokio::spawn(async move { - match cache_api::run_server(cache, runtime_config).await { + match cache_api::run_server(cache, runtime_config, awaiter.clone()).await { Err(error) => { tracing::error!(?error, "failed to serve cache api"); let _ = api_shutdown_tx.send(());