feat: add last message cache to gateway

This commit is contained in:
alyssa 2025-04-01 10:48:20 +00:00
parent 15c992c572
commit a8664665a6
10 changed files with 172 additions and 18 deletions

View file

@ -70,6 +70,9 @@ public class HttpDiscordCache: IDiscordCache
return JsonSerializer.Deserialize<T>(plaintext, _jsonSerializerOptions);
}
public Task<T> GetLastMessage<T>(ulong guildId, ulong channelId)
=> QueryCache<T>($"/guilds/{guildId}/channels/{channelId}/last_message", guildId);
private Task AwaitEvent(ulong guildId, object data)
=> AwaitEventShard((int)((guildId >> 22) % (ulong)_shardCount), data);

View file

@ -305,7 +305,7 @@ public class ProxiedMessage
throw new PKError(error);
}
var lastMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId);
var lastMessage = await _lastMessageCache.GetLastMessage(ctx.Message.GuildId ?? 0, ctx.Message.ChannelId);
var isLatestMessage = lastMessage?.Current.Id == ctx.Message.Id
? lastMessage?.Previous?.Id == msg.Mid

View file

@ -55,7 +55,9 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
public (ulong?, ulong?) ErrorChannelFor(MessageCreateEvent evt, ulong userId) => (evt.GuildId, evt.ChannelId);
private bool IsDuplicateMessage(Message msg) =>
// We consider a message duplicate if it has the same ID as the previous message that hit the gateway
_lastMessageCache.GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
// use only the local cache here
// http gateway sets last message before forwarding the message here, so this will always return true
_lastMessageCache._GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
public async Task Handle(int shardId, MessageCreateEvent evt)
{

View file

@ -65,7 +65,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
var guild = await _cache.TryGetGuild(channel.GuildId!.Value);
if (guild == null)
throw new Exception("could not find self guild in MessageEdited event");
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
var lastMessage = (await _lastMessageCache.GetLastMessage(evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0, evt.ChannelId))?.Current;
// Only react to the last message in the channel
if (lastMessage?.Id != evt.Id)

View file

@ -246,7 +246,7 @@ public class ProxyService
ChannelId = rootChannel.Id,
ThreadId = threadId,
MessageId = trigger.Id,
Name = await FixSameName(messageChannel.Id, ctx, match.Member),
Name = await FixSameName(trigger.GuildId!.Value, messageChannel.Id, ctx, match.Member),
AvatarUrl = AvatarUtils.TryRewriteCdnUrl(match.Member.ProxyAvatar(ctx)),
Content = content,
Attachments = trigger.Attachments,
@ -458,11 +458,11 @@ public class ProxyService
};
}
private async Task<string> FixSameName(ulong channelId, MessageContext ctx, ProxyMember member)
private async Task<string> FixSameName(ulong guildId, ulong channelId, MessageContext ctx, ProxyMember member)
{
var proxyName = member.ProxyName(ctx);
var lastMessage = _lastMessage.GetLastMessage(channelId)?.Previous;
var lastMessage = (await _lastMessage.GetLastMessage(guildId, channelId))?.Previous;
if (lastMessage == null)
// cache is out of date or channel is empty.
return proxyName;

View file

@ -1,6 +1,7 @@
#nullable enable
using System.Collections.Concurrent;
using Myriad.Cache;
using Myriad.Types;
namespace PluralKit.Bot;
@ -9,9 +10,18 @@ public class LastMessageCacheService
{
private readonly IDictionary<ulong, CacheEntry> _cache = new ConcurrentDictionary<ulong, CacheEntry>();
private readonly IDiscordCache _maybeHttp;
public LastMessageCacheService(IDiscordCache cache)
{
_maybeHttp = cache;
}
public void AddMessage(Message msg)
{
var previous = GetLastMessage(msg.ChannelId);
if (_maybeHttp is HttpDiscordCache) return;
var previous = _GetLastMessage(msg.ChannelId);
var current = ToCachedMessage(msg);
_cache[msg.ChannelId] = new CacheEntry(current, previous?.Current);
}
@ -19,12 +29,26 @@ public class LastMessageCacheService
private CachedMessage ToCachedMessage(Message msg) =>
new(msg.Id, msg.ReferencedMessage.Value?.Id, msg.Author.Username);
public CacheEntry? GetLastMessage(ulong channel) =>
_cache.TryGetValue(channel, out var message) ? message : null;
public async Task<CacheEntry?> GetLastMessage(ulong guild, ulong channel)
{
if (_maybeHttp is HttpDiscordCache)
return await (_maybeHttp as HttpDiscordCache).GetLastMessage<CacheEntry>(guild, channel);
return _cache.TryGetValue(channel, out var message) ? message : null;
}
public CacheEntry? _GetLastMessage(ulong channel)
{
if (_maybeHttp is HttpDiscordCache) return null;
return _cache.TryGetValue(channel, out var message) ? message : null;
}
public void HandleMessageDeletion(ulong channel, ulong message)
{
var storedMessage = GetLastMessage(channel);
if (_maybeHttp is HttpDiscordCache) return;
var storedMessage = _GetLastMessage(channel);
if (storedMessage == null)
return;
@ -39,7 +63,9 @@ public class LastMessageCacheService
public void HandleMessageDeletion(ulong channel, List<ulong> messages)
{
var storedMessage = GetLastMessage(channel);
if (_maybeHttp is HttpDiscordCache) return;
var storedMessage = _GetLastMessage(channel);
if (storedMessage == null)
return;

View file

@ -8,7 +8,7 @@ use axum::{
use libpk::runtime_config::RuntimeConfig;
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::Id;
use twilight_model::id::{marker::ChannelMarker, Id};
use crate::{
discord::{
@ -136,7 +136,10 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
get(|State(cache): State<Arc<DiscordCache>>, Path((_guild_id, channel_id)): Path<(u64, Id<ChannelMarker>)>| async move {
let lm = cache.get_last_message(channel_id).await;
status_code(StatusCode::FOUND, to_string(&lm).unwrap())
}),
)
.route(

View file

@ -1,6 +1,7 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
model::CachedMember,
@ -8,11 +9,12 @@ use twilight_cache_inmemory::{
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_gateway::Event;
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
Id,
},
};
@ -123,16 +125,134 @@ pub fn new() -> DiscordCache {
.build(),
);
DiscordCache(cache, client, RwLock::new(Vec::new()))
DiscordCache(
cache,
client,
RwLock::new(Vec::new()),
RwLock::new(HashMap::new()),
)
}
#[derive(Clone, Serialize)]
pub struct CachedMessage {
id: Id<MessageMarker>,
referenced_message: Option<Id<MessageMarker>>,
author_username: String,
}
#[derive(Clone, Serialize)]
pub struct LastMessageCacheEntry {
pub current: CachedMessage,
pub previous: Option<CachedMessage>,
}
pub struct DiscordCache(
pub Arc<InMemoryCache>,
pub Arc<twilight_http::Client>,
pub RwLock<Vec<u32>>,
pub RwLock<HashMap<Id<ChannelMarker>, LastMessageCacheEntry>>,
);
impl DiscordCache {
pub async fn get_last_message(
&self,
channel: Id<ChannelMarker>,
) -> Option<LastMessageCacheEntry> {
self.3.read().await.get(&channel).cloned()
}
pub async fn update(&self, event: &twilight_gateway::Event) {
self.0.update(event);
match event {
Event::MessageCreate(m) => match self.3.write().await.entry(m.channel_id) {
std::collections::hash_map::Entry::Occupied(mut e) => {
let cur = e.get();
e.insert(LastMessageCacheEntry {
current: CachedMessage {
id: m.id,
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
author_username: m.author.name.clone(),
},
previous: Some(cur.current.clone()),
});
}
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(LastMessageCacheEntry {
current: CachedMessage {
id: m.id,
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
author_username: m.author.name.clone(),
},
previous: None,
});
}
},
Event::MessageDelete(m) => {
self.handle_message_deletion(m.channel_id, vec![m.id]).await;
}
Event::MessageDeleteBulk(m) => {
self.handle_message_deletion(m.channel_id, m.ids.clone())
.await;
}
_ => {}
};
}
async fn handle_message_deletion(
&self,
channel_id: Id<ChannelMarker>,
mids: Vec<Id<MessageMarker>>,
) {
let mut lm = self.3.write().await;
let Some(entry) = lm.get(&channel_id) else {
return;
};
let mut entry = entry.clone();
// if none of the deleted messages are relevant, just return
if !mids.contains(&entry.current.id)
&& entry
.previous
.clone()
.map(|v| !mids.contains(&v.id))
.unwrap_or(false)
{
return;
}
// remove "previous" entry if it was deleted
if let Some(prev) = entry.previous.clone()
&& mids.contains(&prev.id)
{
entry.previous = None;
}
// set "current" entry to "previous" if current entry was deleted
// (if the "previous" entry still exists, it was not deleted)
if let Some(prev) = entry.previous.clone()
&& mids.contains(&entry.current.id)
{
entry.current = prev;
entry.previous = None;
}
// if the current entry was already deleted, but previous wasn't,
// we would've set current to previous
// so if current is deleted this means both current and previous have
// been deleted
// so just drop the cache entry here
if mids.contains(&entry.current.id) && entry.previous.is_none() {
lm.remove(&channel_id);
return;
}
// ok, update the entry
lm.insert(channel_id, entry.clone());
}
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,

View file

@ -173,7 +173,7 @@ pub async fn runner(
cache.2.write().await.push(shard_id);
}
}
cache.0.update(&event);
cache.update(&event).await;
// okay, we've handled the event internally, let's send it to consumers

View file

@ -115,7 +115,7 @@ impl EventAwaiter {
.remove(&(message.channel_id, message.author.id))
.map(|(timeout, target, options)| {
if let Some(options) = options
&& !options.contains(&message.content)
&& !options.contains(&message.content.to_lowercase())
{
messages.insert(
(message.channel_id, message.author.id),