mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-16 10:40:12 +00:00
feat: remote await events from gateway
This commit is contained in:
parent
64ff69723c
commit
15c992c572
17 changed files with 439 additions and 30 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -1089,6 +1089,7 @@ dependencies = [
|
||||||
"libpk",
|
"libpk",
|
||||||
"metrics",
|
"metrics",
|
||||||
"reqwest 0.12.8",
|
"reqwest 0.12.8",
|
||||||
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_variant",
|
"serde_variant",
|
||||||
"signal-hook",
|
"signal-hook",
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
using Serilog;
|
using Serilog;
|
||||||
using System.Net;
|
using System.Net;
|
||||||
|
using System.Text;
|
||||||
using System.Text.Json;
|
using System.Text.Json;
|
||||||
|
|
||||||
|
using NodaTime;
|
||||||
|
|
||||||
using Myriad.Serialization;
|
using Myriad.Serialization;
|
||||||
using Myriad.Types;
|
using Myriad.Types;
|
||||||
|
|
||||||
|
|
@ -12,6 +15,7 @@ public class HttpDiscordCache: IDiscordCache
|
||||||
private readonly ILogger _logger;
|
private readonly ILogger _logger;
|
||||||
private readonly HttpClient _client;
|
private readonly HttpClient _client;
|
||||||
private readonly Uri _cacheEndpoint;
|
private readonly Uri _cacheEndpoint;
|
||||||
|
private readonly string? _eventTarget;
|
||||||
private readonly int _shardCount;
|
private readonly int _shardCount;
|
||||||
private readonly ulong _ownUserId;
|
private readonly ulong _ownUserId;
|
||||||
|
|
||||||
|
|
@ -21,11 +25,12 @@ public class HttpDiscordCache: IDiscordCache
|
||||||
|
|
||||||
public EventHandler<(bool?, string)> OnDebug;
|
public EventHandler<(bool?, string)> OnDebug;
|
||||||
|
|
||||||
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, int shardCount, ulong ownUserId, bool useInnerCache)
|
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, string? eventTarget, int shardCount, ulong ownUserId, bool useInnerCache)
|
||||||
{
|
{
|
||||||
_logger = logger;
|
_logger = logger;
|
||||||
_client = client;
|
_client = client;
|
||||||
_cacheEndpoint = new Uri(cacheEndpoint);
|
_cacheEndpoint = new Uri(cacheEndpoint);
|
||||||
|
_eventTarget = eventTarget;
|
||||||
_shardCount = shardCount;
|
_shardCount = shardCount;
|
||||||
_ownUserId = ownUserId;
|
_ownUserId = ownUserId;
|
||||||
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
|
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
|
||||||
|
|
@ -65,6 +70,68 @@ public class HttpDiscordCache: IDiscordCache
|
||||||
return JsonSerializer.Deserialize<T>(plaintext, _jsonSerializerOptions);
|
return JsonSerializer.Deserialize<T>(plaintext, _jsonSerializerOptions);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private Task AwaitEvent(ulong guildId, object data)
|
||||||
|
=> AwaitEventShard((int)((guildId >> 22) % (ulong)_shardCount), data);
|
||||||
|
|
||||||
|
private async Task AwaitEventShard(int shardId, object data)
|
||||||
|
{
|
||||||
|
if (_eventTarget == null)
|
||||||
|
throw new Exception("missing event target for remote await event");
|
||||||
|
|
||||||
|
var cluster = _cacheEndpoint.Authority;
|
||||||
|
// todo: there should not be infra-specific code here
|
||||||
|
if (cluster.Contains(".service.consul") || cluster.Contains("process.pluralkit-gateway.internal"))
|
||||||
|
// int(((guild_id >> 22) % shard_count) / 16)
|
||||||
|
cluster = $"cluster{shardId / 16}.{cluster}";
|
||||||
|
|
||||||
|
var response = await _client.PostAsync(
|
||||||
|
$"{_cacheEndpoint.Scheme}://{cluster}/await_event",
|
||||||
|
new StringContent(JsonSerializer.Serialize(data), Encoding.UTF8)
|
||||||
|
);
|
||||||
|
|
||||||
|
if (response.StatusCode != HttpStatusCode.NoContent)
|
||||||
|
throw new Exception($"failed to await event from gateway: {response.StatusCode}");
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task AwaitReaction(ulong guildId, ulong messageId, ulong userId, Duration? timeout)
|
||||||
|
{
|
||||||
|
var obj = new
|
||||||
|
{
|
||||||
|
message_id = messageId,
|
||||||
|
user_id = userId,
|
||||||
|
target = _eventTarget!,
|
||||||
|
timeout = timeout?.TotalSeconds,
|
||||||
|
};
|
||||||
|
|
||||||
|
await AwaitEvent(guildId, obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task AwaitMessage(ulong guildId, ulong channelId, ulong authorId, Duration? timeout, string[] options = null)
|
||||||
|
{
|
||||||
|
var obj = new
|
||||||
|
{
|
||||||
|
channel_id = channelId,
|
||||||
|
author_id = authorId,
|
||||||
|
target = _eventTarget!,
|
||||||
|
timeout = timeout?.TotalSeconds,
|
||||||
|
options = options,
|
||||||
|
};
|
||||||
|
|
||||||
|
await AwaitEvent(guildId, obj);
|
||||||
|
}
|
||||||
|
|
||||||
|
public async Task AwaitInteraction(int shardId, string id, Duration? timeout)
|
||||||
|
{
|
||||||
|
var obj = new
|
||||||
|
{
|
||||||
|
id = id,
|
||||||
|
target = _eventTarget!,
|
||||||
|
timeout = timeout?.TotalSeconds,
|
||||||
|
};
|
||||||
|
|
||||||
|
await AwaitEventShard(shardId, obj);
|
||||||
|
}
|
||||||
|
|
||||||
public async Task<Guild?> TryGetGuild(ulong guildId)
|
public async Task<Guild?> TryGetGuild(ulong guildId)
|
||||||
{
|
{
|
||||||
var hres = await QueryCache<Guild?>($"/guilds/{guildId}", guildId);
|
var hres = await QueryCache<Guild?>($"/guilds/{guildId}", guildId);
|
||||||
|
|
|
||||||
|
|
@ -21,6 +21,8 @@
|
||||||
</PropertyGroup>
|
</PropertyGroup>
|
||||||
|
|
||||||
<ItemGroup>
|
<ItemGroup>
|
||||||
|
<PackageReference Include="NodaTime" Version="3.2.0" />
|
||||||
|
<PackageReference Include="NodaTime.Serialization.JsonNet" Version="3.1.0" />
|
||||||
<PackageReference Include="Polly" Version="8.5.0" />
|
<PackageReference Include="Polly" Version="8.5.0" />
|
||||||
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
|
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
|
||||||
<PackageReference Include="Serilog" Version="4.2.0" />
|
<PackageReference Include="Serilog" Version="4.2.0" />
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,22 @@
|
||||||
"version": 1,
|
"version": 1,
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"net8.0": {
|
"net8.0": {
|
||||||
|
"NodaTime": {
|
||||||
|
"type": "Direct",
|
||||||
|
"requested": "[3.2.0, )",
|
||||||
|
"resolved": "3.2.0",
|
||||||
|
"contentHash": "yoRA3jEJn8NM0/rQm78zuDNPA3DonNSZdsorMUj+dltc1D+/Lc5h9YXGqbEEZozMGr37lAoYkcSM/KjTVqD0ow=="
|
||||||
|
},
|
||||||
|
"NodaTime.Serialization.JsonNet": {
|
||||||
|
"type": "Direct",
|
||||||
|
"requested": "[3.1.0, )",
|
||||||
|
"resolved": "3.1.0",
|
||||||
|
"contentHash": "eEr9lXUz50TYr4rpeJG4TDAABkpxjIKr5mDSi/Zav8d6Njy6fH7x4ZtNwWFj0Vd+vIvEZNrHFQ4Gfy8j4BqRGg==",
|
||||||
|
"dependencies": {
|
||||||
|
"Newtonsoft.Json": "13.0.3",
|
||||||
|
"NodaTime": "[3.0.0, 4.0.0)"
|
||||||
|
}
|
||||||
|
},
|
||||||
"Polly": {
|
"Polly": {
|
||||||
"type": "Direct",
|
"type": "Direct",
|
||||||
"requested": "[8.5.0, )",
|
"requested": "[8.5.0, )",
|
||||||
|
|
@ -52,6 +68,11 @@
|
||||||
"resolved": "6.0.0",
|
"resolved": "6.0.0",
|
||||||
"contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
|
"contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
|
||||||
},
|
},
|
||||||
|
"Newtonsoft.Json": {
|
||||||
|
"type": "Transitive",
|
||||||
|
"resolved": "13.0.3",
|
||||||
|
"contentHash": "HrC5BXdl00IP9zeV+0Z848QWPAoCr9P3bDEZguI+gkLcBKAOxix/tLEAAHC+UvDNPv4a2d18lOReHMOagPa+zQ=="
|
||||||
|
},
|
||||||
"Pipelines.Sockets.Unofficial": {
|
"Pipelines.Sockets.Unofficial": {
|
||||||
"type": "Transitive",
|
"type": "Transitive",
|
||||||
"resolved": "2.2.8",
|
"resolved": "2.2.8",
|
||||||
|
|
|
||||||
|
|
@ -26,6 +26,7 @@ public class BotConfig
|
||||||
|
|
||||||
public string? HttpListenerAddr { get; set; }
|
public string? HttpListenerAddr { get; set; }
|
||||||
public bool DisableGateway { get; set; } = false;
|
public bool DisableGateway { get; set; } = false;
|
||||||
|
public string? EventAwaiterTarget { get; set; }
|
||||||
|
|
||||||
public string? DiscordBaseUrl { get; set; }
|
public string? DiscordBaseUrl { get; set; }
|
||||||
public string? AvatarServiceUrl { get; set; }
|
public string? AvatarServiceUrl { get; set; }
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ public abstract class BaseInteractive
|
||||||
ButtonStyle style = ButtonStyle.Secondary, bool disabled = false)
|
ButtonStyle style = ButtonStyle.Secondary, bool disabled = false)
|
||||||
{
|
{
|
||||||
var dispatch = _ctx.Services.Resolve<InteractionDispatchService>();
|
var dispatch = _ctx.Services.Resolve<InteractionDispatchService>();
|
||||||
var customId = dispatch.Register(handler, Timeout);
|
var customId = dispatch.Register(_ctx.ShardId, handler, Timeout);
|
||||||
|
|
||||||
var button = new Button
|
var button = new Button
|
||||||
{
|
{
|
||||||
|
|
@ -89,7 +89,7 @@ public abstract class BaseInteractive
|
||||||
{
|
{
|
||||||
var dispatch = ctx.Services.Resolve<InteractionDispatchService>();
|
var dispatch = ctx.Services.Resolve<InteractionDispatchService>();
|
||||||
foreach (var button in _buttons)
|
foreach (var button in _buttons)
|
||||||
button.CustomId = dispatch.Register(button.Handler, Timeout);
|
button.CustomId = dispatch.Register(_ctx.ShardId, button.Handler, Timeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
public abstract Task Start();
|
public abstract Task Start();
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
using Autofac;
|
using Autofac;
|
||||||
|
|
||||||
|
using Myriad.Cache;
|
||||||
using Myriad.Gateway;
|
using Myriad.Gateway;
|
||||||
using Myriad.Rest.Types;
|
using Myriad.Rest.Types;
|
||||||
using Myriad.Types;
|
using Myriad.Types;
|
||||||
|
|
@ -69,6 +70,9 @@ public class YesNoPrompt: BaseInteractive
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// no need to reawait message
|
||||||
|
// gateway will already have sent us only matching messages
|
||||||
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -88,6 +92,17 @@ public class YesNoPrompt: BaseInteractive
|
||||||
{
|
{
|
||||||
try
|
try
|
||||||
{
|
{
|
||||||
|
// check if http gateway and set listener
|
||||||
|
// todo: this one needs to handle options for message
|
||||||
|
if (_ctx.Cache is HttpDiscordCache)
|
||||||
|
await (_ctx.Cache as HttpDiscordCache).AwaitMessage(
|
||||||
|
_ctx.Guild?.Id ?? 0,
|
||||||
|
_ctx.Channel.Id,
|
||||||
|
_ctx.Author.Id,
|
||||||
|
Timeout,
|
||||||
|
options: new[] { "yes", "y", "no", "n" }
|
||||||
|
);
|
||||||
|
|
||||||
await queue.WaitFor(MessagePredicate, Timeout, cts.Token);
|
await queue.WaitFor(MessagePredicate, Timeout, cts.Token);
|
||||||
}
|
}
|
||||||
catch (TimeoutException e)
|
catch (TimeoutException e)
|
||||||
|
|
|
||||||
|
|
@ -49,8 +49,15 @@ public class BotModule: Module
|
||||||
|
|
||||||
if (botConfig.HttpCacheUrl != null)
|
if (botConfig.HttpCacheUrl != null)
|
||||||
{
|
{
|
||||||
var cache = new HttpDiscordCache(c.Resolve<ILogger>(),
|
var cache = new HttpDiscordCache(
|
||||||
c.Resolve<HttpClient>(), botConfig.HttpCacheUrl, botConfig.Cluster?.TotalShards ?? 1, botConfig.ClientId, botConfig.HttpUseInnerCache);
|
c.Resolve<ILogger>(),
|
||||||
|
c.Resolve<HttpClient>(),
|
||||||
|
botConfig.HttpCacheUrl,
|
||||||
|
botConfig.EventAwaiterTarget,
|
||||||
|
botConfig.Cluster?.TotalShards ?? 1,
|
||||||
|
botConfig.ClientId,
|
||||||
|
botConfig.HttpUseInnerCache
|
||||||
|
);
|
||||||
|
|
||||||
var metrics = c.Resolve<IMetrics>();
|
var metrics = c.Resolve<IMetrics>();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
using System.Collections.Concurrent;
|
using System.Collections.Concurrent;
|
||||||
|
|
||||||
|
using Myriad.Cache;
|
||||||
|
|
||||||
using NodaTime;
|
using NodaTime;
|
||||||
|
|
||||||
using Serilog;
|
using Serilog;
|
||||||
|
|
@ -16,9 +18,12 @@ public class InteractionDispatchService: IDisposable
|
||||||
private readonly ConcurrentDictionary<Guid, RegisteredInteraction> _handlers = new();
|
private readonly ConcurrentDictionary<Guid, RegisteredInteraction> _handlers = new();
|
||||||
private readonly ILogger _logger;
|
private readonly ILogger _logger;
|
||||||
|
|
||||||
public InteractionDispatchService(IClock clock, ILogger logger)
|
private readonly IDiscordCache _cache;
|
||||||
|
|
||||||
|
public InteractionDispatchService(IClock clock, ILogger logger, IDiscordCache cache)
|
||||||
{
|
{
|
||||||
_clock = clock;
|
_clock = clock;
|
||||||
|
_cache = cache;
|
||||||
_logger = logger.ForContext<InteractionDispatchService>();
|
_logger = logger.ForContext<InteractionDispatchService>();
|
||||||
|
|
||||||
_cleanupWorker = CleanupLoop(_cts.Token);
|
_cleanupWorker = CleanupLoop(_cts.Token);
|
||||||
|
|
@ -50,9 +55,15 @@ public class InteractionDispatchService: IDisposable
|
||||||
_handlers.TryRemove(customIdGuid, out _);
|
_handlers.TryRemove(customIdGuid, out _);
|
||||||
}
|
}
|
||||||
|
|
||||||
public string Register(Func<InteractionContext, Task> callback, Duration? expiry = null)
|
public string Register(int shardId, Func<InteractionContext, Task> callback, Duration? expiry = null)
|
||||||
{
|
{
|
||||||
var key = Guid.NewGuid();
|
var key = Guid.NewGuid();
|
||||||
|
|
||||||
|
// if http_cache, return RegisterRemote
|
||||||
|
// not awaited here, it's probably fine
|
||||||
|
if (_cache is HttpDiscordCache)
|
||||||
|
(_cache as HttpDiscordCache).AwaitInteraction(shardId, key.ToString(), expiry);
|
||||||
|
|
||||||
var handler = new RegisteredInteraction
|
var handler = new RegisteredInteraction
|
||||||
{
|
{
|
||||||
Callback = callback,
|
Callback = callback,
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
using Autofac;
|
using Autofac;
|
||||||
|
|
||||||
using Myriad.Builders;
|
using Myriad.Builders;
|
||||||
|
using Myriad.Cache;
|
||||||
using Myriad.Gateway;
|
using Myriad.Gateway;
|
||||||
using Myriad.Rest.Exceptions;
|
using Myriad.Rest.Exceptions;
|
||||||
using Myriad.Rest.Types.Requests;
|
using Myriad.Rest.Types.Requests;
|
||||||
|
|
@ -40,8 +41,12 @@ public static class ContextUtils
|
||||||
}
|
}
|
||||||
|
|
||||||
public static async Task<MessageReactionAddEvent> AwaitReaction(this Context ctx, Message message,
|
public static async Task<MessageReactionAddEvent> AwaitReaction(this Context ctx, Message message,
|
||||||
User user = null, Func<MessageReactionAddEvent, bool> predicate = null, Duration? timeout = null)
|
User user, Func<MessageReactionAddEvent, bool> predicate = null, Duration? timeout = null)
|
||||||
{
|
{
|
||||||
|
// check if http gateway and set listener
|
||||||
|
if (ctx.Cache is HttpDiscordCache)
|
||||||
|
await (ctx.Cache as HttpDiscordCache).AwaitReaction(ctx.Guild?.Id ?? 0, message.Id, user!.Id, timeout);
|
||||||
|
|
||||||
bool ReactionPredicate(MessageReactionAddEvent evt)
|
bool ReactionPredicate(MessageReactionAddEvent evt)
|
||||||
{
|
{
|
||||||
if (message.Id != evt.MessageId) return false; // Ignore reactions for different messages
|
if (message.Id != evt.MessageId) return false; // Ignore reactions for different messages
|
||||||
|
|
@ -57,11 +62,17 @@ public static class ContextUtils
|
||||||
|
|
||||||
public static async Task<bool> ConfirmWithReply(this Context ctx, string expectedReply, bool treatAsHid = false)
|
public static async Task<bool> ConfirmWithReply(this Context ctx, string expectedReply, bool treatAsHid = false)
|
||||||
{
|
{
|
||||||
|
var timeout = Duration.FromMinutes(1);
|
||||||
|
|
||||||
|
// check if http gateway and set listener
|
||||||
|
if (ctx.Cache is HttpDiscordCache)
|
||||||
|
await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout);
|
||||||
|
|
||||||
bool Predicate(MessageCreateEvent e) =>
|
bool Predicate(MessageCreateEvent e) =>
|
||||||
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
|
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
|
||||||
|
|
||||||
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
|
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
|
||||||
.WaitFor(Predicate, Duration.FromMinutes(1));
|
.WaitFor(Predicate, timeout);
|
||||||
|
|
||||||
var content = msg.Content;
|
var content = msg.Content;
|
||||||
if (treatAsHid)
|
if (treatAsHid)
|
||||||
|
|
@ -96,11 +107,17 @@ public static class ContextUtils
|
||||||
|
|
||||||
async Task<int> PromptPageNumber()
|
async Task<int> PromptPageNumber()
|
||||||
{
|
{
|
||||||
|
var timeout = Duration.FromMinutes(0.5);
|
||||||
|
|
||||||
|
// check if http gateway and set listener
|
||||||
|
if (ctx.Cache is HttpDiscordCache)
|
||||||
|
await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout);
|
||||||
|
|
||||||
bool Predicate(MessageCreateEvent e) =>
|
bool Predicate(MessageCreateEvent e) =>
|
||||||
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
|
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
|
||||||
|
|
||||||
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
|
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
|
||||||
.WaitFor(Predicate, Duration.FromMinutes(0.5));
|
.WaitFor(Predicate, timeout);
|
||||||
|
|
||||||
int.TryParse(msg.Content, out var num);
|
int.TryParse(msg.Content, out var num);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -764,6 +764,8 @@
|
||||||
"myriad": {
|
"myriad": {
|
||||||
"type": "Project",
|
"type": "Project",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"NodaTime": "[3.2.0, )",
|
||||||
|
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||||
"Polly": "[8.5.0, )",
|
"Polly": "[8.5.0, )",
|
||||||
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
|
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
|
||||||
"Serilog": "[4.2.0, )",
|
"Serilog": "[4.2.0, )",
|
||||||
|
|
|
||||||
|
|
@ -976,6 +976,8 @@
|
||||||
"myriad": {
|
"myriad": {
|
||||||
"type": "Project",
|
"type": "Project",
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"NodaTime": "[3.2.0, )",
|
||||||
|
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||||
"Polly": "[8.5.0, )",
|
"Polly": "[8.5.0, )",
|
||||||
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
|
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
|
||||||
"Serilog": "[4.2.0, )",
|
"Serilog": "[4.2.0, )",
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,7 @@ lazy_static = { workspace = true }
|
||||||
libpk = { path = "../libpk" }
|
libpk = { path = "../libpk" }
|
||||||
metrics = { workspace = true }
|
metrics = { workspace = true }
|
||||||
reqwest = { workspace = true }
|
reqwest = { workspace = true }
|
||||||
|
serde = { workspace = true }
|
||||||
serde_json = { workspace = true }
|
serde_json = { workspace = true }
|
||||||
signal-hook = { workspace = true }
|
signal-hook = { workspace = true }
|
||||||
tokio = { workspace = true }
|
tokio = { workspace = true }
|
||||||
|
|
|
||||||
|
|
@ -10,9 +10,12 @@ use serde_json::{json, to_string};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use twilight_model::id::Id;
|
use twilight_model::id::Id;
|
||||||
|
|
||||||
use crate::discord::{
|
use crate::{
|
||||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
discord::{
|
||||||
gateway::cluster_config,
|
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||||
|
gateway::cluster_config,
|
||||||
|
},
|
||||||
|
event_awaiter::{AwaitEventRequest, EventAwaiter},
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
|
@ -22,10 +25,11 @@ fn status_code(code: StatusCode, body: String) -> Response {
|
||||||
|
|
||||||
// this function is manually formatted for easier legibility of route_services
|
// this function is manually formatted for easier legibility of route_services
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeConfig>) -> anyhow::Result<()> {
|
pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
|
||||||
// hacky fix for `move`
|
// hacky fix for `move`
|
||||||
let runtime_config_for_post = runtime_config.clone();
|
let runtime_config_for_post = runtime_config.clone();
|
||||||
let runtime_config_for_delete = runtime_config.clone();
|
let runtime_config_for_delete = runtime_config.clone();
|
||||||
|
let awaiter_for_clear = awaiter.clone();
|
||||||
|
|
||||||
let app = Router::new()
|
let app = Router::new()
|
||||||
.route(
|
.route(
|
||||||
|
|
@ -190,6 +194,19 @@ pub async fn run_server(cache: Arc<DiscordCache>, runtime_config: Arc<RuntimeCon
|
||||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
.route("/await_event", post(|body: String| async move {
|
||||||
|
info!("got request: {body}");
|
||||||
|
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
|
||||||
|
return status_code(StatusCode::BAD_REQUEST, "".to_string());
|
||||||
|
};
|
||||||
|
awaiter.handle_request(req).await;
|
||||||
|
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||||
|
}))
|
||||||
|
.route("/clear_awaiter", post(|| async move {
|
||||||
|
awaiter_for_clear.clear().await;
|
||||||
|
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||||
|
}))
|
||||||
|
|
||||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||||
.with_state(cache);
|
.with_state(cache);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -82,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
||||||
|
|
||||||
pub async fn runner(
|
pub async fn runner(
|
||||||
mut shard: Shard<RedisQueue>,
|
mut shard: Shard<RedisQueue>,
|
||||||
tx: Sender<(ShardId, String)>,
|
tx: Sender<(ShardId, Event, String)>,
|
||||||
shard_state: ShardStateManager,
|
shard_state: ShardStateManager,
|
||||||
cache: Arc<DiscordCache>,
|
cache: Arc<DiscordCache>,
|
||||||
runtime_config: Arc<RuntimeConfig>,
|
runtime_config: Arc<RuntimeConfig>,
|
||||||
|
|
@ -182,21 +182,21 @@ pub async fn runner(
|
||||||
// and the default match skips the next block (continues to the next event)
|
// and the default match skips the next block (continues to the next event)
|
||||||
match event {
|
match event {
|
||||||
Event::InteractionCreate(_) => {}
|
Event::InteractionCreate(_) => {}
|
||||||
Event::MessageCreate(m) if m.author.id != our_user_id => {}
|
Event::MessageCreate(ref m) if m.author.id != our_user_id => {}
|
||||||
Event::MessageUpdate(m)
|
Event::MessageUpdate(ref m)
|
||||||
if let Some(author) = m.author.clone()
|
if let Some(author) = m.author.clone()
|
||||||
&& author.id != our_user_id
|
&& author.id != our_user_id
|
||||||
&& !author.bot => {}
|
&& !author.bot => {}
|
||||||
Event::MessageDelete(_) => {}
|
Event::MessageDelete(_) => {}
|
||||||
Event::MessageDeleteBulk(_) => {}
|
Event::MessageDeleteBulk(_) => {}
|
||||||
Event::ReactionAdd(r) if r.user_id != our_user_id => {}
|
Event::ReactionAdd(ref r) if r.user_id != our_user_id => {}
|
||||||
_ => {
|
_ => {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
|
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
|
||||||
tx.send((shard.id(), raw_event)).await.unwrap();
|
tx.send((shard.id(), event, raw_event)).await.unwrap();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
223
crates/gateway/src/event_awaiter.rs
Normal file
223
crates/gateway/src/event_awaiter.rs
Normal file
|
|
@ -0,0 +1,223 @@
|
||||||
|
// - reaction: (message_id, user_id)
|
||||||
|
// - message: (author_id, channel_id, ?options)
|
||||||
|
// - interaction: (custom_id where not_includes "help-menu")
|
||||||
|
|
||||||
|
use std::{
|
||||||
|
collections::{hash_map::Entry, HashMap},
|
||||||
|
time::Duration,
|
||||||
|
};
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
use tokio::{sync::RwLock, time::Instant};
|
||||||
|
use tracing::info;
|
||||||
|
use twilight_gateway::Event;
|
||||||
|
use twilight_model::{
|
||||||
|
application::interaction::InteractionData,
|
||||||
|
id::{
|
||||||
|
marker::{ChannelMarker, MessageMarker, UserMarker},
|
||||||
|
Id,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15);
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
pub enum AwaitEventRequest {
|
||||||
|
Reaction {
|
||||||
|
message_id: Id<MessageMarker>,
|
||||||
|
user_id: Id<UserMarker>,
|
||||||
|
target: String,
|
||||||
|
timeout: Option<u64>,
|
||||||
|
},
|
||||||
|
Message {
|
||||||
|
channel_id: Id<ChannelMarker>,
|
||||||
|
author_id: Id<UserMarker>,
|
||||||
|
target: String,
|
||||||
|
timeout: Option<u64>,
|
||||||
|
options: Option<Vec<String>>,
|
||||||
|
},
|
||||||
|
Interaction {
|
||||||
|
id: String,
|
||||||
|
target: String,
|
||||||
|
timeout: Option<u64>,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
pub struct EventAwaiter {
|
||||||
|
reactions: RwLock<HashMap<(Id<MessageMarker>, Id<UserMarker>), (Instant, String)>>,
|
||||||
|
messages: RwLock<
|
||||||
|
HashMap<(Id<ChannelMarker>, Id<UserMarker>), (Instant, String, Option<Vec<String>>)>,
|
||||||
|
>,
|
||||||
|
interactions: RwLock<HashMap<String, (Instant, String)>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl EventAwaiter {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let v = Self {
|
||||||
|
reactions: RwLock::new(HashMap::new()),
|
||||||
|
messages: RwLock::new(HashMap::new()),
|
||||||
|
interactions: RwLock::new(HashMap::new()),
|
||||||
|
};
|
||||||
|
|
||||||
|
v
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn cleanup_loop(&self) {
|
||||||
|
loop {
|
||||||
|
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||||
|
info!("running event_awaiter cleanup loop");
|
||||||
|
let mut counts = (0, 0, 0);
|
||||||
|
let now = Instant::now();
|
||||||
|
{
|
||||||
|
let mut reactions = self.reactions.write().await;
|
||||||
|
for key in reactions.clone().keys() {
|
||||||
|
if let Entry::Occupied(entry) = reactions.entry(key.clone())
|
||||||
|
&& entry.get().0 < now
|
||||||
|
{
|
||||||
|
counts.0 += 1;
|
||||||
|
entry.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut messages = self.messages.write().await;
|
||||||
|
for key in messages.clone().keys() {
|
||||||
|
if let Entry::Occupied(entry) = messages.entry(key.clone())
|
||||||
|
&& entry.get().0 < now
|
||||||
|
{
|
||||||
|
counts.1 += 1;
|
||||||
|
entry.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut interactions = self.interactions.write().await;
|
||||||
|
for key in interactions.clone().keys() {
|
||||||
|
if let Entry::Occupied(entry) = interactions.entry(key.clone())
|
||||||
|
&& entry.get().0 < now
|
||||||
|
{
|
||||||
|
counts.2 += 1;
|
||||||
|
entry.remove();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn target_for_event(&self, event: Event) -> Option<String> {
|
||||||
|
match event {
|
||||||
|
Event::MessageCreate(message) => {
|
||||||
|
let mut messages = self.messages.write().await;
|
||||||
|
|
||||||
|
messages
|
||||||
|
.remove(&(message.channel_id, message.author.id))
|
||||||
|
.map(|(timeout, target, options)| {
|
||||||
|
if let Some(options) = options
|
||||||
|
&& !options.contains(&message.content)
|
||||||
|
{
|
||||||
|
messages.insert(
|
||||||
|
(message.channel_id, message.author.id),
|
||||||
|
(timeout, target, Some(options)),
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some((*target).to_string())
|
||||||
|
})?
|
||||||
|
}
|
||||||
|
Event::ReactionAdd(reaction)
|
||||||
|
if let Some((_, target)) = self
|
||||||
|
.reactions
|
||||||
|
.write()
|
||||||
|
.await
|
||||||
|
.remove(&(reaction.message_id, reaction.user_id)) =>
|
||||||
|
{
|
||||||
|
Some((*target).to_string())
|
||||||
|
}
|
||||||
|
Event::InteractionCreate(interaction)
|
||||||
|
if let Some(data) = interaction.data.clone()
|
||||||
|
&& let InteractionData::MessageComponent(component) = data
|
||||||
|
&& !component.custom_id.contains("help-menu")
|
||||||
|
&& let Some((_, target)) =
|
||||||
|
self.interactions.write().await.remove(&component.custom_id) =>
|
||||||
|
{
|
||||||
|
Some((*target).to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_request(&self, req: AwaitEventRequest) {
|
||||||
|
match req {
|
||||||
|
AwaitEventRequest::Reaction {
|
||||||
|
message_id,
|
||||||
|
user_id,
|
||||||
|
target,
|
||||||
|
timeout,
|
||||||
|
} => {
|
||||||
|
self.reactions.write().await.insert(
|
||||||
|
(message_id, user_id),
|
||||||
|
(
|
||||||
|
Instant::now()
|
||||||
|
.checked_add(
|
||||||
|
timeout
|
||||||
|
.map(|i| Duration::from_secs(i))
|
||||||
|
.unwrap_or(DEFAULT_TIMEOUT),
|
||||||
|
)
|
||||||
|
.expect("invalid time"),
|
||||||
|
target,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
AwaitEventRequest::Message {
|
||||||
|
channel_id,
|
||||||
|
author_id,
|
||||||
|
target,
|
||||||
|
timeout,
|
||||||
|
options,
|
||||||
|
} => {
|
||||||
|
self.messages.write().await.insert(
|
||||||
|
(channel_id, author_id),
|
||||||
|
(
|
||||||
|
Instant::now()
|
||||||
|
.checked_add(
|
||||||
|
timeout
|
||||||
|
.map(|i| Duration::from_secs(i))
|
||||||
|
.unwrap_or(DEFAULT_TIMEOUT),
|
||||||
|
)
|
||||||
|
.expect("invalid time"),
|
||||||
|
target,
|
||||||
|
options,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
AwaitEventRequest::Interaction {
|
||||||
|
id,
|
||||||
|
target,
|
||||||
|
timeout,
|
||||||
|
} => {
|
||||||
|
self.interactions.write().await.insert(
|
||||||
|
id,
|
||||||
|
(
|
||||||
|
Instant::now()
|
||||||
|
.checked_add(
|
||||||
|
timeout
|
||||||
|
.map(|i| Duration::from_secs(i))
|
||||||
|
.unwrap_or(DEFAULT_TIMEOUT),
|
||||||
|
)
|
||||||
|
.expect("invalid time"),
|
||||||
|
target,
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn clear(&self) {
|
||||||
|
self.reactions.write().await.clear();
|
||||||
|
self.messages.write().await.clear();
|
||||||
|
self.interactions.write().await.clear();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
#![feature(let_chains)]
|
#![feature(let_chains)]
|
||||||
#![feature(if_let_guard)]
|
#![feature(if_let_guard)]
|
||||||
|
#![feature(duration_constructors)]
|
||||||
|
|
||||||
use chrono::Timelike;
|
use chrono::Timelike;
|
||||||
use discord::gateway::cluster_config;
|
use discord::gateway::cluster_config;
|
||||||
|
use event_awaiter::EventAwaiter;
|
||||||
use fred::{clients::RedisPool, interfaces::*};
|
use fred::{clients::RedisPool, interfaces::*};
|
||||||
use libpk::runtime_config::RuntimeConfig;
|
use libpk::runtime_config::RuntimeConfig;
|
||||||
use reqwest::ClientBuilder;
|
use reqwest::ClientBuilder;
|
||||||
|
|
@ -12,12 +14,13 @@ use signal_hook::{
|
||||||
};
|
};
|
||||||
use std::{sync::Arc, time::Duration, vec::Vec};
|
use std::{sync::Arc, time::Duration, vec::Vec};
|
||||||
use tokio::{sync::mpsc::channel, task::JoinSet};
|
use tokio::{sync::mpsc::channel, task::JoinSet};
|
||||||
use tracing::{error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
use twilight_gateway::{MessageSender, ShardId};
|
use twilight_gateway::{MessageSender, ShardId};
|
||||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||||
|
|
||||||
mod cache_api;
|
mod cache_api;
|
||||||
mod discord;
|
mod discord;
|
||||||
|
mod event_awaiter;
|
||||||
mod logger;
|
mod logger;
|
||||||
|
|
||||||
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
|
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
|
||||||
|
|
@ -39,6 +42,11 @@ async fn real_main() -> anyhow::Result<()> {
|
||||||
|
|
||||||
let shard_state = discord::shard_state::new(redis.clone());
|
let shard_state = discord::shard_state::new(redis.clone());
|
||||||
let cache = Arc::new(discord::cache::new());
|
let cache = Arc::new(discord::cache::new());
|
||||||
|
let awaiter = Arc::new(EventAwaiter::new());
|
||||||
|
tokio::spawn({
|
||||||
|
let awaiter = awaiter.clone();
|
||||||
|
async move { awaiter.cleanup_loop().await }
|
||||||
|
});
|
||||||
|
|
||||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||||
|
|
||||||
|
|
@ -63,22 +71,36 @@ async fn real_main() -> anyhow::Result<()> {
|
||||||
|
|
||||||
set.spawn(tokio::spawn({
|
set.spawn(tokio::spawn({
|
||||||
let runtime_config = runtime_config.clone();
|
let runtime_config = runtime_config.clone();
|
||||||
async move {
|
let awaiter = awaiter.clone();
|
||||||
let client = Arc::new(ClientBuilder::new()
|
|
||||||
.connect_timeout(Duration::from_secs(1))
|
async move {
|
||||||
.timeout(Duration::from_secs(1))
|
let client = Arc::new(
|
||||||
.build()
|
ClientBuilder::new()
|
||||||
.expect("error making client"));
|
.connect_timeout(Duration::from_secs(1))
|
||||||
|
.timeout(Duration::from_secs(1))
|
||||||
|
.build()
|
||||||
|
.expect("error making client"),
|
||||||
|
);
|
||||||
|
|
||||||
|
while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await {
|
||||||
|
let target = if let Some(target) = awaiter.target_for_event(parsed_event).await {
|
||||||
|
debug!("sending event to awaiter");
|
||||||
|
Some(target)
|
||||||
|
} else if let Some(target) =
|
||||||
|
runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await
|
||||||
|
{
|
||||||
|
Some(target)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
while let Some((shard_id, event)) = event_rx.recv().await {
|
|
||||||
let target = runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await;
|
|
||||||
if let Some(target) = target {
|
if let Some(target) = target {
|
||||||
tokio::spawn({
|
tokio::spawn({
|
||||||
let client = client.clone();
|
let client = client.clone();
|
||||||
async move {
|
async move {
|
||||||
if let Err(error) = client
|
if let Err(error) = client
|
||||||
.post(format!("{target}/{}", shard_id.number()))
|
.post(format!("{target}/{}", shard_id.number()))
|
||||||
.body(event)
|
.body(raw_event)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
|
|
@ -98,7 +120,7 @@ async fn real_main() -> anyhow::Result<()> {
|
||||||
// todo: probably don't do it this way
|
// todo: probably don't do it this way
|
||||||
let api_shutdown_tx = shutdown_tx.clone();
|
let api_shutdown_tx = shutdown_tx.clone();
|
||||||
set.spawn(tokio::spawn(async move {
|
set.spawn(tokio::spawn(async move {
|
||||||
match cache_api::run_server(cache, runtime_config).await {
|
match cache_api::run_server(cache, runtime_config, awaiter.clone()).await {
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
tracing::error!(?error, "failed to serve cache api");
|
tracing::error!(?error, "failed to serve cache api");
|
||||||
let _ = api_shutdown_tx.send(());
|
let _ = api_shutdown_tx.send(());
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue