mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
Merge remote-tracking branch 'upstream/main' into rust-command-parser
This commit is contained in:
commit
f721b850d4
183 changed files with 5121 additions and 1909 deletions
|
|
@ -4,6 +4,7 @@
|
|||
# Include project code and build files
|
||||
!PluralKit.*/
|
||||
!Myriad/
|
||||
!Serilog/
|
||||
!.git
|
||||
!dashboard
|
||||
!crates/
|
||||
|
|
|
|||
8
.github/workflows/dotnet-docker.yml
vendored
8
.github/workflows/dotnet-docker.yml
vendored
|
|
@ -2,7 +2,9 @@ name: Build and push Docker image
|
|||
on:
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/docker.yml'
|
||||
- '.dockerignore'
|
||||
- '.github/workflows/dotnet-docker.yml'
|
||||
- 'ci/Dockerfile.dotnet'
|
||||
- 'ci/dotnet-version.sh'
|
||||
- 'Myriad/**'
|
||||
- 'PluralKit.API/**'
|
||||
|
|
@ -23,6 +25,9 @@ jobs:
|
|||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CR_PAT }}
|
||||
- uses: actions/checkout@v2
|
||||
with:
|
||||
submodules: true
|
||||
|
||||
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" | sed 's|/|-|g' >> $GITHUB_ENV
|
||||
|
||||
- name: Extract Docker metadata
|
||||
|
|
@ -41,6 +46,7 @@ jobs:
|
|||
with:
|
||||
# https://github.com/docker/build-push-action/issues/378
|
||||
context: .
|
||||
file: ci/Dockerfile.dotnet
|
||||
push: true
|
||||
tags: ${{ steps.meta.outputs.tags }}
|
||||
cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}
|
||||
|
|
|
|||
1
.github/workflows/rust-docker.yml
vendored
1
.github/workflows/rust-docker.yml
vendored
|
|
@ -3,6 +3,7 @@ on:
|
|||
push:
|
||||
paths:
|
||||
- 'crates/**'
|
||||
- '.dockerignore'
|
||||
- '.github/workflows/rust.yml'
|
||||
- 'ci/Dockerfile.rust'
|
||||
- 'ci/rust-docker-target.sh'
|
||||
|
|
|
|||
4
.gitmodules
vendored
4
.gitmodules
vendored
|
|
@ -0,0 +1,4 @@
|
|||
[submodule "Serilog"]
|
||||
path = Serilog
|
||||
url = https://github.com/pluralkit/serilog
|
||||
branch = f5eb991cb4c4a0c1e2407de7504c543536786598
|
||||
2722
Cargo.lock
generated
2722
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
26
Cargo.toml
26
Cargo.toml
|
|
@ -6,7 +6,6 @@ members = [
|
|||
|
||||
[workspace.dependencies]
|
||||
anyhow = "1"
|
||||
axum = "0.7.5"
|
||||
axum-macros = "0.4.1"
|
||||
bytes = "1.6.0"
|
||||
chrono = "0.4"
|
||||
|
|
@ -18,21 +17,22 @@ reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-t
|
|||
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
|
||||
serde = { version = "1.0.196", features = ["derive"] }
|
||||
serde_json = "1.0.117"
|
||||
signal-hook = "0.3.17"
|
||||
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "chrono", "macros", "uuid"] }
|
||||
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "macros", "uuid"] }
|
||||
tokio = { version = "1.36.0", features = ["full"] }
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }
|
||||
uuid = { version = "1.7.0", features = ["serde"] }
|
||||
|
||||
twilight-gateway = { git = "https://github.com/pluralkit/twilight" }
|
||||
twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] }
|
||||
twilight-util = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] }
|
||||
twilight-model = { git = "https://github.com/pluralkit/twilight" }
|
||||
twilight-http = { git = "https://github.com/pluralkit/twilight", default-features = false, features = ["rustls-native-roots"] }
|
||||
axum = { git = "https://github.com/pluralkit/axum", branch = "v0.8.4-pluralkit" }
|
||||
|
||||
#twilight-gateway = { path = "../twilight/twilight-gateway" }
|
||||
#twilight-cache-inmemory = { path = "../twilight/twilight-cache-inmemory", features = ["permission-calculator"] }
|
||||
#twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] }
|
||||
#twilight-model = { path = "../twilight/twilight-model" }
|
||||
#twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-native-roots"] }
|
||||
twilight-gateway = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef" }
|
||||
twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef", features = ["permission-calculator"] }
|
||||
twilight-util = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef", features = ["permission-calculator"] }
|
||||
twilight-model = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef" }
|
||||
twilight-http = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef", default-features = false, features = ["rustls-aws_lc_rs", "rustls-native-roots"] }
|
||||
|
||||
# twilight-gateway = { path = "../twilight/twilight-gateway" }
|
||||
# twilight-cache-inmemory = { path = "../twilight/twilight-cache-inmemory", features = ["permission-calculator"] }
|
||||
# twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] }
|
||||
# twilight-model = { path = "../twilight/twilight-model" }
|
||||
# twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-aws_lc_rs", "rustls-native-roots"] }
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
using Serilog;
|
||||
using System.Net;
|
||||
using System.Text;
|
||||
using System.Text.Json;
|
||||
|
||||
using NodaTime;
|
||||
|
||||
using Myriad.Serialization;
|
||||
using Myriad.Types;
|
||||
|
||||
|
|
@ -11,7 +14,8 @@ public class HttpDiscordCache: IDiscordCache
|
|||
{
|
||||
private readonly ILogger _logger;
|
||||
private readonly HttpClient _client;
|
||||
private readonly Uri _cacheEndpoint;
|
||||
private readonly string _cacheEndpoint;
|
||||
private readonly string? _eventTarget;
|
||||
private readonly int _shardCount;
|
||||
private readonly ulong _ownUserId;
|
||||
|
||||
|
|
@ -21,11 +25,12 @@ public class HttpDiscordCache: IDiscordCache
|
|||
|
||||
public EventHandler<(bool?, string)> OnDebug;
|
||||
|
||||
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, int shardCount, ulong ownUserId, bool useInnerCache)
|
||||
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, string? eventTarget, int shardCount, ulong ownUserId, bool useInnerCache)
|
||||
{
|
||||
_logger = logger;
|
||||
_client = client;
|
||||
_cacheEndpoint = new Uri(cacheEndpoint);
|
||||
_cacheEndpoint = cacheEndpoint;
|
||||
_eventTarget = eventTarget;
|
||||
_shardCount = shardCount;
|
||||
_ownUserId = ownUserId;
|
||||
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
|
||||
|
|
@ -47,13 +52,12 @@ public class HttpDiscordCache: IDiscordCache
|
|||
|
||||
private async Task<T?> QueryCache<T>(string endpoint, ulong guildId)
|
||||
{
|
||||
var cluster = _cacheEndpoint.Authority;
|
||||
// todo: there should not be infra-specific code here
|
||||
if (cluster.Contains(".service.consul") || cluster.Contains("process.pluralkit-gateway.internal"))
|
||||
// int(((guild_id >> 22) % shard_count) / 16)
|
||||
cluster = $"cluster{(int)(((guildId >> 22) % (ulong)_shardCount) / 16)}.{cluster}";
|
||||
var cluster = _cacheEndpoint;
|
||||
|
||||
var response = await _client.GetAsync($"{_cacheEndpoint.Scheme}://{cluster}{endpoint}");
|
||||
if (cluster.Contains("{clusterid}"))
|
||||
cluster = cluster.Replace("{clusterid}", $"{(int)(((guildId >> 22) % (ulong)_shardCount) / 16)}");
|
||||
|
||||
var response = await _client.GetAsync($"http://{cluster}{endpoint}");
|
||||
|
||||
if (response.StatusCode == HttpStatusCode.NotFound)
|
||||
return default;
|
||||
|
|
@ -65,6 +69,70 @@ public class HttpDiscordCache: IDiscordCache
|
|||
return JsonSerializer.Deserialize<T>(plaintext, _jsonSerializerOptions);
|
||||
}
|
||||
|
||||
public Task<T> GetLastMessage<T>(ulong guildId, ulong channelId)
|
||||
=> QueryCache<T>($"/guilds/{guildId}/channels/{channelId}/last_message", guildId);
|
||||
|
||||
private Task AwaitEvent(ulong guildId, object data)
|
||||
=> AwaitEventShard((int)((guildId >> 22) % (ulong)_shardCount), data);
|
||||
|
||||
private async Task AwaitEventShard(int shardId, object data)
|
||||
{
|
||||
if (_eventTarget == null)
|
||||
throw new Exception("missing event target for remote await event");
|
||||
|
||||
var cluster = _cacheEndpoint;
|
||||
|
||||
if (cluster.Contains("{clusterid}"))
|
||||
cluster = cluster.Replace("{clusterid}", $"{(int)(shardId / 16)}");
|
||||
|
||||
var response = await _client.PostAsync(
|
||||
$"http://{cluster}/await_event",
|
||||
new StringContent(JsonSerializer.Serialize(data), Encoding.UTF8)
|
||||
);
|
||||
|
||||
if (response.StatusCode != HttpStatusCode.NoContent)
|
||||
throw new Exception($"failed to await event from gateway: {response.StatusCode}");
|
||||
}
|
||||
|
||||
public async Task AwaitReaction(ulong guildId, ulong messageId, ulong userId, Duration? timeout)
|
||||
{
|
||||
var obj = new
|
||||
{
|
||||
message_id = messageId,
|
||||
user_id = userId,
|
||||
target = _eventTarget!,
|
||||
timeout = timeout?.TotalSeconds,
|
||||
};
|
||||
|
||||
await AwaitEvent(guildId, obj);
|
||||
}
|
||||
|
||||
public async Task AwaitMessage(ulong guildId, ulong channelId, ulong authorId, Duration? timeout, string[] options = null)
|
||||
{
|
||||
var obj = new
|
||||
{
|
||||
channel_id = channelId,
|
||||
author_id = authorId,
|
||||
target = _eventTarget!,
|
||||
timeout = timeout?.TotalSeconds,
|
||||
options = options,
|
||||
};
|
||||
|
||||
await AwaitEvent(guildId, obj);
|
||||
}
|
||||
|
||||
public async Task AwaitInteraction(int shardId, string id, Duration? timeout)
|
||||
{
|
||||
var obj = new
|
||||
{
|
||||
id = id,
|
||||
target = _eventTarget!,
|
||||
timeout = timeout?.TotalSeconds,
|
||||
};
|
||||
|
||||
await AwaitEventShard(shardId, obj);
|
||||
}
|
||||
|
||||
public async Task<Guild?> TryGetGuild(ulong guildId)
|
||||
{
|
||||
var hres = await QueryCache<Guild?>($"/guilds/{guildId}", guildId);
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ public record MessageUpdateEvent(ulong Id, ulong ChannelId): IGatewayEvent
|
|||
public Optional<GuildMemberPartial> Member { get; init; }
|
||||
public Optional<Message.Attachment[]> Attachments { get; init; }
|
||||
|
||||
public Message.MessageType Type { get; init; }
|
||||
|
||||
public Optional<ulong?> GuildId { get; init; }
|
||||
// TODO: lots of partials
|
||||
}
|
||||
|
|
@ -21,9 +21,14 @@
|
|||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Serilog\src\Serilog\Serilog.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="NodaTime" Version="3.2.0" />
|
||||
<PackageReference Include="NodaTime.Serialization.JsonNet" Version="3.1.0" />
|
||||
<PackageReference Include="Polly" Version="8.5.0" />
|
||||
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
|
||||
<PackageReference Include="Serilog" Version="4.2.0" />
|
||||
<PackageReference Include="StackExchange.Redis" Version="2.8.22" />
|
||||
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
|
||||
</ItemGroup>
|
||||
|
|
|
|||
|
|
@ -2,6 +2,22 @@
|
|||
"version": 1,
|
||||
"dependencies": {
|
||||
"net8.0": {
|
||||
"NodaTime": {
|
||||
"type": "Direct",
|
||||
"requested": "[3.2.0, )",
|
||||
"resolved": "3.2.0",
|
||||
"contentHash": "yoRA3jEJn8NM0/rQm78zuDNPA3DonNSZdsorMUj+dltc1D+/Lc5h9YXGqbEEZozMGr37lAoYkcSM/KjTVqD0ow=="
|
||||
},
|
||||
"NodaTime.Serialization.JsonNet": {
|
||||
"type": "Direct",
|
||||
"requested": "[3.1.0, )",
|
||||
"resolved": "3.1.0",
|
||||
"contentHash": "eEr9lXUz50TYr4rpeJG4TDAABkpxjIKr5mDSi/Zav8d6Njy6fH7x4ZtNwWFj0Vd+vIvEZNrHFQ4Gfy8j4BqRGg==",
|
||||
"dependencies": {
|
||||
"Newtonsoft.Json": "13.0.3",
|
||||
"NodaTime": "[3.0.0, 4.0.0)"
|
||||
}
|
||||
},
|
||||
"Polly": {
|
||||
"type": "Direct",
|
||||
"requested": "[8.5.0, )",
|
||||
|
|
@ -17,12 +33,6 @@
|
|||
"resolved": "1.1.1",
|
||||
"contentHash": "1MUQLiSo4KDkQe6nzQRhIU05lm9jlexX5BVsbuw0SL82ynZ+GzAHQxJVDPVBboxV37Po3SG077aX8DuSy8TkaA=="
|
||||
},
|
||||
"Serilog": {
|
||||
"type": "Direct",
|
||||
"requested": "[4.2.0, )",
|
||||
"resolved": "4.2.0",
|
||||
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
|
||||
},
|
||||
"StackExchange.Redis": {
|
||||
"type": "Direct",
|
||||
"requested": "[2.8.22, )",
|
||||
|
|
@ -52,6 +62,11 @@
|
|||
"resolved": "6.0.0",
|
||||
"contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
|
||||
},
|
||||
"Newtonsoft.Json": {
|
||||
"type": "Transitive",
|
||||
"resolved": "13.0.3",
|
||||
"contentHash": "HrC5BXdl00IP9zeV+0Z848QWPAoCr9P3bDEZguI+gkLcBKAOxix/tLEAAHC+UvDNPv4a2d18lOReHMOagPa+zQ=="
|
||||
},
|
||||
"Pipelines.Sockets.Unofficial": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.2.8",
|
||||
|
|
@ -69,6 +84,9 @@
|
|||
"type": "Transitive",
|
||||
"resolved": "5.0.1",
|
||||
"contentHash": "qEePWsaq9LoEEIqhbGe6D5J8c9IqQOUuTzzV6wn1POlfdLkJliZY3OlB0j0f17uMWlqZYjH7txj+2YbyrIA8Yg=="
|
||||
},
|
||||
"serilog": {
|
||||
"type": "Project"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,4 +6,6 @@ public class ApiConfig
|
|||
public string? ClientId { get; set; }
|
||||
public string? ClientSecret { get; set; }
|
||||
public bool TrustAuth { get; set; } = false;
|
||||
public string? AvatarServiceUrl { get; set; }
|
||||
public bool SearchGuildSettings = false;
|
||||
}
|
||||
|
|
@ -21,6 +21,12 @@ public class AuthorizationTokenHandlerMiddleware
|
|||
&& int.TryParse(sidHeaders[0], out var systemId))
|
||||
ctx.Items.Add("SystemId", new SystemId(systemId));
|
||||
|
||||
if (cfg.TrustAuth
|
||||
&& ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders)
|
||||
&& aidHeaders.Count > 0
|
||||
&& int.TryParse(aidHeaders[0], out var appId))
|
||||
ctx.Items.Add("AppId", appId);
|
||||
|
||||
await _next.Invoke(ctx);
|
||||
}
|
||||
}
|
||||
|
|
@ -20,7 +20,7 @@ public class DiscordControllerV2: PKControllerBase
|
|||
if (ContextFor(system) != LookupContext.ByOwner)
|
||||
throw Errors.GenericMissingPermissions;
|
||||
|
||||
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false);
|
||||
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false, _config.SearchGuildSettings);
|
||||
if (settings == null)
|
||||
throw Errors.SystemGuildNotFound;
|
||||
|
||||
|
|
@ -34,7 +34,7 @@ public class DiscordControllerV2: PKControllerBase
|
|||
if (ContextFor(system) != LookupContext.ByOwner)
|
||||
throw Errors.GenericMissingPermissions;
|
||||
|
||||
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false);
|
||||
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false, _config.SearchGuildSettings);
|
||||
if (settings == null)
|
||||
throw Errors.SystemGuildNotFound;
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ public class DiscordControllerV2: PKControllerBase
|
|||
if (member.System != system.Id)
|
||||
throw Errors.NotOwnMemberError;
|
||||
|
||||
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false);
|
||||
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false, _config.SearchGuildSettings ? system.Id : null);
|
||||
if (settings == null)
|
||||
throw Errors.MemberGuildNotFound;
|
||||
|
||||
|
|
@ -75,7 +75,7 @@ public class DiscordControllerV2: PKControllerBase
|
|||
if (member.System != system.Id)
|
||||
throw Errors.NotOwnMemberError;
|
||||
|
||||
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false);
|
||||
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false, _config.SearchGuildSettings ? system.Id : null);
|
||||
if (settings == null)
|
||||
throw Errors.MemberGuildNotFound;
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,7 @@
|
|||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Json;
|
||||
|
||||
using Microsoft.AspNetCore.Mvc;
|
||||
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
|
@ -50,6 +54,9 @@ public class MemberControllerV2: PKControllerBase
|
|||
if (patch.Errors.Count > 0)
|
||||
throw new ModelParseError(patch.Errors);
|
||||
|
||||
if (patch.AvatarUrl.Value != null)
|
||||
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
|
||||
|
||||
using var conn = await _db.Obtain();
|
||||
using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
|
|
@ -110,6 +117,9 @@ public class MemberControllerV2: PKControllerBase
|
|||
if (patch.Errors.Count > 0)
|
||||
throw new ModelParseError(patch.Errors);
|
||||
|
||||
if (patch.AvatarUrl.Value != null)
|
||||
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
|
||||
|
||||
var newMember = await _repo.UpdateMember(member.Id, patch);
|
||||
return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid));
|
||||
}
|
||||
|
|
@ -129,4 +139,28 @@ public class MemberControllerV2: PKControllerBase
|
|||
|
||||
return NoContent();
|
||||
}
|
||||
|
||||
private async Task<string> TryUploadAvatar(string avatarUrl, PKSystem system)
|
||||
{
|
||||
if (!avatarUrl.StartsWith("https://serve.apparyllis.com/")) return avatarUrl;
|
||||
if (_config.AvatarServiceUrl == null) return avatarUrl;
|
||||
if (!HttpContext.Items.TryGetValue("AppId", out var appId) || (int)appId != 1) return avatarUrl;
|
||||
|
||||
using var client = new HttpClient();
|
||||
var response = await client.PostAsJsonAsync(_config.AvatarServiceUrl + "/pull",
|
||||
new { url = avatarUrl, kind = "avatar", uploaded_by = (string)null, system_id = system.Uuid.ToString() });
|
||||
|
||||
if (response.StatusCode != HttpStatusCode.OK)
|
||||
{
|
||||
var error = await response.Content.ReadFromJsonAsync<ErrorResponse>();
|
||||
throw new PKError(500, 0, $"Error uploading image to CDN: {error.Error}");
|
||||
}
|
||||
|
||||
var success = await response.Content.ReadFromJsonAsync<SuccessResponse>();
|
||||
return success.Url;
|
||||
}
|
||||
|
||||
public record ErrorResponse(string Error);
|
||||
|
||||
public record SuccessResponse(string Url, bool New);
|
||||
}
|
||||
|
|
@ -32,6 +32,6 @@
|
|||
<PackageReference Include="Microsoft.AspNetCore.Mvc.Versioning" Version="5.1.0" />
|
||||
<PackageReference Include="Microsoft.AspNetCore.Mvc.Versioning.ApiExplorer" Version="5.1.0" />
|
||||
<PackageReference Include="Sentry" Version="4.13.0" />
|
||||
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
|
||||
<PackageReference Include="Serilog.AspNetCore" Version="8.0.0" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
|
|||
|
|
@ -35,7 +35,7 @@ public class Startup
|
|||
builder.RegisterInstance(InitUtils.BuildConfiguration(Environment.GetCommandLineArgs()).Build())
|
||||
.As<IConfiguration>();
|
||||
builder.RegisterModule(new ConfigModule<ApiConfig>("API"));
|
||||
builder.RegisterModule(new LoggingModule("api",
|
||||
builder.RegisterModule(new LoggingModule("dotnet-api",
|
||||
cfg: new LoggerConfiguration().Filter.ByExcluding(
|
||||
exc => exc.Exception is PKError || exc.Exception.IsUserError()
|
||||
)));
|
||||
|
|
|
|||
|
|
@ -36,17 +36,20 @@
|
|||
},
|
||||
"Serilog.AspNetCore": {
|
||||
"type": "Direct",
|
||||
"requested": "[9.0.0, )",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "JslDajPlBsn3Pww1554flJFTqROvK9zz9jONNQgn0D8Lx2Trw8L0A8/n6zEQK1DAZWXrJwiVLw8cnTR3YFuYsg==",
|
||||
"requested": "[8.0.0, )",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "FAjtKPZ4IzqFQBqZKPv6evcXK/F0ls7RoXI/62Pnx2igkDZ6nZ/jn/C/FxVATqQbEQvtqP+KViWYIe4NZIHa2w==",
|
||||
"dependencies": {
|
||||
"Serilog": "4.2.0",
|
||||
"Serilog.Extensions.Hosting": "9.0.0",
|
||||
"Serilog.Formatting.Compact": "3.0.0",
|
||||
"Serilog.Settings.Configuration": "9.0.0",
|
||||
"Serilog.Sinks.Console": "6.0.0",
|
||||
"Serilog.Sinks.Debug": "3.0.0",
|
||||
"Serilog.Sinks.File": "6.0.0"
|
||||
"Microsoft.Extensions.DependencyInjection": "8.0.0",
|
||||
"Microsoft.Extensions.Logging": "8.0.0",
|
||||
"Serilog": "3.1.1",
|
||||
"Serilog.Extensions.Hosting": "8.0.0",
|
||||
"Serilog.Extensions.Logging": "8.0.0",
|
||||
"Serilog.Formatting.Compact": "2.0.0",
|
||||
"Serilog.Settings.Configuration": "8.0.0",
|
||||
"Serilog.Sinks.Console": "5.0.0",
|
||||
"Serilog.Sinks.Debug": "2.0.0",
|
||||
"Serilog.Sinks.File": "5.0.0"
|
||||
}
|
||||
},
|
||||
"App.Metrics": {
|
||||
|
|
@ -296,21 +299,21 @@
|
|||
},
|
||||
"Microsoft.Extensions.DependencyModel": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "saxr2XzwgDU77LaQfYFXmddEDRUKHF4DaGMZkNB3qjdVSZlax3//dGJagJkKrGMIPNZs2jVFXITyCCR6UHJNdA==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "NSmDw3K0ozNDgShSIpsZcbFIzBX4w28nDag+TfaQujkXGazBm+lid5onlWoCBy4VsLxqnnKjEBbGSJVWJMf43g==",
|
||||
"dependencies": {
|
||||
"System.Text.Encodings.Web": "9.0.0",
|
||||
"System.Text.Json": "9.0.0"
|
||||
"System.Text.Encodings.Web": "8.0.0",
|
||||
"System.Text.Json": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Diagnostics.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "1K8P7XzuzX8W8pmXcZjcrqS6x5eSSdvhQohmcpgiQNY/HlDAlnrhR9dvlURfFz428A+RTCJpUyB+aKTA6AgVcQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "JHYCQG7HmugNYUhOl368g+NMxYE/N/AiclCYRNlgCY9eVyiBkOHMwK4x60RYMxv9EL3+rmj1mqHvdCiPpC+D4Q==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Options": "9.0.0",
|
||||
"System.Diagnostics.DiagnosticSource": "9.0.0"
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Options": "8.0.0",
|
||||
"System.Diagnostics.DiagnosticSource": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": {
|
||||
|
|
@ -338,14 +341,14 @@
|
|||
},
|
||||
"Microsoft.Extensions.Hosting.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "yUKJgu81ExjvqbNWqZKshBbLntZMbMVz/P7Way2SBx7bMqA08Mfdc9O7hWDKAiSp+zPUGT6LKcSCQIPeDK+CCw==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "AG7HWwVRdCHlaA++1oKDxLsXIBxmDpMPb3VoyOoAghEWnkUvEAdYQUwnV4jJbAaa/nMYNiEh5ByoLauZBEiovg==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Configuration.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "9.0.0"
|
||||
"Microsoft.Extensions.Configuration.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Diagnostics.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Logging": {
|
||||
|
|
@ -461,30 +464,25 @@
|
|||
"System.IO.Pipelines": "5.0.1"
|
||||
}
|
||||
},
|
||||
"Serilog": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.2.0",
|
||||
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
|
||||
},
|
||||
"Serilog.Extensions.Hosting": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "u2TRxuxbjvTAldQn7uaAwePkWxTHIqlgjelekBtilAGL5sYyF3+65NWctN4UrwwGLsDC7c3Vz3HnOlu+PcoxXg==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "db0OcbWeSCvYQkHWu6n0v40N4kKaTAXNjlM3BKvcbwvNzYphQFcBR+36eQ/7hMMwOkJvAyLC2a9/jNdUL5NjtQ==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Hosting.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "9.0.0",
|
||||
"Serilog": "4.2.0",
|
||||
"Serilog.Extensions.Logging": "9.0.0"
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Hosting.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "8.0.0",
|
||||
"Serilog": "3.1.1",
|
||||
"Serilog.Extensions.Logging": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Serilog.Extensions.Logging": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Logging": "9.0.0",
|
||||
"Serilog": "4.2.0"
|
||||
"Microsoft.Extensions.Logging": "8.0.0",
|
||||
"Serilog": "3.1.1"
|
||||
}
|
||||
},
|
||||
"Serilog.Formatting.Compact": {
|
||||
|
|
@ -514,12 +512,12 @@
|
|||
},
|
||||
"Serilog.Settings.Configuration": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "4/Et4Cqwa+F88l5SeFeNZ4c4Z6dEAIKbu3MaQb2Zz9F/g27T5a3wvfMcmCOaAiACjfUb4A6wrlTVfyYUZk3RRQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "nR0iL5HwKj5v6ULo3/zpP8NMcq9E2pxYA6XKTSWCbugVs4YqPyvaqaKOY+OMpPivKp7zMEpax2UKHnDodbRB0Q==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Configuration.Binder": "9.0.0",
|
||||
"Microsoft.Extensions.DependencyModel": "9.0.0",
|
||||
"Serilog": "4.2.0"
|
||||
"Microsoft.Extensions.Configuration.Binder": "8.0.0",
|
||||
"Microsoft.Extensions.DependencyModel": "8.0.0",
|
||||
"Serilog": "3.1.1"
|
||||
}
|
||||
},
|
||||
"Serilog.Sinks.Async": {
|
||||
|
|
@ -540,10 +538,10 @@
|
|||
},
|
||||
"Serilog.Sinks.Debug": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.0",
|
||||
"contentHash": "4BzXcdrgRX7wde9PmHuYd9U6YqycCC28hhpKonK7hx0wb19eiuRj16fPcPSVp0o/Y1ipJuNLYQ00R3q2Zs8FDA==",
|
||||
"resolved": "2.0.0",
|
||||
"contentHash": "Y6g3OBJ4JzTyyw16fDqtFcQ41qQAydnEvEqmXjhwhgjsnG/FaJ8GUqF5ldsC/bVkK8KYmqrPhDO+tm4dF6xx4A==",
|
||||
"dependencies": {
|
||||
"Serilog": "4.0.0"
|
||||
"Serilog": "2.10.0"
|
||||
}
|
||||
},
|
||||
"Serilog.Sinks.Elasticsearch": {
|
||||
|
|
@ -837,8 +835,8 @@
|
|||
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||
"Npgsql": "[9.0.2, )",
|
||||
"Npgsql.NodaTime": "[9.0.2, )",
|
||||
"Serilog": "[4.2.0, )",
|
||||
"Serilog.Extensions.Logging": "[9.0.0, )",
|
||||
"Serilog": "[4.1.0, )",
|
||||
"Serilog.Extensions.Logging": "[8.0.0, )",
|
||||
"Serilog.Formatting.Compact": "[3.0.0, )",
|
||||
"Serilog.NodaTime": "[3.0.0, )",
|
||||
"Serilog.Sinks.Async": "[2.1.0, )",
|
||||
|
|
@ -852,6 +850,9 @@
|
|||
"System.Interactive.Async": "[6.0.1, )",
|
||||
"ipnetwork2": "[3.0.667, )"
|
||||
}
|
||||
},
|
||||
"serilog": {
|
||||
"type": "Project"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -33,7 +33,10 @@ public class ApplicationCommandProxiedMessage
|
|||
var messageId = ctx.Event.Data!.TargetId!.Value;
|
||||
var msg = await ctx.Repository.GetFullMessage(messageId);
|
||||
if (msg == null)
|
||||
throw Errors.MessageNotFound(messageId);
|
||||
{
|
||||
await QueryCommandMessage(ctx);
|
||||
return;
|
||||
}
|
||||
|
||||
var showContent = true;
|
||||
var channel = await _rest.GetChannelOrNull(msg.Message.Channel);
|
||||
|
|
@ -58,6 +61,20 @@ public class ApplicationCommandProxiedMessage
|
|||
await ctx.Reply(embeds: embeds.ToArray());
|
||||
}
|
||||
|
||||
private async Task QueryCommandMessage(InteractionContext ctx)
|
||||
{
|
||||
var messageId = ctx.Event.Data!.TargetId!.Value;
|
||||
var msg = await ctx.Repository.GetCommandMessage(messageId);
|
||||
if (msg == null)
|
||||
throw Errors.MessageNotFound(messageId);
|
||||
|
||||
var embeds = new List<Embed>();
|
||||
|
||||
embeds.Add(await _embeds.CreateCommandMessageInfoEmbed(msg, true));
|
||||
|
||||
await ctx.Reply(embeds: embeds.ToArray());
|
||||
}
|
||||
|
||||
public async Task DeleteMessage(InteractionContext ctx)
|
||||
{
|
||||
var messageId = ctx.Event.Data!.TargetId!.Value;
|
||||
|
|
|
|||
|
|
@ -32,13 +32,14 @@ public class Bot
|
|||
private readonly DiscordApiClient _rest;
|
||||
private readonly RedisService _redis;
|
||||
private readonly ILifetimeScope _services;
|
||||
private readonly RuntimeConfigService _runtimeConfig;
|
||||
|
||||
private Timer _periodicTask; // Never read, just kept here for GC reasons
|
||||
|
||||
public Bot(ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,
|
||||
BotConfig config, RedisService redis,
|
||||
ErrorMessageService errorMessageService, CommandMessageService commandMessageService,
|
||||
Cluster cluster, DiscordApiClient rest, IDiscordCache cache)
|
||||
Cluster cluster, DiscordApiClient rest, IDiscordCache cache, RuntimeConfigService runtimeConfig)
|
||||
{
|
||||
_logger = logger.ForContext<Bot>();
|
||||
_services = services;
|
||||
|
|
@ -51,6 +52,7 @@ public class Bot
|
|||
_rest = rest;
|
||||
_redis = redis;
|
||||
_cache = cache;
|
||||
_runtimeConfig = runtimeConfig;
|
||||
}
|
||||
|
||||
private string BotStatus => $"{(_config.Prefixes ?? BotConfig.DefaultPrefixes)[0]}help"
|
||||
|
|
@ -97,13 +99,15 @@ public class Bot
|
|||
|
||||
private async Task OnEventReceived(int shardId, IGatewayEvent evt)
|
||||
{
|
||||
if (_runtimeConfig.Exists("disable_events")) return;
|
||||
|
||||
// we HandleGatewayEvent **before** getting the own user, because the own user is set in HandleGatewayEvent for ReadyEvent
|
||||
await _cache.HandleGatewayEvent(evt);
|
||||
await _cache.TryUpdateSelfMember(_config.ClientId, evt);
|
||||
await OnEventReceivedInner(shardId, evt);
|
||||
}
|
||||
|
||||
private async Task OnEventReceivedInner(int shardId, IGatewayEvent evt)
|
||||
public async Task OnEventReceivedInner(int shardId, IGatewayEvent evt)
|
||||
{
|
||||
// HandleEvent takes a type parameter, automatically inferred by the event type
|
||||
// It will then look up an IEventHandler<TypeOfEvent> in the DI container and call that object's handler method
|
||||
|
|
@ -278,7 +282,7 @@ public class Bot
|
|||
_logger.Debug("Running once-per-minute scheduled tasks");
|
||||
|
||||
// Check from a new custom status from Redis and update Discord accordingly
|
||||
if (true)
|
||||
if (!_config.DisableGateway)
|
||||
{
|
||||
var newStatus = await _redis.Connection.GetDatabase().StringGetAsync("pluralkit:botstatus");
|
||||
if (newStatus != CustomStatusMessage)
|
||||
|
|
|
|||
|
|
@ -24,6 +24,10 @@ public class BotConfig
|
|||
public string? HttpCacheUrl { get; set; }
|
||||
public bool HttpUseInnerCache { get; set; } = false;
|
||||
|
||||
public string? HttpListenerAddr { get; set; }
|
||||
public bool DisableGateway { get; set; } = false;
|
||||
public string? EventAwaiterTarget { get; set; }
|
||||
|
||||
public string? DiscordBaseUrl { get; set; }
|
||||
public string? AvatarServiceUrl { get; set; }
|
||||
|
||||
|
|
|
|||
|
|
@ -108,7 +108,7 @@ public partial class CommandTree
|
|||
ctx.Reply(
|
||||
$"{Emojis.Error} Parsed command {ctx.Parameters.Callback().AsCode()} not implemented in PluralKit.Bot!"),
|
||||
};
|
||||
if (ctx.Match("system", "s"))
|
||||
if (ctx.Match("system", "s", "account", "acc"))
|
||||
return HandleSystemCommand(ctx);
|
||||
if (ctx.Match("member", "m"))
|
||||
return HandleMemberCommand(ctx);
|
||||
|
|
@ -554,6 +554,8 @@ public partial class CommandTree
|
|||
case "system":
|
||||
case "systems":
|
||||
case "s":
|
||||
case "account":
|
||||
case "acc":
|
||||
await PrintCommandList(ctx, "systems", SystemCommands);
|
||||
break;
|
||||
case "member":
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ public class Context
|
|||
internal readonly ModelRepository Repository;
|
||||
internal readonly RedisService Redis;
|
||||
|
||||
public async Task<Message> Reply(string text = null, Embed embed = null, AllowedMentions? mentions = null)
|
||||
public async Task<Message> Reply(string text = null, Embed embed = null, AllowedMentions? mentions = null, MultipartFile[]? files = null)
|
||||
{
|
||||
var botPerms = await BotPermissions;
|
||||
|
||||
|
|
@ -91,20 +91,28 @@ public class Context
|
|||
if (embed != null && !botPerms.HasFlag(PermissionSet.EmbedLinks))
|
||||
throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled.");
|
||||
|
||||
if (files != null && !botPerms.HasFlag(PermissionSet.AttachFiles))
|
||||
throw new PKError("PluralKit does not have permission to attach files in this channel. Please ensure I have the **Attach Files** permission enabled.");
|
||||
|
||||
var msg = await Rest.CreateMessage(Channel.Id, new MessageRequest
|
||||
{
|
||||
Content = text,
|
||||
Embeds = embed != null ? new[] { embed } : null,
|
||||
// Default to an empty allowed mentions object instead of null (which means no mentions allowed)
|
||||
AllowedMentions = mentions ?? new AllowedMentions()
|
||||
});
|
||||
}, files: files);
|
||||
|
||||
// if (embed != null)
|
||||
// {
|
||||
// Sensitive information that might want to be deleted by :x: reaction is typically in an embed format (member cards, for example)
|
||||
// but since we can, we just store all sent messages for possible deletion
|
||||
await _commandMessageService.RegisterMessage(msg.Id, Guild?.Id ?? 0, msg.ChannelId, Author.Id);
|
||||
// }
|
||||
// store log of sent message, so it can be queried or deleted later
|
||||
// skip DMs as DM messages can always be deleted
|
||||
if (Guild != null)
|
||||
await Repository.AddCommandMessage(new Core.CommandMessage
|
||||
{
|
||||
Mid = msg.Id,
|
||||
Guild = Guild!.Id,
|
||||
Channel = Channel.Id,
|
||||
Sender = Author.Id,
|
||||
OriginalMid = Message.Id,
|
||||
});
|
||||
|
||||
return msg;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -443,10 +443,11 @@ public class Groups
|
|||
await ctx.Reply(embed: new EmbedBuilder()
|
||||
.Title("Group color")
|
||||
.Color(target.Color.ToDiscordColor())
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{target.Color}/?text=%20"))
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
|
||||
.Description($"This group's color is **#{target.Color}**."
|
||||
+ (isOwnSystem ? $" To clear it, type `{ctx.DefaultPrefix}group {target.Reference(ctx)} color -clear`." : ""))
|
||||
.Build());
|
||||
.Build(),
|
||||
files: [MiscUtils.GenerateColorPreview(target.Color)]);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -471,8 +472,9 @@ public class Groups
|
|||
await ctx.Reply(embed: new EmbedBuilder()
|
||||
.Title($"{Emojis.Success} Group color changed.")
|
||||
.Color(color.ToDiscordColor())
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{color}/?text=%20"))
|
||||
.Build());
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
|
||||
.Build(),
|
||||
files: [MiscUtils.GenerateColorPreview(color)]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -308,10 +308,11 @@ public class MemberEdit
|
|||
await ctx.Reply(embed: new EmbedBuilder()
|
||||
.Title("Member color")
|
||||
.Color(target.Color.ToDiscordColor())
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{target.Color}/?text=%20"))
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
|
||||
.Description($"This member's color is **#{target.Color}**."
|
||||
+ (isOwnSystem ? $" To clear it, type `{ctx.DefaultPrefix}member {target.Reference(ctx)} color -clear`." : ""))
|
||||
.Build());
|
||||
.Build(),
|
||||
files: [MiscUtils.GenerateColorPreview(target.Color)]);
|
||||
return;
|
||||
}
|
||||
|
||||
|
|
@ -336,8 +337,9 @@ public class MemberEdit
|
|||
await ctx.Reply(embed: new EmbedBuilder()
|
||||
.Title($"{Emojis.Success} Member color changed.")
|
||||
.Color(color.ToDiscordColor())
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{color}/?text=%20"))
|
||||
.Build());
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
|
||||
.Build(),
|
||||
files: [MiscUtils.GenerateColorPreview(color)]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ public class ProxiedMessage
|
|||
throw new PKError(error);
|
||||
}
|
||||
|
||||
var lastMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId);
|
||||
var lastMessage = await _lastMessageCache.GetLastMessage(ctx.Message.GuildId ?? 0, ctx.Message.ChannelId);
|
||||
|
||||
var isLatestMessage = lastMessage?.Current.Id == ctx.Message.Id
|
||||
? lastMessage?.Previous?.Id == msg.Mid
|
||||
|
|
@ -347,13 +347,8 @@ public class ProxiedMessage
|
|||
var message = await ctx.Repository.GetFullMessage(messageId.Value);
|
||||
if (message == null)
|
||||
{
|
||||
if (isDelete)
|
||||
{
|
||||
await DeleteCommandMessage(ctx, messageId.Value);
|
||||
return;
|
||||
}
|
||||
|
||||
throw Errors.MessageNotFound(messageId.Value);
|
||||
await GetCommandMessage(ctx, messageId.Value, isDelete);
|
||||
return;
|
||||
}
|
||||
|
||||
var showContent = true;
|
||||
|
|
@ -448,20 +443,35 @@ public class ProxiedMessage
|
|||
await ctx.Reply(embed: await _embeds.CreateMessageInfoEmbed(message, showContent, ctx.Config));
|
||||
}
|
||||
|
||||
private async Task DeleteCommandMessage(Context ctx, ulong messageId)
|
||||
private async Task GetCommandMessage(Context ctx, ulong messageId, bool isDelete)
|
||||
{
|
||||
var cmessage = await ctx.Services.Resolve<CommandMessageService>().GetCommandMessage(messageId);
|
||||
if (cmessage == null)
|
||||
var msg = await _repo.GetCommandMessage(messageId);
|
||||
if (msg == null)
|
||||
throw Errors.MessageNotFound(messageId);
|
||||
|
||||
if (cmessage!.AuthorId != ctx.Author.Id)
|
||||
throw new PKError("You can only delete command messages queried by this account.");
|
||||
if (isDelete)
|
||||
{
|
||||
if (msg.Sender != ctx.Author.Id)
|
||||
throw new PKError("You can only delete command messages queried by this account.");
|
||||
|
||||
await ctx.Rest.DeleteMessage(cmessage.ChannelId, messageId);
|
||||
await ctx.Rest.DeleteMessage(msg.Channel, messageId);
|
||||
|
||||
if (ctx.Guild != null)
|
||||
await ctx.Rest.DeleteMessage(ctx.Message);
|
||||
else
|
||||
await ctx.Rest.CreateReaction(ctx.Message.ChannelId, ctx.Message.Id, new Emoji { Name = Emojis.Success });
|
||||
if (ctx.Guild != null)
|
||||
await ctx.Rest.DeleteMessage(ctx.Message);
|
||||
else
|
||||
await ctx.Rest.CreateReaction(ctx.Message.ChannelId, ctx.Message.Id, new Emoji { Name = Emojis.Success });
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
var showContent = true;
|
||||
|
||||
var channel = await _rest.GetChannelOrNull(msg.Channel);
|
||||
if (channel == null)
|
||||
showContent = false;
|
||||
else if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel))
|
||||
showContent = false;
|
||||
|
||||
await ctx.Reply(embed: await _embeds.CreateCommandMessageInfoEmbed(msg, showContent));
|
||||
}
|
||||
}
|
||||
|
|
@ -37,10 +37,13 @@ public class System
|
|||
.Field(new Embed.Field("Getting Started",
|
||||
"New to PK? Check out our Getting Started guide on setting up members and proxies: https://pluralkit.me/start\n" +
|
||||
$"Otherwise, type `{ctx.DefaultPrefix}system` to view your system and `{ctx.DefaultPrefix}system help` for more information about commands you can use."))
|
||||
.Field(new Embed.Field($"{Emojis.Warn} Notice {Emojis.Warn}", "PluralKit is a bot meant to help you share information about your system. " +
|
||||
.Field(new Embed.Field($"{Emojis.Warn} Notice: Public By Default {Emojis.Warn}", "PluralKit is a bot meant to help you share information about your system. " +
|
||||
"Member descriptions are meant to be the equivalent to a Discord About Me. Because of this, any info you put in PK is **public by default**.\n" +
|
||||
"Note that this does **not** include message content, only member fields. For more information, check out " +
|
||||
"[the privacy section of the user guide](https://pluralkit.me/guide/#privacy). "))
|
||||
.Field(new Embed.Field($"{Emojis.Warn} Notice: Implicit Acceptance of ToS {Emojis.Warn}", "By using the PluralKit bot you implicitly agree to our " +
|
||||
"[Terms of Service](https://pluralkit.me/terms-of-service/). For questions please ask in our [support server](<https://discord.gg/PczBt78>) or " +
|
||||
"email legal@pluralkit.me"))
|
||||
.Field(new Embed.Field("System Recovery", "In the case of your Discord account getting lost or deleted, the PluralKit staff can help you recover your system. " +
|
||||
"In order to do so, we will need your **PluralKit token**. This is the *only* way you can prove ownership so we can help you recover your system. " +
|
||||
$"To get it, run `{ctx.DefaultPrefix}token` and then store it in a safe place.\n\n" +
|
||||
|
|
|
|||
|
|
@ -233,8 +233,9 @@ public class SystemEdit
|
|||
await ctx.Reply(embed: new EmbedBuilder()
|
||||
.Title($"{Emojis.Success} System color changed.")
|
||||
.Color(newColor.ToDiscordColor())
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{newColor}/?text=%20"))
|
||||
.Build());
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
|
||||
.Build(),
|
||||
files: [MiscUtils.GenerateColorPreview(color)]););
|
||||
}
|
||||
|
||||
public async Task ClearColor(Context ctx, PKSystem target, bool flagConfirmYes)
|
||||
|
|
@ -274,10 +275,11 @@ public class SystemEdit
|
|||
await ctx.Reply(embed: new EmbedBuilder()
|
||||
.Title("System color")
|
||||
.Color(target.Color.ToDiscordColor())
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{target.Color}/?text=%20"))
|
||||
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
|
||||
.Description(
|
||||
$"This system's color is **#{target.Color}**." + (isOwnSystem ? $" To clear it, type `{ctx.DefaultPrefix}s color -clear`." : ""))
|
||||
.Build());
|
||||
.Build(),
|
||||
files: [MiscUtils.GenerateColorPreview(target.Color)]);
|
||||
}
|
||||
|
||||
public async Task ClearTag(Context ctx, PKSystem target, bool flagConfirmYes)
|
||||
|
|
@ -475,7 +477,7 @@ public class SystemEdit
|
|||
else
|
||||
str +=
|
||||
" Member names will now use the global system tag when proxied in the current server, if there is one set."
|
||||
+ "\n\nTo check or change where your tag appears in your name use the command `{ctx.DefaultPrefix}cfg name format`.";
|
||||
+ $"\n\nTo check or change where your tag appears in your name use the command `{ctx.DefaultPrefix}cfg name format`.";
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,9 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
|
|||
public (ulong?, ulong?) ErrorChannelFor(MessageCreateEvent evt, ulong userId) => (evt.GuildId, evt.ChannelId);
|
||||
private bool IsDuplicateMessage(Message msg) =>
|
||||
// We consider a message duplicate if it has the same ID as the previous message that hit the gateway
|
||||
_lastMessageCache.GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
|
||||
// use only the local cache here
|
||||
// http gateway sets last message before forwarding the message here, so this will always return true
|
||||
_lastMessageCache._GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
|
||||
|
||||
public async Task Handle(int shardId, MessageCreateEvent evt)
|
||||
{
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
|||
var guild = await _cache.TryGetGuild(channel.GuildId!.Value);
|
||||
if (guild == null)
|
||||
throw new Exception("could not find self guild in MessageEdited event");
|
||||
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
|
||||
var lastMessage = (await _lastMessageCache.GetLastMessage(evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0, evt.ChannelId))?.Current;
|
||||
|
||||
// Only react to the last message in the channel
|
||||
if (lastMessage?.Id != evt.Id)
|
||||
|
|
@ -107,10 +107,6 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
|||
? new Message.Reference(channel.GuildId, evt.ChannelId, lastMessage.ReferencedMessage.Value)
|
||||
: null;
|
||||
|
||||
var messageType = lastMessage.ReferencedMessage != null
|
||||
? Message.MessageType.Reply
|
||||
: Message.MessageType.Default;
|
||||
|
||||
// TODO: is this missing anything?
|
||||
var equivalentEvt = new MessageCreateEvent
|
||||
{
|
||||
|
|
@ -123,7 +119,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
|
|||
Attachments = evt.Attachments.Value ?? Array.Empty<Message.Attachment>(),
|
||||
MessageReference = messageReference,
|
||||
ReferencedMessage = referencedMessage,
|
||||
Type = messageType,
|
||||
Type = evt.Type,
|
||||
};
|
||||
return equivalentEvt;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -111,6 +111,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
|
|||
case "\U0001F514": // Bell
|
||||
case "\U0001F6CE": // Bellhop bell
|
||||
case "\U0001F3D3": // Ping pong paddle (lol)
|
||||
case "\U0001FAD1": // Bell pepper
|
||||
case "\u23F0": // Alarm clock
|
||||
case "\u2757": // Exclamation mark
|
||||
{
|
||||
|
|
|
|||
|
|
@ -57,16 +57,6 @@ public class Init
|
|||
|
||||
var cache = services.Resolve<IDiscordCache>();
|
||||
|
||||
if (config.Cluster == null)
|
||||
{
|
||||
// "Connect to the database" (ie. set off database migrations and ensure state)
|
||||
logger.Information("Connecting to database");
|
||||
await services.Resolve<IDatabase>().ApplyMigrations();
|
||||
|
||||
// Clear shard status from Redis
|
||||
await redis.Connection.GetDatabase().KeyDeleteAsync("pluralkit:shardstatus");
|
||||
}
|
||||
|
||||
logger.Information("Initializing bot");
|
||||
var bot = services.Resolve<Bot>();
|
||||
|
||||
|
|
@ -76,10 +66,19 @@ public class Init
|
|||
// Init the bot instance itself, register handlers and such to the client before beginning to connect
|
||||
bot.Init();
|
||||
|
||||
// Start the Discord shards themselves (handlers already set up)
|
||||
logger.Information("Connecting to Discord");
|
||||
await StartCluster(services);
|
||||
// load runtime config from redis
|
||||
await services.Resolve<RuntimeConfigService>().LoadConfig();
|
||||
|
||||
// Start HTTP server
|
||||
if (config.HttpListenerAddr != null)
|
||||
services.Resolve<HttpListenerService>().Start(config.HttpListenerAddr);
|
||||
|
||||
// Start the Discord shards themselves (handlers already set up)
|
||||
if (!config.DisableGateway)
|
||||
{
|
||||
logger.Information("Connecting to Discord");
|
||||
await StartCluster(services);
|
||||
}
|
||||
logger.Information("Connected! All is good (probably).");
|
||||
|
||||
// Lastly, we just... wait. Everything else is handled in the DiscordClient event loop
|
||||
|
|
@ -149,7 +148,7 @@ public class Init
|
|||
var builder = new ContainerBuilder();
|
||||
builder.RegisterInstance(config);
|
||||
builder.RegisterModule(new ConfigModule<BotConfig>("Bot"));
|
||||
builder.RegisterModule(new LoggingModule("bot"));
|
||||
builder.RegisterModule(new LoggingModule("dotnet-bot"));
|
||||
builder.RegisterModule(new MetricsModule());
|
||||
builder.RegisterModule<DataStoreModule>();
|
||||
builder.RegisterModule<BotModule>();
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ public abstract class BaseInteractive
|
|||
ButtonStyle style = ButtonStyle.Secondary, bool disabled = false)
|
||||
{
|
||||
var dispatch = _ctx.Services.Resolve<InteractionDispatchService>();
|
||||
var customId = dispatch.Register(handler, Timeout);
|
||||
var customId = dispatch.Register(_ctx.ShardId, handler, Timeout);
|
||||
|
||||
var button = new Button
|
||||
{
|
||||
|
|
@ -89,7 +89,7 @@ public abstract class BaseInteractive
|
|||
{
|
||||
var dispatch = ctx.Services.Resolve<InteractionDispatchService>();
|
||||
foreach (var button in _buttons)
|
||||
button.CustomId = dispatch.Register(button.Handler, Timeout);
|
||||
button.CustomId = dispatch.Register(_ctx.ShardId, button.Handler, Timeout);
|
||||
}
|
||||
|
||||
public abstract Task Start();
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
using Autofac;
|
||||
|
||||
using Myriad.Cache;
|
||||
using Myriad.Gateway;
|
||||
using Myriad.Rest.Types;
|
||||
using Myriad.Types;
|
||||
|
|
@ -69,6 +70,9 @@ public class YesNoPrompt: BaseInteractive
|
|||
return true;
|
||||
}
|
||||
|
||||
// no need to reawait message
|
||||
// gateway will already have sent us only matching messages
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -88,6 +92,17 @@ public class YesNoPrompt: BaseInteractive
|
|||
{
|
||||
try
|
||||
{
|
||||
// check if http gateway and set listener
|
||||
// todo: this one needs to handle options for message
|
||||
if (_ctx.Cache is HttpDiscordCache)
|
||||
await (_ctx.Cache as HttpDiscordCache).AwaitMessage(
|
||||
_ctx.Guild?.Id ?? 0,
|
||||
_ctx.Channel.Id,
|
||||
_ctx.Author.Id,
|
||||
Timeout,
|
||||
options: new[] { "yes", "y", "no", "n" }
|
||||
);
|
||||
|
||||
await queue.WaitFor(MessagePredicate, Timeout, cts.Token);
|
||||
}
|
||||
catch (TimeoutException e)
|
||||
|
|
|
|||
|
|
@ -49,8 +49,15 @@ public class BotModule: Module
|
|||
|
||||
if (botConfig.HttpCacheUrl != null)
|
||||
{
|
||||
var cache = new HttpDiscordCache(c.Resolve<ILogger>(),
|
||||
c.Resolve<HttpClient>(), botConfig.HttpCacheUrl, botConfig.Cluster?.TotalShards ?? 1, botConfig.ClientId, botConfig.HttpUseInnerCache);
|
||||
var cache = new HttpDiscordCache(
|
||||
c.Resolve<ILogger>(),
|
||||
c.Resolve<HttpClient>(),
|
||||
botConfig.HttpCacheUrl,
|
||||
botConfig.EventAwaiterTarget,
|
||||
botConfig.Cluster?.TotalShards ?? 1,
|
||||
botConfig.ClientId,
|
||||
botConfig.HttpUseInnerCache
|
||||
);
|
||||
|
||||
var metrics = c.Resolve<IMetrics>();
|
||||
|
||||
|
|
@ -153,6 +160,8 @@ public class BotModule: Module
|
|||
builder.RegisterType<CommandMessageService>().AsSelf().SingleInstance();
|
||||
builder.RegisterType<InteractionDispatchService>().AsSelf().SingleInstance();
|
||||
builder.RegisterType<AvatarHostingService>().AsSelf().SingleInstance();
|
||||
builder.RegisterType<HttpListenerService>().AsSelf().SingleInstance();
|
||||
builder.RegisterType<RuntimeConfigService>().AsSelf().SingleInstance();
|
||||
|
||||
// Sentry stuff
|
||||
builder.Register(_ => new Scope(null)).AsSelf().InstancePerLifetimeScope();
|
||||
|
|
|
|||
|
|
@ -25,5 +25,6 @@
|
|||
<ItemGroup>
|
||||
<PackageReference Include="Humanizer.Core" Version="2.14.1" />
|
||||
<PackageReference Include="Sentry" Version="4.13.0" />
|
||||
<PackageReference Include="Watson.Lite" Version="6.3.5" />
|
||||
</ItemGroup>
|
||||
</Project>
|
||||
|
|
|
|||
|
|
@ -246,7 +246,7 @@ public class ProxyService
|
|||
ChannelId = rootChannel.Id,
|
||||
ThreadId = threadId,
|
||||
MessageId = trigger.Id,
|
||||
Name = await FixSameName(messageChannel.Id, ctx, match.Member),
|
||||
Name = await FixSameName(trigger.GuildId!.Value, messageChannel.Id, ctx, match.Member),
|
||||
AvatarUrl = AvatarUtils.TryRewriteCdnUrl(match.Member.ProxyAvatar(ctx)),
|
||||
Content = content,
|
||||
Attachments = trigger.Attachments,
|
||||
|
|
@ -458,11 +458,11 @@ public class ProxyService
|
|||
};
|
||||
}
|
||||
|
||||
private async Task<string> FixSameName(ulong channelId, MessageContext ctx, ProxyMember member)
|
||||
private async Task<string> FixSameName(ulong guildId, ulong channelId, MessageContext ctx, ProxyMember member)
|
||||
{
|
||||
var proxyName = member.ProxyName(ctx);
|
||||
|
||||
var lastMessage = _lastMessage.GetLastMessage(channelId)?.Previous;
|
||||
var lastMessage = (await _lastMessage.GetLastMessage(guildId, channelId))?.Previous;
|
||||
if (lastMessage == null)
|
||||
// cache is out of date or channel is empty.
|
||||
return proxyName;
|
||||
|
|
|
|||
|
|
@ -9,35 +9,35 @@ namespace PluralKit.Bot;
|
|||
public class CommandMessageService
|
||||
{
|
||||
private readonly RedisService _redis;
|
||||
private readonly ModelRepository _repo;
|
||||
private readonly ILogger _logger;
|
||||
private static readonly TimeSpan CommandMessageRetention = TimeSpan.FromHours(24);
|
||||
|
||||
public CommandMessageService(RedisService redis, IClock clock, ILogger logger)
|
||||
public CommandMessageService(RedisService redis, ModelRepository repo, IClock clock, ILogger logger)
|
||||
{
|
||||
_redis = redis;
|
||||
_repo = repo;
|
||||
_logger = logger.ForContext<CommandMessageService>();
|
||||
}
|
||||
|
||||
public async Task RegisterMessage(ulong messageId, ulong guildId, ulong channelId, ulong authorId)
|
||||
{
|
||||
if (_redis.Connection == null) return;
|
||||
|
||||
_logger.Debug(
|
||||
"Registering command response {MessageId} from author {AuthorId} in {ChannelId}",
|
||||
messageId, authorId, channelId
|
||||
);
|
||||
|
||||
await _redis.Connection.GetDatabase().StringSetAsync(messageId.ToString(), $"{authorId}-{channelId}-{guildId}", expiry: CommandMessageRetention);
|
||||
}
|
||||
|
||||
public async Task<CommandMessage?> GetCommandMessage(ulong messageId)
|
||||
{
|
||||
var repoMsg = await _repo.GetCommandMessage(messageId);
|
||||
if (repoMsg != null)
|
||||
return new CommandMessage(repoMsg.Sender, repoMsg.Channel, repoMsg.Guild);
|
||||
|
||||
var str = await _redis.Connection.GetDatabase().StringGetAsync(messageId.ToString());
|
||||
if (str.HasValue)
|
||||
{
|
||||
var split = ((string)str).Split("-");
|
||||
return new CommandMessage(ulong.Parse(split[0]), ulong.Parse(split[1]), ulong.Parse(split[2]));
|
||||
}
|
||||
str = await _redis.Connection.GetDatabase().StringGetAsync("command_message:" + messageId.ToString());
|
||||
if (str.HasValue)
|
||||
{
|
||||
var split = ((string)str).Split("-");
|
||||
return new CommandMessage(ulong.Parse(split[0]), ulong.Parse(split[1]), ulong.Parse(split[2]));
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -420,6 +420,24 @@ public class EmbedService
|
|||
return eb.Build();
|
||||
}
|
||||
|
||||
public async Task<Embed> CreateCommandMessageInfoEmbed(Core.CommandMessage msg, bool showContent)
|
||||
{
|
||||
var content = "*(command message deleted or inaccessible)*";
|
||||
if (showContent)
|
||||
{
|
||||
var discordMessage = await _rest.GetMessageOrNull(msg.Channel, msg.OriginalMid);
|
||||
if (discordMessage != null)
|
||||
content = discordMessage.Content;
|
||||
}
|
||||
|
||||
return new EmbedBuilder()
|
||||
.Title("Command response message")
|
||||
.Description(content)
|
||||
.Field(new("Original message", $"https://discord.com/channels/{msg.Guild}/{msg.Channel}/{msg.OriginalMid}", true))
|
||||
.Field(new("Sent by", $"<@{msg.Sender}>", true))
|
||||
.Build();
|
||||
}
|
||||
|
||||
public Task<Embed> CreateFrontPercentEmbed(FrontBreakdown breakdown, PKSystem system, PKGroup group,
|
||||
DateTimeZone tz, LookupContext ctx, string embedTitle,
|
||||
bool ignoreNoFronters, bool showFlat)
|
||||
|
|
|
|||
146
PluralKit.Bot/Services/HttpListenerService.cs
Normal file
146
PluralKit.Bot/Services/HttpListenerService.cs
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
using System.Text;
|
||||
using System.Text.Json;
|
||||
|
||||
using Serilog;
|
||||
|
||||
using WatsonWebserver.Lite;
|
||||
using WatsonWebserver.Core;
|
||||
|
||||
using Myriad.Gateway;
|
||||
using Myriad.Serialization;
|
||||
|
||||
namespace PluralKit.Bot;
|
||||
|
||||
public class HttpListenerService
|
||||
{
|
||||
private readonly ILogger _logger;
|
||||
private readonly RuntimeConfigService _runtimeConfig;
|
||||
private readonly Bot _bot;
|
||||
|
||||
public HttpListenerService(ILogger logger, RuntimeConfigService runtimeConfig, Bot bot)
|
||||
{
|
||||
_logger = logger.ForContext<HttpListenerService>();
|
||||
_runtimeConfig = runtimeConfig;
|
||||
_bot = bot;
|
||||
}
|
||||
|
||||
public void Start(string host)
|
||||
{
|
||||
var hosts = new[] { host };
|
||||
if (host == "allv4v6")
|
||||
{
|
||||
hosts = new[] { "[::]", "0.0.0.0" };
|
||||
}
|
||||
foreach (var h in hosts)
|
||||
{
|
||||
var server = new WebserverLite(new WebserverSettings(h, 5002), DefaultRoute);
|
||||
|
||||
server.Routes.PreAuthentication.Static.Add(WatsonWebserver.Core.HttpMethod.GET, "/runtime_config", RuntimeConfigGet);
|
||||
server.Routes.PreAuthentication.Parameter.Add(WatsonWebserver.Core.HttpMethod.POST, "/runtime_config/{key}", RuntimeConfigSet);
|
||||
server.Routes.PreAuthentication.Parameter.Add(WatsonWebserver.Core.HttpMethod.DELETE, "/runtime_config/{key}", RuntimeConfigDelete);
|
||||
|
||||
server.Routes.PreAuthentication.Parameter.Add(WatsonWebserver.Core.HttpMethod.POST, "/events/{shard_id}", GatewayEvent);
|
||||
|
||||
server.Start();
|
||||
}
|
||||
}
|
||||
|
||||
private async Task DefaultRoute(HttpContextBase ctx)
|
||||
=> await ctx.Response.Send("hellorld");
|
||||
|
||||
private async Task RuntimeConfigGet(HttpContextBase ctx)
|
||||
{
|
||||
var config = _runtimeConfig.GetAll();
|
||||
ctx.Response.Headers.Add("content-type", "application/json");
|
||||
await ctx.Response.Send(JsonSerializer.Serialize(config));
|
||||
}
|
||||
|
||||
private async Task RuntimeConfigSet(HttpContextBase ctx)
|
||||
{
|
||||
var key = ctx.Request.Url.Parameters["key"];
|
||||
var value = ReadStream(ctx.Request.Data, ctx.Request.ContentLength);
|
||||
await _runtimeConfig.Set(key, value);
|
||||
await RuntimeConfigGet(ctx);
|
||||
}
|
||||
|
||||
private async Task RuntimeConfigDelete(HttpContextBase ctx)
|
||||
{
|
||||
var key = ctx.Request.Url.Parameters["key"];
|
||||
await _runtimeConfig.Delete(key);
|
||||
await RuntimeConfigGet(ctx);
|
||||
}
|
||||
|
||||
private JsonSerializerOptions _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
|
||||
|
||||
private async Task GatewayEvent(HttpContextBase ctx)
|
||||
{
|
||||
var shardIdString = ctx.Request.Url.Parameters["shard_id"];
|
||||
if (!int.TryParse(shardIdString, out var shardId)) return;
|
||||
|
||||
var packet = JsonSerializer.Deserialize<GatewayPacket>(ReadStream(ctx.Request.Data, ctx.Request.ContentLength), _jsonSerializerOptions);
|
||||
var evt = DeserializeEvent(shardId, packet.EventType!, (JsonElement)packet.Payload!);
|
||||
if (evt != null)
|
||||
{
|
||||
await _bot.OnEventReceivedInner(shardId, evt);
|
||||
}
|
||||
await ctx.Response.Send("a");
|
||||
}
|
||||
|
||||
private IGatewayEvent? DeserializeEvent(int shardId, string eventType, JsonElement payload)
|
||||
{
|
||||
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
|
||||
{
|
||||
_logger.Debug("Shard {ShardId}: Received unknown event type {EventType}", shardId, eventType);
|
||||
return null;
|
||||
}
|
||||
|
||||
try
|
||||
{
|
||||
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", shardId, eventType,
|
||||
clrType);
|
||||
return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions)
|
||||
as IGatewayEvent;
|
||||
}
|
||||
catch (JsonException e)
|
||||
{
|
||||
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", shardId,
|
||||
eventType, clrType);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
//temporary re-implementation of the ReadStream function found in WatsonWebserver.Lite, but with handling for closed connections
|
||||
//https://github.com/dotnet/WatsonWebserver/issues/171
|
||||
private static string ReadStream(Stream input, long contentLength)
|
||||
{
|
||||
if (input == null) throw new ArgumentNullException(nameof(input));
|
||||
if (!input.CanRead) throw new InvalidOperationException("Input stream is not readable");
|
||||
if (contentLength < 1) return "";
|
||||
|
||||
byte[] buffer = new byte[65536];
|
||||
long bytesRemaining = contentLength;
|
||||
|
||||
using (MemoryStream ms = new MemoryStream())
|
||||
{
|
||||
int read;
|
||||
|
||||
while (bytesRemaining > 0)
|
||||
{
|
||||
read = input.Read(buffer, 0, buffer.Length);
|
||||
if (read > 0)
|
||||
{
|
||||
ms.Write(buffer, 0, read);
|
||||
bytesRemaining -= read;
|
||||
}
|
||||
else
|
||||
{
|
||||
throw new IOException("Connection closed before reading end of stream.");
|
||||
}
|
||||
}
|
||||
|
||||
if (ms.Length < 1) return null;
|
||||
var str = Encoding.Default.GetString(ms.ToArray());
|
||||
return str;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,7 @@
|
|||
using System.Collections.Concurrent;
|
||||
|
||||
using Myriad.Cache;
|
||||
|
||||
using NodaTime;
|
||||
|
||||
using Serilog;
|
||||
|
|
@ -16,9 +18,12 @@ public class InteractionDispatchService: IDisposable
|
|||
private readonly ConcurrentDictionary<Guid, RegisteredInteraction> _handlers = new();
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public InteractionDispatchService(IClock clock, ILogger logger)
|
||||
private readonly IDiscordCache _cache;
|
||||
|
||||
public InteractionDispatchService(IClock clock, ILogger logger, IDiscordCache cache)
|
||||
{
|
||||
_clock = clock;
|
||||
_cache = cache;
|
||||
_logger = logger.ForContext<InteractionDispatchService>();
|
||||
|
||||
_cleanupWorker = CleanupLoop(_cts.Token);
|
||||
|
|
@ -50,9 +55,15 @@ public class InteractionDispatchService: IDisposable
|
|||
_handlers.TryRemove(customIdGuid, out _);
|
||||
}
|
||||
|
||||
public string Register(Func<InteractionContext, Task> callback, Duration? expiry = null)
|
||||
public string Register(int shardId, Func<InteractionContext, Task> callback, Duration? expiry = null)
|
||||
{
|
||||
var key = Guid.NewGuid();
|
||||
|
||||
// if http_cache, return RegisterRemote
|
||||
// not awaited here, it's probably fine
|
||||
if (_cache is HttpDiscordCache)
|
||||
(_cache as HttpDiscordCache).AwaitInteraction(shardId, key.ToString(), expiry);
|
||||
|
||||
var handler = new RegisteredInteraction
|
||||
{
|
||||
Callback = callback,
|
||||
|
|
|
|||
|
|
@ -1,6 +1,7 @@
|
|||
#nullable enable
|
||||
using System.Collections.Concurrent;
|
||||
|
||||
using Myriad.Cache;
|
||||
using Myriad.Types;
|
||||
|
||||
namespace PluralKit.Bot;
|
||||
|
|
@ -9,9 +10,18 @@ public class LastMessageCacheService
|
|||
{
|
||||
private readonly IDictionary<ulong, CacheEntry> _cache = new ConcurrentDictionary<ulong, CacheEntry>();
|
||||
|
||||
private readonly IDiscordCache _maybeHttp;
|
||||
|
||||
public LastMessageCacheService(IDiscordCache cache)
|
||||
{
|
||||
_maybeHttp = cache;
|
||||
}
|
||||
|
||||
public void AddMessage(Message msg)
|
||||
{
|
||||
var previous = GetLastMessage(msg.ChannelId);
|
||||
if (_maybeHttp is HttpDiscordCache) return;
|
||||
|
||||
var previous = _GetLastMessage(msg.ChannelId);
|
||||
var current = ToCachedMessage(msg);
|
||||
_cache[msg.ChannelId] = new CacheEntry(current, previous?.Current);
|
||||
}
|
||||
|
|
@ -19,12 +29,26 @@ public class LastMessageCacheService
|
|||
private CachedMessage ToCachedMessage(Message msg) =>
|
||||
new(msg.Id, msg.ReferencedMessage.Value?.Id, msg.Author.Username);
|
||||
|
||||
public CacheEntry? GetLastMessage(ulong channel) =>
|
||||
_cache.TryGetValue(channel, out var message) ? message : null;
|
||||
public async Task<CacheEntry?> GetLastMessage(ulong guild, ulong channel)
|
||||
{
|
||||
if (_maybeHttp is HttpDiscordCache)
|
||||
return await (_maybeHttp as HttpDiscordCache).GetLastMessage<CacheEntry>(guild, channel);
|
||||
|
||||
return _cache.TryGetValue(channel, out var message) ? message : null;
|
||||
}
|
||||
|
||||
public CacheEntry? _GetLastMessage(ulong channel)
|
||||
{
|
||||
if (_maybeHttp is HttpDiscordCache) return null;
|
||||
|
||||
return _cache.TryGetValue(channel, out var message) ? message : null;
|
||||
}
|
||||
|
||||
public void HandleMessageDeletion(ulong channel, ulong message)
|
||||
{
|
||||
var storedMessage = GetLastMessage(channel);
|
||||
if (_maybeHttp is HttpDiscordCache) return;
|
||||
|
||||
var storedMessage = _GetLastMessage(channel);
|
||||
if (storedMessage == null)
|
||||
return;
|
||||
|
||||
|
|
@ -39,7 +63,9 @@ public class LastMessageCacheService
|
|||
|
||||
public void HandleMessageDeletion(ulong channel, List<ulong> messages)
|
||||
{
|
||||
var storedMessage = GetLastMessage(channel);
|
||||
if (_maybeHttp is HttpDiscordCache) return;
|
||||
|
||||
var storedMessage = _GetLastMessage(channel);
|
||||
if (storedMessage == null)
|
||||
return;
|
||||
|
||||
|
|
|
|||
|
|
@ -45,6 +45,7 @@ public class LoggerCleanService
|
|||
private static readonly Regex _AnnabelleRegex = new("```\n(\\d{17,19})\n```");
|
||||
private static readonly Regex _AnnabelleRegexFuzzy = new("\\<t:(\\d+)\\> A message from \\*\\*[\\w.]{2,32}\\*\\* \\(`(\\d{17,19})`\\) was deleted in <#\\d{17,19}>");
|
||||
private static readonly Regex _koiraRegex = new("ID:\\*\\* (\\d{17,19})");
|
||||
private static readonly Regex _zeppelinRegex = new("🗑 Message \\(`(\\d{17,19})`\\)");
|
||||
|
||||
private static readonly Regex _VortexRegex =
|
||||
new("`\\[(\\d\\d:\\d\\d:\\d\\d)\\]` .* \\(ID:(\\d{17,19})\\).* <#\\d{17,19}>:");
|
||||
|
|
@ -83,7 +84,8 @@ public class LoggerCleanService
|
|||
new LoggerBot("Dozer", 356535250932858885, ExtractDozer),
|
||||
new LoggerBot("Skyra", 266624760782258186, ExtractSkyra),
|
||||
new LoggerBot("Annabelle", 231241068383961088, ExtractAnnabelle, fuzzyExtractFunc: ExtractAnnabelleFuzzy),
|
||||
new LoggerBot("Koira", 1247013404569239624, ExtractKoira)
|
||||
new LoggerBot("Koira", 1247013404569239624, ExtractKoira),
|
||||
new LoggerBot("Zeppelin", 473868086773153793, ExtractZeppelin) // webhook
|
||||
}.ToDictionary(b => b.Id);
|
||||
|
||||
private static Dictionary<ulong, LoggerBot> _botsByApplicationId
|
||||
|
|
@ -441,6 +443,23 @@ public class LoggerCleanService
|
|||
return match.Success ? ulong.Parse(match.Groups[1].Value) : null;
|
||||
}
|
||||
|
||||
private static ulong? ExtractZeppelin(Message msg)
|
||||
{
|
||||
// zeppelin uses a non-embed format by default but can be configured to use a customizable embed
|
||||
// if it's an embed, assume the footer contains the message ID
|
||||
var embed = msg.Embeds?.FirstOrDefault();
|
||||
if (embed == null)
|
||||
{
|
||||
var match = _zeppelinRegex.Match(msg.Content ?? "");
|
||||
return match.Success ? ulong.Parse(match.Groups[1].Value) : null;
|
||||
}
|
||||
else
|
||||
{
|
||||
var match = _basicRegex.Match(embed.Footer?.Text ?? "");
|
||||
return match.Success ? ulong.Parse(match.Groups[1].Value) : null;
|
||||
}
|
||||
}
|
||||
|
||||
public class LoggerBot
|
||||
{
|
||||
public ulong Id;
|
||||
|
|
|
|||
58
PluralKit.Bot/Services/RuntimeConfigService.cs
Normal file
58
PluralKit.Bot/Services/RuntimeConfigService.cs
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
using Newtonsoft.Json;
|
||||
|
||||
using Serilog;
|
||||
|
||||
using StackExchange.Redis;
|
||||
|
||||
using PluralKit.Core;
|
||||
|
||||
namespace PluralKit.Bot;
|
||||
|
||||
public class RuntimeConfigService
|
||||
{
|
||||
private readonly RedisService _redis;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
private Dictionary<string, string> settings = new();
|
||||
|
||||
private string RedisKey;
|
||||
|
||||
public RuntimeConfigService(ILogger logger, RedisService redis, BotConfig config)
|
||||
{
|
||||
_logger = logger.ForContext<RuntimeConfigService>();
|
||||
_redis = redis;
|
||||
|
||||
var clusterId = config.Cluster?.NodeIndex ?? 0;
|
||||
RedisKey = $"remote_config:dotnet_bot:{clusterId}";
|
||||
}
|
||||
|
||||
public async Task LoadConfig()
|
||||
{
|
||||
var redisConfig = await _redis.Connection.GetDatabase().HashGetAllAsync(RedisKey);
|
||||
foreach (var entry in redisConfig)
|
||||
settings.Add(entry.Name, entry.Value);
|
||||
|
||||
var configStr = JsonConvert.SerializeObject(settings);
|
||||
_logger.Information($"starting with runtime config: {configStr}");
|
||||
}
|
||||
|
||||
public async Task Set(string key, string value)
|
||||
{
|
||||
await _redis.Connection.GetDatabase().HashSetAsync(RedisKey, new[] { new HashEntry(key, new RedisValue(value)) });
|
||||
settings.Add(key, value);
|
||||
_logger.Information($"updated runtime config: {key}={value}");
|
||||
}
|
||||
|
||||
public async Task Delete(string key)
|
||||
{
|
||||
await _redis.Connection.GetDatabase().HashDeleteAsync(RedisKey, key);
|
||||
settings.Remove(key);
|
||||
_logger.Information($"updated runtime config: {key} removed");
|
||||
}
|
||||
|
||||
public object? Get(string key) => settings.GetValueOrDefault(key);
|
||||
|
||||
public bool Exists(string key) => settings.ContainsKey(key);
|
||||
|
||||
public Dictionary<string, string> GetAll() => settings;
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
using Autofac;
|
||||
|
||||
using Myriad.Builders;
|
||||
using Myriad.Cache;
|
||||
using Myriad.Gateway;
|
||||
using Myriad.Rest.Exceptions;
|
||||
using Myriad.Rest.Types.Requests;
|
||||
|
|
@ -40,8 +41,12 @@ public static class ContextUtils
|
|||
}
|
||||
|
||||
public static async Task<MessageReactionAddEvent> AwaitReaction(this Context ctx, Message message,
|
||||
User user = null, Func<MessageReactionAddEvent, bool> predicate = null, Duration? timeout = null)
|
||||
User user, Func<MessageReactionAddEvent, bool> predicate = null, Duration? timeout = null)
|
||||
{
|
||||
// check if http gateway and set listener
|
||||
if (ctx.Cache is HttpDiscordCache)
|
||||
await (ctx.Cache as HttpDiscordCache).AwaitReaction(ctx.Guild?.Id ?? 0, message.Id, user!.Id, timeout);
|
||||
|
||||
bool ReactionPredicate(MessageReactionAddEvent evt)
|
||||
{
|
||||
if (message.Id != evt.MessageId) return false; // Ignore reactions for different messages
|
||||
|
|
@ -57,11 +62,17 @@ public static class ContextUtils
|
|||
|
||||
public static async Task<bool> ConfirmWithReply(this Context ctx, string expectedReply, bool treatAsHid = false)
|
||||
{
|
||||
var timeout = Duration.FromMinutes(1);
|
||||
|
||||
// check if http gateway and set listener
|
||||
if (ctx.Cache is HttpDiscordCache)
|
||||
await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout);
|
||||
|
||||
bool Predicate(MessageCreateEvent e) =>
|
||||
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
|
||||
|
||||
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
|
||||
.WaitFor(Predicate, Duration.FromMinutes(1));
|
||||
.WaitFor(Predicate, timeout);
|
||||
|
||||
var content = msg.Content;
|
||||
if (treatAsHid)
|
||||
|
|
@ -96,11 +107,17 @@ public static class ContextUtils
|
|||
|
||||
async Task<int> PromptPageNumber()
|
||||
{
|
||||
var timeout = Duration.FromMinutes(0.5);
|
||||
|
||||
// check if http gateway and set listener
|
||||
if (ctx.Cache is HttpDiscordCache)
|
||||
await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout);
|
||||
|
||||
bool Predicate(MessageCreateEvent e) =>
|
||||
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
|
||||
|
||||
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
|
||||
.WaitFor(Predicate, Duration.FromMinutes(0.5));
|
||||
.WaitFor(Predicate, timeout);
|
||||
|
||||
int.TryParse(msg.Content, out var num);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,8 @@
|
|||
using System.Net;
|
||||
using System.Net.Sockets;
|
||||
|
||||
using System.Globalization;
|
||||
using Myriad.Rest.Exceptions;
|
||||
using Myriad.Rest.Types;
|
||||
|
||||
using Newtonsoft.Json;
|
||||
|
||||
|
|
@ -102,4 +103,26 @@ public static class MiscUtils
|
|||
|
||||
return true;
|
||||
}
|
||||
|
||||
public static MultipartFile GenerateColorPreview(string color)
|
||||
{
|
||||
//generate a 128x128 solid color gif from bytes
|
||||
//image data is a 1x1 pixel, using the background color to fill the rest of the canvas
|
||||
var imgBytes = new byte[]
|
||||
{
|
||||
0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // Header
|
||||
0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x00, // Logical Screen Descriptor
|
||||
0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, // Global Color Table
|
||||
0x21, 0xF9, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, // Graphics Control Extension
|
||||
0x2C, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, // Image Descriptor
|
||||
0x02, 0x02, 0x4C, 0x01, 0x00, // Image Data
|
||||
0x3B // Trailer
|
||||
}; //indices 13, 14 and 15 are the R, G, and B values respectively
|
||||
|
||||
imgBytes[13] = byte.Parse(color.Substring(0, 2), NumberStyles.HexNumber);
|
||||
imgBytes[14] = byte.Parse(color.Substring(2, 2), NumberStyles.HexNumber);
|
||||
imgBytes[15] = byte.Parse(color.Substring(4, 2), NumberStyles.HexNumber);
|
||||
|
||||
return new MultipartFile("color.gif", new MemoryStream(imgBytes), null, null, null);
|
||||
}
|
||||
}
|
||||
|
|
@ -14,6 +14,16 @@
|
|||
"resolved": "4.13.0",
|
||||
"contentHash": "Wfw3M1WpFcrYaGzPm7QyUTfIOYkVXQ1ry6p4WYjhbLz9fPwV23SGQZTFDpdox67NHM0V0g1aoQ4YKLm4ANtEEg=="
|
||||
},
|
||||
"Watson.Lite": {
|
||||
"type": "Direct",
|
||||
"requested": "[6.3.5, )",
|
||||
"resolved": "6.3.5",
|
||||
"contentHash": "YF8+se3IVenn8YlyNeb4wSJK6QMnVD0QHIOEiZ22wS4K2wkwoSDzWS+ZAjk1MaPeB+XO5gRoENUN//pOc+wI2g==",
|
||||
"dependencies": {
|
||||
"CavemanTcp": "2.0.5",
|
||||
"Watson.Core": "6.3.5"
|
||||
}
|
||||
},
|
||||
"App.Metrics": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.3.0",
|
||||
|
|
@ -107,6 +117,11 @@
|
|||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.1"
|
||||
}
|
||||
},
|
||||
"CavemanTcp": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.0.5",
|
||||
"contentHash": "90wywmGpjrj26HMAkufYZwuZI8sVYB1mRwEdqugSR3kgDnPX+3l0jO86gwtFKsPvsEpsS4Dn/1EbhguzUxMU8Q=="
|
||||
},
|
||||
"Dapper": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.1.35",
|
||||
|
|
@ -130,6 +145,11 @@
|
|||
"System.Diagnostics.DiagnosticSource": "5.0.0"
|
||||
}
|
||||
},
|
||||
"IpMatcher": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.0.5",
|
||||
"contentHash": "WXNlWERj+0GN699AnMNsuJ7PfUAbU4xhOHP3nrNXLHqbOaBxybu25luSYywX1133NSlitA4YkSNmJuyPvea4sw=="
|
||||
},
|
||||
"IPNetwork2": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.667",
|
||||
|
|
@ -391,18 +411,18 @@
|
|||
"resolved": "8.5.0",
|
||||
"contentHash": "VYYMZNitZ85UEhwOKkTQI63WEMvzUqwQc74I2mm8h/DBVAMcBBxqYPni4DmuRtbCwngmuONuK2yBJfWNRKzI+A=="
|
||||
},
|
||||
"Serilog": {
|
||||
"RegexMatcher": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.2.0",
|
||||
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
|
||||
"resolved": "1.0.9",
|
||||
"contentHash": "RkQGXIrqHjD5h1mqefhgCbkaSdRYNRG5rrbzyw5zeLWiS0K1wq9xR3cNhQdzYR2MsKZ3GN523yRUsEQIMPxh3Q=="
|
||||
},
|
||||
"Serilog.Extensions.Logging": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Logging": "9.0.0",
|
||||
"Serilog": "4.2.0"
|
||||
"Microsoft.Extensions.Logging": "8.0.0",
|
||||
"Serilog": "3.1.1"
|
||||
}
|
||||
},
|
||||
"Serilog.Formatting.Compact": {
|
||||
|
|
@ -714,12 +734,36 @@
|
|||
"System.Runtime": "4.3.0"
|
||||
}
|
||||
},
|
||||
"Timestamps": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.0.11",
|
||||
"contentHash": "SnWhXm3FkEStQGgUTfWMh9mKItNW032o/v8eAtFrOGqG0/ejvPPA1LdLZx0N/qqoY0TH3x11+dO00jeVcM8xNQ=="
|
||||
},
|
||||
"UrlMatcher": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.1",
|
||||
"contentHash": "hHBZVzFSfikrx4XsRsnCIwmGLgbNKtntnlqf4z+ygcNA6Y/L/J0x5GiZZWfXdTfpxhy5v7mlt2zrZs/L9SvbOA=="
|
||||
},
|
||||
"Watson.Core": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.3.5",
|
||||
"contentHash": "Y5YxKOCSLe2KDmfwvI/J0qApgmmZR77LwyoufRVfKH7GLdHiE7fY0IfoNxWTG7nNv8knBfgwyOxdehRm+4HaCg==",
|
||||
"dependencies": {
|
||||
"IpMatcher": "1.0.5",
|
||||
"RegexMatcher": "1.0.9",
|
||||
"System.Text.Json": "8.0.5",
|
||||
"Timestamps": "1.0.11",
|
||||
"UrlMatcher": "3.0.1"
|
||||
}
|
||||
},
|
||||
"myriad": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"NodaTime": "[3.2.0, )",
|
||||
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||
"Polly": "[8.5.0, )",
|
||||
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
|
||||
"Serilog": "[4.2.0, )",
|
||||
"Serilog": "[4.1.0, )",
|
||||
"StackExchange.Redis": "[2.8.22, )",
|
||||
"System.Linq.Async": "[6.0.1, )"
|
||||
}
|
||||
|
|
@ -747,8 +791,8 @@
|
|||
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||
"Npgsql": "[9.0.2, )",
|
||||
"Npgsql.NodaTime": "[9.0.2, )",
|
||||
"Serilog": "[4.2.0, )",
|
||||
"Serilog.Extensions.Logging": "[9.0.0, )",
|
||||
"Serilog": "[4.1.0, )",
|
||||
"Serilog.Extensions.Logging": "[8.0.0, )",
|
||||
"Serilog.Formatting.Compact": "[3.0.0, )",
|
||||
"Serilog.NodaTime": "[3.0.0, )",
|
||||
"Serilog.Sinks.Async": "[2.1.0, )",
|
||||
|
|
@ -762,6 +806,9 @@
|
|||
"System.Interactive.Async": "[6.0.1, )",
|
||||
"ipnetwork2": "[3.0.667, )"
|
||||
}
|
||||
},
|
||||
"serilog": {
|
||||
"type": "Project"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -19,5 +19,4 @@ public class CoreConfig
|
|||
|
||||
public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Debug;
|
||||
public LogEventLevel ElasticLogLevel { get; set; } = LogEventLevel.Information;
|
||||
public LogEventLevel FileLogLevel { get; set; } = LogEventLevel.Information;
|
||||
}
|
||||
|
|
@ -19,16 +19,40 @@ public partial class ModelRepository
|
|||
}
|
||||
|
||||
|
||||
public Task<SystemGuildSettings> GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true)
|
||||
public async Task<SystemGuildSettings> GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true, bool search = false)
|
||||
{
|
||||
if (!defaultInsert)
|
||||
return _db.QueryFirst<SystemGuildSettings>(new Query("system_guild")
|
||||
{
|
||||
var simpleRes = await _db.QueryFirst<SystemGuildSettings>(new Query("system_guild")
|
||||
.Where("guild", guild)
|
||||
.Where("system", system)
|
||||
);
|
||||
if (simpleRes != null || !search)
|
||||
return simpleRes;
|
||||
|
||||
var accounts = await GetSystemAccounts(system);
|
||||
|
||||
var searchRes = await _db.QueryFirst<bool>(
|
||||
"select exists(select 1 from command_messages where guild = @guild and sender = any(@accounts))",
|
||||
new { guild = guild, accounts = accounts.Select(u => (long)u).ToArray() },
|
||||
queryName: "find_system_from_commands",
|
||||
messages: true
|
||||
);
|
||||
|
||||
if (!searchRes)
|
||||
searchRes = await _db.QueryFirst<bool>(
|
||||
"select exists(select 1 from command_messages where guild = @guild and sender = any(@accounts))",
|
||||
new { guild = guild, accounts = accounts.Select(u => (long)u).ToArray() },
|
||||
queryName: "find_system_from_messages",
|
||||
messages: true
|
||||
);
|
||||
|
||||
if (!searchRes)
|
||||
return null;
|
||||
}
|
||||
|
||||
var query = new Query("system_guild").AsInsert(new { guild, system });
|
||||
return _db.QueryFirst<SystemGuildSettings>(query,
|
||||
return await _db.QueryFirst<SystemGuildSettings>(query,
|
||||
"on conflict (guild, system) do update set guild = $1, system = $2 returning *"
|
||||
);
|
||||
}
|
||||
|
|
@ -42,16 +66,25 @@ public partial class ModelRepository
|
|||
return settings;
|
||||
}
|
||||
|
||||
public Task<MemberGuildSettings> GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true)
|
||||
public async Task<MemberGuildSettings> GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true, SystemId? search = null)
|
||||
{
|
||||
if (!defaultInsert)
|
||||
return _db.QueryFirst<MemberGuildSettings>(new Query("member_guild")
|
||||
{
|
||||
var simpleRes = await _db.QueryFirst<MemberGuildSettings>(new Query("member_guild")
|
||||
.Where("guild", guild)
|
||||
.Where("member", member)
|
||||
);
|
||||
if (simpleRes != null || !search.HasValue)
|
||||
return simpleRes;
|
||||
|
||||
var systemConfig = await GetSystemGuild(guild, search.Value, defaultInsert: false, search: true);
|
||||
|
||||
if (systemConfig == null)
|
||||
return null;
|
||||
}
|
||||
|
||||
var query = new Query("member_guild").AsInsert(new { guild, member });
|
||||
return _db.QueryFirst<MemberGuildSettings>(query,
|
||||
return await _db.QueryFirst<MemberGuildSettings>(query,
|
||||
"on conflict (guild, member) do update set guild = $1, member = $2 returning *"
|
||||
);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -42,12 +42,37 @@ public partial class ModelRepository
|
|||
};
|
||||
}
|
||||
|
||||
public async Task AddCommandMessage(CommandMessage msg)
|
||||
{
|
||||
var query = new Query("command_messages").AsInsert(new
|
||||
{
|
||||
mid = msg.Mid,
|
||||
guild = msg.Guild,
|
||||
channel = msg.Channel,
|
||||
sender = msg.Sender,
|
||||
original_mid = msg.OriginalMid
|
||||
});
|
||||
await _db.ExecuteQuery(query, messages: true);
|
||||
|
||||
_logger.Debug("Stored command message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
|
||||
}
|
||||
|
||||
public Task<CommandMessage?> GetCommandMessage(ulong id)
|
||||
=> _db.QueryFirst<CommandMessage?>(new Query("command_messages").Where("mid", id), messages: true);
|
||||
|
||||
public async Task DeleteMessage(ulong id)
|
||||
{
|
||||
var query = new Query("messages").AsDelete().Where("mid", id);
|
||||
var rowCount = await _db.ExecuteQuery(query, messages: true);
|
||||
if (rowCount > 0)
|
||||
_logger.Information("Deleted message {MessageId} from database", id);
|
||||
else
|
||||
{
|
||||
var cquery = new Query("command_messages").AsDelete().Where("mid", id);
|
||||
var crowCount = await _db.ExecuteQuery(query, messages: true);
|
||||
if (crowCount > 0)
|
||||
_logger.Information("Deleted command message {MessageId} from database", id);
|
||||
}
|
||||
}
|
||||
|
||||
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
|
||||
|
|
@ -59,5 +84,19 @@ public partial class ModelRepository
|
|||
if (rowCount > 0)
|
||||
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
|
||||
ids);
|
||||
var cquery = new Query("command_messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray());
|
||||
var crowCount = await _db.ExecuteQuery(query, messages: true);
|
||||
if (crowCount > 0)
|
||||
_logger.Information("Bulk deleted command messages ({FoundCount} found) from database: {MessageIds}", rowCount,
|
||||
ids);
|
||||
}
|
||||
}
|
||||
|
||||
public class CommandMessage
|
||||
{
|
||||
public ulong Mid { get; set; }
|
||||
public ulong Guild { get; set; }
|
||||
public ulong Channel { get; set; }
|
||||
public ulong Sender { get; set; }
|
||||
public ulong OriginalMid { get; set; }
|
||||
}
|
||||
|
|
@ -9,7 +9,7 @@ namespace PluralKit.Core;
|
|||
internal class DatabaseMigrator
|
||||
{
|
||||
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
|
||||
private const int TargetSchemaVersion = 51;
|
||||
private const int TargetSchemaVersion = 52;
|
||||
private readonly ILogger _logger;
|
||||
|
||||
public DatabaseMigrator(ILogger logger)
|
||||
|
|
|
|||
|
|
@ -10,8 +10,7 @@ using NodaTime;
|
|||
|
||||
using Serilog;
|
||||
using Serilog.Events;
|
||||
using Serilog.Formatting.Compact;
|
||||
using Serilog.Sinks.Seq;
|
||||
using Serilog.Formatting.Json;
|
||||
using Serilog.Sinks.SystemConsole.Themes;
|
||||
|
||||
using ILogger = Serilog.ILogger;
|
||||
|
|
@ -50,14 +49,9 @@ public class LoggingModule: Module
|
|||
|
||||
private ILogger InitLogger(CoreConfig config)
|
||||
{
|
||||
var consoleTemplate = "[{Timestamp:HH:mm:ss.fff}] {Level:u3} {Message:lj}{NewLine}{Exception}";
|
||||
var outputTemplate = "[{Timestamp:yyyy-MM-dd HH:mm:ss.ffffff}] {Level:u3} {Message:lj}{NewLine}{Exception}";
|
||||
|
||||
var logCfg = _cfg
|
||||
.Enrich.FromLogContext()
|
||||
.Enrich.WithProperty("GitCommitHash", BuildInfoService.FullVersion)
|
||||
.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb)
|
||||
.Enrich.WithProperty("Component", _component)
|
||||
.MinimumLevel.Is(config.ConsoleLogLevel)
|
||||
|
||||
// Don't want App.Metrics/D#+ spam
|
||||
|
|
@ -73,35 +67,10 @@ public class LoggingModule: Module
|
|||
.Destructure.AsScalar<SwitchId>()
|
||||
.Destructure.ByTransforming<ProxyTag>(t => new { t.Prefix, t.Suffix })
|
||||
.Destructure.With<PatchObjectDestructuring>()
|
||||
.WriteTo.Async(a =>
|
||||
{
|
||||
// Both the same output, except one is raw compact JSON and one is plain text.
|
||||
// Output simultaneously. May remove the JSON formatter later, keeping it just in cast.
|
||||
// Flush interval is 50ms (down from 10s) to make "tail -f" easier. May be too low?
|
||||
a.File(
|
||||
(config.LogDir ?? "logs") + $"/pluralkit.{_component}.log",
|
||||
outputTemplate: outputTemplate,
|
||||
retainedFileCountLimit: 10,
|
||||
rollingInterval: RollingInterval.Day,
|
||||
fileSizeLimitBytes: null,
|
||||
flushToDiskInterval: TimeSpan.FromMilliseconds(50),
|
||||
restrictedToMinimumLevel: config.FileLogLevel,
|
||||
formatProvider: new UTCTimestampFormatProvider(),
|
||||
buffered: true);
|
||||
|
||||
a.File(
|
||||
new RenderedCompactJsonFormatter(new ScalarFormatting.JsonValue()),
|
||||
(config.LogDir ?? "logs") + $"/pluralkit.{_component}.json",
|
||||
rollingInterval: RollingInterval.Day,
|
||||
flushToDiskInterval: TimeSpan.FromMilliseconds(50),
|
||||
restrictedToMinimumLevel: config.FileLogLevel,
|
||||
buffered: true);
|
||||
})
|
||||
.WriteTo.Async(a =>
|
||||
a.Console(
|
||||
theme: AnsiConsoleTheme.Code,
|
||||
outputTemplate: consoleTemplate,
|
||||
restrictedToMinimumLevel: config.ConsoleLogLevel));
|
||||
new CustomJsonFormatter(_component),
|
||||
config.ConsoleLogLevel));
|
||||
|
||||
if (config.ElasticUrl != null)
|
||||
{
|
||||
|
|
@ -113,15 +82,6 @@ public class LoggingModule: Module
|
|||
);
|
||||
}
|
||||
|
||||
if (config.SeqLogUrl != null)
|
||||
{
|
||||
logCfg.WriteTo.Seq(
|
||||
config.SeqLogUrl,
|
||||
restrictedToMinimumLevel: LogEventLevel.Verbose
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
_fn.Invoke(logCfg);
|
||||
return Log.Logger = logCfg.CreateLogger();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -15,6 +15,10 @@
|
|||
<RestorePackagesWithLockFile>true</RestorePackagesWithLockFile>
|
||||
</PropertyGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<ProjectReference Include="..\Serilog\src\Serilog\Serilog.csproj" />
|
||||
</ItemGroup>
|
||||
|
||||
<ItemGroup>
|
||||
<PackageReference Include="App.Metrics" Version="4.3.0" />
|
||||
<PackageReference Include="App.Metrics.Reporting.InfluxDB" Version="4.3.0" />
|
||||
|
|
@ -37,8 +41,7 @@
|
|||
<PackageReference Include="NodaTime.Serialization.JsonNet" Version="3.1.0" />
|
||||
<PackageReference Include="Npgsql" Version="9.0.2" />
|
||||
<PackageReference Include="Npgsql.NodaTime" Version="9.0.2" />
|
||||
<PackageReference Include="Serilog" Version="4.2.0" />
|
||||
<PackageReference Include="Serilog.Extensions.Logging" Version="9.0.0" />
|
||||
<PackageReference Include="Serilog.Extensions.Logging" Version="8.0.0" />
|
||||
<PackageReference Include="Serilog.Formatting.Compact" Version="3.0.0" />
|
||||
<PackageReference Include="Serilog.NodaTime" Version="3.0.0" />
|
||||
<PackageReference Include="Serilog.Sinks.Async" Version="2.1.0" />
|
||||
|
|
|
|||
|
|
@ -205,7 +205,7 @@ public partial class BulkImporter
|
|||
? existingSwitches.Select(sw => sw.Id).Max()
|
||||
: (SwitchId?)null;
|
||||
|
||||
if (switches.Count > 10000)
|
||||
if (switches.Count > 100000)
|
||||
throw new ImportException("Too many switches present in import file.");
|
||||
|
||||
// Import switch definitions
|
||||
|
|
|
|||
215
PluralKit.Core/Utils/SerilogJsonFormatter.cs
Normal file
215
PluralKit.Core/Utils/SerilogJsonFormatter.cs
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
using System.Runtime.CompilerServices;
|
||||
|
||||
using Serilog.Events;
|
||||
using Serilog.Formatting;
|
||||
using Serilog.Formatting.Json;
|
||||
using Serilog.Parsing;
|
||||
using Serilog.Rendering;
|
||||
|
||||
// Customized Serilog JSON output for PluralKit
|
||||
|
||||
// Copyright 2013-2015 Serilog Contributors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
namespace PluralKit.Core;
|
||||
|
||||
static class Guard
|
||||
{
|
||||
public static T AgainstNull<T>(
|
||||
T? argument,
|
||||
[CallerArgumentExpression("argument")] string? paramName = null)
|
||||
where T : class
|
||||
{
|
||||
if (argument is null)
|
||||
{
|
||||
throw new ArgumentNullException(paramName);
|
||||
}
|
||||
|
||||
return argument;
|
||||
}
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Formats log events in a simple JSON structure. Instances of this class
|
||||
/// are safe for concurrent access by multiple threads.
|
||||
/// </summary>
|
||||
/// <remarks>New code should prefer formatters from <c>Serilog.Formatting.Compact</c>, or <c>ExpressionTemplate</c> from
|
||||
/// <c>Serilog.Expressions</c>.</remarks>
|
||||
public sealed class CustomJsonFormatter: ITextFormatter
|
||||
{
|
||||
readonly JsonValueFormatter _jsonValueFormatter = new();
|
||||
readonly string _component;
|
||||
|
||||
/// <summary>
|
||||
/// Construct a <see cref="JsonFormatter"/>.
|
||||
/// </summary>
|
||||
/// <param name="closingDelimiter">A string that will be written after each log event is formatted.
|
||||
/// If null, <see cref="Environment.NewLine"/> will be used.</param>
|
||||
/// <param name="renderMessage">If <see langword="true"/>, the message will be rendered and written to the output as a
|
||||
/// property named RenderedMessage.</param>
|
||||
/// <param name="formatProvider">Supplies culture-specific formatting information, or null.</param>
|
||||
public CustomJsonFormatter(string component)
|
||||
{
|
||||
_component = component;
|
||||
}
|
||||
|
||||
private string CustomLevelString(LogEventLevel level)
|
||||
{
|
||||
switch (level)
|
||||
{
|
||||
case LogEventLevel.Verbose:
|
||||
return "TRACE";
|
||||
case LogEventLevel.Debug:
|
||||
return "DEBUG";
|
||||
case LogEventLevel.Information:
|
||||
return "INFO";
|
||||
case LogEventLevel.Warning:
|
||||
return "WARN";
|
||||
case LogEventLevel.Error:
|
||||
return "ERROR";
|
||||
case LogEventLevel.Fatal:
|
||||
return "FATAL";
|
||||
};
|
||||
|
||||
return "UNKNOWN";
|
||||
}
|
||||
|
||||
/// <summary>
|
||||
/// Format the log event into the output.
|
||||
/// </summary>
|
||||
/// <param name="logEvent">The event to format.</param>
|
||||
/// <param name="output">The output.</param>
|
||||
/// <exception cref="ArgumentNullException">When <paramref name="logEvent"/> is <code>null</code></exception>
|
||||
/// <exception cref="ArgumentNullException">When <paramref name="output"/> is <code>null</code></exception>
|
||||
public void Format(LogEvent logEvent, TextWriter output)
|
||||
{
|
||||
Guard.AgainstNull(logEvent);
|
||||
Guard.AgainstNull(output);
|
||||
|
||||
output.Write("{\"component\":\"");
|
||||
output.Write(_component);
|
||||
output.Write("\",\"timestamp\":\"");
|
||||
output.Write(logEvent.Timestamp.ToString("O").Replace("+00:00", "Z"));
|
||||
output.Write("\",\"level\":\"");
|
||||
output.Write(CustomLevelString(logEvent.Level));
|
||||
|
||||
output.Write("\",\"message\":");
|
||||
var message = logEvent.MessageTemplate.Render(logEvent.Properties);
|
||||
JsonValueFormatter.WriteQuotedJsonString(message, output);
|
||||
|
||||
if (logEvent.TraceId != null)
|
||||
{
|
||||
output.Write(",\"TraceId\":");
|
||||
JsonValueFormatter.WriteQuotedJsonString(logEvent.TraceId.ToString()!, output);
|
||||
}
|
||||
|
||||
if (logEvent.SpanId != null)
|
||||
{
|
||||
output.Write(",\"SpanId\":");
|
||||
JsonValueFormatter.WriteQuotedJsonString(logEvent.SpanId.ToString()!, output);
|
||||
}
|
||||
|
||||
if (logEvent.Exception != null)
|
||||
{
|
||||
output.Write(",\"Exception\":");
|
||||
JsonValueFormatter.WriteQuotedJsonString(logEvent.Exception.ToString(), output);
|
||||
}
|
||||
|
||||
if (logEvent.Properties.Count != 0)
|
||||
{
|
||||
output.Write(",\"Properties\":{");
|
||||
|
||||
char? propertyDelimiter = null;
|
||||
foreach (var property in logEvent.Properties)
|
||||
{
|
||||
if (propertyDelimiter != null)
|
||||
output.Write(propertyDelimiter.Value);
|
||||
else
|
||||
propertyDelimiter = ',';
|
||||
|
||||
JsonValueFormatter.WriteQuotedJsonString(property.Key, output);
|
||||
output.Write(':');
|
||||
_jsonValueFormatter.Format(property.Value, output);
|
||||
}
|
||||
|
||||
output.Write('}');
|
||||
}
|
||||
|
||||
var tokensWithFormat = logEvent.MessageTemplate.Tokens
|
||||
.OfType<PropertyToken>()
|
||||
.Where(pt => pt.Format != null)
|
||||
.GroupBy(pt => pt.PropertyName)
|
||||
.ToArray();
|
||||
|
||||
if (tokensWithFormat.Length != 0)
|
||||
{
|
||||
output.Write(",\"Renderings\":{");
|
||||
WriteRenderingsValues(tokensWithFormat, logEvent.Properties, output);
|
||||
output.Write('}');
|
||||
}
|
||||
|
||||
output.Write('}');
|
||||
output.Write("\n");
|
||||
}
|
||||
|
||||
void WriteRenderingsValues(IEnumerable<IGrouping<string, PropertyToken>> tokensWithFormat, IReadOnlyDictionary<string, LogEventPropertyValue> properties, TextWriter output)
|
||||
{
|
||||
static void WriteNameValuePair(string name, string value, ref char? precedingDelimiter, TextWriter output)
|
||||
{
|
||||
if (precedingDelimiter != null)
|
||||
output.Write(precedingDelimiter.Value);
|
||||
|
||||
JsonValueFormatter.WriteQuotedJsonString(name, output);
|
||||
output.Write(':');
|
||||
JsonValueFormatter.WriteQuotedJsonString(value, output);
|
||||
precedingDelimiter = ',';
|
||||
}
|
||||
|
||||
char? propertyDelimiter = null;
|
||||
foreach (var propertyFormats in tokensWithFormat)
|
||||
{
|
||||
if (propertyDelimiter != null)
|
||||
output.Write(propertyDelimiter.Value);
|
||||
else
|
||||
propertyDelimiter = ',';
|
||||
|
||||
output.Write('"');
|
||||
output.Write(propertyFormats.Key);
|
||||
output.Write("\":[");
|
||||
|
||||
char? formatDelimiter = null;
|
||||
foreach (var format in propertyFormats)
|
||||
{
|
||||
if (formatDelimiter != null)
|
||||
output.Write(formatDelimiter.Value);
|
||||
|
||||
formatDelimiter = ',';
|
||||
|
||||
output.Write('{');
|
||||
char? elementDelimiter = null;
|
||||
|
||||
// Caller ensures that `tokensWithFormat` contains only property tokens that have non-null `Format`s.
|
||||
WriteNameValuePair("Format", format.Format!, ref elementDelimiter, output);
|
||||
|
||||
using var sw = ReusableStringWriter.GetOrCreate();
|
||||
MessageTemplateRenderer.RenderPropertyToken(format, properties, sw, null, isLiteral: true, isJson: false);
|
||||
WriteNameValuePair("Rendering", sw.ToString(), ref elementDelimiter, output);
|
||||
|
||||
output.Write('}');
|
||||
}
|
||||
|
||||
output.Write(']');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -198,20 +198,14 @@
|
|||
"Npgsql": "9.0.2"
|
||||
}
|
||||
},
|
||||
"Serilog": {
|
||||
"type": "Direct",
|
||||
"requested": "[4.2.0, )",
|
||||
"resolved": "4.2.0",
|
||||
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
|
||||
},
|
||||
"Serilog.Extensions.Logging": {
|
||||
"type": "Direct",
|
||||
"requested": "[9.0.0, )",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
|
||||
"requested": "[8.0.0, )",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Logging": "9.0.0",
|
||||
"Serilog": "4.2.0"
|
||||
"Microsoft.Extensions.Logging": "8.0.0",
|
||||
"Serilog": "3.1.1"
|
||||
}
|
||||
},
|
||||
"Serilog.Formatting.Compact": {
|
||||
|
|
@ -722,6 +716,9 @@
|
|||
"Microsoft.NETCore.Targets": "1.1.0",
|
||||
"System.Runtime": "4.3.0"
|
||||
}
|
||||
},
|
||||
"serilog": {
|
||||
"type": "Project"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -128,6 +128,11 @@
|
|||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.1"
|
||||
}
|
||||
},
|
||||
"CavemanTcp": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.0.5",
|
||||
"contentHash": "90wywmGpjrj26HMAkufYZwuZI8sVYB1mRwEdqugSR3kgDnPX+3l0jO86gwtFKsPvsEpsS4Dn/1EbhguzUxMU8Q=="
|
||||
},
|
||||
"Dapper": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.1.35",
|
||||
|
|
@ -156,6 +161,11 @@
|
|||
"resolved": "2.14.1",
|
||||
"contentHash": "lQKvtaTDOXnoVJ20ibTuSIOf2i0uO0MPbDhd1jm238I+U/2ZnRENj0cktKZhtchBMtCUSRQ5v4xBCUbKNmyVMw=="
|
||||
},
|
||||
"IpMatcher": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.0.5",
|
||||
"contentHash": "WXNlWERj+0GN699AnMNsuJ7PfUAbU4xhOHP3nrNXLHqbOaBxybu25luSYywX1133NSlitA4YkSNmJuyPvea4sw=="
|
||||
},
|
||||
"IPNetwork2": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.667",
|
||||
|
|
@ -310,21 +320,21 @@
|
|||
},
|
||||
"Microsoft.Extensions.DependencyModel": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "saxr2XzwgDU77LaQfYFXmddEDRUKHF4DaGMZkNB3qjdVSZlax3//dGJagJkKrGMIPNZs2jVFXITyCCR6UHJNdA==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "NSmDw3K0ozNDgShSIpsZcbFIzBX4w28nDag+TfaQujkXGazBm+lid5onlWoCBy4VsLxqnnKjEBbGSJVWJMf43g==",
|
||||
"dependencies": {
|
||||
"System.Text.Encodings.Web": "9.0.0",
|
||||
"System.Text.Json": "9.0.0"
|
||||
"System.Text.Encodings.Web": "8.0.0",
|
||||
"System.Text.Json": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Diagnostics.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "1K8P7XzuzX8W8pmXcZjcrqS6x5eSSdvhQohmcpgiQNY/HlDAlnrhR9dvlURfFz428A+RTCJpUyB+aKTA6AgVcQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "JHYCQG7HmugNYUhOl368g+NMxYE/N/AiclCYRNlgCY9eVyiBkOHMwK4x60RYMxv9EL3+rmj1mqHvdCiPpC+D4Q==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Options": "9.0.0",
|
||||
"System.Diagnostics.DiagnosticSource": "9.0.0"
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Options": "8.0.0",
|
||||
"System.Diagnostics.DiagnosticSource": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": {
|
||||
|
|
@ -352,14 +362,14 @@
|
|||
},
|
||||
"Microsoft.Extensions.Hosting.Abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "yUKJgu81ExjvqbNWqZKshBbLntZMbMVz/P7Way2SBx7bMqA08Mfdc9O7hWDKAiSp+zPUGT6LKcSCQIPeDK+CCw==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "AG7HWwVRdCHlaA++1oKDxLsXIBxmDpMPb3VoyOoAghEWnkUvEAdYQUwnV4jJbAaa/nMYNiEh5ByoLauZBEiovg==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Configuration.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "9.0.0"
|
||||
"Microsoft.Extensions.Configuration.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Diagnostics.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.FileProviders.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Microsoft.Extensions.Logging": {
|
||||
|
|
@ -510,49 +520,52 @@
|
|||
"resolved": "8.5.0",
|
||||
"contentHash": "VYYMZNitZ85UEhwOKkTQI63WEMvzUqwQc74I2mm8h/DBVAMcBBxqYPni4DmuRtbCwngmuONuK2yBJfWNRKzI+A=="
|
||||
},
|
||||
"RegexMatcher": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.0.9",
|
||||
"contentHash": "RkQGXIrqHjD5h1mqefhgCbkaSdRYNRG5rrbzyw5zeLWiS0K1wq9xR3cNhQdzYR2MsKZ3GN523yRUsEQIMPxh3Q=="
|
||||
},
|
||||
"Sentry": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.13.0",
|
||||
"contentHash": "Wfw3M1WpFcrYaGzPm7QyUTfIOYkVXQ1ry6p4WYjhbLz9fPwV23SGQZTFDpdox67NHM0V0g1aoQ4YKLm4ANtEEg=="
|
||||
},
|
||||
"Serilog": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.2.0",
|
||||
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
|
||||
},
|
||||
"Serilog.AspNetCore": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "JslDajPlBsn3Pww1554flJFTqROvK9zz9jONNQgn0D8Lx2Trw8L0A8/n6zEQK1DAZWXrJwiVLw8cnTR3YFuYsg==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "FAjtKPZ4IzqFQBqZKPv6evcXK/F0ls7RoXI/62Pnx2igkDZ6nZ/jn/C/FxVATqQbEQvtqP+KViWYIe4NZIHa2w==",
|
||||
"dependencies": {
|
||||
"Serilog": "4.2.0",
|
||||
"Serilog.Extensions.Hosting": "9.0.0",
|
||||
"Serilog.Formatting.Compact": "3.0.0",
|
||||
"Serilog.Settings.Configuration": "9.0.0",
|
||||
"Serilog.Sinks.Console": "6.0.0",
|
||||
"Serilog.Sinks.Debug": "3.0.0",
|
||||
"Serilog.Sinks.File": "6.0.0"
|
||||
"Microsoft.Extensions.DependencyInjection": "8.0.0",
|
||||
"Microsoft.Extensions.Logging": "8.0.0",
|
||||
"Serilog": "3.1.1",
|
||||
"Serilog.Extensions.Hosting": "8.0.0",
|
||||
"Serilog.Extensions.Logging": "8.0.0",
|
||||
"Serilog.Formatting.Compact": "2.0.0",
|
||||
"Serilog.Settings.Configuration": "8.0.0",
|
||||
"Serilog.Sinks.Console": "5.0.0",
|
||||
"Serilog.Sinks.Debug": "2.0.0",
|
||||
"Serilog.Sinks.File": "5.0.0"
|
||||
}
|
||||
},
|
||||
"Serilog.Extensions.Hosting": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "u2TRxuxbjvTAldQn7uaAwePkWxTHIqlgjelekBtilAGL5sYyF3+65NWctN4UrwwGLsDC7c3Vz3HnOlu+PcoxXg==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "db0OcbWeSCvYQkHWu6n0v40N4kKaTAXNjlM3BKvcbwvNzYphQFcBR+36eQ/7hMMwOkJvAyLC2a9/jNdUL5NjtQ==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Hosting.Abstractions": "9.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "9.0.0",
|
||||
"Serilog": "4.2.0",
|
||||
"Serilog.Extensions.Logging": "9.0.0"
|
||||
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Hosting.Abstractions": "8.0.0",
|
||||
"Microsoft.Extensions.Logging.Abstractions": "8.0.0",
|
||||
"Serilog": "3.1.1",
|
||||
"Serilog.Extensions.Logging": "8.0.0"
|
||||
}
|
||||
},
|
||||
"Serilog.Extensions.Logging": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Logging": "9.0.0",
|
||||
"Serilog": "4.2.0"
|
||||
"Microsoft.Extensions.Logging": "8.0.0",
|
||||
"Serilog": "3.1.1"
|
||||
}
|
||||
},
|
||||
"Serilog.Formatting.Compact": {
|
||||
|
|
@ -582,12 +595,12 @@
|
|||
},
|
||||
"Serilog.Settings.Configuration": {
|
||||
"type": "Transitive",
|
||||
"resolved": "9.0.0",
|
||||
"contentHash": "4/Et4Cqwa+F88l5SeFeNZ4c4Z6dEAIKbu3MaQb2Zz9F/g27T5a3wvfMcmCOaAiACjfUb4A6wrlTVfyYUZk3RRQ==",
|
||||
"resolved": "8.0.0",
|
||||
"contentHash": "nR0iL5HwKj5v6ULo3/zpP8NMcq9E2pxYA6XKTSWCbugVs4YqPyvaqaKOY+OMpPivKp7zMEpax2UKHnDodbRB0Q==",
|
||||
"dependencies": {
|
||||
"Microsoft.Extensions.Configuration.Binder": "9.0.0",
|
||||
"Microsoft.Extensions.DependencyModel": "9.0.0",
|
||||
"Serilog": "4.2.0"
|
||||
"Microsoft.Extensions.Configuration.Binder": "8.0.0",
|
||||
"Microsoft.Extensions.DependencyModel": "8.0.0",
|
||||
"Serilog": "3.1.1"
|
||||
}
|
||||
},
|
||||
"Serilog.Sinks.Async": {
|
||||
|
|
@ -608,10 +621,10 @@
|
|||
},
|
||||
"Serilog.Sinks.Debug": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.0",
|
||||
"contentHash": "4BzXcdrgRX7wde9PmHuYd9U6YqycCC28hhpKonK7hx0wb19eiuRj16fPcPSVp0o/Y1ipJuNLYQ00R3q2Zs8FDA==",
|
||||
"resolved": "2.0.0",
|
||||
"contentHash": "Y6g3OBJ4JzTyyw16fDqtFcQ41qQAydnEvEqmXjhwhgjsnG/FaJ8GUqF5ldsC/bVkK8KYmqrPhDO+tm4dF6xx4A==",
|
||||
"dependencies": {
|
||||
"Serilog": "4.0.0"
|
||||
"Serilog": "2.10.0"
|
||||
}
|
||||
},
|
||||
"Serilog.Sinks.Elasticsearch": {
|
||||
|
|
@ -887,6 +900,37 @@
|
|||
"System.Runtime": "4.3.0"
|
||||
}
|
||||
},
|
||||
"Timestamps": {
|
||||
"type": "Transitive",
|
||||
"resolved": "1.0.11",
|
||||
"contentHash": "SnWhXm3FkEStQGgUTfWMh9mKItNW032o/v8eAtFrOGqG0/ejvPPA1LdLZx0N/qqoY0TH3x11+dO00jeVcM8xNQ=="
|
||||
},
|
||||
"UrlMatcher": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.1",
|
||||
"contentHash": "hHBZVzFSfikrx4XsRsnCIwmGLgbNKtntnlqf4z+ygcNA6Y/L/J0x5GiZZWfXdTfpxhy5v7mlt2zrZs/L9SvbOA=="
|
||||
},
|
||||
"Watson.Core": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.3.5",
|
||||
"contentHash": "Y5YxKOCSLe2KDmfwvI/J0qApgmmZR77LwyoufRVfKH7GLdHiE7fY0IfoNxWTG7nNv8knBfgwyOxdehRm+4HaCg==",
|
||||
"dependencies": {
|
||||
"IpMatcher": "1.0.5",
|
||||
"RegexMatcher": "1.0.9",
|
||||
"System.Text.Json": "8.0.5",
|
||||
"Timestamps": "1.0.11",
|
||||
"UrlMatcher": "3.0.1"
|
||||
}
|
||||
},
|
||||
"Watson.Lite": {
|
||||
"type": "Transitive",
|
||||
"resolved": "6.3.5",
|
||||
"contentHash": "YF8+se3IVenn8YlyNeb4wSJK6QMnVD0QHIOEiZ22wS4K2wkwoSDzWS+ZAjk1MaPeB+XO5gRoENUN//pOc+wI2g==",
|
||||
"dependencies": {
|
||||
"CavemanTcp": "2.0.5",
|
||||
"Watson.Core": "6.3.5"
|
||||
}
|
||||
},
|
||||
"xunit.abstractions": {
|
||||
"type": "Transitive",
|
||||
"resolved": "2.0.3",
|
||||
|
|
@ -930,9 +974,11 @@
|
|||
"myriad": {
|
||||
"type": "Project",
|
||||
"dependencies": {
|
||||
"NodaTime": "[3.2.0, )",
|
||||
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||
"Polly": "[8.5.0, )",
|
||||
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
|
||||
"Serilog": "[4.2.0, )",
|
||||
"Serilog": "[4.1.0, )",
|
||||
"StackExchange.Redis": "[2.8.22, )",
|
||||
"System.Linq.Async": "[6.0.1, )"
|
||||
}
|
||||
|
|
@ -945,7 +991,7 @@
|
|||
"Microsoft.AspNetCore.Mvc.Versioning.ApiExplorer": "[5.1.0, )",
|
||||
"PluralKit.Core": "[1.0.0, )",
|
||||
"Sentry": "[4.13.0, )",
|
||||
"Serilog.AspNetCore": "[9.0.0, )"
|
||||
"Serilog.AspNetCore": "[8.0.0, )"
|
||||
}
|
||||
},
|
||||
"pluralkit.bot": {
|
||||
|
|
@ -954,7 +1000,8 @@
|
|||
"Humanizer.Core": "[2.14.1, )",
|
||||
"Myriad": "[1.0.0, )",
|
||||
"PluralKit.Core": "[1.0.0, )",
|
||||
"Sentry": "[4.13.0, )"
|
||||
"Sentry": "[4.13.0, )",
|
||||
"Watson.Lite": "[6.3.5, )"
|
||||
}
|
||||
},
|
||||
"pluralkit.core": {
|
||||
|
|
@ -980,8 +1027,8 @@
|
|||
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
|
||||
"Npgsql": "[9.0.2, )",
|
||||
"Npgsql.NodaTime": "[9.0.2, )",
|
||||
"Serilog": "[4.2.0, )",
|
||||
"Serilog.Extensions.Logging": "[9.0.0, )",
|
||||
"Serilog": "[4.1.0, )",
|
||||
"Serilog.Extensions.Logging": "[8.0.0, )",
|
||||
"Serilog.Formatting.Compact": "[3.0.0, )",
|
||||
"Serilog.NodaTime": "[3.0.0, )",
|
||||
"Serilog.Sinks.Async": "[2.1.0, )",
|
||||
|
|
@ -995,6 +1042,9 @@
|
|||
"System.Interactive.Async": "[6.0.1, )",
|
||||
"ipnetwork2": "[3.0.667, )"
|
||||
}
|
||||
},
|
||||
"serilog": {
|
||||
"type": "Project"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
72
README.md
72
README.md
|
|
@ -5,75 +5,35 @@ PluralKit is a Discord bot meant for plural communities. It has features like me
|
|||
|
||||
PluralKit has a Discord server for support, feedback, and discussion: https://discord.gg/PczBt78
|
||||
|
||||
# Requirements
|
||||
Running the bot requires [.NET 5](https://dotnet.microsoft.com/download), a PostgreSQL database and a Redis database. It should function on any system where the prerequisites are set up (including Windows).
|
||||
|
||||
Optionally, it can integrate with [Sentry](https://sentry.io/welcome/) for error reporting and [InfluxDB](https://www.influxdata.com/products/influxdb-overview/) for aggregate statistics.
|
||||
|
||||
# Configuration
|
||||
Configuring the bot is done through a JSON configuration file. An example of the configuration format can be seen in [`pluralkit.conf.example`](https://github.com/PluralKit/PluralKit/blob/master/pluralkit.conf.example).
|
||||
The configuration file needs to be placed in the bot's working directory (usually the repository root) and must be called `pluralkit.conf`.
|
||||
|
||||
The configuration file is in JSON format (albeit with a `.conf` extension). The following keys are available (using `.` to indicate a nested object level), bolded key names are required:
|
||||
* **`PluralKit.Bot.Token`**: the Discord bot token to connect with
|
||||
* **`PluralKit.Database`**: the URI of the database to connect to (in [ADO.NET Npgsql format](https://www.connectionstrings.com/npgsql/))
|
||||
* **`PluralKit.RedisAddr`**: the `host:port` of a Redis database to connect to
|
||||
* `PluralKit.Bot.Prefixes`: an array of command prefixes to use (default `["pk;", "pk!"]`).
|
||||
* **`PluralKit.Bot.ClientId`**: the ID of the bot's user account, used for calculating the bot's own permissions and for the link in `pk;invite`.
|
||||
* `PluralKit.SentryUrl` *(optional)*: the [Sentry](https://sentry.io/welcome/) client key/DSN to report runtime errors to. If absent, disables Sentry integration.
|
||||
* `PluralKit.InfluxUrl` *(optional)*: the URL to an [InfluxDB](https://www.influxdata.com/products/influxdb-overview/) server to report aggregate statistics to. An example of these stats can be seen on [the public stats page](https://stats.pluralkit.me).
|
||||
* `PluralKit.InfluxDb` *(optional)*: the name of an InfluxDB database to report statistics to. If either this field or `PluralKit.InfluxUrl` are absent, InfluxDB reporting will be disabled.
|
||||
* `PluralKit.LogDir` *(optional)*: the directory to save information and error logs to. If left blank, will default to `logs/` in the current working directory.
|
||||
|
||||
The bot can also take configuration from environment variables, which will override the values read from the file. Here, use `:` (colon) or `__` (double underscore) as a level separator (eg. `export PluralKit__Bot__Token=foobar123`) as per [ASP.NET config](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/configuration/?view=aspnetcore-3.1#environment-variables).
|
||||
|
||||
# Running
|
||||
In production, we run PluralKit using Kubernetes. The configuration can be found in the infra repo.
|
||||
|
||||
## Docker
|
||||
The easiest way to get the bot running is with Docker. The repository contains a `docker-compose.yml` file ready to use.
|
||||
For self-hosting, it's simpler to use Docker, with the provided [docker-compose](./docker-compose.yml) file.
|
||||
|
||||
* Clone this repository: `git clone https://github.com/PluralKit/PluralKit`
|
||||
* Create a `pluralkit.conf` file in the same directory as `docker-compose.yml` containing at least `PluralKit.Bot.Token` and `PluralKit.Bot.ClientId` fields
|
||||
* (`PluralKit.Database` is overridden in `docker-compose.yml` to point to the Postgres container)
|
||||
* Build the bot: `docker-compose build`
|
||||
* Run the bot: `docker-compose up`
|
||||
|
||||
In other words:
|
||||
Create a `.env` file with the Discord client ID and bot token:
|
||||
```
|
||||
$ git clone https://github.com/PluralKit/PluralKit
|
||||
$ cd PluralKit
|
||||
$ cp pluralkit.conf.example pluralkit.conf
|
||||
$ nano pluralkit.conf # (or vim, or whatever)
|
||||
$ docker-compose up -d
|
||||
CLIENT_ID=198622483471925248
|
||||
BOT_TOKEN=MTk4NjIyNDgzNDcxOTI1MjQ4.Cl2FMQ.ZnCjm1XVW7vRze4b7Cq4se7kKWs
|
||||
```
|
||||
|
||||
## Manually
|
||||
* Install the .NET 6 SDK (see https://dotnet.microsoft.com/download)
|
||||
* Clone this repository: `git clone https://github.com/PluralKit/PluralKit`
|
||||
* Create and fill in a `pluralkit.conf` file in the same directory as `docker-compose.yml`
|
||||
* Run the bot: `dotnet run --project PluralKit.Bot`
|
||||
* Alternatively, `dotnet build -c Release -o build/`, then `dotnet build/PluralKit.Bot.dll`
|
||||
If you want to use `pk;admin` commands (to raise member limits and such), set `ADMIN_ROLE` to a Discord role ID:
|
||||
|
||||
(tip: use `scripts/run-test-db.sh` to run a temporary PostgreSQL database on your local system. Requires Docker.)
|
||||
```
|
||||
ADMIN_ROLE=682632767057428509
|
||||
```
|
||||
|
||||
## Scheduled Tasks worker
|
||||
*If you didn't clone the repository with submodules, run `git submodule update --init` first to pull the required submodules.*
|
||||
Run `docker compose build`, then `docker compose up -d`.
|
||||
|
||||
There is a scheduled tasks worker that needs to be ran separately from the bot. This handles cleaning up the database, and updating statistics (system/member/etc counts, shown in the `pk;stats` embed).
|
||||
To view logs, use `docker compose logs`.
|
||||
|
||||
Note: This worker is *not required*, and the bot will function correctly without it.
|
||||
Postgres data is stored in a `pluralkit_data` [Docker volume](https://docs.docker.com/engine/storage/volumes/).
|
||||
|
||||
If you are running the bot via docker-compose, this is set up automatically.
|
||||
|
||||
If you run the bot manually you can run the worker as such:
|
||||
* `dotnet run --project PluralKit.ScheduledTasks`
|
||||
* or if you used `dotnet build` rather than `dotnet run` to run the bot: `dotnet build/PluralKit.ScheduledTasks.dll`
|
||||
|
||||
# Upgrading database from legacy version
|
||||
If you have an instance of the Python version of the bot (from the `legacy` branch), you may need to take extra database migration steps.
|
||||
For more information, see [LEGACYMIGRATE.md](./LEGACYMIGRATE.md).
|
||||
# Development
|
||||
See [the dev-docs/ directory](./dev-docs/README.md)
|
||||
|
||||
# User documentation
|
||||
See [the docs/ directory](./docs/README.md)
|
||||
|
||||
# License
|
||||
This project is under the GNU Affero General Public License, Version 3. It is available at the following link: https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
This project is under the GNU Affero General Public License, Version 3. It is available at the following link: https://www.gnu.org/licenses/agpl-3.0.en.html
|
||||
|
|
|
|||
1
Serilog
Submodule
1
Serilog
Submodule
|
|
@ -0,0 +1 @@
|
|||
Subproject commit f5eb991cb4c4a0c1e2407de7504c543536786598
|
||||
|
|
@ -10,6 +10,7 @@ COPY PluralKit.Bot/PluralKit.Bot.csproj /app/PluralKit.Bot/
|
|||
COPY PluralKit.Core/PluralKit.Core.csproj /app/PluralKit.Core/
|
||||
COPY PluralKit.Tests/PluralKit.Tests.csproj /app/PluralKit.Tests/
|
||||
COPY .git/ /app/.git
|
||||
COPY Serilog/ /app/Serilog/
|
||||
RUN dotnet restore PluralKit.sln
|
||||
|
||||
# Copy the rest of the code and build
|
||||
|
|
@ -25,18 +25,22 @@ COPY Cargo.lock /build/
|
|||
|
||||
COPY crates/ /build/crates
|
||||
|
||||
RUN cargo build --bin migrate --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin api --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin dispatch --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin gateway --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin avatars --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin avatar_cleanup --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin scheduled_tasks --release --target x86_64-unknown-linux-musl
|
||||
RUN cargo build --bin gdpr_worker --release --target x86_64-unknown-linux-musl
|
||||
|
||||
FROM scratch
|
||||
FROM alpine:latest
|
||||
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/migrate /migrate
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/api /api
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/dispatch /dispatch
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/gateway /gateway
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/avatars /avatars
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/avatar_cleanup /avatar_cleanup
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/scheduled_tasks /scheduled_tasks
|
||||
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/gdpr_worker /gdpr_worker
|
||||
|
|
|
|||
|
|
@ -37,8 +37,10 @@ EOF
|
|||
}
|
||||
|
||||
# add rust binaries here to build
|
||||
build migrate
|
||||
build api
|
||||
build dispatch
|
||||
build gateway
|
||||
build avatars "COPY .docker-bin/avatar_cleanup /bin/avatar_cleanup"
|
||||
build scheduled_tasks
|
||||
build gdpr_worker
|
||||
|
|
|
|||
48
crates/api/src/auth.rs
Normal file
48
crates/api/src/auth.rs
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
|
||||
|
||||
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthState {
|
||||
system_id: Option<i32>,
|
||||
app_id: Option<i32>,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
|
||||
Self { system_id, app_id }
|
||||
}
|
||||
|
||||
pub fn system_id(&self) -> Option<i32> {
|
||||
self.system_id
|
||||
}
|
||||
|
||||
pub fn app_id(&self) -> Option<i32> {
|
||||
self.app_id
|
||||
}
|
||||
|
||||
pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel {
|
||||
if self
|
||||
.system_id
|
||||
.map(|id| id == a.authable_system_id())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
PrivacyLevel::Private
|
||||
} else {
|
||||
PrivacyLevel::Public
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// authable trait/impls
|
||||
|
||||
pub trait Authable {
|
||||
fn authable_system_id(&self) -> SystemId;
|
||||
}
|
||||
|
||||
impl Authable for PKSystem {
|
||||
fn authable_system_id(&self) -> SystemId {
|
||||
self.id
|
||||
}
|
||||
}
|
||||
|
|
@ -1 +1,2 @@
|
|||
pub mod private;
|
||||
pub mod system;
|
||||
|
|
|
|||
|
|
@ -52,7 +52,7 @@ use axum::{
|
|||
};
|
||||
use hyper::StatusCode;
|
||||
use libpk::config;
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig};
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
|
||||
use reqwest::ClientBuilder;
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
|
|
@ -151,14 +151,12 @@ pub async fn discord_callback(
|
|||
.await
|
||||
.expect("failed to query");
|
||||
|
||||
if system.is_none() {
|
||||
let Some(system) = system else {
|
||||
return json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"user does not have a system registered".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let system = system.unwrap();
|
||||
};
|
||||
|
||||
let system_config: Option<PKSystemConfig> = sqlx::query_as(
|
||||
r#"
|
||||
|
|
@ -179,7 +177,7 @@ pub async fn discord_callback(
|
|||
(
|
||||
StatusCode::OK,
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"system": system.to_json(),
|
||||
"system": system.to_json(PrivacyLevel::Private),
|
||||
"config": system_config.to_json(),
|
||||
"user": user,
|
||||
"token": token,
|
||||
|
|
|
|||
69
crates/api/src/endpoints/system.rs
Normal file
69
crates/api/src/endpoints/system.rs
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
Extension, Json,
|
||||
};
|
||||
use serde_json::json;
|
||||
use sqlx::Postgres;
|
||||
use tracing::error;
|
||||
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
|
||||
|
||||
use crate::{auth::AuthState, util::json_err, ApiContext};
|
||||
|
||||
pub async fn get_system_settings(
|
||||
Extension(auth): Extension<AuthState>,
|
||||
Extension(system): Extension<PKSystem>,
|
||||
State(ctx): State<ApiContext>,
|
||||
) -> Response {
|
||||
let access_level = auth.access_level_for(&system);
|
||||
|
||||
let mut config = match sqlx::query_as::<Postgres, PKSystemConfig>(
|
||||
"select * from system_config where system = $1",
|
||||
)
|
||||
.bind(system.id)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(config)) => config,
|
||||
Ok(None) => {
|
||||
error!(
|
||||
system = system.id,
|
||||
"failed to find system config for existing system"
|
||||
);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query system config");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
// fix this
|
||||
if config.name_format.is_none() {
|
||||
config.name_format = Some("{name} {tag}".to_string());
|
||||
}
|
||||
|
||||
Json(&match access_level {
|
||||
PrivacyLevel::Private => config.to_json(),
|
||||
PrivacyLevel::Public => json!({
|
||||
"pings_enabled": config.pings_enabled,
|
||||
"latch_timeout": config.latch_timeout,
|
||||
"case_sensitive_proxy_tags": config.case_sensitive_proxy_tags,
|
||||
"proxy_error_message_enabled": config.proxy_error_message_enabled,
|
||||
"hid_display_split": config.hid_display_split,
|
||||
"hid_display_caps": config.hid_display_caps,
|
||||
"hid_list_padding": config.hid_list_padding,
|
||||
"proxy_switch": config.proxy_switch,
|
||||
"name_format": config.name_format,
|
||||
}),
|
||||
})
|
||||
.into_response()
|
||||
}
|
||||
|
|
@ -1,10 +1,13 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
http::{Response, StatusCode, Uri},
|
||||
response::IntoResponse,
|
||||
routing::{delete, get, patch, post},
|
||||
Router,
|
||||
Extension, Router,
|
||||
};
|
||||
use hyper_util::{
|
||||
client::legacy::{connect::HttpConnector, Client},
|
||||
|
|
@ -12,6 +15,7 @@ use hyper_util::{
|
|||
};
|
||||
use tracing::{error, info};
|
||||
|
||||
mod auth;
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod middleware;
|
||||
|
|
@ -27,6 +31,7 @@ pub struct ApiContext {
|
|||
}
|
||||
|
||||
async fn rproxy(
|
||||
Extension(auth): Extension<AuthState>,
|
||||
State(ctx): State<ApiContext>,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Result<Response<Body>, StatusCode> {
|
||||
|
|
@ -41,12 +46,25 @@ async fn rproxy(
|
|||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
let headers = req.headers_mut();
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
if let Some(sid) = auth.system_id() {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
|
||||
if let Some(aid) = auth.app_id() {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
|
||||
Ok(ctx
|
||||
.rproxy_client
|
||||
.request(req)
|
||||
.await
|
||||
.map_err(|err| {
|
||||
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
|
||||
.map_err(|error| {
|
||||
error!(?error, "failed to serve reverse proxy to dotnet-api");
|
||||
StatusCode::BAD_GATEWAY
|
||||
})?
|
||||
.into_response())
|
||||
|
|
@ -57,52 +75,52 @@ async fn rproxy(
|
|||
fn router(ctx: ApiContext) -> Router {
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
Router::new()
|
||||
.route("/v2/systems/:system_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", get(rproxy))
|
||||
.route("/v2/systems/:system_id/settings", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
|
||||
.route("/v2/systems/{system_id}/settings", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/members", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/members", get(rproxy))
|
||||
.route("/v2/members", post(rproxy))
|
||||
.route("/v2/members/:member_id", get(rproxy))
|
||||
.route("/v2/members/:member_id", patch(rproxy))
|
||||
.route("/v2/members/:member_id", delete(rproxy))
|
||||
.route("/v2/members/{member_id}", get(rproxy))
|
||||
.route("/v2/members/{member_id}", patch(rproxy))
|
||||
.route("/v2/members/{member_id}", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/groups", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/groups", get(rproxy))
|
||||
.route("/v2/groups", post(rproxy))
|
||||
.route("/v2/groups/:group_id", get(rproxy))
|
||||
.route("/v2/groups/:group_id", patch(rproxy))
|
||||
.route("/v2/groups/:group_id", delete(rproxy))
|
||||
.route("/v2/groups/{group_id}", get(rproxy))
|
||||
.route("/v2/groups/{group_id}", patch(rproxy))
|
||||
.route("/v2/groups/{group_id}", delete(rproxy))
|
||||
|
||||
.route("/v2/groups/:group_id/members", get(rproxy))
|
||||
.route("/v2/groups/:group_id/members/add", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/remove", post(rproxy))
|
||||
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/add", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/groups", get(rproxy))
|
||||
.route("/v2/members/:member_id/groups/add", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/remove", post(rproxy))
|
||||
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups", get(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/add", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches", post(rproxy))
|
||||
.route("/v2/systems/:system_id/fronters", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches", post(rproxy))
|
||||
.route("/v2/systems/{system_id}/fronters", get(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
|
||||
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
|
||||
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
|
||||
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
|
||||
|
||||
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
|
||||
|
||||
.route("/v2/messages/:message_id", get(rproxy))
|
||||
.route("/v2/messages/{message_id}", get(rproxy))
|
||||
|
||||
.route("/private/bulk_privacy/member", post(rproxy))
|
||||
.route("/private/bulk_privacy/group", post(rproxy))
|
||||
|
|
@ -111,16 +129,19 @@ fn router(ctx: ApiContext) -> Router {
|
|||
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
|
||||
.route("/private/stats", get(endpoints::private::meta))
|
||||
|
||||
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::cors))
|
||||
.layer(axum::middleware::from_fn(middleware::logger))
|
||||
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::logger::logger))
|
||||
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::params::params))
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth))
|
||||
|
||||
.layer(axum::middleware::from_fn(middleware::cors::cors))
|
||||
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
|
||||
|
||||
.with_state(ctx)
|
||||
|
|
@ -128,8 +149,8 @@ fn router(ctx: ApiContext) -> Router {
|
|||
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
|
||||
}
|
||||
|
||||
libpk::main!("api");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_data_db().await?;
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
|
|
|
|||
62
crates/api/src/middleware/auth.rs
Normal file
62
crates/api/src/middleware/auth.rs
Normal file
|
|
@ -0,0 +1,62 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
use tracing::error;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
use crate::{util::json_err, ApiContext};
|
||||
|
||||
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
|
||||
let mut authed_system_id: Option<i32> = None;
|
||||
let mut authed_app_id: Option<i32> = None;
|
||||
|
||||
// fetch user authorization
|
||||
if let Some(system_auth_header) = req
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
{
|
||||
authed_system_id = Some(system_id);
|
||||
}
|
||||
|
||||
// fetch app authorization
|
||||
// todo: actually fetch it from db
|
||||
if let Some(app_auth_header) = req
|
||||
.headers()
|
||||
.get("x-pluralkit-app")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(config_token2) = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
.as_ref()
|
||||
// this is NOT how you validate tokens
|
||||
// but this is low abuse risk so we're keeping it for now
|
||||
&& app_auth_header == config_token2
|
||||
{
|
||||
authed_app_id = Some(1);
|
||||
}
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(AuthState::new(authed_system_id, authed_app_id));
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
|
@ -1,45 +0,0 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::HeaderValue,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::ApiContext;
|
||||
|
||||
use super::logger::DID_AUTHENTICATE_HEADER;
|
||||
|
||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||
let headers = request.headers_mut();
|
||||
headers.remove("x-pluralkit-systemid");
|
||||
let auth_header = headers
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten();
|
||||
let mut authenticated = false;
|
||||
if let Some(auth_header) = auth_header {
|
||||
if let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
None
|
||||
}
|
||||
}
|
||||
{
|
||||
headers.append(
|
||||
"x-pluralkit-systemid",
|
||||
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
|
||||
);
|
||||
authenticated = true;
|
||||
}
|
||||
}
|
||||
let mut response = next.run(request).await;
|
||||
if authenticated {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
|
||||
}
|
||||
response
|
||||
}
|
||||
|
|
@ -4,27 +4,30 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
|
|||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
use crate::{auth::AuthState, util::header_or_unknown};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
|
||||
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let auth = extensions
|
||||
.get::<AuthState>()
|
||||
.expect("should always have AuthState");
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_span = span!(
|
||||
|
|
@ -37,25 +40,26 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let mut response = next.run(request).instrument(request_span).await;
|
||||
let response = next.run(request).instrument(request_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
let authenticated = {
|
||||
let headers = response.headers_mut();
|
||||
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
|
||||
headers.remove(DID_AUTHENTICATE_HEADER);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
};
|
||||
let system_id = auth
|
||||
.system_id()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or("none".to_string());
|
||||
|
||||
let app_id = auth
|
||||
.app_id()
|
||||
.map(|v| v.to_string())
|
||||
.unwrap_or("none".to_string());
|
||||
|
||||
counter!(
|
||||
"pluralkit_api_requests",
|
||||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
|
|
@ -63,7 +67,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
|
|
@ -81,7 +86,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,6 @@
|
|||
mod cors;
|
||||
pub use cors::cors;
|
||||
|
||||
mod logger;
|
||||
pub use logger::logger;
|
||||
|
||||
mod ignore_invalid_routes;
|
||||
pub use ignore_invalid_routes::ignore_invalid_routes;
|
||||
|
||||
pub mod auth;
|
||||
pub mod cors;
|
||||
pub mod ignore_invalid_routes;
|
||||
pub mod logger;
|
||||
pub mod params;
|
||||
pub mod ratelimit;
|
||||
|
||||
mod authnz;
|
||||
pub use authnz::authnz;
|
||||
|
|
|
|||
139
crates/api/src/middleware/params.rs
Normal file
139
crates/api/src/middleware/params.rs
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
routing::url_params::UrlParams,
|
||||
};
|
||||
|
||||
use sqlx::{types::Uuid, Postgres};
|
||||
use tracing::error;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
use crate::{util::json_err, ApiContext};
|
||||
use pluralkit_models::PKSystem;
|
||||
|
||||
// move this somewhere else
|
||||
fn parse_hid(hid: &str) -> String {
|
||||
if hid.len() > 7 || hid.len() < 5 {
|
||||
hid.to_string()
|
||||
} else {
|
||||
hid.to_lowercase().replace("-", "")
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn params(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
|
||||
let pms = match req.extensions().get::<UrlParams>() {
|
||||
None => Vec::new(),
|
||||
Some(UrlParams::Params(pms)) => pms.clone(),
|
||||
_ => {
|
||||
return json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message":"400: Bad Request","code": 0}"#.to_string(),
|
||||
)
|
||||
.into()
|
||||
}
|
||||
};
|
||||
|
||||
for (key, value) in pms {
|
||||
match key.as_ref() {
|
||||
"system_id" => match value.as_str() {
|
||||
"@me" => {
|
||||
let Some(system_id) = req
|
||||
.extensions()
|
||||
.get::<AuthState>()
|
||||
.expect("missing auth state")
|
||||
.system_id()
|
||||
else {
|
||||
return json_err(
|
||||
StatusCode::UNAUTHORIZED,
|
||||
r#"{"message":"401: Missing or invalid Authorization header","code": 0}"#.to_string(),
|
||||
)
|
||||
.into();
|
||||
};
|
||||
|
||||
match sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where id = $1",
|
||||
)
|
||||
.bind(system_id)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(system)) => {
|
||||
req.extensions_mut().insert(system);
|
||||
}
|
||||
Ok(None) => {
|
||||
error!(
|
||||
?system_id,
|
||||
"could not find previously authenticated system in db"
|
||||
);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
Err(err) => {
|
||||
error!(
|
||||
?err,
|
||||
"failed to query previously authenticated system in db"
|
||||
);
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
id => {
|
||||
println!("a {id}");
|
||||
match match Uuid::parse_str(id) {
|
||||
Ok(uuid) => sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where uuid = $1",
|
||||
)
|
||||
.bind(uuid),
|
||||
Err(_) => match id.parse::<i64>() {
|
||||
Ok(parsed) => sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where id = (select system from accounts where uid = $1)"
|
||||
)
|
||||
.bind(parsed),
|
||||
Err(_) => sqlx::query_as::<Postgres, PKSystem>(
|
||||
"select * from systems where hid = $1",
|
||||
)
|
||||
.bind(parse_hid(id))
|
||||
},
|
||||
}
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
{
|
||||
Ok(Some(system)) => {
|
||||
req.extensions_mut().insert(system);
|
||||
}
|
||||
Ok(None) => {
|
||||
return json_err(
|
||||
StatusCode::NOT_FOUND,
|
||||
r#"{"message":"System not found.","code":20001}"#.to_string(),
|
||||
)
|
||||
}
|
||||
Err(err) => {
|
||||
error!(?err, ?id, "failed to query system from path in db");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"member_id" => {}
|
||||
"group_id" => {}
|
||||
"switch_id" => {}
|
||||
"guild_id" => {}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
|
@ -10,7 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
|
|||
use metrics::counter;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
use crate::{
|
||||
auth::AuthState,
|
||||
util::{header_or_unknown, json_err},
|
||||
};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
|
|
@ -50,7 +53,7 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
|||
.await
|
||||
{
|
||||
Ok(_) => info!("connected to redis for request rate limiting"),
|
||||
Err(err) => error!("could not load redis script: {}", err),
|
||||
Err(error) => error!(?error, "could not load redis script"),
|
||||
}
|
||||
} else {
|
||||
error!("could not wait for connection to load redis script!");
|
||||
|
|
@ -103,37 +106,28 @@ pub async fn do_request_ratelimited(
|
|||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
|
||||
|
||||
// https://github.com/rust-lang/rust/issues/53667
|
||||
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
||||
{
|
||||
if let Some(token2) = &libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
{
|
||||
if header.to_str().unwrap_or("invalid") == token2 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let rlimit = if is_temp_token2 {
|
||||
let auth = extensions
|
||||
.get::<AuthState>()
|
||||
.expect("should always have AuthState");
|
||||
|
||||
// looks like this chooses the tokens/sec by app_id or endpoint
|
||||
// then chooses the key by system_id or source_ip
|
||||
// todo: key should probably be chosen by app_id when it's present
|
||||
// todo: make x-ratelimit-scope actually meaningful
|
||||
|
||||
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||
let rlimit = if let Some(app_id) = auth.app_id()
|
||||
&& app_id == 1
|
||||
{
|
||||
RatelimitType::TempCustom
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
|
|
@ -145,12 +139,12 @@ pub async fn do_request_ratelimited(
|
|||
|
||||
let rl_key = format!(
|
||||
"{}:{}",
|
||||
if authenticated_system_id != "unknown"
|
||||
if let Some(system_id) = auth.system_id()
|
||||
&& matches!(rlimit, RatelimitType::GenericUpdate)
|
||||
{
|
||||
authenticated_system_id
|
||||
system_id.to_string()
|
||||
} else {
|
||||
source_ip
|
||||
source_ip.to_string()
|
||||
},
|
||||
rlimit.key()
|
||||
);
|
||||
|
|
@ -224,8 +218,8 @@ pub async fn do_request_ratelimited(
|
|||
|
||||
return response;
|
||||
}
|
||||
Err(err) => {
|
||||
tracing::error!("error getting ratelimit info: {}", err);
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "error getting ratelimit info");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
|
|||
match value.to_str() {
|
||||
Ok(v) => v,
|
||||
Err(err) => {
|
||||
error!("failed to parse header value {:#?}: {:#?}", value, err);
|
||||
error!(?err, ?value, "failed to parse header value");
|
||||
"failed to parse"
|
||||
}
|
||||
}
|
||||
|
|
@ -34,11 +34,7 @@ where
|
|||
.unwrap(),
|
||||
),
|
||||
None => {
|
||||
error!(
|
||||
"error in handler {}: {:#?}",
|
||||
std::any::type_name::<F>(),
|
||||
error
|
||||
);
|
||||
error!(?error, "error in handler {}", std::any::type_name::<F>(),);
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
|
|
@ -48,14 +44,15 @@ where
|
|||
}
|
||||
}
|
||||
|
||||
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
error!("caught panic from handler: {:#?}", err);
|
||||
pub fn handle_panic(error: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
|
||||
error!(?error, "caught panic from handler");
|
||||
json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
// todo: make 500 not duplicated
|
||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||
let mut response = (code, text).into_response();
|
||||
let headers = response.headers_mut();
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ use sqlx::prelude::FromRow;
|
|||
use std::{sync::Arc, time::Duration};
|
||||
use tracing::{error, info};
|
||||
|
||||
libpk::main!("avatar_cleanup");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
|
|
@ -13,7 +13,7 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
|
||||
let bucket = {
|
||||
let region = s3::Region::Custom {
|
||||
region: "s3".to_string(),
|
||||
region: "auto".to_string(),
|
||||
endpoint: config.s3.endpoint.to_string(),
|
||||
};
|
||||
|
||||
|
|
@ -38,8 +38,8 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
match cleanup_job(pool.clone(), bucket.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(err) => {
|
||||
error!("failed to run avatar cleanup job: {}", err);
|
||||
Err(error) => {
|
||||
error!(?error, "failed to run avatar cleanup job");
|
||||
// sentry
|
||||
}
|
||||
}
|
||||
|
|
@ -55,9 +55,10 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
|
|||
let mut tx = pool.begin().await?;
|
||||
|
||||
let image_id: Option<CleanupJobEntry> = sqlx::query_as(
|
||||
// no timestamp checking here
|
||||
// images are only added to the table after 24h
|
||||
r#"
|
||||
select id from image_cleanup_jobs
|
||||
where ts < now() - interval '1 day'
|
||||
for update skip locked limit 1;"#,
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
|
|
@ -72,6 +73,7 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
|
|||
|
||||
let image_data = libpk::db::repository::avatars::get_by_id(&pool, image_id.clone()).await?;
|
||||
if image_data.is_none() {
|
||||
// unsure how this can happen? there is a FK reference
|
||||
info!("image {image_id} was already deleted, skipping");
|
||||
sqlx::query("delete from image_cleanup_jobs where id = $1")
|
||||
.bind(image_id)
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ async fn pull(
|
|||
) -> Result<Json<PullResponse>, PKAvatarError> {
|
||||
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
||||
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||
if !req.force {
|
||||
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
|
||||
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||
// remove any pending image cleanup
|
||||
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
|
||||
|
|
@ -170,8 +170,8 @@ pub struct AppState {
|
|||
config: Arc<AvatarsConfig>,
|
||||
}
|
||||
|
||||
libpk::main!("avatars");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let config = libpk::config
|
||||
.avatars
|
||||
.as_ref()
|
||||
|
|
@ -179,7 +179,7 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
|
||||
let bucket = {
|
||||
let region = s3::Region::Custom {
|
||||
region: "s3".to_string(),
|
||||
region: "auto".to_string(),
|
||||
endpoint: config.s3.endpoint.to_string(),
|
||||
};
|
||||
|
||||
|
|
@ -232,26 +232,11 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
struct AppError(anyhow::Error);
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct ErrorResponse {
|
||||
error: String,
|
||||
}
|
||||
|
||||
impl IntoResponse for AppError {
|
||||
fn into_response(self) -> Response {
|
||||
error!("error handling request: {}", self.0);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(ErrorResponse {
|
||||
error: self.0.to_string(),
|
||||
}),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl IntoResponse for PKAvatarError {
|
||||
fn into_response(self) -> Response {
|
||||
let status_code = match self {
|
||||
|
|
@ -278,12 +263,3 @@ impl IntoResponse for PKAvatarError {
|
|||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
impl<E> From<E> for AppError
|
||||
where
|
||||
E: Into<anyhow::Error>,
|
||||
{
|
||||
fn from(err: E) -> Self {
|
||||
Self(err.into())
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -129,9 +129,9 @@ pub async fn worker(worker_id: u32, state: Arc<AppState>) {
|
|||
Ok(()) => {}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"error in migrate worker {}: {}",
|
||||
worker_id,
|
||||
e.source().unwrap_or(&e)
|
||||
error = e.source().unwrap_or(&e)
|
||||
?worker_id,
|
||||
"error in migrate worker",
|
||||
);
|
||||
tokio::time::sleep(Duration::from_secs(5)).await;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ pub fn process(data: &[u8], kind: ImageKind) -> Result<ProcessOutput, PKAvatarEr
|
|||
} else {
|
||||
reader.decode().map_err(|e| {
|
||||
// print the ugly error, return the nice error
|
||||
error!("error decoding image: {}", e);
|
||||
error!(error = format!("{e:#?}"), "error decoding image");
|
||||
PKAvatarError::ImageFormatError(e)
|
||||
})?
|
||||
};
|
||||
|
|
|
|||
|
|
@ -41,7 +41,11 @@ pub async fn pull(
|
|||
}
|
||||
}
|
||||
|
||||
error!("network error for {}: {}", parsed_url.full_url, s);
|
||||
error!(
|
||||
url = parsed_url.full_url,
|
||||
error = s,
|
||||
"network error pulling image"
|
||||
);
|
||||
PKAvatarError::NetworkErrorString(s)
|
||||
})?;
|
||||
let time_after_headers = Instant::now();
|
||||
|
|
@ -82,7 +86,22 @@ pub async fn pull(
|
|||
.map(|x| x.to_string());
|
||||
|
||||
let body = response.bytes().await.map_err(|e| {
|
||||
error!("network error for {}: {}", parsed_url.full_url, e);
|
||||
// terrible
|
||||
let mut s = format!("{}", e);
|
||||
if let Some(src) = e.source() {
|
||||
let _ = write!(s, ": {}", src);
|
||||
let mut err = src;
|
||||
while let Some(src) = err.source() {
|
||||
let _ = write!(s, ": {}", src);
|
||||
err = src;
|
||||
}
|
||||
}
|
||||
|
||||
error!(
|
||||
url = parsed_url.full_url,
|
||||
error = s,
|
||||
"network error pulling image"
|
||||
);
|
||||
PKAvatarError::NetworkError(e)
|
||||
})?;
|
||||
if body.len() != size as usize {
|
||||
|
|
@ -137,6 +156,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
|
|||
|
||||
match (url.scheme(), url.domain()) {
|
||||
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
|
||||
("https", Some("serve.apparyllis.com")) => {
|
||||
return Ok(ParsedUrl {
|
||||
channel_id: 0,
|
||||
attachment_id: 0,
|
||||
filename: "".to_string(),
|
||||
full_url: url.to_string(),
|
||||
})
|
||||
}
|
||||
_ => anyhow::bail!("not a discord cdn url"),
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ edition = "2021"
|
|||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
libpk = { path = "../libpk" }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -19,17 +19,8 @@ use axum::{extract::State, http::Uri, routing::post, Json, Router};
|
|||
|
||||
mod logger;
|
||||
|
||||
// this package does not currently use libpk
|
||||
|
||||
#[tokio::main]
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.json()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
info!("hello world");
|
||||
|
||||
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
|
||||
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
|
||||
let (client, bg) = AsyncClient::connect(stream).await?;
|
||||
|
|
@ -86,11 +77,11 @@ async fn dispatch(
|
|||
let uri = match req.url.parse::<Uri>() {
|
||||
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
|
||||
Err(error) => {
|
||||
error!(?error, "failed to parse uri {}", req.url);
|
||||
error!(?error, uri = req.url, "failed to parse uri");
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
_ => {
|
||||
error!("uri {} is invalid", req.url);
|
||||
error!(uri = req.url, "uri is invalid");
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
};
|
||||
|
|
|
|||
|
|
@ -13,8 +13,9 @@ futures = { workspace = true }
|
|||
lazy_static = { workspace = true }
|
||||
libpk = { path = "../libpk" }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
signal-hook = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
|
||||
|
|
|
|||
|
|
@ -1,19 +1,24 @@
|
|||
use axum::{
|
||||
extract::{Path, State},
|
||||
extract::{ConnectInfo, Path, State},
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Response},
|
||||
routing::get,
|
||||
routing::{delete, get, post},
|
||||
Router,
|
||||
};
|
||||
use libpk::runtime_config::RuntimeConfig;
|
||||
use serde_json::{json, to_string};
|
||||
use tracing::{error, info};
|
||||
use twilight_model::id::Id;
|
||||
use twilight_model::id::{marker::ChannelMarker, Id};
|
||||
|
||||
use crate::discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
use crate::{
|
||||
discord::{
|
||||
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
|
||||
gateway::cluster_config,
|
||||
shard_state::ShardStateManager,
|
||||
},
|
||||
event_awaiter::{AwaitEventRequest, EventAwaiter},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
use std::{net::SocketAddr, sync::Arc};
|
||||
|
||||
fn status_code(code: StatusCode, body: String) -> Response {
|
||||
(code, body).into_response()
|
||||
|
|
@ -21,10 +26,15 @@ fn status_code(code: StatusCode, body: String) -> Response {
|
|||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
||||
pub async fn run_server(cache: Arc<DiscordCache>, shard_state: Arc<ShardStateManager>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
|
||||
// hacky fix for `move`
|
||||
let runtime_config_for_post = runtime_config.clone();
|
||||
let runtime_config_for_delete = runtime_config.clone();
|
||||
let awaiter_for_clear = awaiter.clone();
|
||||
|
||||
let app = Router::new()
|
||||
.route(
|
||||
"/guilds/:guild_id",
|
||||
"/guilds/{guild_id}",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild(Id::new(guild_id)) {
|
||||
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
|
||||
|
|
@ -33,7 +43,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/members/@me",
|
||||
"/guilds/{guild_id}/members/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
|
||||
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
|
||||
|
|
@ -42,7 +52,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/@me",
|
||||
"/guilds/{guild_id}/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
|
||||
Ok(val) => {
|
||||
|
|
@ -56,7 +66,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/permissions/:user_id",
|
||||
"/guilds/{guild_id}/permissions/{user_id}",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
|
||||
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
|
||||
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
|
||||
|
|
@ -69,7 +79,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/channels",
|
||||
"/guilds/{guild_id}/channels",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
|
||||
Some(channels) => channels.to_owned(),
|
||||
|
|
@ -95,7 +105,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id",
|
||||
"/guilds/{guild_id}/channels/{channel_id}",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
|
||||
|
|
@ -107,7 +117,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
})
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
|
||||
"/guilds/{guild_id}/channels/{channel_id}/permissions/@me",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
|
||||
if guild_id == 0 {
|
||||
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
|
||||
|
|
@ -122,16 +132,19 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
}),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
|
||||
"/guilds/{guild_id}/channels/{channel_id}/permissions/{user_id}",
|
||||
get(|| async { "todo" }),
|
||||
)
|
||||
.route(
|
||||
"/guilds/:guild_id/channels/:channel_id/last_message",
|
||||
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
|
||||
"/guilds/{guild_id}/channels/{channel_id}/last_message",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path((_guild_id, channel_id)): Path<(u64, Id<ChannelMarker>)>| async move {
|
||||
let lm = cache.get_last_message(channel_id).await;
|
||||
status_code(StatusCode::FOUND, to_string(&lm).unwrap())
|
||||
}),
|
||||
)
|
||||
|
||||
.route(
|
||||
"/guilds/:guild_id/roles",
|
||||
"/guilds/{guild_id}/roles",
|
||||
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
|
||||
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
|
||||
Some(roles) => roles.to_owned(),
|
||||
|
|
@ -171,13 +184,45 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
|
|||
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
|
||||
}))
|
||||
|
||||
.route("/runtime_config", get(|| async move {
|
||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
.route("/runtime_config/{key}", post(|Path(key): Path<String>, body: String| async move {
|
||||
let runtime_config = runtime_config_for_post;
|
||||
runtime_config.set(key, body).await.expect("failed to update runtime config");
|
||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
.route("/runtime_config/{key}", delete(|Path(key): Path<String>| async move {
|
||||
let runtime_config = runtime_config_for_delete;
|
||||
runtime_config.delete(key).await.expect("failed to update runtime config");
|
||||
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
|
||||
}))
|
||||
|
||||
.route("/await_event", post(|ConnectInfo(addr): ConnectInfo<SocketAddr>, body: String| async move {
|
||||
info!("got request: {body} from: {addr}");
|
||||
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
|
||||
return status_code(StatusCode::BAD_REQUEST, "".to_string());
|
||||
};
|
||||
|
||||
awaiter.handle_request(req, addr).await;
|
||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||
}))
|
||||
.route("/clear_awaiter", post(|| async move {
|
||||
awaiter_for_clear.clear().await;
|
||||
status_code(StatusCode::NO_CONTENT, "".to_string())
|
||||
}))
|
||||
|
||||
.route("/shard_status", get(|| async move {
|
||||
status_code(StatusCode::FOUND, to_string(&shard_state.get().await).unwrap())
|
||||
}))
|
||||
|
||||
.layer(axum::middleware::from_fn(crate::logger::logger))
|
||||
.with_state(cache);
|
||||
|
||||
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
use anyhow::format_err;
|
||||
use lazy_static::lazy_static;
|
||||
use std::sync::Arc;
|
||||
use serde::Serialize;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::sync::RwLock;
|
||||
use twilight_cache_inmemory::{
|
||||
model::CachedMember,
|
||||
|
|
@ -8,11 +9,12 @@ use twilight_cache_inmemory::{
|
|||
traits::CacheableChannel,
|
||||
InMemoryCache, ResourceType,
|
||||
};
|
||||
use twilight_gateway::Event;
|
||||
use twilight_model::{
|
||||
channel::{Channel, ChannelType},
|
||||
guild::{Guild, Member, Permissions},
|
||||
id::{
|
||||
marker::{ChannelMarker, GuildMarker, UserMarker},
|
||||
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
|
|
@ -123,16 +125,134 @@ pub fn new() -> DiscordCache {
|
|||
.build(),
|
||||
);
|
||||
|
||||
DiscordCache(cache, client, RwLock::new(Vec::new()))
|
||||
DiscordCache(
|
||||
cache,
|
||||
client,
|
||||
RwLock::new(Vec::new()),
|
||||
RwLock::new(HashMap::new()),
|
||||
)
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct CachedMessage {
|
||||
id: Id<MessageMarker>,
|
||||
referenced_message: Option<Id<MessageMarker>>,
|
||||
author_username: String,
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize)]
|
||||
pub struct LastMessageCacheEntry {
|
||||
pub current: CachedMessage,
|
||||
pub previous: Option<CachedMessage>,
|
||||
}
|
||||
|
||||
pub struct DiscordCache(
|
||||
pub Arc<InMemoryCache>,
|
||||
pub Arc<twilight_http::Client>,
|
||||
pub RwLock<Vec<u32>>,
|
||||
pub RwLock<HashMap<Id<ChannelMarker>, LastMessageCacheEntry>>,
|
||||
);
|
||||
|
||||
impl DiscordCache {
|
||||
pub async fn get_last_message(
|
||||
&self,
|
||||
channel: Id<ChannelMarker>,
|
||||
) -> Option<LastMessageCacheEntry> {
|
||||
self.3.read().await.get(&channel).cloned()
|
||||
}
|
||||
|
||||
pub async fn update(&self, event: &twilight_gateway::Event) {
|
||||
self.0.update(event);
|
||||
|
||||
match event {
|
||||
Event::MessageCreate(m) => match self.3.write().await.entry(m.channel_id) {
|
||||
std::collections::hash_map::Entry::Occupied(mut e) => {
|
||||
let cur = e.get();
|
||||
e.insert(LastMessageCacheEntry {
|
||||
current: CachedMessage {
|
||||
id: m.id,
|
||||
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
|
||||
author_username: m.author.name.clone(),
|
||||
},
|
||||
previous: Some(cur.current.clone()),
|
||||
});
|
||||
}
|
||||
std::collections::hash_map::Entry::Vacant(e) => {
|
||||
e.insert(LastMessageCacheEntry {
|
||||
current: CachedMessage {
|
||||
id: m.id,
|
||||
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
|
||||
author_username: m.author.name.clone(),
|
||||
},
|
||||
previous: None,
|
||||
});
|
||||
}
|
||||
},
|
||||
Event::MessageDelete(m) => {
|
||||
self.handle_message_deletion(m.channel_id, vec![m.id]).await;
|
||||
}
|
||||
Event::MessageDeleteBulk(m) => {
|
||||
self.handle_message_deletion(m.channel_id, m.ids.clone())
|
||||
.await;
|
||||
}
|
||||
_ => {}
|
||||
};
|
||||
}
|
||||
|
||||
async fn handle_message_deletion(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
mids: Vec<Id<MessageMarker>>,
|
||||
) {
|
||||
let mut lm = self.3.write().await;
|
||||
|
||||
let Some(entry) = lm.get(&channel_id) else {
|
||||
return;
|
||||
};
|
||||
|
||||
let mut entry = entry.clone();
|
||||
|
||||
// if none of the deleted messages are relevant, just return
|
||||
if !mids.contains(&entry.current.id)
|
||||
&& entry
|
||||
.previous
|
||||
.clone()
|
||||
.map(|v| !mids.contains(&v.id))
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return;
|
||||
}
|
||||
|
||||
// remove "previous" entry if it was deleted
|
||||
if let Some(prev) = entry.previous.clone()
|
||||
&& mids.contains(&prev.id)
|
||||
{
|
||||
entry.previous = None;
|
||||
}
|
||||
|
||||
// set "current" entry to "previous" if current entry was deleted
|
||||
// (if the "previous" entry still exists, it was not deleted)
|
||||
if let Some(prev) = entry.previous.clone()
|
||||
&& mids.contains(&entry.current.id)
|
||||
{
|
||||
entry.current = prev;
|
||||
entry.previous = None;
|
||||
}
|
||||
|
||||
// if the current entry was already deleted, but previous wasn't,
|
||||
// we would've set current to previous
|
||||
// so if current is deleted this means both current and previous have
|
||||
// been deleted
|
||||
// so just drop the cache entry here
|
||||
if mids.contains(&entry.current.id) && entry.previous.is_none() {
|
||||
lm.remove(&channel_id);
|
||||
return;
|
||||
}
|
||||
|
||||
// ok, update the entry
|
||||
lm.insert(channel_id, entry.clone());
|
||||
}
|
||||
|
||||
pub async fn guild_permissions(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
|
|
@ -356,12 +476,14 @@ impl DiscordCache {
|
|||
system_channel_flags: guild.system_channel_flags(),
|
||||
system_channel_id: guild.system_channel_id(),
|
||||
threads: vec![],
|
||||
unavailable: false,
|
||||
unavailable: Some(false),
|
||||
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
|
||||
verification_level: guild.verification_level(),
|
||||
voice_states: vec![],
|
||||
widget_channel_id: guild.widget_channel_id(),
|
||||
widget_enabled: guild.widget_enabled(),
|
||||
guild_scheduled_events: guild.guild_scheduled_events().to_vec(),
|
||||
max_stage_video_channel_users: guild.max_stage_video_channel_users(),
|
||||
}
|
||||
})
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
use anyhow::anyhow;
|
||||
use futures::StreamExt;
|
||||
use libpk::_config::ClusterSettings;
|
||||
use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig, state::ShardStateEvent};
|
||||
use metrics::counter;
|
||||
use std::sync::{mpsc::Sender, Arc};
|
||||
use std::sync::Arc;
|
||||
use tokio::sync::mpsc::Sender;
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_gateway::{
|
||||
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
|
||||
|
|
@ -12,9 +14,12 @@ use twilight_model::gateway::{
|
|||
Intents,
|
||||
};
|
||||
|
||||
use crate::discord::identify_queue::{self, RedisQueue};
|
||||
use crate::{
|
||||
discord::identify_queue::{self, RedisQueue},
|
||||
RUNTIME_CONFIG_KEY_EVENT_TARGET,
|
||||
};
|
||||
|
||||
use super::{cache::DiscordCache, shard_state::ShardStateManager};
|
||||
use super::cache::DiscordCache;
|
||||
|
||||
pub fn cluster_config() -> ClusterSettings {
|
||||
libpk::config
|
||||
|
|
@ -44,6 +49,12 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
|
||||
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
|
||||
warn!("we have less than 16 shards, assuming single gateway process");
|
||||
if cluster_settings.node_id != 0 {
|
||||
return Err(anyhow!(
|
||||
"expecting to be node 0 in single-process mode, but we are node {}",
|
||||
cluster_settings.node_id
|
||||
));
|
||||
}
|
||||
(0, (cluster_settings.total_shards - 1).into())
|
||||
} else {
|
||||
(
|
||||
|
|
@ -52,6 +63,13 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
)
|
||||
};
|
||||
|
||||
let prefix = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_prefix_for_gateway
|
||||
.clone();
|
||||
|
||||
let shards = create_iterator(
|
||||
start_shard..end_shard + 1,
|
||||
cluster_settings.total_shards,
|
||||
|
|
@ -64,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
.to_owned(),
|
||||
intents,
|
||||
)
|
||||
.presence(presence("pk;help", false))
|
||||
.presence(presence(format!("{prefix}help").as_str(), false))
|
||||
.queue(queue.clone())
|
||||
.build(),
|
||||
|_, builder| builder.build(),
|
||||
|
|
@ -76,15 +94,23 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
|
|||
Ok(shards_vec)
|
||||
}
|
||||
|
||||
#[tracing::instrument(fields(shard = %shard.id()), skip_all)]
|
||||
pub async fn runner(
|
||||
mut shard: Shard<RedisQueue>,
|
||||
_tx: Sender<(ShardId, String)>,
|
||||
shard_state: ShardStateManager,
|
||||
tx: Sender<(ShardId, Event, String)>,
|
||||
tx_state: Sender<(ShardId, ShardStateEvent, Option<Event>, Option<i32>)>,
|
||||
cache: Arc<DiscordCache>,
|
||||
runtime_config: Arc<RuntimeConfig>,
|
||||
) {
|
||||
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
|
||||
let shard_id = shard.id().number();
|
||||
|
||||
let our_user_id = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.client_id;
|
||||
|
||||
info!("waiting for events");
|
||||
while let Some(item) = shard.next().await {
|
||||
let raw_event = match item {
|
||||
|
|
@ -105,7 +131,9 @@ pub async fn runner(
|
|||
)
|
||||
.increment(1);
|
||||
|
||||
if let Err(error) = shard_state.socket_closed(shard_id).await {
|
||||
if let Err(error) =
|
||||
tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None))
|
||||
{
|
||||
error!("failed to update shard state for socket closure: {error}");
|
||||
}
|
||||
|
||||
|
|
@ -127,7 +155,7 @@ pub async fn runner(
|
|||
continue;
|
||||
}
|
||||
Err(error) => {
|
||||
error!("shard {shard_id} failed to parse gateway event: {error}");
|
||||
error!(?error, ?shard_id, "failed to parse gateway event");
|
||||
continue;
|
||||
}
|
||||
};
|
||||
|
|
@ -147,14 +175,31 @@ pub async fn runner(
|
|||
.increment(1);
|
||||
|
||||
// update shard state and discord cache
|
||||
if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await {
|
||||
tracing::error!(?error, "error updating redis state");
|
||||
if matches!(event, Event::Ready(_)) || matches!(event, Event::Resumed) {
|
||||
if let Err(error) = tx_state.try_send((
|
||||
shard.id(),
|
||||
ShardStateEvent::Other,
|
||||
Some(event.clone()),
|
||||
None,
|
||||
)) {
|
||||
tracing::error!(?error, "error updating shard state");
|
||||
}
|
||||
}
|
||||
// need to do heartbeat separately, to get the latency
|
||||
let latency_num = shard
|
||||
.latency()
|
||||
.recent()
|
||||
.first()
|
||||
.map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
if let Event::GatewayHeartbeatAck = event
|
||||
&& let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await
|
||||
&& let Err(error) = tx_state.try_send((
|
||||
shard.id(),
|
||||
ShardStateEvent::Heartbeat,
|
||||
Some(event.clone()),
|
||||
Some(latency_num),
|
||||
))
|
||||
{
|
||||
tracing::error!(?error, "error updating redis state for latency");
|
||||
tracing::error!(?error, "error updating shard state for latency");
|
||||
}
|
||||
|
||||
if let Event::Ready(_) = event {
|
||||
|
|
@ -162,10 +207,28 @@ pub async fn runner(
|
|||
cache.2.write().await.push(shard_id);
|
||||
}
|
||||
}
|
||||
cache.0.update(&event);
|
||||
cache.update(&event).await;
|
||||
|
||||
// okay, we've handled the event internally, let's send it to consumers
|
||||
// tx.send((shard.id(), raw_event)).unwrap();
|
||||
|
||||
// some basic filtering here is useful
|
||||
// we can't use if matching using the | operator, so anything matched does nothing
|
||||
// and the default match skips the next block (continues to the next event)
|
||||
match event {
|
||||
Event::InteractionCreate(_) => {}
|
||||
Event::MessageCreate(ref m) if m.author.id != our_user_id => {}
|
||||
Event::MessageUpdate(ref m) if m.author.id != our_user_id && !m.author.bot => {}
|
||||
Event::MessageDelete(_) => {}
|
||||
Event::MessageDeleteBulk(_) => {}
|
||||
Event::ReactionAdd(ref r) if r.user_id != our_user_id => {}
|
||||
_ => {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
|
||||
tx.send((shard.id(), event, raw_event)).await.unwrap();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -78,8 +78,8 @@ async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: on
|
|||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
Err(error) => {
|
||||
error!(?error, ?shard_id, ?bucket, "error getting shard allowance")
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1,49 +1,63 @@
|
|||
use fred::{clients::RedisPool, interfaces::HashesInterface};
|
||||
use metrics::{counter, gauge};
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
use twilight_gateway::{Event, Latency};
|
||||
use twilight_gateway::Event;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use libpk::state::ShardState;
|
||||
|
||||
#[derive(Clone)]
|
||||
use super::gateway::cluster_config;
|
||||
|
||||
pub struct ShardStateManager {
|
||||
redis: RedisPool,
|
||||
shards: RwLock<HashMap<u32, ShardState>>,
|
||||
}
|
||||
|
||||
pub fn new(redis: RedisPool) -> ShardStateManager {
|
||||
ShardStateManager { redis }
|
||||
ShardStateManager {
|
||||
redis: redis,
|
||||
shards: RwLock::new(HashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
impl ShardStateManager {
|
||||
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
|
||||
match event {
|
||||
// also update gateway.rs with event types
|
||||
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
|
||||
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
|
||||
_ => Ok(()),
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
|
||||
let data: Option<String> = self.redis.hget("pluralkit:shardstatus", shard_id).await?;
|
||||
match data {
|
||||
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
|
||||
None => Ok(ShardState::default()),
|
||||
async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> {
|
||||
{
|
||||
let mut shards = self.shards.write().await;
|
||||
shards.insert(id, state.clone());
|
||||
}
|
||||
}
|
||||
|
||||
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset::<(), &str, (String, String)>(
|
||||
"pluralkit:shardstatus",
|
||||
(
|
||||
shard_id.to_string(),
|
||||
serde_json::to_string(&info).expect("could not serialize shard"),
|
||||
id.to_string(),
|
||||
serde_json::to_string(&state).expect("could not serialize shard"),
|
||||
),
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_shard(&self, id: u32) -> Option<ShardState> {
|
||||
let shards = self.shards.read().await;
|
||||
shards.get(&id).cloned()
|
||||
}
|
||||
|
||||
pub async fn get(&self) -> Vec<ShardState> {
|
||||
self.shards.read().await.values().cloned().collect()
|
||||
}
|
||||
|
||||
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
|
||||
info!(
|
||||
"shard {} {}",
|
||||
|
|
@ -57,32 +71,52 @@ impl ShardStateManager {
|
|||
)
|
||||
.increment(1);
|
||||
gauge!("pluralkit_gateway_shard_up").increment(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
|
||||
let mut info = self
|
||||
.get_shard(shard_id)
|
||||
.await
|
||||
.unwrap_or(ShardState::default());
|
||||
|
||||
info.shard_id = shard_id as i32;
|
||||
info.cluster_id = Some(cluster_config().node_id as i32);
|
||||
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.up = true;
|
||||
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
|
||||
gauge!("pluralkit_gateway_shard_up").decrement(1);
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
|
||||
let mut info = self
|
||||
.get_shard(shard_id)
|
||||
.await
|
||||
.unwrap_or(ShardState::default());
|
||||
|
||||
info.shard_id = shard_id as i32;
|
||||
info.cluster_id = Some(cluster_config().node_id as i32);
|
||||
info.up = false;
|
||||
info.disconnection_count += 1;
|
||||
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
|
||||
let mut info = self.get_shard(shard_id).await?;
|
||||
pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
|
||||
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
|
||||
|
||||
let mut info = self
|
||||
.get_shard(shard_id)
|
||||
.await
|
||||
.unwrap_or(ShardState::default());
|
||||
|
||||
info.shard_id = shard_id as i32;
|
||||
info.cluster_id = Some(cluster_config().node_id as i32);
|
||||
info.up = true;
|
||||
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
|
||||
info.latency = latency
|
||||
.recent()
|
||||
.first()
|
||||
.map_or_else(|| 0, |d| d.as_millis()) as i32;
|
||||
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string())
|
||||
.set(info.latency);
|
||||
info.latency = latency;
|
||||
|
||||
self.save_shard(shard_id, info).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
|||
242
crates/gateway/src/event_awaiter.rs
Normal file
242
crates/gateway/src/event_awaiter.rs
Normal file
|
|
@ -0,0 +1,242 @@
|
|||
// - reaction: (message_id, user_id)
|
||||
// - message: (author_id, channel_id, ?options)
|
||||
// - interaction: (custom_id where not_includes "help-menu")
|
||||
|
||||
use std::{
|
||||
collections::{hash_map::Entry, HashMap},
|
||||
net::{IpAddr, SocketAddr},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use serde::Deserialize;
|
||||
use tokio::{sync::RwLock, time::Instant};
|
||||
use tracing::info;
|
||||
use twilight_gateway::Event;
|
||||
use twilight_model::{
|
||||
application::interaction::InteractionData,
|
||||
id::{
|
||||
marker::{ChannelMarker, MessageMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
};
|
||||
|
||||
static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15);
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(untagged)]
|
||||
pub enum AwaitEventRequest {
|
||||
Reaction {
|
||||
message_id: Id<MessageMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
Message {
|
||||
channel_id: Id<ChannelMarker>,
|
||||
author_id: Id<UserMarker>,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
options: Option<Vec<String>>,
|
||||
},
|
||||
Interaction {
|
||||
id: String,
|
||||
target: String,
|
||||
timeout: Option<u64>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct EventAwaiter {
|
||||
reactions: RwLock<HashMap<(Id<MessageMarker>, Id<UserMarker>), (Instant, String)>>,
|
||||
messages: RwLock<
|
||||
HashMap<(Id<ChannelMarker>, Id<UserMarker>), (Instant, String, Option<Vec<String>>)>,
|
||||
>,
|
||||
interactions: RwLock<HashMap<String, (Instant, String)>>,
|
||||
}
|
||||
|
||||
impl EventAwaiter {
|
||||
pub fn new() -> Self {
|
||||
let v = Self {
|
||||
reactions: RwLock::new(HashMap::new()),
|
||||
messages: RwLock::new(HashMap::new()),
|
||||
interactions: RwLock::new(HashMap::new()),
|
||||
};
|
||||
|
||||
v
|
||||
}
|
||||
|
||||
pub async fn cleanup_loop(&self) {
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(30)).await;
|
||||
info!("running event_awaiter cleanup loop");
|
||||
let mut counts = (0, 0, 0);
|
||||
let now = Instant::now();
|
||||
{
|
||||
let mut reactions = self.reactions.write().await;
|
||||
for key in reactions.clone().keys() {
|
||||
if let Entry::Occupied(entry) = reactions.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.0 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut messages = self.messages.write().await;
|
||||
for key in messages.clone().keys() {
|
||||
if let Entry::Occupied(entry) = messages.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.1 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
{
|
||||
let mut interactions = self.interactions.write().await;
|
||||
for key in interactions.clone().keys() {
|
||||
if let Entry::Occupied(entry) = interactions.entry(key.clone())
|
||||
&& entry.get().0 < now
|
||||
{
|
||||
counts.2 += 1;
|
||||
entry.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn target_for_event(&self, event: Event) -> Option<String> {
|
||||
match event {
|
||||
Event::MessageCreate(message) => {
|
||||
let mut messages = self.messages.write().await;
|
||||
|
||||
messages
|
||||
.remove(&(message.channel_id, message.author.id))
|
||||
.map(|(timeout, target, options)| {
|
||||
if let Some(options) = options
|
||||
&& !options.contains(&message.content.to_lowercase())
|
||||
{
|
||||
messages.insert(
|
||||
(message.channel_id, message.author.id),
|
||||
(timeout, target, Some(options)),
|
||||
);
|
||||
return None;
|
||||
}
|
||||
Some((*target).to_string())
|
||||
})?
|
||||
}
|
||||
Event::ReactionAdd(reaction)
|
||||
if let Some((_, target)) = self
|
||||
.reactions
|
||||
.write()
|
||||
.await
|
||||
.remove(&(reaction.message_id, reaction.user_id)) =>
|
||||
{
|
||||
Some((*target).to_string())
|
||||
}
|
||||
Event::InteractionCreate(interaction)
|
||||
if let Some(data) = interaction.data.clone()
|
||||
&& let InteractionData::MessageComponent(component) = data
|
||||
&& !component.custom_id.contains("help-menu")
|
||||
&& let Some((_, target)) =
|
||||
self.interactions.write().await.remove(&component.custom_id) =>
|
||||
{
|
||||
Some((*target).to_string())
|
||||
}
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn handle_request(&self, req: AwaitEventRequest, addr: SocketAddr) {
|
||||
match req {
|
||||
AwaitEventRequest::Reaction {
|
||||
message_id,
|
||||
user_id,
|
||||
target,
|
||||
timeout,
|
||||
} => {
|
||||
self.reactions.write().await.insert(
|
||||
(message_id, user_id),
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target_or_addr(target, addr),
|
||||
),
|
||||
);
|
||||
}
|
||||
AwaitEventRequest::Message {
|
||||
channel_id,
|
||||
author_id,
|
||||
target,
|
||||
timeout,
|
||||
options,
|
||||
} => {
|
||||
self.messages.write().await.insert(
|
||||
(channel_id, author_id),
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target_or_addr(target, addr),
|
||||
options,
|
||||
),
|
||||
);
|
||||
}
|
||||
AwaitEventRequest::Interaction {
|
||||
id,
|
||||
target,
|
||||
timeout,
|
||||
} => {
|
||||
self.interactions.write().await.insert(
|
||||
id,
|
||||
(
|
||||
Instant::now()
|
||||
.checked_add(
|
||||
timeout
|
||||
.map(|i| Duration::from_secs(i))
|
||||
.unwrap_or(DEFAULT_TIMEOUT),
|
||||
)
|
||||
.expect("invalid time"),
|
||||
target_or_addr(target, addr),
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn clear(&self) {
|
||||
self.reactions.write().await.clear();
|
||||
self.messages.write().await.clear();
|
||||
self.interactions.write().await.clear();
|
||||
}
|
||||
}
|
||||
|
||||
fn target_or_addr(target: String, addr: SocketAddr) -> String {
|
||||
if target == "source-addr" {
|
||||
let ip_str = match addr.ip() {
|
||||
IpAddr::V4(v4) => v4.to_string(),
|
||||
IpAddr::V6(v6) => {
|
||||
if let Some(v4) = v6.to_ipv4_mapped() {
|
||||
v4.to_string()
|
||||
} else {
|
||||
format!("[{v6}]")
|
||||
}
|
||||
}
|
||||
};
|
||||
format!("http://{ip_str}:5002/events")
|
||||
} else {
|
||||
target
|
||||
}
|
||||
}
|
||||
|
|
@ -1,39 +1,79 @@
|
|||
#![feature(let_chains)]
|
||||
#![feature(if_let_guard)]
|
||||
#![feature(duration_constructors)]
|
||||
|
||||
use chrono::Timelike;
|
||||
use discord::gateway::cluster_config;
|
||||
use event_awaiter::EventAwaiter;
|
||||
use fred::{clients::RedisPool, interfaces::*};
|
||||
use signal_hook::{
|
||||
consts::{SIGINT, SIGTERM},
|
||||
iterator::Signals,
|
||||
use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
|
||||
use reqwest::{ClientBuilder, StatusCode};
|
||||
use std::{sync::Arc, time::Duration, vec::Vec};
|
||||
use tokio::{
|
||||
signal::unix::{signal, SignalKind},
|
||||
sync::mpsc::channel,
|
||||
task::JoinSet,
|
||||
};
|
||||
use std::{
|
||||
sync::{mpsc::channel, Arc},
|
||||
time::Duration,
|
||||
vec::Vec,
|
||||
};
|
||||
use tokio::task::JoinSet;
|
||||
use tracing::{info, warn};
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_gateway::{MessageSender, ShardId};
|
||||
use twilight_model::gateway::payload::outgoing::UpdatePresence;
|
||||
|
||||
mod cache_api;
|
||||
mod api;
|
||||
mod discord;
|
||||
mod event_awaiter;
|
||||
mod logger;
|
||||
|
||||
libpk::main!("gateway");
|
||||
async fn real_main() -> anyhow::Result<()> {
|
||||
let (shutdown_tx, shutdown_rx) = channel::<()>();
|
||||
let shutdown_tx = Arc::new(shutdown_tx);
|
||||
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
|
||||
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let shard_state = discord::shard_state::new(redis.clone());
|
||||
let runtime_config = Arc::new(
|
||||
RuntimeConfig::new(
|
||||
redis.clone(),
|
||||
format!(
|
||||
"{}:{}",
|
||||
libpk::config.runtime_config_key.as_ref().unwrap(),
|
||||
cluster_config().node_id
|
||||
),
|
||||
)
|
||||
.await?,
|
||||
);
|
||||
|
||||
// hacky, but needed for selfhost for now
|
||||
if let Some(target) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.gateway_target
|
||||
.clone()
|
||||
{
|
||||
runtime_config
|
||||
.set(RUNTIME_CONFIG_KEY_EVENT_TARGET.to_string(), target)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let cache = Arc::new(discord::cache::new());
|
||||
let awaiter = Arc::new(EventAwaiter::new());
|
||||
tokio::spawn({
|
||||
let awaiter = awaiter.clone();
|
||||
async move { awaiter.cleanup_loop().await }
|
||||
});
|
||||
|
||||
let shards = discord::gateway::create_shards(redis.clone())?;
|
||||
|
||||
let (event_tx, _event_rx) = channel();
|
||||
// arbitrary
|
||||
// todo: make sure this doesn't fill up
|
||||
let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000);
|
||||
|
||||
// todo: make sure this doesn't fill up
|
||||
let (state_tx, mut state_rx) = channel::<(
|
||||
ShardId,
|
||||
ShardStateEvent,
|
||||
Option<twilight_gateway::Event>,
|
||||
Option<i32>,
|
||||
)>(1000);
|
||||
|
||||
let mut senders = Vec::new();
|
||||
let mut signal_senders = Vec::new();
|
||||
|
|
@ -45,61 +85,160 @@ async fn real_main() -> anyhow::Result<()> {
|
|||
set.spawn(tokio::spawn(discord::gateway::runner(
|
||||
shard,
|
||||
event_tx.clone(),
|
||||
shard_state.clone(),
|
||||
state_tx.clone(),
|
||||
cache.clone(),
|
||||
runtime_config.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
let shard_state = Arc::new(discord::shard_state::new(redis.clone()));
|
||||
|
||||
set.spawn(tokio::spawn({
|
||||
let shard_state = shard_state.clone();
|
||||
|
||||
async move {
|
||||
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
|
||||
match state_event {
|
||||
ShardStateEvent::Heartbeat => {
|
||||
if !latency.is_none()
|
||||
&& let Err(error) = shard_state
|
||||
.heartbeated(shard_id.number(), latency.unwrap())
|
||||
.await
|
||||
{
|
||||
error!("failed to update shard state for heartbeat: {error}")
|
||||
};
|
||||
}
|
||||
ShardStateEvent::Closed => {
|
||||
if let Err(error) = shard_state.socket_closed(shard_id.number()).await {
|
||||
error!("failed to update shard state for heartbeat: {error}")
|
||||
};
|
||||
}
|
||||
ShardStateEvent::Other => {
|
||||
if let Err(error) = shard_state
|
||||
.handle_event(
|
||||
shard_id.number(),
|
||||
parsed_event.expect("shard state event not provided!"),
|
||||
)
|
||||
.await
|
||||
{
|
||||
error!("failed to update shard state for heartbeat: {error}")
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
set.spawn(tokio::spawn({
|
||||
let runtime_config = runtime_config.clone();
|
||||
let awaiter = awaiter.clone();
|
||||
|
||||
async move {
|
||||
let client = Arc::new(
|
||||
ClientBuilder::new()
|
||||
.connect_timeout(Duration::from_secs(1))
|
||||
.timeout(Duration::from_secs(1))
|
||||
.build()
|
||||
.expect("error making client"),
|
||||
);
|
||||
|
||||
while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await {
|
||||
let target = if let Some(target) = awaiter.target_for_event(parsed_event).await {
|
||||
info!(target = ?target, "sending event to awaiter");
|
||||
Some(target)
|
||||
} else if let Some(target) =
|
||||
runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await
|
||||
{
|
||||
Some(target)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(target) = target {
|
||||
tokio::spawn({
|
||||
let client = client.clone();
|
||||
async move {
|
||||
match client
|
||||
.post(format!("{target}/{}", shard_id.number()))
|
||||
.body(raw_event)
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) => {
|
||||
if res.status() != StatusCode::OK {
|
||||
error!(
|
||||
status = ?res.status(),
|
||||
target = ?target,
|
||||
"got non-200 from bot while sending event",
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, "failed to request event target");
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}));
|
||||
|
||||
set.spawn(tokio::spawn(
|
||||
async move { scheduled_task(redis, senders).await },
|
||||
));
|
||||
|
||||
// todo: probably don't do it this way
|
||||
let api_shutdown_tx = shutdown_tx.clone();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
match cache_api::run_server(cache).await {
|
||||
match api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await {
|
||||
Err(error) => {
|
||||
tracing::error!(?error, "failed to serve cache api");
|
||||
let _ = api_shutdown_tx.send(());
|
||||
error!(?error, "failed to serve cache api");
|
||||
}
|
||||
_ => unreachable!(),
|
||||
}
|
||||
}));
|
||||
|
||||
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
|
||||
|
||||
set.spawn(tokio::spawn(async move {
|
||||
for sig in signals.forever() {
|
||||
info!("received signal {:?}", sig);
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
let _ = shutdown_tx.send(());
|
||||
break;
|
||||
}
|
||||
signal(SignalKind::interrupt()).unwrap().recv().await;
|
||||
info!("got SIGINT");
|
||||
}));
|
||||
|
||||
let _ = shutdown_rx.recv();
|
||||
set.spawn(tokio::spawn(async move {
|
||||
signal(SignalKind::terminate()).unwrap().recv().await;
|
||||
info!("got SIGTERM");
|
||||
}));
|
||||
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
set.join_next().await;
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
|
||||
let presence = UpdatePresence {
|
||||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence("Restarting... (please wait)", true),
|
||||
};
|
||||
|
||||
for sender in signal_senders.iter() {
|
||||
let presence = presence.clone();
|
||||
let _ = sender.command(&presence);
|
||||
}
|
||||
|
||||
set.abort_all();
|
||||
|
||||
info!("gateway exiting, have a nice day!");
|
||||
// sleep 500ms to allow everything to clean up properly
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
|
||||
let prefix = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_prefix_for_gateway
|
||||
.clone();
|
||||
|
||||
println!("{prefix}");
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(Duration::from_secs(
|
||||
(60 - chrono::offset::Utc::now().second()).into(),
|
||||
|
|
@ -119,9 +258,9 @@ async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>
|
|||
op: twilight_model::gateway::OpCode::PresenceUpdate,
|
||||
d: discord::gateway::presence(
|
||||
if let Some(status) = status {
|
||||
format!("pk;help | {}", status)
|
||||
format!("{prefix}help | {status}")
|
||||
} else {
|
||||
"pk;help".to_string()
|
||||
format!("{prefix}help")
|
||||
}
|
||||
.as_str(),
|
||||
false,
|
||||
|
|
|
|||
15
crates/gdpr_worker/Cargo.toml
Normal file
15
crates/gdpr_worker/Cargo.toml
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
[package]
|
||||
name = "gdpr_worker"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
libpk = { path = "../libpk" }
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
twilight-model = { workspace = true }
|
||||
149
crates/gdpr_worker/src/main.rs
Normal file
149
crates/gdpr_worker/src/main.rs
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use sqlx::prelude::FromRow;
|
||||
use std::{sync::Arc, time::Duration};
|
||||
use tracing::{error, info, warn};
|
||||
use twilight_http::api_error::{ApiError, GeneralApiError};
|
||||
use twilight_model::id::{
|
||||
marker::{ChannelMarker, MessageMarker},
|
||||
Id,
|
||||
};
|
||||
|
||||
// create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null);
|
||||
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_messages_db().await?;
|
||||
|
||||
let mut client_builder = twilight_http::Client::builder()
|
||||
.token(
|
||||
libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.bot_token
|
||||
.clone(),
|
||||
)
|
||||
.timeout(Duration::from_secs(30));
|
||||
|
||||
if let Some(base_url) = libpk::config
|
||||
.discord
|
||||
.as_ref()
|
||||
.expect("missing discord config")
|
||||
.api_base_url
|
||||
.clone()
|
||||
{
|
||||
client_builder = client_builder.proxy(base_url, true).ratelimiter(None);
|
||||
}
|
||||
|
||||
let client = Arc::new(client_builder.build());
|
||||
|
||||
loop {
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
|
||||
match run_job(db.clone(), client.clone()).await {
|
||||
Ok(()) => {}
|
||||
Err(error) => {
|
||||
error!(?error, "failed to run messages gdpr job");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(FromRow)]
|
||||
struct GdprJobEntry {
|
||||
mid: i64,
|
||||
channel_id: i64,
|
||||
}
|
||||
|
||||
async fn run_job(pool: sqlx::PgPool, discord: Arc<twilight_http::Client>) -> anyhow::Result<()> {
|
||||
let mut tx = pool.begin().await?;
|
||||
|
||||
let message: Option<GdprJobEntry> = sqlx::query_as(
|
||||
"select mid, channel_id from messages_gdpr_jobs for update skip locked limit 1;",
|
||||
)
|
||||
.fetch_optional(&mut *tx)
|
||||
.await?;
|
||||
|
||||
let Some(message) = message else {
|
||||
info!("no job to run, sleeping for 1 minute");
|
||||
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
|
||||
return Ok(());
|
||||
};
|
||||
|
||||
info!("got mid={}, cleaning up...", message.mid);
|
||||
|
||||
// naively delete message on discord's end
|
||||
let res = discord
|
||||
.delete_message(
|
||||
Id::<ChannelMarker>::new(message.channel_id as u64),
|
||||
Id::<MessageMarker>::new(message.mid as u64),
|
||||
)
|
||||
.await;
|
||||
|
||||
if res.is_ok() {
|
||||
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
|
||||
.bind(message.mid)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
|
||||
if let Err(err) = res {
|
||||
if let twilight_http::error::ErrorType::Response { error, status, .. } = err.kind()
|
||||
&& let ApiError::General(GeneralApiError { code, .. }) = error
|
||||
{
|
||||
match (status.get(), code) {
|
||||
(403, _) => {
|
||||
warn!(
|
||||
"got 403 while deleting message in channel {}, failing fast",
|
||||
message.channel_id
|
||||
);
|
||||
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
|
||||
.bind(message.channel_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
(_, 10003) => {
|
||||
warn!(
|
||||
"deleting message in channel {}: channel not found, failing fast",
|
||||
message.channel_id
|
||||
);
|
||||
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
|
||||
.bind(message.channel_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
(_, 10008) => {
|
||||
warn!("deleting message {}: message not found", message.mid);
|
||||
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
|
||||
.bind(message.mid)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
(_, 50083) => {
|
||||
warn!(
|
||||
"could not delete message in thread {}: thread is archived, failing fast",
|
||||
message.channel_id
|
||||
);
|
||||
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
|
||||
.bind(message.channel_id)
|
||||
.execute(&mut *tx)
|
||||
.await?;
|
||||
}
|
||||
_ => {
|
||||
error!(
|
||||
?status,
|
||||
?code,
|
||||
message_id = message.mid,
|
||||
"got unknown error deleting message",
|
||||
);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
return Err(err.into());
|
||||
}
|
||||
}
|
||||
|
||||
tx.commit().await?;
|
||||
|
||||
return Ok(());
|
||||
}
|
||||
0
crates/h
Normal file
0
crates/h
Normal file
|
|
@ -8,6 +8,7 @@ anyhow = { workspace = true }
|
|||
fred = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
pk_macros = { path = "../macros" }
|
||||
sentry = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -12,10 +12,16 @@ pub struct ClusterSettings {
|
|||
pub total_nodes: u32,
|
||||
}
|
||||
|
||||
fn _default_bot_prefix() -> String {
|
||||
"pk;".to_string()
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct DiscordConfig {
|
||||
pub client_id: Id<UserMarker>,
|
||||
pub bot_token: String,
|
||||
#[serde(default = "_default_bot_prefix")]
|
||||
pub bot_prefix_for_gateway: String,
|
||||
pub client_secret: String,
|
||||
pub max_concurrency: u32,
|
||||
#[serde(default)]
|
||||
|
|
@ -24,6 +30,9 @@ pub struct DiscordConfig {
|
|||
|
||||
#[serde(default = "_default_api_addr")]
|
||||
pub cache_api_addr: String,
|
||||
|
||||
#[serde(default)]
|
||||
pub gateway_target: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
|
|
@ -85,6 +94,7 @@ pub struct ScheduledTasksConfig {
|
|||
pub set_guild_count: bool,
|
||||
pub expected_gateway_count: usize,
|
||||
pub gateway_url: String,
|
||||
pub prometheus_url: String,
|
||||
}
|
||||
|
||||
fn _metrics_default() -> bool {
|
||||
|
|
@ -113,6 +123,9 @@ pub struct PKConfig {
|
|||
#[serde(default = "_json_log_default")]
|
||||
pub(crate) json_log: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub runtime_config_key: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub sentry_url: Option<String>,
|
||||
}
|
||||
|
|
@ -132,10 +145,15 @@ impl PKConfig {
|
|||
lazy_static! {
|
||||
#[derive(Debug)]
|
||||
pub static ref CONFIG: Arc<PKConfig> = {
|
||||
// hacks
|
||||
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
|
||||
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
|
||||
std::env::set_var("pluralkit__discord__cluster__node_id", var);
|
||||
}
|
||||
if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX")
|
||||
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
|
||||
std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap());
|
||||
}
|
||||
|
||||
Arc::new(Config::builder()
|
||||
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
|
||||
|
|
|
|||
|
|
@ -8,12 +8,15 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
|
|||
use sentry_tracing::event_from_event;
|
||||
|
||||
pub mod db;
|
||||
pub mod runtime_config;
|
||||
pub mod state;
|
||||
|
||||
pub mod _config;
|
||||
pub use crate::_config::CONFIG as config;
|
||||
|
||||
// functions in this file are only used by the main function below
|
||||
// functions in this file are only used by the main function in macros/entrypoint.rs
|
||||
|
||||
pub use pk_macros::main;
|
||||
|
||||
pub fn init_logging(component: &str) {
|
||||
let sentry_layer =
|
||||
|
|
@ -42,6 +45,7 @@ pub fn init_logging(component: &str) {
|
|||
tracing_subscriber::registry()
|
||||
.with(sentry_layer)
|
||||
.with(tracing_subscriber::fmt::layer())
|
||||
.with(EnvFilter::from_default_env())
|
||||
.init();
|
||||
}
|
||||
}
|
||||
|
|
@ -66,28 +70,3 @@ pub fn init_sentry() -> sentry::ClientInitGuard {
|
|||
..Default::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! main {
|
||||
($component:expr) => {
|
||||
fn main() -> anyhow::Result<()> {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
// we might also be able to use env!("CARGO_CRATE_NAME") here
|
||||
libpk::init_logging($component);
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
if let Err(err) = libpk::init_metrics() {
|
||||
tracing::error!("failed to init metrics collector: {err}");
|
||||
};
|
||||
tracing::info!("hello world");
|
||||
if let Err(err) = real_main().await {
|
||||
tracing::error!("failed to run service: {err}");
|
||||
};
|
||||
});
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
}
|
||||
|
|
|
|||
72
crates/libpk/src/runtime_config.rs
Normal file
72
crates/libpk/src/runtime_config.rs
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
use fred::{clients::RedisPool, interfaces::HashesInterface};
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::RwLock;
|
||||
use tracing::info;
|
||||
|
||||
pub struct RuntimeConfig {
|
||||
redis: RedisPool,
|
||||
settings: RwLock<HashMap<String, String>>,
|
||||
redis_key: String,
|
||||
}
|
||||
|
||||
impl RuntimeConfig {
|
||||
pub async fn new(redis: RedisPool, component_key: String) -> anyhow::Result<Self> {
|
||||
let redis_key = format!("remote_config:{component_key}");
|
||||
|
||||
let mut c = RuntimeConfig {
|
||||
redis,
|
||||
settings: RwLock::new(HashMap::new()),
|
||||
redis_key,
|
||||
};
|
||||
|
||||
c.load().await?;
|
||||
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
pub async fn load(&mut self) -> anyhow::Result<()> {
|
||||
let redis_config: HashMap<String, String> = self.redis.hgetall(&self.redis_key).await?;
|
||||
|
||||
let mut settings = self.settings.write().await;
|
||||
|
||||
for (key, value) in redis_config {
|
||||
settings.insert(key, value);
|
||||
}
|
||||
|
||||
info!("starting with runtime config: {:?}", settings);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set(&self, key: String, value: String) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hset::<(), &str, (String, String)>(&self.redis_key, (key.clone(), value.clone()))
|
||||
.await?;
|
||||
self.settings
|
||||
.write()
|
||||
.await
|
||||
.insert(key.clone(), value.clone());
|
||||
info!("updated runtime config: {key}={value}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete(&self, key: String) -> anyhow::Result<()> {
|
||||
self.redis
|
||||
.hdel::<(), &str, String>(&self.redis_key, key.clone())
|
||||
.await?;
|
||||
self.settings.write().await.remove(&key.clone());
|
||||
info!("updated runtime config: {key} removed");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get(&self, key: &str) -> Option<String> {
|
||||
self.settings.read().await.get(key).cloned()
|
||||
}
|
||||
|
||||
pub async fn exists(&self, key: &str) -> bool {
|
||||
self.settings.read().await.contains_key(key)
|
||||
}
|
||||
|
||||
pub async fn get_all(&self) -> HashMap<String, String> {
|
||||
self.settings.read().await.clone()
|
||||
}
|
||||
}
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
|
||||
#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)]
|
||||
pub struct ShardState {
|
||||
pub shard_id: i32,
|
||||
pub up: bool,
|
||||
|
|
@ -10,3 +10,9 @@ pub struct ShardState {
|
|||
pub last_connection: i32,
|
||||
pub cluster_id: Option<i32>,
|
||||
}
|
||||
|
||||
pub enum ShardStateEvent {
|
||||
Closed,
|
||||
Heartbeat,
|
||||
Other,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
[package]
|
||||
name = "model_macros"
|
||||
name = "pk_macros"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
41
crates/macros/src/entrypoint.rs
Normal file
41
crates/macros/src/entrypoint.rs
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
use proc_macro::{Delimiter, TokenTree};
|
||||
use quote::quote;
|
||||
|
||||
pub fn macro_impl(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
// yes, this ignores everything except the codeblock
|
||||
// it's fine.
|
||||
let body = match input.into_iter().last().expect("empty") {
|
||||
TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => group.stream(),
|
||||
_ => panic!("invalid function"),
|
||||
};
|
||||
|
||||
let body = proc_macro2::TokenStream::from(body);
|
||||
|
||||
return quote! {
|
||||
fn main() {
|
||||
let _sentry_guard = libpk::init_sentry();
|
||||
libpk::init_logging(env!("CARGO_CRATE_NAME"));
|
||||
tokio::runtime::Builder::new_multi_thread()
|
||||
.enable_all()
|
||||
.build()
|
||||
.unwrap()
|
||||
.block_on(async {
|
||||
if let Err(error) = libpk::init_metrics() {
|
||||
tracing::error!(?error, "failed to init metrics collector");
|
||||
};
|
||||
|
||||
tracing::info!("hello world");
|
||||
|
||||
let result: anyhow::Result<()> = async { #body }.await;
|
||||
|
||||
if let Err(error) = result {
|
||||
tracing::error!(?error, "failed to run service");
|
||||
};
|
||||
});
|
||||
}
|
||||
}
|
||||
.into();
|
||||
}
|
||||
14
crates/macros/src/lib.rs
Normal file
14
crates/macros/src/lib.rs
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
use proc_macro::TokenStream;
|
||||
|
||||
mod entrypoint;
|
||||
mod model;
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
entrypoint::macro_impl(args, input)
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn pk_model(args: TokenStream, input: TokenStream) -> TokenStream {
|
||||
model::macro_impl(args, input)
|
||||
}
|
||||
|
|
@ -16,6 +16,7 @@ struct ModelField {
|
|||
patch: ElemPatchability,
|
||||
json: Option<Expr>,
|
||||
is_privacy: bool,
|
||||
privacy: Option<Expr>,
|
||||
default: Option<Expr>,
|
||||
}
|
||||
|
||||
|
|
@ -26,6 +27,7 @@ fn parse_field(field: syn::Field) -> ModelField {
|
|||
patch: ElemPatchability::None,
|
||||
json: None,
|
||||
is_privacy: false,
|
||||
privacy: None,
|
||||
default: None,
|
||||
};
|
||||
|
||||
|
|
@ -61,6 +63,12 @@ fn parse_field(field: syn::Field) -> ModelField {
|
|||
}
|
||||
f.json = Some(nv.value.clone());
|
||||
}
|
||||
"privacy" => {
|
||||
if f.privacy.is_some() {
|
||||
panic!("cannot set privacy multiple times for same field");
|
||||
}
|
||||
f.privacy = Some(nv.value.clone());
|
||||
}
|
||||
"default" => {
|
||||
if f.default.is_some() {
|
||||
panic!("cannot set default multiple times for same field");
|
||||
|
|
@ -84,8 +92,7 @@ fn parse_field(field: syn::Field) -> ModelField {
|
|||
f
|
||||
}
|
||||
|
||||
#[proc_macro_attribute]
|
||||
pub fn pk_model(
|
||||
pub fn macro_impl(
|
||||
_args: proc_macro::TokenStream,
|
||||
input: proc_macro::TokenStream,
|
||||
) -> proc_macro::TokenStream {
|
||||
|
|
@ -108,8 +115,6 @@ pub fn pk_model(
|
|||
panic!("fields of a struct must be named");
|
||||
};
|
||||
|
||||
// println!("{}: {:#?}", tname, fields);
|
||||
|
||||
let tfields = mk_tfields(fields.clone());
|
||||
let from_json = mk_tfrom_json(fields.clone());
|
||||
let _from_sql = mk_tfrom_sql(fields.clone());
|
||||
|
|
@ -138,9 +143,7 @@ pub fn pk_model(
|
|||
#from_json
|
||||
}
|
||||
|
||||
pub fn to_json(self) -> serde_json::Value {
|
||||
#to_json
|
||||
}
|
||||
#to_json
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
|
|
@ -189,19 +192,28 @@ fn mk_tfrom_sql(_fields: Vec<ModelField>) -> TokenStream {
|
|||
quote! { unimplemented!(); }
|
||||
}
|
||||
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
||||
// todo: check privacy access
|
||||
let has_privacy = fields.iter().any(|f| f.privacy.is_some());
|
||||
let fielddefs: TokenStream = fields
|
||||
.iter()
|
||||
.filter_map(|f| {
|
||||
f.json.as_ref().map(|v| {
|
||||
let tname = f.name.clone();
|
||||
if let Some(default) = f.default.as_ref() {
|
||||
let maybepriv = if let Some(privacy) = f.privacy.as_ref() {
|
||||
quote! {
|
||||
#v: self.#tname.unwrap_or(#default),
|
||||
#v: crate::_util::privacy_lookup!(self.#tname, self.#privacy, lookup_level)
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#v: self.#tname,
|
||||
#v: self.#tname
|
||||
}
|
||||
};
|
||||
if let Some(default) = f.default.as_ref() {
|
||||
quote! {
|
||||
#maybepriv.unwrap_or(#default),
|
||||
}
|
||||
} else {
|
||||
quote! {
|
||||
#maybepriv,
|
||||
}
|
||||
}
|
||||
})
|
||||
|
|
@ -223,13 +235,35 @@ fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
|
|||
})
|
||||
.collect();
|
||||
|
||||
quote! {
|
||||
serde_json::json!({
|
||||
#fielddefs
|
||||
"privacy": {
|
||||
#privacyfielddefs
|
||||
let privdef = if has_privacy {
|
||||
quote! {
|
||||
, lookup_level: crate::PrivacyLevel
|
||||
}
|
||||
} else {
|
||||
quote! {}
|
||||
};
|
||||
|
||||
let privacy_fielddefs = if has_privacy {
|
||||
quote! {
|
||||
"privacy": if matches!(lookup_level, crate::PrivacyLevel::Private) {
|
||||
Some(serde_json::json!({
|
||||
#privacyfielddefs
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
} else {
|
||||
quote! {}
|
||||
};
|
||||
|
||||
quote! {
|
||||
pub fn to_json(self #privdef) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
#fielddefs
|
||||
#privacy_fielddefs
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue