diff --git a/Myriad/Cache/HTTPDiscordCache.cs b/Myriad/Cache/HTTPDiscordCache.cs index 311d697d..b70fcf12 100644 --- a/Myriad/Cache/HTTPDiscordCache.cs +++ b/Myriad/Cache/HTTPDiscordCache.cs @@ -70,6 +70,9 @@ public class HttpDiscordCache: IDiscordCache return JsonSerializer.Deserialize(plaintext, _jsonSerializerOptions); } + public Task GetLastMessage(ulong guildId, ulong channelId) + => QueryCache($"/guilds/{guildId}/channels/{channelId}/last_message", guildId); + private Task AwaitEvent(ulong guildId, object data) => AwaitEventShard((int)((guildId >> 22) % (ulong)_shardCount), data); diff --git a/PluralKit.Bot/Commands/Message.cs b/PluralKit.Bot/Commands/Message.cs index 692f75b1..94602a06 100644 --- a/PluralKit.Bot/Commands/Message.cs +++ b/PluralKit.Bot/Commands/Message.cs @@ -305,7 +305,7 @@ public class ProxiedMessage throw new PKError(error); } - var lastMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId); + var lastMessage = await _lastMessageCache.GetLastMessage(ctx.Message.GuildId ?? 0, ctx.Message.ChannelId); var isLatestMessage = lastMessage?.Current.Id == ctx.Message.Id ? lastMessage?.Previous?.Id == msg.Mid diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index 2bb6bfcf..a14d024b 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -55,7 +55,9 @@ public class MessageCreated: IEventHandler public (ulong?, ulong?) ErrorChannelFor(MessageCreateEvent evt, ulong userId) => (evt.GuildId, evt.ChannelId); private bool IsDuplicateMessage(Message msg) => // We consider a message duplicate if it has the same ID as the previous message that hit the gateway - _lastMessageCache.GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id; + // use only the local cache here + // http gateway sets last message before forwarding the message here, so this will always return true + _lastMessageCache._GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id; public async Task Handle(int shardId, MessageCreateEvent evt) { diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index 8e131c0c..6ffcd00e 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -65,7 +65,7 @@ public class MessageEdited: IEventHandler var guild = await _cache.TryGetGuild(channel.GuildId!.Value); if (guild == null) throw new Exception("could not find self guild in MessageEdited event"); - var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current; + var lastMessage = (await _lastMessageCache.GetLastMessage(evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0, evt.ChannelId))?.Current; // Only react to the last message in the channel if (lastMessage?.Id != evt.Id) diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index 43e1ed5e..8a59957d 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -246,7 +246,7 @@ public class ProxyService ChannelId = rootChannel.Id, ThreadId = threadId, MessageId = trigger.Id, - Name = await FixSameName(messageChannel.Id, ctx, match.Member), + Name = await FixSameName(trigger.GuildId!.Value, messageChannel.Id, ctx, match.Member), AvatarUrl = AvatarUtils.TryRewriteCdnUrl(match.Member.ProxyAvatar(ctx)), Content = content, Attachments = trigger.Attachments, @@ -458,11 +458,11 @@ public class ProxyService }; } - private async Task FixSameName(ulong channelId, MessageContext ctx, ProxyMember member) + private async Task FixSameName(ulong guildId, ulong channelId, MessageContext ctx, ProxyMember member) { var proxyName = member.ProxyName(ctx); - var lastMessage = _lastMessage.GetLastMessage(channelId)?.Previous; + var lastMessage = (await _lastMessage.GetLastMessage(guildId, channelId))?.Previous; if (lastMessage == null) // cache is out of date or channel is empty. return proxyName; diff --git a/PluralKit.Bot/Services/LastMessageCacheService.cs b/PluralKit.Bot/Services/LastMessageCacheService.cs index a06da7fa..46a51c64 100644 --- a/PluralKit.Bot/Services/LastMessageCacheService.cs +++ b/PluralKit.Bot/Services/LastMessageCacheService.cs @@ -1,6 +1,7 @@ #nullable enable using System.Collections.Concurrent; +using Myriad.Cache; using Myriad.Types; namespace PluralKit.Bot; @@ -9,9 +10,18 @@ public class LastMessageCacheService { private readonly IDictionary _cache = new ConcurrentDictionary(); + private readonly IDiscordCache _maybeHttp; + + public LastMessageCacheService(IDiscordCache cache) + { + _maybeHttp = cache; + } + public void AddMessage(Message msg) { - var previous = GetLastMessage(msg.ChannelId); + if (_maybeHttp is HttpDiscordCache) return; + + var previous = _GetLastMessage(msg.ChannelId); var current = ToCachedMessage(msg); _cache[msg.ChannelId] = new CacheEntry(current, previous?.Current); } @@ -19,12 +29,26 @@ public class LastMessageCacheService private CachedMessage ToCachedMessage(Message msg) => new(msg.Id, msg.ReferencedMessage.Value?.Id, msg.Author.Username); - public CacheEntry? GetLastMessage(ulong channel) => - _cache.TryGetValue(channel, out var message) ? message : null; + public async Task GetLastMessage(ulong guild, ulong channel) + { + if (_maybeHttp is HttpDiscordCache) + return await (_maybeHttp as HttpDiscordCache).GetLastMessage(guild, channel); + + return _cache.TryGetValue(channel, out var message) ? message : null; + } + + public CacheEntry? _GetLastMessage(ulong channel) + { + if (_maybeHttp is HttpDiscordCache) return null; + + return _cache.TryGetValue(channel, out var message) ? message : null; + } public void HandleMessageDeletion(ulong channel, ulong message) { - var storedMessage = GetLastMessage(channel); + if (_maybeHttp is HttpDiscordCache) return; + + var storedMessage = _GetLastMessage(channel); if (storedMessage == null) return; @@ -39,7 +63,9 @@ public class LastMessageCacheService public void HandleMessageDeletion(ulong channel, List messages) { - var storedMessage = GetLastMessage(channel); + if (_maybeHttp is HttpDiscordCache) return; + + var storedMessage = _GetLastMessage(channel); if (storedMessage == null) return; diff --git a/crates/gateway/src/cache_api.rs b/crates/gateway/src/cache_api.rs index d22fb113..2e91f465 100644 --- a/crates/gateway/src/cache_api.rs +++ b/crates/gateway/src/cache_api.rs @@ -8,7 +8,7 @@ use axum::{ use libpk::runtime_config::RuntimeConfig; use serde_json::{json, to_string}; use tracing::{error, info}; -use twilight_model::id::Id; +use twilight_model::id::{marker::ChannelMarker, Id}; use crate::{ discord::{ @@ -136,7 +136,10 @@ pub async fn run_server(cache: Arc, runtime_config: Arc>, Path((_guild_id, channel_id)): Path<(u64, Id)>| async move { + let lm = cache.get_last_message(channel_id).await; + status_code(StatusCode::FOUND, to_string(&lm).unwrap()) + }), ) .route( diff --git a/crates/gateway/src/discord/cache.rs b/crates/gateway/src/discord/cache.rs index b4a81664..2b5b22b8 100644 --- a/crates/gateway/src/discord/cache.rs +++ b/crates/gateway/src/discord/cache.rs @@ -1,6 +1,7 @@ use anyhow::format_err; use lazy_static::lazy_static; -use std::sync::Arc; +use serde::Serialize; +use std::{collections::HashMap, sync::Arc}; use tokio::sync::RwLock; use twilight_cache_inmemory::{ model::CachedMember, @@ -8,11 +9,12 @@ use twilight_cache_inmemory::{ traits::CacheableChannel, InMemoryCache, ResourceType, }; +use twilight_gateway::Event; use twilight_model::{ channel::{Channel, ChannelType}, guild::{Guild, Member, Permissions}, id::{ - marker::{ChannelMarker, GuildMarker, UserMarker}, + marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker}, Id, }, }; @@ -123,16 +125,134 @@ pub fn new() -> DiscordCache { .build(), ); - DiscordCache(cache, client, RwLock::new(Vec::new())) + DiscordCache( + cache, + client, + RwLock::new(Vec::new()), + RwLock::new(HashMap::new()), + ) +} + +#[derive(Clone, Serialize)] +pub struct CachedMessage { + id: Id, + referenced_message: Option>, + author_username: String, +} + +#[derive(Clone, Serialize)] +pub struct LastMessageCacheEntry { + pub current: CachedMessage, + pub previous: Option, } pub struct DiscordCache( pub Arc, pub Arc, pub RwLock>, + pub RwLock, LastMessageCacheEntry>>, ); impl DiscordCache { + pub async fn get_last_message( + &self, + channel: Id, + ) -> Option { + self.3.read().await.get(&channel).cloned() + } + + pub async fn update(&self, event: &twilight_gateway::Event) { + self.0.update(event); + + match event { + Event::MessageCreate(m) => match self.3.write().await.entry(m.channel_id) { + std::collections::hash_map::Entry::Occupied(mut e) => { + let cur = e.get(); + e.insert(LastMessageCacheEntry { + current: CachedMessage { + id: m.id, + referenced_message: m.referenced_message.as_ref().map(|v| v.id), + author_username: m.author.name.clone(), + }, + previous: Some(cur.current.clone()), + }); + } + std::collections::hash_map::Entry::Vacant(e) => { + e.insert(LastMessageCacheEntry { + current: CachedMessage { + id: m.id, + referenced_message: m.referenced_message.as_ref().map(|v| v.id), + author_username: m.author.name.clone(), + }, + previous: None, + }); + } + }, + Event::MessageDelete(m) => { + self.handle_message_deletion(m.channel_id, vec![m.id]).await; + } + Event::MessageDeleteBulk(m) => { + self.handle_message_deletion(m.channel_id, m.ids.clone()) + .await; + } + _ => {} + }; + } + + async fn handle_message_deletion( + &self, + channel_id: Id, + mids: Vec>, + ) { + let mut lm = self.3.write().await; + + let Some(entry) = lm.get(&channel_id) else { + return; + }; + + let mut entry = entry.clone(); + + // if none of the deleted messages are relevant, just return + if !mids.contains(&entry.current.id) + && entry + .previous + .clone() + .map(|v| !mids.contains(&v.id)) + .unwrap_or(false) + { + return; + } + + // remove "previous" entry if it was deleted + if let Some(prev) = entry.previous.clone() + && mids.contains(&prev.id) + { + entry.previous = None; + } + + // set "current" entry to "previous" if current entry was deleted + // (if the "previous" entry still exists, it was not deleted) + if let Some(prev) = entry.previous.clone() + && mids.contains(&entry.current.id) + { + entry.current = prev; + entry.previous = None; + } + + // if the current entry was already deleted, but previous wasn't, + // we would've set current to previous + // so if current is deleted this means both current and previous have + // been deleted + // so just drop the cache entry here + if mids.contains(&entry.current.id) && entry.previous.is_none() { + lm.remove(&channel_id); + return; + } + + // ok, update the entry + lm.insert(channel_id, entry.clone()); + } + pub async fn guild_permissions( &self, guild_id: Id, diff --git a/crates/gateway/src/discord/gateway.rs b/crates/gateway/src/discord/gateway.rs index 89dfc26f..fc881f34 100644 --- a/crates/gateway/src/discord/gateway.rs +++ b/crates/gateway/src/discord/gateway.rs @@ -173,7 +173,7 @@ pub async fn runner( cache.2.write().await.push(shard_id); } } - cache.0.update(&event); + cache.update(&event).await; // okay, we've handled the event internally, let's send it to consumers diff --git a/crates/gateway/src/event_awaiter.rs b/crates/gateway/src/event_awaiter.rs index 9196170d..10f1888d 100644 --- a/crates/gateway/src/event_awaiter.rs +++ b/crates/gateway/src/event_awaiter.rs @@ -115,7 +115,7 @@ impl EventAwaiter { .remove(&(message.channel_id, message.author.id)) .map(|(timeout, target, options)| { if let Some(options) = options - && !options.contains(&message.content) + && !options.contains(&message.content.to_lowercase()) { messages.insert( (message.channel_id, message.author.id),