Merge remote-tracking branch 'upstream/main' into rust-command-parser

This commit is contained in:
dusk 2025-08-09 17:38:44 +03:00
commit f721b850d4
No known key found for this signature in database
183 changed files with 5121 additions and 1909 deletions

View file

@ -4,6 +4,7 @@
# Include project code and build files
!PluralKit.*/
!Myriad/
!Serilog/
!.git
!dashboard
!crates/

View file

@ -2,7 +2,9 @@ name: Build and push Docker image
on:
push:
paths:
- '.github/workflows/docker.yml'
- '.dockerignore'
- '.github/workflows/dotnet-docker.yml'
- 'ci/Dockerfile.dotnet'
- 'ci/dotnet-version.sh'
- 'Myriad/**'
- 'PluralKit.API/**'
@ -23,6 +25,9 @@ jobs:
username: ${{ github.actor }}
password: ${{ secrets.CR_PAT }}
- uses: actions/checkout@v2
with:
submodules: true
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" | sed 's|/|-|g' >> $GITHUB_ENV
- name: Extract Docker metadata
@ -41,6 +46,7 @@ jobs:
with:
# https://github.com/docker/build-push-action/issues/378
context: .
file: ci/Dockerfile.dotnet
push: true
tags: ${{ steps.meta.outputs.tags }}
cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}

View file

@ -3,6 +3,7 @@ on:
push:
paths:
- 'crates/**'
- '.dockerignore'
- '.github/workflows/rust.yml'
- 'ci/Dockerfile.rust'
- 'ci/rust-docker-target.sh'

4
.gitmodules vendored
View file

@ -0,0 +1,4 @@
[submodule "Serilog"]
path = Serilog
url = https://github.com/pluralkit/serilog
branch = f5eb991cb4c4a0c1e2407de7504c543536786598

2722
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -6,7 +6,6 @@ members = [
[workspace.dependencies]
anyhow = "1"
axum = "0.7.5"
axum-macros = "0.4.1"
bytes = "1.6.0"
chrono = "0.4"
@ -18,21 +17,22 @@ reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-t
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.117"
signal-hook = "0.3.17"
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "chrono", "macros", "uuid"] }
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "macros", "uuid"] }
tokio = { version = "1.36.0", features = ["full"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }
uuid = { version = "1.7.0", features = ["serde"] }
twilight-gateway = { git = "https://github.com/pluralkit/twilight" }
twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] }
twilight-util = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] }
twilight-model = { git = "https://github.com/pluralkit/twilight" }
twilight-http = { git = "https://github.com/pluralkit/twilight", default-features = false, features = ["rustls-native-roots"] }
axum = { git = "https://github.com/pluralkit/axum", branch = "v0.8.4-pluralkit" }
#twilight-gateway = { path = "../twilight/twilight-gateway" }
#twilight-cache-inmemory = { path = "../twilight/twilight-cache-inmemory", features = ["permission-calculator"] }
#twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] }
#twilight-model = { path = "../twilight/twilight-model" }
#twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-native-roots"] }
twilight-gateway = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef" }
twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef", features = ["permission-calculator"] }
twilight-util = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef", features = ["permission-calculator"] }
twilight-model = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef" }
twilight-http = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-70105ef", default-features = false, features = ["rustls-aws_lc_rs", "rustls-native-roots"] }
# twilight-gateway = { path = "../twilight/twilight-gateway" }
# twilight-cache-inmemory = { path = "../twilight/twilight-cache-inmemory", features = ["permission-calculator"] }
# twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] }
# twilight-model = { path = "../twilight/twilight-model" }
# twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-aws_lc_rs", "rustls-native-roots"] }

View file

@ -1,7 +1,10 @@
using Serilog;
using System.Net;
using System.Text;
using System.Text.Json;
using NodaTime;
using Myriad.Serialization;
using Myriad.Types;
@ -11,7 +14,8 @@ public class HttpDiscordCache: IDiscordCache
{
private readonly ILogger _logger;
private readonly HttpClient _client;
private readonly Uri _cacheEndpoint;
private readonly string _cacheEndpoint;
private readonly string? _eventTarget;
private readonly int _shardCount;
private readonly ulong _ownUserId;
@ -21,11 +25,12 @@ public class HttpDiscordCache: IDiscordCache
public EventHandler<(bool?, string)> OnDebug;
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, int shardCount, ulong ownUserId, bool useInnerCache)
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, string? eventTarget, int shardCount, ulong ownUserId, bool useInnerCache)
{
_logger = logger;
_client = client;
_cacheEndpoint = new Uri(cacheEndpoint);
_cacheEndpoint = cacheEndpoint;
_eventTarget = eventTarget;
_shardCount = shardCount;
_ownUserId = ownUserId;
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
@ -47,13 +52,12 @@ public class HttpDiscordCache: IDiscordCache
private async Task<T?> QueryCache<T>(string endpoint, ulong guildId)
{
var cluster = _cacheEndpoint.Authority;
// todo: there should not be infra-specific code here
if (cluster.Contains(".service.consul") || cluster.Contains("process.pluralkit-gateway.internal"))
// int(((guild_id >> 22) % shard_count) / 16)
cluster = $"cluster{(int)(((guildId >> 22) % (ulong)_shardCount) / 16)}.{cluster}";
var cluster = _cacheEndpoint;
var response = await _client.GetAsync($"{_cacheEndpoint.Scheme}://{cluster}{endpoint}");
if (cluster.Contains("{clusterid}"))
cluster = cluster.Replace("{clusterid}", $"{(int)(((guildId >> 22) % (ulong)_shardCount) / 16)}");
var response = await _client.GetAsync($"http://{cluster}{endpoint}");
if (response.StatusCode == HttpStatusCode.NotFound)
return default;
@ -65,6 +69,70 @@ public class HttpDiscordCache: IDiscordCache
return JsonSerializer.Deserialize<T>(plaintext, _jsonSerializerOptions);
}
public Task<T> GetLastMessage<T>(ulong guildId, ulong channelId)
=> QueryCache<T>($"/guilds/{guildId}/channels/{channelId}/last_message", guildId);
private Task AwaitEvent(ulong guildId, object data)
=> AwaitEventShard((int)((guildId >> 22) % (ulong)_shardCount), data);
private async Task AwaitEventShard(int shardId, object data)
{
if (_eventTarget == null)
throw new Exception("missing event target for remote await event");
var cluster = _cacheEndpoint;
if (cluster.Contains("{clusterid}"))
cluster = cluster.Replace("{clusterid}", $"{(int)(shardId / 16)}");
var response = await _client.PostAsync(
$"http://{cluster}/await_event",
new StringContent(JsonSerializer.Serialize(data), Encoding.UTF8)
);
if (response.StatusCode != HttpStatusCode.NoContent)
throw new Exception($"failed to await event from gateway: {response.StatusCode}");
}
public async Task AwaitReaction(ulong guildId, ulong messageId, ulong userId, Duration? timeout)
{
var obj = new
{
message_id = messageId,
user_id = userId,
target = _eventTarget!,
timeout = timeout?.TotalSeconds,
};
await AwaitEvent(guildId, obj);
}
public async Task AwaitMessage(ulong guildId, ulong channelId, ulong authorId, Duration? timeout, string[] options = null)
{
var obj = new
{
channel_id = channelId,
author_id = authorId,
target = _eventTarget!,
timeout = timeout?.TotalSeconds,
options = options,
};
await AwaitEvent(guildId, obj);
}
public async Task AwaitInteraction(int shardId, string id, Duration? timeout)
{
var obj = new
{
id = id,
target = _eventTarget!,
timeout = timeout?.TotalSeconds,
};
await AwaitEventShard(shardId, obj);
}
public async Task<Guild?> TryGetGuild(ulong guildId)
{
var hres = await QueryCache<Guild?>($"/guilds/{guildId}", guildId);

View file

@ -10,6 +10,8 @@ public record MessageUpdateEvent(ulong Id, ulong ChannelId): IGatewayEvent
public Optional<GuildMemberPartial> Member { get; init; }
public Optional<Message.Attachment[]> Attachments { get; init; }
public Message.MessageType Type { get; init; }
public Optional<ulong?> GuildId { get; init; }
// TODO: lots of partials
}

View file

@ -21,9 +21,14 @@
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Serilog\src\Serilog\Serilog.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="NodaTime" Version="3.2.0" />
<PackageReference Include="NodaTime.Serialization.JsonNet" Version="3.1.0" />
<PackageReference Include="Polly" Version="8.5.0" />
<PackageReference Include="Polly.Contrib.WaitAndRetry" Version="1.1.1" />
<PackageReference Include="Serilog" Version="4.2.0" />
<PackageReference Include="StackExchange.Redis" Version="2.8.22" />
<PackageReference Include="System.Linq.Async" Version="6.0.1" />
</ItemGroup>

View file

@ -2,6 +2,22 @@
"version": 1,
"dependencies": {
"net8.0": {
"NodaTime": {
"type": "Direct",
"requested": "[3.2.0, )",
"resolved": "3.2.0",
"contentHash": "yoRA3jEJn8NM0/rQm78zuDNPA3DonNSZdsorMUj+dltc1D+/Lc5h9YXGqbEEZozMGr37lAoYkcSM/KjTVqD0ow=="
},
"NodaTime.Serialization.JsonNet": {
"type": "Direct",
"requested": "[3.1.0, )",
"resolved": "3.1.0",
"contentHash": "eEr9lXUz50TYr4rpeJG4TDAABkpxjIKr5mDSi/Zav8d6Njy6fH7x4ZtNwWFj0Vd+vIvEZNrHFQ4Gfy8j4BqRGg==",
"dependencies": {
"Newtonsoft.Json": "13.0.3",
"NodaTime": "[3.0.0, 4.0.0)"
}
},
"Polly": {
"type": "Direct",
"requested": "[8.5.0, )",
@ -17,12 +33,6 @@
"resolved": "1.1.1",
"contentHash": "1MUQLiSo4KDkQe6nzQRhIU05lm9jlexX5BVsbuw0SL82ynZ+GzAHQxJVDPVBboxV37Po3SG077aX8DuSy8TkaA=="
},
"Serilog": {
"type": "Direct",
"requested": "[4.2.0, )",
"resolved": "4.2.0",
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
},
"StackExchange.Redis": {
"type": "Direct",
"requested": "[2.8.22, )",
@ -52,6 +62,11 @@
"resolved": "6.0.0",
"contentHash": "/HggWBbTwy8TgebGSX5DBZ24ndhzi93sHUBDvP1IxbZD7FDokYzdAr6+vbWGjw2XAfR2EJ1sfKUotpjHnFWPxA=="
},
"Newtonsoft.Json": {
"type": "Transitive",
"resolved": "13.0.3",
"contentHash": "HrC5BXdl00IP9zeV+0Z848QWPAoCr9P3bDEZguI+gkLcBKAOxix/tLEAAHC+UvDNPv4a2d18lOReHMOagPa+zQ=="
},
"Pipelines.Sockets.Unofficial": {
"type": "Transitive",
"resolved": "2.2.8",
@ -69,6 +84,9 @@
"type": "Transitive",
"resolved": "5.0.1",
"contentHash": "qEePWsaq9LoEEIqhbGe6D5J8c9IqQOUuTzzV6wn1POlfdLkJliZY3OlB0j0f17uMWlqZYjH7txj+2YbyrIA8Yg=="
},
"serilog": {
"type": "Project"
}
}
}

View file

@ -6,4 +6,6 @@ public class ApiConfig
public string? ClientId { get; set; }
public string? ClientSecret { get; set; }
public bool TrustAuth { get; set; } = false;
public string? AvatarServiceUrl { get; set; }
public bool SearchGuildSettings = false;
}

View file

@ -21,6 +21,12 @@ public class AuthorizationTokenHandlerMiddleware
&& int.TryParse(sidHeaders[0], out var systemId))
ctx.Items.Add("SystemId", new SystemId(systemId));
if (cfg.TrustAuth
&& ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders)
&& aidHeaders.Count > 0
&& int.TryParse(aidHeaders[0], out var appId))
ctx.Items.Add("AppId", appId);
await _next.Invoke(ctx);
}
}

View file

@ -20,7 +20,7 @@ public class DiscordControllerV2: PKControllerBase
if (ContextFor(system) != LookupContext.ByOwner)
throw Errors.GenericMissingPermissions;
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false);
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false, _config.SearchGuildSettings);
if (settings == null)
throw Errors.SystemGuildNotFound;
@ -34,7 +34,7 @@ public class DiscordControllerV2: PKControllerBase
if (ContextFor(system) != LookupContext.ByOwner)
throw Errors.GenericMissingPermissions;
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false);
var settings = await _repo.GetSystemGuild(guild_id, system.Id, false, _config.SearchGuildSettings);
if (settings == null)
throw Errors.SystemGuildNotFound;
@ -58,7 +58,7 @@ public class DiscordControllerV2: PKControllerBase
if (member.System != system.Id)
throw Errors.NotOwnMemberError;
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false);
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false, _config.SearchGuildSettings ? system.Id : null);
if (settings == null)
throw Errors.MemberGuildNotFound;
@ -75,7 +75,7 @@ public class DiscordControllerV2: PKControllerBase
if (member.System != system.Id)
throw Errors.NotOwnMemberError;
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false);
var settings = await _repo.GetMemberGuild(guild_id, member.Id, false, _config.SearchGuildSettings ? system.Id : null);
if (settings == null)
throw Errors.MemberGuildNotFound;

View file

@ -1,3 +1,7 @@
using System.Net;
using System.Net.Http;
using System.Net.Http.Json;
using Microsoft.AspNetCore.Mvc;
using Newtonsoft.Json.Linq;
@ -50,6 +54,9 @@ public class MemberControllerV2: PKControllerBase
if (patch.Errors.Count > 0)
throw new ModelParseError(patch.Errors);
if (patch.AvatarUrl.Value != null)
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
using var conn = await _db.Obtain();
using var tx = await conn.BeginTransactionAsync();
@ -110,6 +117,9 @@ public class MemberControllerV2: PKControllerBase
if (patch.Errors.Count > 0)
throw new ModelParseError(patch.Errors);
if (patch.AvatarUrl.Value != null)
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
var newMember = await _repo.UpdateMember(member.Id, patch);
return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid));
}
@ -129,4 +139,28 @@ public class MemberControllerV2: PKControllerBase
return NoContent();
}
private async Task<string> TryUploadAvatar(string avatarUrl, PKSystem system)
{
if (!avatarUrl.StartsWith("https://serve.apparyllis.com/")) return avatarUrl;
if (_config.AvatarServiceUrl == null) return avatarUrl;
if (!HttpContext.Items.TryGetValue("AppId", out var appId) || (int)appId != 1) return avatarUrl;
using var client = new HttpClient();
var response = await client.PostAsJsonAsync(_config.AvatarServiceUrl + "/pull",
new { url = avatarUrl, kind = "avatar", uploaded_by = (string)null, system_id = system.Uuid.ToString() });
if (response.StatusCode != HttpStatusCode.OK)
{
var error = await response.Content.ReadFromJsonAsync<ErrorResponse>();
throw new PKError(500, 0, $"Error uploading image to CDN: {error.Error}");
}
var success = await response.Content.ReadFromJsonAsync<SuccessResponse>();
return success.Url;
}
public record ErrorResponse(string Error);
public record SuccessResponse(string Url, bool New);
}

View file

@ -32,6 +32,6 @@
<PackageReference Include="Microsoft.AspNetCore.Mvc.Versioning" Version="5.1.0" />
<PackageReference Include="Microsoft.AspNetCore.Mvc.Versioning.ApiExplorer" Version="5.1.0" />
<PackageReference Include="Sentry" Version="4.13.0" />
<PackageReference Include="Serilog.AspNetCore" Version="9.0.0" />
<PackageReference Include="Serilog.AspNetCore" Version="8.0.0" />
</ItemGroup>
</Project>

View file

@ -35,7 +35,7 @@ public class Startup
builder.RegisterInstance(InitUtils.BuildConfiguration(Environment.GetCommandLineArgs()).Build())
.As<IConfiguration>();
builder.RegisterModule(new ConfigModule<ApiConfig>("API"));
builder.RegisterModule(new LoggingModule("api",
builder.RegisterModule(new LoggingModule("dotnet-api",
cfg: new LoggerConfiguration().Filter.ByExcluding(
exc => exc.Exception is PKError || exc.Exception.IsUserError()
)));

View file

@ -36,17 +36,20 @@
},
"Serilog.AspNetCore": {
"type": "Direct",
"requested": "[9.0.0, )",
"resolved": "9.0.0",
"contentHash": "JslDajPlBsn3Pww1554flJFTqROvK9zz9jONNQgn0D8Lx2Trw8L0A8/n6zEQK1DAZWXrJwiVLw8cnTR3YFuYsg==",
"requested": "[8.0.0, )",
"resolved": "8.0.0",
"contentHash": "FAjtKPZ4IzqFQBqZKPv6evcXK/F0ls7RoXI/62Pnx2igkDZ6nZ/jn/C/FxVATqQbEQvtqP+KViWYIe4NZIHa2w==",
"dependencies": {
"Serilog": "4.2.0",
"Serilog.Extensions.Hosting": "9.0.0",
"Serilog.Formatting.Compact": "3.0.0",
"Serilog.Settings.Configuration": "9.0.0",
"Serilog.Sinks.Console": "6.0.0",
"Serilog.Sinks.Debug": "3.0.0",
"Serilog.Sinks.File": "6.0.0"
"Microsoft.Extensions.DependencyInjection": "8.0.0",
"Microsoft.Extensions.Logging": "8.0.0",
"Serilog": "3.1.1",
"Serilog.Extensions.Hosting": "8.0.0",
"Serilog.Extensions.Logging": "8.0.0",
"Serilog.Formatting.Compact": "2.0.0",
"Serilog.Settings.Configuration": "8.0.0",
"Serilog.Sinks.Console": "5.0.0",
"Serilog.Sinks.Debug": "2.0.0",
"Serilog.Sinks.File": "5.0.0"
}
},
"App.Metrics": {
@ -296,21 +299,21 @@
},
"Microsoft.Extensions.DependencyModel": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "saxr2XzwgDU77LaQfYFXmddEDRUKHF4DaGMZkNB3qjdVSZlax3//dGJagJkKrGMIPNZs2jVFXITyCCR6UHJNdA==",
"resolved": "8.0.0",
"contentHash": "NSmDw3K0ozNDgShSIpsZcbFIzBX4w28nDag+TfaQujkXGazBm+lid5onlWoCBy4VsLxqnnKjEBbGSJVWJMf43g==",
"dependencies": {
"System.Text.Encodings.Web": "9.0.0",
"System.Text.Json": "9.0.0"
"System.Text.Encodings.Web": "8.0.0",
"System.Text.Json": "8.0.0"
}
},
"Microsoft.Extensions.Diagnostics.Abstractions": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "1K8P7XzuzX8W8pmXcZjcrqS6x5eSSdvhQohmcpgiQNY/HlDAlnrhR9dvlURfFz428A+RTCJpUyB+aKTA6AgVcQ==",
"resolved": "8.0.0",
"contentHash": "JHYCQG7HmugNYUhOl368g+NMxYE/N/AiclCYRNlgCY9eVyiBkOHMwK4x60RYMxv9EL3+rmj1mqHvdCiPpC+D4Q==",
"dependencies": {
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
"Microsoft.Extensions.Options": "9.0.0",
"System.Diagnostics.DiagnosticSource": "9.0.0"
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
"Microsoft.Extensions.Options": "8.0.0",
"System.Diagnostics.DiagnosticSource": "8.0.0"
}
},
"Microsoft.Extensions.FileProviders.Abstractions": {
@ -338,14 +341,14 @@
},
"Microsoft.Extensions.Hosting.Abstractions": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "yUKJgu81ExjvqbNWqZKshBbLntZMbMVz/P7Way2SBx7bMqA08Mfdc9O7hWDKAiSp+zPUGT6LKcSCQIPeDK+CCw==",
"resolved": "8.0.0",
"contentHash": "AG7HWwVRdCHlaA++1oKDxLsXIBxmDpMPb3VoyOoAghEWnkUvEAdYQUwnV4jJbAaa/nMYNiEh5ByoLauZBEiovg==",
"dependencies": {
"Microsoft.Extensions.Configuration.Abstractions": "9.0.0",
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
"Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0",
"Microsoft.Extensions.FileProviders.Abstractions": "9.0.0",
"Microsoft.Extensions.Logging.Abstractions": "9.0.0"
"Microsoft.Extensions.Configuration.Abstractions": "8.0.0",
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
"Microsoft.Extensions.Diagnostics.Abstractions": "8.0.0",
"Microsoft.Extensions.FileProviders.Abstractions": "8.0.0",
"Microsoft.Extensions.Logging.Abstractions": "8.0.0"
}
},
"Microsoft.Extensions.Logging": {
@ -461,30 +464,25 @@
"System.IO.Pipelines": "5.0.1"
}
},
"Serilog": {
"type": "Transitive",
"resolved": "4.2.0",
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
},
"Serilog.Extensions.Hosting": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "u2TRxuxbjvTAldQn7uaAwePkWxTHIqlgjelekBtilAGL5sYyF3+65NWctN4UrwwGLsDC7c3Vz3HnOlu+PcoxXg==",
"resolved": "8.0.0",
"contentHash": "db0OcbWeSCvYQkHWu6n0v40N4kKaTAXNjlM3BKvcbwvNzYphQFcBR+36eQ/7hMMwOkJvAyLC2a9/jNdUL5NjtQ==",
"dependencies": {
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
"Microsoft.Extensions.Hosting.Abstractions": "9.0.0",
"Microsoft.Extensions.Logging.Abstractions": "9.0.0",
"Serilog": "4.2.0",
"Serilog.Extensions.Logging": "9.0.0"
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
"Microsoft.Extensions.Hosting.Abstractions": "8.0.0",
"Microsoft.Extensions.Logging.Abstractions": "8.0.0",
"Serilog": "3.1.1",
"Serilog.Extensions.Logging": "8.0.0"
}
},
"Serilog.Extensions.Logging": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
"resolved": "8.0.0",
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
"dependencies": {
"Microsoft.Extensions.Logging": "9.0.0",
"Serilog": "4.2.0"
"Microsoft.Extensions.Logging": "8.0.0",
"Serilog": "3.1.1"
}
},
"Serilog.Formatting.Compact": {
@ -514,12 +512,12 @@
},
"Serilog.Settings.Configuration": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "4/Et4Cqwa+F88l5SeFeNZ4c4Z6dEAIKbu3MaQb2Zz9F/g27T5a3wvfMcmCOaAiACjfUb4A6wrlTVfyYUZk3RRQ==",
"resolved": "8.0.0",
"contentHash": "nR0iL5HwKj5v6ULo3/zpP8NMcq9E2pxYA6XKTSWCbugVs4YqPyvaqaKOY+OMpPivKp7zMEpax2UKHnDodbRB0Q==",
"dependencies": {
"Microsoft.Extensions.Configuration.Binder": "9.0.0",
"Microsoft.Extensions.DependencyModel": "9.0.0",
"Serilog": "4.2.0"
"Microsoft.Extensions.Configuration.Binder": "8.0.0",
"Microsoft.Extensions.DependencyModel": "8.0.0",
"Serilog": "3.1.1"
}
},
"Serilog.Sinks.Async": {
@ -540,10 +538,10 @@
},
"Serilog.Sinks.Debug": {
"type": "Transitive",
"resolved": "3.0.0",
"contentHash": "4BzXcdrgRX7wde9PmHuYd9U6YqycCC28hhpKonK7hx0wb19eiuRj16fPcPSVp0o/Y1ipJuNLYQ00R3q2Zs8FDA==",
"resolved": "2.0.0",
"contentHash": "Y6g3OBJ4JzTyyw16fDqtFcQ41qQAydnEvEqmXjhwhgjsnG/FaJ8GUqF5ldsC/bVkK8KYmqrPhDO+tm4dF6xx4A==",
"dependencies": {
"Serilog": "4.0.0"
"Serilog": "2.10.0"
}
},
"Serilog.Sinks.Elasticsearch": {
@ -837,8 +835,8 @@
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
"Npgsql": "[9.0.2, )",
"Npgsql.NodaTime": "[9.0.2, )",
"Serilog": "[4.2.0, )",
"Serilog.Extensions.Logging": "[9.0.0, )",
"Serilog": "[4.1.0, )",
"Serilog.Extensions.Logging": "[8.0.0, )",
"Serilog.Formatting.Compact": "[3.0.0, )",
"Serilog.NodaTime": "[3.0.0, )",
"Serilog.Sinks.Async": "[2.1.0, )",
@ -852,6 +850,9 @@
"System.Interactive.Async": "[6.0.1, )",
"ipnetwork2": "[3.0.667, )"
}
},
"serilog": {
"type": "Project"
}
}
}

View file

@ -33,7 +33,10 @@ public class ApplicationCommandProxiedMessage
var messageId = ctx.Event.Data!.TargetId!.Value;
var msg = await ctx.Repository.GetFullMessage(messageId);
if (msg == null)
throw Errors.MessageNotFound(messageId);
{
await QueryCommandMessage(ctx);
return;
}
var showContent = true;
var channel = await _rest.GetChannelOrNull(msg.Message.Channel);
@ -58,6 +61,20 @@ public class ApplicationCommandProxiedMessage
await ctx.Reply(embeds: embeds.ToArray());
}
private async Task QueryCommandMessage(InteractionContext ctx)
{
var messageId = ctx.Event.Data!.TargetId!.Value;
var msg = await ctx.Repository.GetCommandMessage(messageId);
if (msg == null)
throw Errors.MessageNotFound(messageId);
var embeds = new List<Embed>();
embeds.Add(await _embeds.CreateCommandMessageInfoEmbed(msg, true));
await ctx.Reply(embeds: embeds.ToArray());
}
public async Task DeleteMessage(InteractionContext ctx)
{
var messageId = ctx.Event.Data!.TargetId!.Value;

View file

@ -32,13 +32,14 @@ public class Bot
private readonly DiscordApiClient _rest;
private readonly RedisService _redis;
private readonly ILifetimeScope _services;
private readonly RuntimeConfigService _runtimeConfig;
private Timer _periodicTask; // Never read, just kept here for GC reasons
public Bot(ILifetimeScope services, ILogger logger, PeriodicStatCollector collector, IMetrics metrics,
BotConfig config, RedisService redis,
ErrorMessageService errorMessageService, CommandMessageService commandMessageService,
Cluster cluster, DiscordApiClient rest, IDiscordCache cache)
Cluster cluster, DiscordApiClient rest, IDiscordCache cache, RuntimeConfigService runtimeConfig)
{
_logger = logger.ForContext<Bot>();
_services = services;
@ -51,6 +52,7 @@ public class Bot
_rest = rest;
_redis = redis;
_cache = cache;
_runtimeConfig = runtimeConfig;
}
private string BotStatus => $"{(_config.Prefixes ?? BotConfig.DefaultPrefixes)[0]}help"
@ -97,13 +99,15 @@ public class Bot
private async Task OnEventReceived(int shardId, IGatewayEvent evt)
{
if (_runtimeConfig.Exists("disable_events")) return;
// we HandleGatewayEvent **before** getting the own user, because the own user is set in HandleGatewayEvent for ReadyEvent
await _cache.HandleGatewayEvent(evt);
await _cache.TryUpdateSelfMember(_config.ClientId, evt);
await OnEventReceivedInner(shardId, evt);
}
private async Task OnEventReceivedInner(int shardId, IGatewayEvent evt)
public async Task OnEventReceivedInner(int shardId, IGatewayEvent evt)
{
// HandleEvent takes a type parameter, automatically inferred by the event type
// It will then look up an IEventHandler<TypeOfEvent> in the DI container and call that object's handler method
@ -278,7 +282,7 @@ public class Bot
_logger.Debug("Running once-per-minute scheduled tasks");
// Check from a new custom status from Redis and update Discord accordingly
if (true)
if (!_config.DisableGateway)
{
var newStatus = await _redis.Connection.GetDatabase().StringGetAsync("pluralkit:botstatus");
if (newStatus != CustomStatusMessage)

View file

@ -24,6 +24,10 @@ public class BotConfig
public string? HttpCacheUrl { get; set; }
public bool HttpUseInnerCache { get; set; } = false;
public string? HttpListenerAddr { get; set; }
public bool DisableGateway { get; set; } = false;
public string? EventAwaiterTarget { get; set; }
public string? DiscordBaseUrl { get; set; }
public string? AvatarServiceUrl { get; set; }

View file

@ -108,7 +108,7 @@ public partial class CommandTree
ctx.Reply(
$"{Emojis.Error} Parsed command {ctx.Parameters.Callback().AsCode()} not implemented in PluralKit.Bot!"),
};
if (ctx.Match("system", "s"))
if (ctx.Match("system", "s", "account", "acc"))
return HandleSystemCommand(ctx);
if (ctx.Match("member", "m"))
return HandleMemberCommand(ctx);
@ -554,6 +554,8 @@ public partial class CommandTree
case "system":
case "systems":
case "s":
case "account":
case "acc":
await PrintCommandList(ctx, "systems", SystemCommands);
break;
case "member":

View file

@ -80,7 +80,7 @@ public class Context
internal readonly ModelRepository Repository;
internal readonly RedisService Redis;
public async Task<Message> Reply(string text = null, Embed embed = null, AllowedMentions? mentions = null)
public async Task<Message> Reply(string text = null, Embed embed = null, AllowedMentions? mentions = null, MultipartFile[]? files = null)
{
var botPerms = await BotPermissions;
@ -91,20 +91,28 @@ public class Context
if (embed != null && !botPerms.HasFlag(PermissionSet.EmbedLinks))
throw new PKError("PluralKit does not have permission to send embeds in this channel. Please ensure I have the **Embed Links** permission enabled.");
if (files != null && !botPerms.HasFlag(PermissionSet.AttachFiles))
throw new PKError("PluralKit does not have permission to attach files in this channel. Please ensure I have the **Attach Files** permission enabled.");
var msg = await Rest.CreateMessage(Channel.Id, new MessageRequest
{
Content = text,
Embeds = embed != null ? new[] { embed } : null,
// Default to an empty allowed mentions object instead of null (which means no mentions allowed)
AllowedMentions = mentions ?? new AllowedMentions()
});
}, files: files);
// if (embed != null)
// {
// Sensitive information that might want to be deleted by :x: reaction is typically in an embed format (member cards, for example)
// but since we can, we just store all sent messages for possible deletion
await _commandMessageService.RegisterMessage(msg.Id, Guild?.Id ?? 0, msg.ChannelId, Author.Id);
// }
// store log of sent message, so it can be queried or deleted later
// skip DMs as DM messages can always be deleted
if (Guild != null)
await Repository.AddCommandMessage(new Core.CommandMessage
{
Mid = msg.Id,
Guild = Guild!.Id,
Channel = Channel.Id,
Sender = Author.Id,
OriginalMid = Message.Id,
});
return msg;
}

View file

@ -443,10 +443,11 @@ public class Groups
await ctx.Reply(embed: new EmbedBuilder()
.Title("Group color")
.Color(target.Color.ToDiscordColor())
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{target.Color}/?text=%20"))
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
.Description($"This group's color is **#{target.Color}**."
+ (isOwnSystem ? $" To clear it, type `{ctx.DefaultPrefix}group {target.Reference(ctx)} color -clear`." : ""))
.Build());
.Build(),
files: [MiscUtils.GenerateColorPreview(target.Color)]);
return;
}
@ -471,8 +472,9 @@ public class Groups
await ctx.Reply(embed: new EmbedBuilder()
.Title($"{Emojis.Success} Group color changed.")
.Color(color.ToDiscordColor())
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{color}/?text=%20"))
.Build());
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
.Build(),
files: [MiscUtils.GenerateColorPreview(color)]);
}
}

View file

@ -308,10 +308,11 @@ public class MemberEdit
await ctx.Reply(embed: new EmbedBuilder()
.Title("Member color")
.Color(target.Color.ToDiscordColor())
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{target.Color}/?text=%20"))
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
.Description($"This member's color is **#{target.Color}**."
+ (isOwnSystem ? $" To clear it, type `{ctx.DefaultPrefix}member {target.Reference(ctx)} color -clear`." : ""))
.Build());
.Build(),
files: [MiscUtils.GenerateColorPreview(target.Color)]);
return;
}
@ -336,8 +337,9 @@ public class MemberEdit
await ctx.Reply(embed: new EmbedBuilder()
.Title($"{Emojis.Success} Member color changed.")
.Color(color.ToDiscordColor())
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{color}/?text=%20"))
.Build());
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
.Build(),
files: [MiscUtils.GenerateColorPreview(color)]);
}
}

View file

@ -305,7 +305,7 @@ public class ProxiedMessage
throw new PKError(error);
}
var lastMessage = _lastMessageCache.GetLastMessage(ctx.Message.ChannelId);
var lastMessage = await _lastMessageCache.GetLastMessage(ctx.Message.GuildId ?? 0, ctx.Message.ChannelId);
var isLatestMessage = lastMessage?.Current.Id == ctx.Message.Id
? lastMessage?.Previous?.Id == msg.Mid
@ -347,13 +347,8 @@ public class ProxiedMessage
var message = await ctx.Repository.GetFullMessage(messageId.Value);
if (message == null)
{
if (isDelete)
{
await DeleteCommandMessage(ctx, messageId.Value);
return;
}
throw Errors.MessageNotFound(messageId.Value);
await GetCommandMessage(ctx, messageId.Value, isDelete);
return;
}
var showContent = true;
@ -448,20 +443,35 @@ public class ProxiedMessage
await ctx.Reply(embed: await _embeds.CreateMessageInfoEmbed(message, showContent, ctx.Config));
}
private async Task DeleteCommandMessage(Context ctx, ulong messageId)
private async Task GetCommandMessage(Context ctx, ulong messageId, bool isDelete)
{
var cmessage = await ctx.Services.Resolve<CommandMessageService>().GetCommandMessage(messageId);
if (cmessage == null)
var msg = await _repo.GetCommandMessage(messageId);
if (msg == null)
throw Errors.MessageNotFound(messageId);
if (cmessage!.AuthorId != ctx.Author.Id)
throw new PKError("You can only delete command messages queried by this account.");
if (isDelete)
{
if (msg.Sender != ctx.Author.Id)
throw new PKError("You can only delete command messages queried by this account.");
await ctx.Rest.DeleteMessage(cmessage.ChannelId, messageId);
await ctx.Rest.DeleteMessage(msg.Channel, messageId);
if (ctx.Guild != null)
await ctx.Rest.DeleteMessage(ctx.Message);
else
await ctx.Rest.CreateReaction(ctx.Message.ChannelId, ctx.Message.Id, new Emoji { Name = Emojis.Success });
if (ctx.Guild != null)
await ctx.Rest.DeleteMessage(ctx.Message);
else
await ctx.Rest.CreateReaction(ctx.Message.ChannelId, ctx.Message.Id, new Emoji { Name = Emojis.Success });
return;
}
var showContent = true;
var channel = await _rest.GetChannelOrNull(msg.Channel);
if (channel == null)
showContent = false;
else if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel))
showContent = false;
await ctx.Reply(embed: await _embeds.CreateCommandMessageInfoEmbed(msg, showContent));
}
}

View file

@ -37,10 +37,13 @@ public class System
.Field(new Embed.Field("Getting Started",
"New to PK? Check out our Getting Started guide on setting up members and proxies: https://pluralkit.me/start\n" +
$"Otherwise, type `{ctx.DefaultPrefix}system` to view your system and `{ctx.DefaultPrefix}system help` for more information about commands you can use."))
.Field(new Embed.Field($"{Emojis.Warn} Notice {Emojis.Warn}", "PluralKit is a bot meant to help you share information about your system. " +
.Field(new Embed.Field($"{Emojis.Warn} Notice: Public By Default {Emojis.Warn}", "PluralKit is a bot meant to help you share information about your system. " +
"Member descriptions are meant to be the equivalent to a Discord About Me. Because of this, any info you put in PK is **public by default**.\n" +
"Note that this does **not** include message content, only member fields. For more information, check out " +
"[the privacy section of the user guide](https://pluralkit.me/guide/#privacy). "))
.Field(new Embed.Field($"{Emojis.Warn} Notice: Implicit Acceptance of ToS {Emojis.Warn}", "By using the PluralKit bot you implicitly agree to our " +
"[Terms of Service](https://pluralkit.me/terms-of-service/). For questions please ask in our [support server](<https://discord.gg/PczBt78>) or " +
"email legal@pluralkit.me"))
.Field(new Embed.Field("System Recovery", "In the case of your Discord account getting lost or deleted, the PluralKit staff can help you recover your system. " +
"In order to do so, we will need your **PluralKit token**. This is the *only* way you can prove ownership so we can help you recover your system. " +
$"To get it, run `{ctx.DefaultPrefix}token` and then store it in a safe place.\n\n" +

View file

@ -233,8 +233,9 @@ public class SystemEdit
await ctx.Reply(embed: new EmbedBuilder()
.Title($"{Emojis.Success} System color changed.")
.Color(newColor.ToDiscordColor())
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{newColor}/?text=%20"))
.Build());
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
.Build(),
files: [MiscUtils.GenerateColorPreview(color)]););
}
public async Task ClearColor(Context ctx, PKSystem target, bool flagConfirmYes)
@ -274,10 +275,11 @@ public class SystemEdit
await ctx.Reply(embed: new EmbedBuilder()
.Title("System color")
.Color(target.Color.ToDiscordColor())
.Thumbnail(new Embed.EmbedThumbnail($"https://fakeimg.pl/256x256/{target.Color}/?text=%20"))
.Thumbnail(new Embed.EmbedThumbnail($"attachment://color.gif"))
.Description(
$"This system's color is **#{target.Color}**." + (isOwnSystem ? $" To clear it, type `{ctx.DefaultPrefix}s color -clear`." : ""))
.Build());
.Build(),
files: [MiscUtils.GenerateColorPreview(target.Color)]);
}
public async Task ClearTag(Context ctx, PKSystem target, bool flagConfirmYes)
@ -475,7 +477,7 @@ public class SystemEdit
else
str +=
" Member names will now use the global system tag when proxied in the current server, if there is one set."
+ "\n\nTo check or change where your tag appears in your name use the command `{ctx.DefaultPrefix}cfg name format`.";
+ $"\n\nTo check or change where your tag appears in your name use the command `{ctx.DefaultPrefix}cfg name format`.";
}
}

View file

@ -55,7 +55,9 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
public (ulong?, ulong?) ErrorChannelFor(MessageCreateEvent evt, ulong userId) => (evt.GuildId, evt.ChannelId);
private bool IsDuplicateMessage(Message msg) =>
// We consider a message duplicate if it has the same ID as the previous message that hit the gateway
_lastMessageCache.GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
// use only the local cache here
// http gateway sets last message before forwarding the message here, so this will always return true
_lastMessageCache._GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
public async Task Handle(int shardId, MessageCreateEvent evt)
{

View file

@ -65,7 +65,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
var guild = await _cache.TryGetGuild(channel.GuildId!.Value);
if (guild == null)
throw new Exception("could not find self guild in MessageEdited event");
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
var lastMessage = (await _lastMessageCache.GetLastMessage(evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0, evt.ChannelId))?.Current;
// Only react to the last message in the channel
if (lastMessage?.Id != evt.Id)
@ -107,10 +107,6 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
? new Message.Reference(channel.GuildId, evt.ChannelId, lastMessage.ReferencedMessage.Value)
: null;
var messageType = lastMessage.ReferencedMessage != null
? Message.MessageType.Reply
: Message.MessageType.Default;
// TODO: is this missing anything?
var equivalentEvt = new MessageCreateEvent
{
@ -123,7 +119,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
Attachments = evt.Attachments.Value ?? Array.Empty<Message.Attachment>(),
MessageReference = messageReference,
ReferencedMessage = referencedMessage,
Type = messageType,
Type = evt.Type,
};
return equivalentEvt;
}

View file

@ -111,6 +111,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
case "\U0001F514": // Bell
case "\U0001F6CE": // Bellhop bell
case "\U0001F3D3": // Ping pong paddle (lol)
case "\U0001FAD1": // Bell pepper
case "\u23F0": // Alarm clock
case "\u2757": // Exclamation mark
{

View file

@ -57,16 +57,6 @@ public class Init
var cache = services.Resolve<IDiscordCache>();
if (config.Cluster == null)
{
// "Connect to the database" (ie. set off database migrations and ensure state)
logger.Information("Connecting to database");
await services.Resolve<IDatabase>().ApplyMigrations();
// Clear shard status from Redis
await redis.Connection.GetDatabase().KeyDeleteAsync("pluralkit:shardstatus");
}
logger.Information("Initializing bot");
var bot = services.Resolve<Bot>();
@ -76,10 +66,19 @@ public class Init
// Init the bot instance itself, register handlers and such to the client before beginning to connect
bot.Init();
// Start the Discord shards themselves (handlers already set up)
logger.Information("Connecting to Discord");
await StartCluster(services);
// load runtime config from redis
await services.Resolve<RuntimeConfigService>().LoadConfig();
// Start HTTP server
if (config.HttpListenerAddr != null)
services.Resolve<HttpListenerService>().Start(config.HttpListenerAddr);
// Start the Discord shards themselves (handlers already set up)
if (!config.DisableGateway)
{
logger.Information("Connecting to Discord");
await StartCluster(services);
}
logger.Information("Connected! All is good (probably).");
// Lastly, we just... wait. Everything else is handled in the DiscordClient event loop
@ -149,7 +148,7 @@ public class Init
var builder = new ContainerBuilder();
builder.RegisterInstance(config);
builder.RegisterModule(new ConfigModule<BotConfig>("Bot"));
builder.RegisterModule(new LoggingModule("bot"));
builder.RegisterModule(new LoggingModule("dotnet-bot"));
builder.RegisterModule(new MetricsModule());
builder.RegisterModule<DataStoreModule>();
builder.RegisterModule<BotModule>();

View file

@ -28,7 +28,7 @@ public abstract class BaseInteractive
ButtonStyle style = ButtonStyle.Secondary, bool disabled = false)
{
var dispatch = _ctx.Services.Resolve<InteractionDispatchService>();
var customId = dispatch.Register(handler, Timeout);
var customId = dispatch.Register(_ctx.ShardId, handler, Timeout);
var button = new Button
{
@ -89,7 +89,7 @@ public abstract class BaseInteractive
{
var dispatch = ctx.Services.Resolve<InteractionDispatchService>();
foreach (var button in _buttons)
button.CustomId = dispatch.Register(button.Handler, Timeout);
button.CustomId = dispatch.Register(_ctx.ShardId, button.Handler, Timeout);
}
public abstract Task Start();

View file

@ -1,5 +1,6 @@
using Autofac;
using Myriad.Cache;
using Myriad.Gateway;
using Myriad.Rest.Types;
using Myriad.Types;
@ -69,6 +70,9 @@ public class YesNoPrompt: BaseInteractive
return true;
}
// no need to reawait message
// gateway will already have sent us only matching messages
return false;
}
@ -88,6 +92,17 @@ public class YesNoPrompt: BaseInteractive
{
try
{
// check if http gateway and set listener
// todo: this one needs to handle options for message
if (_ctx.Cache is HttpDiscordCache)
await (_ctx.Cache as HttpDiscordCache).AwaitMessage(
_ctx.Guild?.Id ?? 0,
_ctx.Channel.Id,
_ctx.Author.Id,
Timeout,
options: new[] { "yes", "y", "no", "n" }
);
await queue.WaitFor(MessagePredicate, Timeout, cts.Token);
}
catch (TimeoutException e)

View file

@ -49,8 +49,15 @@ public class BotModule: Module
if (botConfig.HttpCacheUrl != null)
{
var cache = new HttpDiscordCache(c.Resolve<ILogger>(),
c.Resolve<HttpClient>(), botConfig.HttpCacheUrl, botConfig.Cluster?.TotalShards ?? 1, botConfig.ClientId, botConfig.HttpUseInnerCache);
var cache = new HttpDiscordCache(
c.Resolve<ILogger>(),
c.Resolve<HttpClient>(),
botConfig.HttpCacheUrl,
botConfig.EventAwaiterTarget,
botConfig.Cluster?.TotalShards ?? 1,
botConfig.ClientId,
botConfig.HttpUseInnerCache
);
var metrics = c.Resolve<IMetrics>();
@ -153,6 +160,8 @@ public class BotModule: Module
builder.RegisterType<CommandMessageService>().AsSelf().SingleInstance();
builder.RegisterType<InteractionDispatchService>().AsSelf().SingleInstance();
builder.RegisterType<AvatarHostingService>().AsSelf().SingleInstance();
builder.RegisterType<HttpListenerService>().AsSelf().SingleInstance();
builder.RegisterType<RuntimeConfigService>().AsSelf().SingleInstance();
// Sentry stuff
builder.Register(_ => new Scope(null)).AsSelf().InstancePerLifetimeScope();

View file

@ -25,5 +25,6 @@
<ItemGroup>
<PackageReference Include="Humanizer.Core" Version="2.14.1" />
<PackageReference Include="Sentry" Version="4.13.0" />
<PackageReference Include="Watson.Lite" Version="6.3.5" />
</ItemGroup>
</Project>

View file

@ -246,7 +246,7 @@ public class ProxyService
ChannelId = rootChannel.Id,
ThreadId = threadId,
MessageId = trigger.Id,
Name = await FixSameName(messageChannel.Id, ctx, match.Member),
Name = await FixSameName(trigger.GuildId!.Value, messageChannel.Id, ctx, match.Member),
AvatarUrl = AvatarUtils.TryRewriteCdnUrl(match.Member.ProxyAvatar(ctx)),
Content = content,
Attachments = trigger.Attachments,
@ -458,11 +458,11 @@ public class ProxyService
};
}
private async Task<string> FixSameName(ulong channelId, MessageContext ctx, ProxyMember member)
private async Task<string> FixSameName(ulong guildId, ulong channelId, MessageContext ctx, ProxyMember member)
{
var proxyName = member.ProxyName(ctx);
var lastMessage = _lastMessage.GetLastMessage(channelId)?.Previous;
var lastMessage = (await _lastMessage.GetLastMessage(guildId, channelId))?.Previous;
if (lastMessage == null)
// cache is out of date or channel is empty.
return proxyName;

View file

@ -9,35 +9,35 @@ namespace PluralKit.Bot;
public class CommandMessageService
{
private readonly RedisService _redis;
private readonly ModelRepository _repo;
private readonly ILogger _logger;
private static readonly TimeSpan CommandMessageRetention = TimeSpan.FromHours(24);
public CommandMessageService(RedisService redis, IClock clock, ILogger logger)
public CommandMessageService(RedisService redis, ModelRepository repo, IClock clock, ILogger logger)
{
_redis = redis;
_repo = repo;
_logger = logger.ForContext<CommandMessageService>();
}
public async Task RegisterMessage(ulong messageId, ulong guildId, ulong channelId, ulong authorId)
{
if (_redis.Connection == null) return;
_logger.Debug(
"Registering command response {MessageId} from author {AuthorId} in {ChannelId}",
messageId, authorId, channelId
);
await _redis.Connection.GetDatabase().StringSetAsync(messageId.ToString(), $"{authorId}-{channelId}-{guildId}", expiry: CommandMessageRetention);
}
public async Task<CommandMessage?> GetCommandMessage(ulong messageId)
{
var repoMsg = await _repo.GetCommandMessage(messageId);
if (repoMsg != null)
return new CommandMessage(repoMsg.Sender, repoMsg.Channel, repoMsg.Guild);
var str = await _redis.Connection.GetDatabase().StringGetAsync(messageId.ToString());
if (str.HasValue)
{
var split = ((string)str).Split("-");
return new CommandMessage(ulong.Parse(split[0]), ulong.Parse(split[1]), ulong.Parse(split[2]));
}
str = await _redis.Connection.GetDatabase().StringGetAsync("command_message:" + messageId.ToString());
if (str.HasValue)
{
var split = ((string)str).Split("-");
return new CommandMessage(ulong.Parse(split[0]), ulong.Parse(split[1]), ulong.Parse(split[2]));
}
return null;
}
}

View file

@ -420,6 +420,24 @@ public class EmbedService
return eb.Build();
}
public async Task<Embed> CreateCommandMessageInfoEmbed(Core.CommandMessage msg, bool showContent)
{
var content = "*(command message deleted or inaccessible)*";
if (showContent)
{
var discordMessage = await _rest.GetMessageOrNull(msg.Channel, msg.OriginalMid);
if (discordMessage != null)
content = discordMessage.Content;
}
return new EmbedBuilder()
.Title("Command response message")
.Description(content)
.Field(new("Original message", $"https://discord.com/channels/{msg.Guild}/{msg.Channel}/{msg.OriginalMid}", true))
.Field(new("Sent by", $"<@{msg.Sender}>", true))
.Build();
}
public Task<Embed> CreateFrontPercentEmbed(FrontBreakdown breakdown, PKSystem system, PKGroup group,
DateTimeZone tz, LookupContext ctx, string embedTitle,
bool ignoreNoFronters, bool showFlat)

View file

@ -0,0 +1,146 @@
using System.Text;
using System.Text.Json;
using Serilog;
using WatsonWebserver.Lite;
using WatsonWebserver.Core;
using Myriad.Gateway;
using Myriad.Serialization;
namespace PluralKit.Bot;
public class HttpListenerService
{
private readonly ILogger _logger;
private readonly RuntimeConfigService _runtimeConfig;
private readonly Bot _bot;
public HttpListenerService(ILogger logger, RuntimeConfigService runtimeConfig, Bot bot)
{
_logger = logger.ForContext<HttpListenerService>();
_runtimeConfig = runtimeConfig;
_bot = bot;
}
public void Start(string host)
{
var hosts = new[] { host };
if (host == "allv4v6")
{
hosts = new[] { "[::]", "0.0.0.0" };
}
foreach (var h in hosts)
{
var server = new WebserverLite(new WebserverSettings(h, 5002), DefaultRoute);
server.Routes.PreAuthentication.Static.Add(WatsonWebserver.Core.HttpMethod.GET, "/runtime_config", RuntimeConfigGet);
server.Routes.PreAuthentication.Parameter.Add(WatsonWebserver.Core.HttpMethod.POST, "/runtime_config/{key}", RuntimeConfigSet);
server.Routes.PreAuthentication.Parameter.Add(WatsonWebserver.Core.HttpMethod.DELETE, "/runtime_config/{key}", RuntimeConfigDelete);
server.Routes.PreAuthentication.Parameter.Add(WatsonWebserver.Core.HttpMethod.POST, "/events/{shard_id}", GatewayEvent);
server.Start();
}
}
private async Task DefaultRoute(HttpContextBase ctx)
=> await ctx.Response.Send("hellorld");
private async Task RuntimeConfigGet(HttpContextBase ctx)
{
var config = _runtimeConfig.GetAll();
ctx.Response.Headers.Add("content-type", "application/json");
await ctx.Response.Send(JsonSerializer.Serialize(config));
}
private async Task RuntimeConfigSet(HttpContextBase ctx)
{
var key = ctx.Request.Url.Parameters["key"];
var value = ReadStream(ctx.Request.Data, ctx.Request.ContentLength);
await _runtimeConfig.Set(key, value);
await RuntimeConfigGet(ctx);
}
private async Task RuntimeConfigDelete(HttpContextBase ctx)
{
var key = ctx.Request.Url.Parameters["key"];
await _runtimeConfig.Delete(key);
await RuntimeConfigGet(ctx);
}
private JsonSerializerOptions _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
private async Task GatewayEvent(HttpContextBase ctx)
{
var shardIdString = ctx.Request.Url.Parameters["shard_id"];
if (!int.TryParse(shardIdString, out var shardId)) return;
var packet = JsonSerializer.Deserialize<GatewayPacket>(ReadStream(ctx.Request.Data, ctx.Request.ContentLength), _jsonSerializerOptions);
var evt = DeserializeEvent(shardId, packet.EventType!, (JsonElement)packet.Payload!);
if (evt != null)
{
await _bot.OnEventReceivedInner(shardId, evt);
}
await ctx.Response.Send("a");
}
private IGatewayEvent? DeserializeEvent(int shardId, string eventType, JsonElement payload)
{
if (!IGatewayEvent.EventTypes.TryGetValue(eventType, out var clrType))
{
_logger.Debug("Shard {ShardId}: Received unknown event type {EventType}", shardId, eventType);
return null;
}
try
{
_logger.Verbose("Shard {ShardId}: Deserializing {EventType} to {ClrType}", shardId, eventType,
clrType);
return JsonSerializer.Deserialize(payload.GetRawText(), clrType, _jsonSerializerOptions)
as IGatewayEvent;
}
catch (JsonException e)
{
_logger.Error(e, "Shard {ShardId}: Error deserializing event {EventType} to {ClrType}", shardId,
eventType, clrType);
return null;
}
}
//temporary re-implementation of the ReadStream function found in WatsonWebserver.Lite, but with handling for closed connections
//https://github.com/dotnet/WatsonWebserver/issues/171
private static string ReadStream(Stream input, long contentLength)
{
if (input == null) throw new ArgumentNullException(nameof(input));
if (!input.CanRead) throw new InvalidOperationException("Input stream is not readable");
if (contentLength < 1) return "";
byte[] buffer = new byte[65536];
long bytesRemaining = contentLength;
using (MemoryStream ms = new MemoryStream())
{
int read;
while (bytesRemaining > 0)
{
read = input.Read(buffer, 0, buffer.Length);
if (read > 0)
{
ms.Write(buffer, 0, read);
bytesRemaining -= read;
}
else
{
throw new IOException("Connection closed before reading end of stream.");
}
}
if (ms.Length < 1) return null;
var str = Encoding.Default.GetString(ms.ToArray());
return str;
}
}
}

View file

@ -1,5 +1,7 @@
using System.Collections.Concurrent;
using Myriad.Cache;
using NodaTime;
using Serilog;
@ -16,9 +18,12 @@ public class InteractionDispatchService: IDisposable
private readonly ConcurrentDictionary<Guid, RegisteredInteraction> _handlers = new();
private readonly ILogger _logger;
public InteractionDispatchService(IClock clock, ILogger logger)
private readonly IDiscordCache _cache;
public InteractionDispatchService(IClock clock, ILogger logger, IDiscordCache cache)
{
_clock = clock;
_cache = cache;
_logger = logger.ForContext<InteractionDispatchService>();
_cleanupWorker = CleanupLoop(_cts.Token);
@ -50,9 +55,15 @@ public class InteractionDispatchService: IDisposable
_handlers.TryRemove(customIdGuid, out _);
}
public string Register(Func<InteractionContext, Task> callback, Duration? expiry = null)
public string Register(int shardId, Func<InteractionContext, Task> callback, Duration? expiry = null)
{
var key = Guid.NewGuid();
// if http_cache, return RegisterRemote
// not awaited here, it's probably fine
if (_cache is HttpDiscordCache)
(_cache as HttpDiscordCache).AwaitInteraction(shardId, key.ToString(), expiry);
var handler = new RegisteredInteraction
{
Callback = callback,

View file

@ -1,6 +1,7 @@
#nullable enable
using System.Collections.Concurrent;
using Myriad.Cache;
using Myriad.Types;
namespace PluralKit.Bot;
@ -9,9 +10,18 @@ public class LastMessageCacheService
{
private readonly IDictionary<ulong, CacheEntry> _cache = new ConcurrentDictionary<ulong, CacheEntry>();
private readonly IDiscordCache _maybeHttp;
public LastMessageCacheService(IDiscordCache cache)
{
_maybeHttp = cache;
}
public void AddMessage(Message msg)
{
var previous = GetLastMessage(msg.ChannelId);
if (_maybeHttp is HttpDiscordCache) return;
var previous = _GetLastMessage(msg.ChannelId);
var current = ToCachedMessage(msg);
_cache[msg.ChannelId] = new CacheEntry(current, previous?.Current);
}
@ -19,12 +29,26 @@ public class LastMessageCacheService
private CachedMessage ToCachedMessage(Message msg) =>
new(msg.Id, msg.ReferencedMessage.Value?.Id, msg.Author.Username);
public CacheEntry? GetLastMessage(ulong channel) =>
_cache.TryGetValue(channel, out var message) ? message : null;
public async Task<CacheEntry?> GetLastMessage(ulong guild, ulong channel)
{
if (_maybeHttp is HttpDiscordCache)
return await (_maybeHttp as HttpDiscordCache).GetLastMessage<CacheEntry>(guild, channel);
return _cache.TryGetValue(channel, out var message) ? message : null;
}
public CacheEntry? _GetLastMessage(ulong channel)
{
if (_maybeHttp is HttpDiscordCache) return null;
return _cache.TryGetValue(channel, out var message) ? message : null;
}
public void HandleMessageDeletion(ulong channel, ulong message)
{
var storedMessage = GetLastMessage(channel);
if (_maybeHttp is HttpDiscordCache) return;
var storedMessage = _GetLastMessage(channel);
if (storedMessage == null)
return;
@ -39,7 +63,9 @@ public class LastMessageCacheService
public void HandleMessageDeletion(ulong channel, List<ulong> messages)
{
var storedMessage = GetLastMessage(channel);
if (_maybeHttp is HttpDiscordCache) return;
var storedMessage = _GetLastMessage(channel);
if (storedMessage == null)
return;

View file

@ -45,6 +45,7 @@ public class LoggerCleanService
private static readonly Regex _AnnabelleRegex = new("```\n(\\d{17,19})\n```");
private static readonly Regex _AnnabelleRegexFuzzy = new("\\<t:(\\d+)\\> A message from \\*\\*[\\w.]{2,32}\\*\\* \\(`(\\d{17,19})`\\) was deleted in <#\\d{17,19}>");
private static readonly Regex _koiraRegex = new("ID:\\*\\* (\\d{17,19})");
private static readonly Regex _zeppelinRegex = new("🗑 Message \\(`(\\d{17,19})`\\)");
private static readonly Regex _VortexRegex =
new("`\\[(\\d\\d:\\d\\d:\\d\\d)\\]` .* \\(ID:(\\d{17,19})\\).* <#\\d{17,19}>:");
@ -83,7 +84,8 @@ public class LoggerCleanService
new LoggerBot("Dozer", 356535250932858885, ExtractDozer),
new LoggerBot("Skyra", 266624760782258186, ExtractSkyra),
new LoggerBot("Annabelle", 231241068383961088, ExtractAnnabelle, fuzzyExtractFunc: ExtractAnnabelleFuzzy),
new LoggerBot("Koira", 1247013404569239624, ExtractKoira)
new LoggerBot("Koira", 1247013404569239624, ExtractKoira),
new LoggerBot("Zeppelin", 473868086773153793, ExtractZeppelin) // webhook
}.ToDictionary(b => b.Id);
private static Dictionary<ulong, LoggerBot> _botsByApplicationId
@ -441,6 +443,23 @@ public class LoggerCleanService
return match.Success ? ulong.Parse(match.Groups[1].Value) : null;
}
private static ulong? ExtractZeppelin(Message msg)
{
// zeppelin uses a non-embed format by default but can be configured to use a customizable embed
// if it's an embed, assume the footer contains the message ID
var embed = msg.Embeds?.FirstOrDefault();
if (embed == null)
{
var match = _zeppelinRegex.Match(msg.Content ?? "");
return match.Success ? ulong.Parse(match.Groups[1].Value) : null;
}
else
{
var match = _basicRegex.Match(embed.Footer?.Text ?? "");
return match.Success ? ulong.Parse(match.Groups[1].Value) : null;
}
}
public class LoggerBot
{
public ulong Id;

View file

@ -0,0 +1,58 @@
using Newtonsoft.Json;
using Serilog;
using StackExchange.Redis;
using PluralKit.Core;
namespace PluralKit.Bot;
public class RuntimeConfigService
{
private readonly RedisService _redis;
private readonly ILogger _logger;
private Dictionary<string, string> settings = new();
private string RedisKey;
public RuntimeConfigService(ILogger logger, RedisService redis, BotConfig config)
{
_logger = logger.ForContext<RuntimeConfigService>();
_redis = redis;
var clusterId = config.Cluster?.NodeIndex ?? 0;
RedisKey = $"remote_config:dotnet_bot:{clusterId}";
}
public async Task LoadConfig()
{
var redisConfig = await _redis.Connection.GetDatabase().HashGetAllAsync(RedisKey);
foreach (var entry in redisConfig)
settings.Add(entry.Name, entry.Value);
var configStr = JsonConvert.SerializeObject(settings);
_logger.Information($"starting with runtime config: {configStr}");
}
public async Task Set(string key, string value)
{
await _redis.Connection.GetDatabase().HashSetAsync(RedisKey, new[] { new HashEntry(key, new RedisValue(value)) });
settings.Add(key, value);
_logger.Information($"updated runtime config: {key}={value}");
}
public async Task Delete(string key)
{
await _redis.Connection.GetDatabase().HashDeleteAsync(RedisKey, key);
settings.Remove(key);
_logger.Information($"updated runtime config: {key} removed");
}
public object? Get(string key) => settings.GetValueOrDefault(key);
public bool Exists(string key) => settings.ContainsKey(key);
public Dictionary<string, string> GetAll() => settings;
}

View file

@ -1,6 +1,7 @@
using Autofac;
using Myriad.Builders;
using Myriad.Cache;
using Myriad.Gateway;
using Myriad.Rest.Exceptions;
using Myriad.Rest.Types.Requests;
@ -40,8 +41,12 @@ public static class ContextUtils
}
public static async Task<MessageReactionAddEvent> AwaitReaction(this Context ctx, Message message,
User user = null, Func<MessageReactionAddEvent, bool> predicate = null, Duration? timeout = null)
User user, Func<MessageReactionAddEvent, bool> predicate = null, Duration? timeout = null)
{
// check if http gateway and set listener
if (ctx.Cache is HttpDiscordCache)
await (ctx.Cache as HttpDiscordCache).AwaitReaction(ctx.Guild?.Id ?? 0, message.Id, user!.Id, timeout);
bool ReactionPredicate(MessageReactionAddEvent evt)
{
if (message.Id != evt.MessageId) return false; // Ignore reactions for different messages
@ -57,11 +62,17 @@ public static class ContextUtils
public static async Task<bool> ConfirmWithReply(this Context ctx, string expectedReply, bool treatAsHid = false)
{
var timeout = Duration.FromMinutes(1);
// check if http gateway and set listener
if (ctx.Cache is HttpDiscordCache)
await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout);
bool Predicate(MessageCreateEvent e) =>
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
.WaitFor(Predicate, Duration.FromMinutes(1));
.WaitFor(Predicate, timeout);
var content = msg.Content;
if (treatAsHid)
@ -96,11 +107,17 @@ public static class ContextUtils
async Task<int> PromptPageNumber()
{
var timeout = Duration.FromMinutes(0.5);
// check if http gateway and set listener
if (ctx.Cache is HttpDiscordCache)
await (ctx.Cache as HttpDiscordCache).AwaitMessage(ctx.Guild?.Id ?? 0, ctx.Channel.Id, ctx.Author.Id, timeout);
bool Predicate(MessageCreateEvent e) =>
e.Author.Id == ctx.Author.Id && e.ChannelId == ctx.Channel.Id;
var msg = await ctx.Services.Resolve<HandlerQueue<MessageCreateEvent>>()
.WaitFor(Predicate, Duration.FromMinutes(0.5));
.WaitFor(Predicate, timeout);
int.TryParse(msg.Content, out var num);

View file

@ -1,7 +1,8 @@
using System.Net;
using System.Net.Sockets;
using System.Globalization;
using Myriad.Rest.Exceptions;
using Myriad.Rest.Types;
using Newtonsoft.Json;
@ -102,4 +103,26 @@ public static class MiscUtils
return true;
}
public static MultipartFile GenerateColorPreview(string color)
{
//generate a 128x128 solid color gif from bytes
//image data is a 1x1 pixel, using the background color to fill the rest of the canvas
var imgBytes = new byte[]
{
0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // Header
0x80, 0x00, 0x80, 0x00, 0x80, 0x00, 0x00, // Logical Screen Descriptor
0xFF, 0xFF, 0xFF, 0x00, 0x00, 0x00, // Global Color Table
0x21, 0xF9, 0x04, 0x08, 0x00, 0x00, 0x00, 0x00, // Graphics Control Extension
0x2C, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, // Image Descriptor
0x02, 0x02, 0x4C, 0x01, 0x00, // Image Data
0x3B // Trailer
}; //indices 13, 14 and 15 are the R, G, and B values respectively
imgBytes[13] = byte.Parse(color.Substring(0, 2), NumberStyles.HexNumber);
imgBytes[14] = byte.Parse(color.Substring(2, 2), NumberStyles.HexNumber);
imgBytes[15] = byte.Parse(color.Substring(4, 2), NumberStyles.HexNumber);
return new MultipartFile("color.gif", new MemoryStream(imgBytes), null, null, null);
}
}

View file

@ -14,6 +14,16 @@
"resolved": "4.13.0",
"contentHash": "Wfw3M1WpFcrYaGzPm7QyUTfIOYkVXQ1ry6p4WYjhbLz9fPwV23SGQZTFDpdox67NHM0V0g1aoQ4YKLm4ANtEEg=="
},
"Watson.Lite": {
"type": "Direct",
"requested": "[6.3.5, )",
"resolved": "6.3.5",
"contentHash": "YF8+se3IVenn8YlyNeb4wSJK6QMnVD0QHIOEiZ22wS4K2wkwoSDzWS+ZAjk1MaPeB+XO5gRoENUN//pOc+wI2g==",
"dependencies": {
"CavemanTcp": "2.0.5",
"Watson.Core": "6.3.5"
}
},
"App.Metrics": {
"type": "Transitive",
"resolved": "4.3.0",
@ -107,6 +117,11 @@
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.1"
}
},
"CavemanTcp": {
"type": "Transitive",
"resolved": "2.0.5",
"contentHash": "90wywmGpjrj26HMAkufYZwuZI8sVYB1mRwEdqugSR3kgDnPX+3l0jO86gwtFKsPvsEpsS4Dn/1EbhguzUxMU8Q=="
},
"Dapper": {
"type": "Transitive",
"resolved": "2.1.35",
@ -130,6 +145,11 @@
"System.Diagnostics.DiagnosticSource": "5.0.0"
}
},
"IpMatcher": {
"type": "Transitive",
"resolved": "1.0.5",
"contentHash": "WXNlWERj+0GN699AnMNsuJ7PfUAbU4xhOHP3nrNXLHqbOaBxybu25luSYywX1133NSlitA4YkSNmJuyPvea4sw=="
},
"IPNetwork2": {
"type": "Transitive",
"resolved": "3.0.667",
@ -391,18 +411,18 @@
"resolved": "8.5.0",
"contentHash": "VYYMZNitZ85UEhwOKkTQI63WEMvzUqwQc74I2mm8h/DBVAMcBBxqYPni4DmuRtbCwngmuONuK2yBJfWNRKzI+A=="
},
"Serilog": {
"RegexMatcher": {
"type": "Transitive",
"resolved": "4.2.0",
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
"resolved": "1.0.9",
"contentHash": "RkQGXIrqHjD5h1mqefhgCbkaSdRYNRG5rrbzyw5zeLWiS0K1wq9xR3cNhQdzYR2MsKZ3GN523yRUsEQIMPxh3Q=="
},
"Serilog.Extensions.Logging": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
"resolved": "8.0.0",
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
"dependencies": {
"Microsoft.Extensions.Logging": "9.0.0",
"Serilog": "4.2.0"
"Microsoft.Extensions.Logging": "8.0.0",
"Serilog": "3.1.1"
}
},
"Serilog.Formatting.Compact": {
@ -714,12 +734,36 @@
"System.Runtime": "4.3.0"
}
},
"Timestamps": {
"type": "Transitive",
"resolved": "1.0.11",
"contentHash": "SnWhXm3FkEStQGgUTfWMh9mKItNW032o/v8eAtFrOGqG0/ejvPPA1LdLZx0N/qqoY0TH3x11+dO00jeVcM8xNQ=="
},
"UrlMatcher": {
"type": "Transitive",
"resolved": "3.0.1",
"contentHash": "hHBZVzFSfikrx4XsRsnCIwmGLgbNKtntnlqf4z+ygcNA6Y/L/J0x5GiZZWfXdTfpxhy5v7mlt2zrZs/L9SvbOA=="
},
"Watson.Core": {
"type": "Transitive",
"resolved": "6.3.5",
"contentHash": "Y5YxKOCSLe2KDmfwvI/J0qApgmmZR77LwyoufRVfKH7GLdHiE7fY0IfoNxWTG7nNv8knBfgwyOxdehRm+4HaCg==",
"dependencies": {
"IpMatcher": "1.0.5",
"RegexMatcher": "1.0.9",
"System.Text.Json": "8.0.5",
"Timestamps": "1.0.11",
"UrlMatcher": "3.0.1"
}
},
"myriad": {
"type": "Project",
"dependencies": {
"NodaTime": "[3.2.0, )",
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
"Polly": "[8.5.0, )",
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
"Serilog": "[4.2.0, )",
"Serilog": "[4.1.0, )",
"StackExchange.Redis": "[2.8.22, )",
"System.Linq.Async": "[6.0.1, )"
}
@ -747,8 +791,8 @@
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
"Npgsql": "[9.0.2, )",
"Npgsql.NodaTime": "[9.0.2, )",
"Serilog": "[4.2.0, )",
"Serilog.Extensions.Logging": "[9.0.0, )",
"Serilog": "[4.1.0, )",
"Serilog.Extensions.Logging": "[8.0.0, )",
"Serilog.Formatting.Compact": "[3.0.0, )",
"Serilog.NodaTime": "[3.0.0, )",
"Serilog.Sinks.Async": "[2.1.0, )",
@ -762,6 +806,9 @@
"System.Interactive.Async": "[6.0.1, )",
"ipnetwork2": "[3.0.667, )"
}
},
"serilog": {
"type": "Project"
}
}
}

View file

@ -19,5 +19,4 @@ public class CoreConfig
public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Debug;
public LogEventLevel ElasticLogLevel { get; set; } = LogEventLevel.Information;
public LogEventLevel FileLogLevel { get; set; } = LogEventLevel.Information;
}

View file

@ -19,16 +19,40 @@ public partial class ModelRepository
}
public Task<SystemGuildSettings> GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true)
public async Task<SystemGuildSettings> GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true, bool search = false)
{
if (!defaultInsert)
return _db.QueryFirst<SystemGuildSettings>(new Query("system_guild")
{
var simpleRes = await _db.QueryFirst<SystemGuildSettings>(new Query("system_guild")
.Where("guild", guild)
.Where("system", system)
);
if (simpleRes != null || !search)
return simpleRes;
var accounts = await GetSystemAccounts(system);
var searchRes = await _db.QueryFirst<bool>(
"select exists(select 1 from command_messages where guild = @guild and sender = any(@accounts))",
new { guild = guild, accounts = accounts.Select(u => (long)u).ToArray() },
queryName: "find_system_from_commands",
messages: true
);
if (!searchRes)
searchRes = await _db.QueryFirst<bool>(
"select exists(select 1 from command_messages where guild = @guild and sender = any(@accounts))",
new { guild = guild, accounts = accounts.Select(u => (long)u).ToArray() },
queryName: "find_system_from_messages",
messages: true
);
if (!searchRes)
return null;
}
var query = new Query("system_guild").AsInsert(new { guild, system });
return _db.QueryFirst<SystemGuildSettings>(query,
return await _db.QueryFirst<SystemGuildSettings>(query,
"on conflict (guild, system) do update set guild = $1, system = $2 returning *"
);
}
@ -42,16 +66,25 @@ public partial class ModelRepository
return settings;
}
public Task<MemberGuildSettings> GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true)
public async Task<MemberGuildSettings> GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true, SystemId? search = null)
{
if (!defaultInsert)
return _db.QueryFirst<MemberGuildSettings>(new Query("member_guild")
{
var simpleRes = await _db.QueryFirst<MemberGuildSettings>(new Query("member_guild")
.Where("guild", guild)
.Where("member", member)
);
if (simpleRes != null || !search.HasValue)
return simpleRes;
var systemConfig = await GetSystemGuild(guild, search.Value, defaultInsert: false, search: true);
if (systemConfig == null)
return null;
}
var query = new Query("member_guild").AsInsert(new { guild, member });
return _db.QueryFirst<MemberGuildSettings>(query,
return await _db.QueryFirst<MemberGuildSettings>(query,
"on conflict (guild, member) do update set guild = $1, member = $2 returning *"
);
}

View file

@ -42,12 +42,37 @@ public partial class ModelRepository
};
}
public async Task AddCommandMessage(CommandMessage msg)
{
var query = new Query("command_messages").AsInsert(new
{
mid = msg.Mid,
guild = msg.Guild,
channel = msg.Channel,
sender = msg.Sender,
original_mid = msg.OriginalMid
});
await _db.ExecuteQuery(query, messages: true);
_logger.Debug("Stored command message {@StoredMessage} in channel {Channel}", msg, msg.Channel);
}
public Task<CommandMessage?> GetCommandMessage(ulong id)
=> _db.QueryFirst<CommandMessage?>(new Query("command_messages").Where("mid", id), messages: true);
public async Task DeleteMessage(ulong id)
{
var query = new Query("messages").AsDelete().Where("mid", id);
var rowCount = await _db.ExecuteQuery(query, messages: true);
if (rowCount > 0)
_logger.Information("Deleted message {MessageId} from database", id);
else
{
var cquery = new Query("command_messages").AsDelete().Where("mid", id);
var crowCount = await _db.ExecuteQuery(query, messages: true);
if (crowCount > 0)
_logger.Information("Deleted command message {MessageId} from database", id);
}
}
public async Task DeleteMessagesBulk(IReadOnlyCollection<ulong> ids)
@ -59,5 +84,19 @@ public partial class ModelRepository
if (rowCount > 0)
_logger.Information("Bulk deleted messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids);
var cquery = new Query("command_messages").AsDelete().WhereIn("mid", ids.Select(id => (long)id).ToArray());
var crowCount = await _db.ExecuteQuery(query, messages: true);
if (crowCount > 0)
_logger.Information("Bulk deleted command messages ({FoundCount} found) from database: {MessageIds}", rowCount,
ids);
}
}
public class CommandMessage
{
public ulong Mid { get; set; }
public ulong Guild { get; set; }
public ulong Channel { get; set; }
public ulong Sender { get; set; }
public ulong OriginalMid { get; set; }
}

View file

@ -9,7 +9,7 @@ namespace PluralKit.Core;
internal class DatabaseMigrator
{
private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files
private const int TargetSchemaVersion = 51;
private const int TargetSchemaVersion = 52;
private readonly ILogger _logger;
public DatabaseMigrator(ILogger logger)

View file

@ -10,8 +10,7 @@ using NodaTime;
using Serilog;
using Serilog.Events;
using Serilog.Formatting.Compact;
using Serilog.Sinks.Seq;
using Serilog.Formatting.Json;
using Serilog.Sinks.SystemConsole.Themes;
using ILogger = Serilog.ILogger;
@ -50,14 +49,9 @@ public class LoggingModule: Module
private ILogger InitLogger(CoreConfig config)
{
var consoleTemplate = "[{Timestamp:HH:mm:ss.fff}] {Level:u3} {Message:lj}{NewLine}{Exception}";
var outputTemplate = "[{Timestamp:yyyy-MM-dd HH:mm:ss.ffffff}] {Level:u3} {Message:lj}{NewLine}{Exception}";
var logCfg = _cfg
.Enrich.FromLogContext()
.Enrich.WithProperty("GitCommitHash", BuildInfoService.FullVersion)
.ConfigureForNodaTime(DateTimeZoneProviders.Tzdb)
.Enrich.WithProperty("Component", _component)
.MinimumLevel.Is(config.ConsoleLogLevel)
// Don't want App.Metrics/D#+ spam
@ -73,35 +67,10 @@ public class LoggingModule: Module
.Destructure.AsScalar<SwitchId>()
.Destructure.ByTransforming<ProxyTag>(t => new { t.Prefix, t.Suffix })
.Destructure.With<PatchObjectDestructuring>()
.WriteTo.Async(a =>
{
// Both the same output, except one is raw compact JSON and one is plain text.
// Output simultaneously. May remove the JSON formatter later, keeping it just in cast.
// Flush interval is 50ms (down from 10s) to make "tail -f" easier. May be too low?
a.File(
(config.LogDir ?? "logs") + $"/pluralkit.{_component}.log",
outputTemplate: outputTemplate,
retainedFileCountLimit: 10,
rollingInterval: RollingInterval.Day,
fileSizeLimitBytes: null,
flushToDiskInterval: TimeSpan.FromMilliseconds(50),
restrictedToMinimumLevel: config.FileLogLevel,
formatProvider: new UTCTimestampFormatProvider(),
buffered: true);
a.File(
new RenderedCompactJsonFormatter(new ScalarFormatting.JsonValue()),
(config.LogDir ?? "logs") + $"/pluralkit.{_component}.json",
rollingInterval: RollingInterval.Day,
flushToDiskInterval: TimeSpan.FromMilliseconds(50),
restrictedToMinimumLevel: config.FileLogLevel,
buffered: true);
})
.WriteTo.Async(a =>
a.Console(
theme: AnsiConsoleTheme.Code,
outputTemplate: consoleTemplate,
restrictedToMinimumLevel: config.ConsoleLogLevel));
new CustomJsonFormatter(_component),
config.ConsoleLogLevel));
if (config.ElasticUrl != null)
{
@ -113,15 +82,6 @@ public class LoggingModule: Module
);
}
if (config.SeqLogUrl != null)
{
logCfg.WriteTo.Seq(
config.SeqLogUrl,
restrictedToMinimumLevel: LogEventLevel.Verbose
);
}
_fn.Invoke(logCfg);
return Log.Logger = logCfg.CreateLogger();
}

View file

@ -15,6 +15,10 @@
<RestorePackagesWithLockFile>true</RestorePackagesWithLockFile>
</PropertyGroup>
<ItemGroup>
<ProjectReference Include="..\Serilog\src\Serilog\Serilog.csproj" />
</ItemGroup>
<ItemGroup>
<PackageReference Include="App.Metrics" Version="4.3.0" />
<PackageReference Include="App.Metrics.Reporting.InfluxDB" Version="4.3.0" />
@ -37,8 +41,7 @@
<PackageReference Include="NodaTime.Serialization.JsonNet" Version="3.1.0" />
<PackageReference Include="Npgsql" Version="9.0.2" />
<PackageReference Include="Npgsql.NodaTime" Version="9.0.2" />
<PackageReference Include="Serilog" Version="4.2.0" />
<PackageReference Include="Serilog.Extensions.Logging" Version="9.0.0" />
<PackageReference Include="Serilog.Extensions.Logging" Version="8.0.0" />
<PackageReference Include="Serilog.Formatting.Compact" Version="3.0.0" />
<PackageReference Include="Serilog.NodaTime" Version="3.0.0" />
<PackageReference Include="Serilog.Sinks.Async" Version="2.1.0" />

View file

@ -205,7 +205,7 @@ public partial class BulkImporter
? existingSwitches.Select(sw => sw.Id).Max()
: (SwitchId?)null;
if (switches.Count > 10000)
if (switches.Count > 100000)
throw new ImportException("Too many switches present in import file.");
// Import switch definitions

View file

@ -0,0 +1,215 @@
using System.Runtime.CompilerServices;
using Serilog.Events;
using Serilog.Formatting;
using Serilog.Formatting.Json;
using Serilog.Parsing;
using Serilog.Rendering;
// Customized Serilog JSON output for PluralKit
// Copyright 2013-2015 Serilog Contributors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
namespace PluralKit.Core;
static class Guard
{
public static T AgainstNull<T>(
T? argument,
[CallerArgumentExpression("argument")] string? paramName = null)
where T : class
{
if (argument is null)
{
throw new ArgumentNullException(paramName);
}
return argument;
}
}
/// <summary>
/// Formats log events in a simple JSON structure. Instances of this class
/// are safe for concurrent access by multiple threads.
/// </summary>
/// <remarks>New code should prefer formatters from <c>Serilog.Formatting.Compact</c>, or <c>ExpressionTemplate</c> from
/// <c>Serilog.Expressions</c>.</remarks>
public sealed class CustomJsonFormatter: ITextFormatter
{
readonly JsonValueFormatter _jsonValueFormatter = new();
readonly string _component;
/// <summary>
/// Construct a <see cref="JsonFormatter"/>.
/// </summary>
/// <param name="closingDelimiter">A string that will be written after each log event is formatted.
/// If null, <see cref="Environment.NewLine"/> will be used.</param>
/// <param name="renderMessage">If <see langword="true"/>, the message will be rendered and written to the output as a
/// property named RenderedMessage.</param>
/// <param name="formatProvider">Supplies culture-specific formatting information, or null.</param>
public CustomJsonFormatter(string component)
{
_component = component;
}
private string CustomLevelString(LogEventLevel level)
{
switch (level)
{
case LogEventLevel.Verbose:
return "TRACE";
case LogEventLevel.Debug:
return "DEBUG";
case LogEventLevel.Information:
return "INFO";
case LogEventLevel.Warning:
return "WARN";
case LogEventLevel.Error:
return "ERROR";
case LogEventLevel.Fatal:
return "FATAL";
};
return "UNKNOWN";
}
/// <summary>
/// Format the log event into the output.
/// </summary>
/// <param name="logEvent">The event to format.</param>
/// <param name="output">The output.</param>
/// <exception cref="ArgumentNullException">When <paramref name="logEvent"/> is <code>null</code></exception>
/// <exception cref="ArgumentNullException">When <paramref name="output"/> is <code>null</code></exception>
public void Format(LogEvent logEvent, TextWriter output)
{
Guard.AgainstNull(logEvent);
Guard.AgainstNull(output);
output.Write("{\"component\":\"");
output.Write(_component);
output.Write("\",\"timestamp\":\"");
output.Write(logEvent.Timestamp.ToString("O").Replace("+00:00", "Z"));
output.Write("\",\"level\":\"");
output.Write(CustomLevelString(logEvent.Level));
output.Write("\",\"message\":");
var message = logEvent.MessageTemplate.Render(logEvent.Properties);
JsonValueFormatter.WriteQuotedJsonString(message, output);
if (logEvent.TraceId != null)
{
output.Write(",\"TraceId\":");
JsonValueFormatter.WriteQuotedJsonString(logEvent.TraceId.ToString()!, output);
}
if (logEvent.SpanId != null)
{
output.Write(",\"SpanId\":");
JsonValueFormatter.WriteQuotedJsonString(logEvent.SpanId.ToString()!, output);
}
if (logEvent.Exception != null)
{
output.Write(",\"Exception\":");
JsonValueFormatter.WriteQuotedJsonString(logEvent.Exception.ToString(), output);
}
if (logEvent.Properties.Count != 0)
{
output.Write(",\"Properties\":{");
char? propertyDelimiter = null;
foreach (var property in logEvent.Properties)
{
if (propertyDelimiter != null)
output.Write(propertyDelimiter.Value);
else
propertyDelimiter = ',';
JsonValueFormatter.WriteQuotedJsonString(property.Key, output);
output.Write(':');
_jsonValueFormatter.Format(property.Value, output);
}
output.Write('}');
}
var tokensWithFormat = logEvent.MessageTemplate.Tokens
.OfType<PropertyToken>()
.Where(pt => pt.Format != null)
.GroupBy(pt => pt.PropertyName)
.ToArray();
if (tokensWithFormat.Length != 0)
{
output.Write(",\"Renderings\":{");
WriteRenderingsValues(tokensWithFormat, logEvent.Properties, output);
output.Write('}');
}
output.Write('}');
output.Write("\n");
}
void WriteRenderingsValues(IEnumerable<IGrouping<string, PropertyToken>> tokensWithFormat, IReadOnlyDictionary<string, LogEventPropertyValue> properties, TextWriter output)
{
static void WriteNameValuePair(string name, string value, ref char? precedingDelimiter, TextWriter output)
{
if (precedingDelimiter != null)
output.Write(precedingDelimiter.Value);
JsonValueFormatter.WriteQuotedJsonString(name, output);
output.Write(':');
JsonValueFormatter.WriteQuotedJsonString(value, output);
precedingDelimiter = ',';
}
char? propertyDelimiter = null;
foreach (var propertyFormats in tokensWithFormat)
{
if (propertyDelimiter != null)
output.Write(propertyDelimiter.Value);
else
propertyDelimiter = ',';
output.Write('"');
output.Write(propertyFormats.Key);
output.Write("\":[");
char? formatDelimiter = null;
foreach (var format in propertyFormats)
{
if (formatDelimiter != null)
output.Write(formatDelimiter.Value);
formatDelimiter = ',';
output.Write('{');
char? elementDelimiter = null;
// Caller ensures that `tokensWithFormat` contains only property tokens that have non-null `Format`s.
WriteNameValuePair("Format", format.Format!, ref elementDelimiter, output);
using var sw = ReusableStringWriter.GetOrCreate();
MessageTemplateRenderer.RenderPropertyToken(format, properties, sw, null, isLiteral: true, isJson: false);
WriteNameValuePair("Rendering", sw.ToString(), ref elementDelimiter, output);
output.Write('}');
}
output.Write(']');
}
}
}

View file

@ -198,20 +198,14 @@
"Npgsql": "9.0.2"
}
},
"Serilog": {
"type": "Direct",
"requested": "[4.2.0, )",
"resolved": "4.2.0",
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
},
"Serilog.Extensions.Logging": {
"type": "Direct",
"requested": "[9.0.0, )",
"resolved": "9.0.0",
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
"requested": "[8.0.0, )",
"resolved": "8.0.0",
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
"dependencies": {
"Microsoft.Extensions.Logging": "9.0.0",
"Serilog": "4.2.0"
"Microsoft.Extensions.Logging": "8.0.0",
"Serilog": "3.1.1"
}
},
"Serilog.Formatting.Compact": {
@ -722,6 +716,9 @@
"Microsoft.NETCore.Targets": "1.1.0",
"System.Runtime": "4.3.0"
}
},
"serilog": {
"type": "Project"
}
}
}

View file

@ -128,6 +128,11 @@
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.1"
}
},
"CavemanTcp": {
"type": "Transitive",
"resolved": "2.0.5",
"contentHash": "90wywmGpjrj26HMAkufYZwuZI8sVYB1mRwEdqugSR3kgDnPX+3l0jO86gwtFKsPvsEpsS4Dn/1EbhguzUxMU8Q=="
},
"Dapper": {
"type": "Transitive",
"resolved": "2.1.35",
@ -156,6 +161,11 @@
"resolved": "2.14.1",
"contentHash": "lQKvtaTDOXnoVJ20ibTuSIOf2i0uO0MPbDhd1jm238I+U/2ZnRENj0cktKZhtchBMtCUSRQ5v4xBCUbKNmyVMw=="
},
"IpMatcher": {
"type": "Transitive",
"resolved": "1.0.5",
"contentHash": "WXNlWERj+0GN699AnMNsuJ7PfUAbU4xhOHP3nrNXLHqbOaBxybu25luSYywX1133NSlitA4YkSNmJuyPvea4sw=="
},
"IPNetwork2": {
"type": "Transitive",
"resolved": "3.0.667",
@ -310,21 +320,21 @@
},
"Microsoft.Extensions.DependencyModel": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "saxr2XzwgDU77LaQfYFXmddEDRUKHF4DaGMZkNB3qjdVSZlax3//dGJagJkKrGMIPNZs2jVFXITyCCR6UHJNdA==",
"resolved": "8.0.0",
"contentHash": "NSmDw3K0ozNDgShSIpsZcbFIzBX4w28nDag+TfaQujkXGazBm+lid5onlWoCBy4VsLxqnnKjEBbGSJVWJMf43g==",
"dependencies": {
"System.Text.Encodings.Web": "9.0.0",
"System.Text.Json": "9.0.0"
"System.Text.Encodings.Web": "8.0.0",
"System.Text.Json": "8.0.0"
}
},
"Microsoft.Extensions.Diagnostics.Abstractions": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "1K8P7XzuzX8W8pmXcZjcrqS6x5eSSdvhQohmcpgiQNY/HlDAlnrhR9dvlURfFz428A+RTCJpUyB+aKTA6AgVcQ==",
"resolved": "8.0.0",
"contentHash": "JHYCQG7HmugNYUhOl368g+NMxYE/N/AiclCYRNlgCY9eVyiBkOHMwK4x60RYMxv9EL3+rmj1mqHvdCiPpC+D4Q==",
"dependencies": {
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
"Microsoft.Extensions.Options": "9.0.0",
"System.Diagnostics.DiagnosticSource": "9.0.0"
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
"Microsoft.Extensions.Options": "8.0.0",
"System.Diagnostics.DiagnosticSource": "8.0.0"
}
},
"Microsoft.Extensions.FileProviders.Abstractions": {
@ -352,14 +362,14 @@
},
"Microsoft.Extensions.Hosting.Abstractions": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "yUKJgu81ExjvqbNWqZKshBbLntZMbMVz/P7Way2SBx7bMqA08Mfdc9O7hWDKAiSp+zPUGT6LKcSCQIPeDK+CCw==",
"resolved": "8.0.0",
"contentHash": "AG7HWwVRdCHlaA++1oKDxLsXIBxmDpMPb3VoyOoAghEWnkUvEAdYQUwnV4jJbAaa/nMYNiEh5ByoLauZBEiovg==",
"dependencies": {
"Microsoft.Extensions.Configuration.Abstractions": "9.0.0",
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
"Microsoft.Extensions.Diagnostics.Abstractions": "9.0.0",
"Microsoft.Extensions.FileProviders.Abstractions": "9.0.0",
"Microsoft.Extensions.Logging.Abstractions": "9.0.0"
"Microsoft.Extensions.Configuration.Abstractions": "8.0.0",
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
"Microsoft.Extensions.Diagnostics.Abstractions": "8.0.0",
"Microsoft.Extensions.FileProviders.Abstractions": "8.0.0",
"Microsoft.Extensions.Logging.Abstractions": "8.0.0"
}
},
"Microsoft.Extensions.Logging": {
@ -510,49 +520,52 @@
"resolved": "8.5.0",
"contentHash": "VYYMZNitZ85UEhwOKkTQI63WEMvzUqwQc74I2mm8h/DBVAMcBBxqYPni4DmuRtbCwngmuONuK2yBJfWNRKzI+A=="
},
"RegexMatcher": {
"type": "Transitive",
"resolved": "1.0.9",
"contentHash": "RkQGXIrqHjD5h1mqefhgCbkaSdRYNRG5rrbzyw5zeLWiS0K1wq9xR3cNhQdzYR2MsKZ3GN523yRUsEQIMPxh3Q=="
},
"Sentry": {
"type": "Transitive",
"resolved": "4.13.0",
"contentHash": "Wfw3M1WpFcrYaGzPm7QyUTfIOYkVXQ1ry6p4WYjhbLz9fPwV23SGQZTFDpdox67NHM0V0g1aoQ4YKLm4ANtEEg=="
},
"Serilog": {
"type": "Transitive",
"resolved": "4.2.0",
"contentHash": "gmoWVOvKgbME8TYR+gwMf7osROiWAURterc6Rt2dQyX7wtjZYpqFiA/pY6ztjGQKKV62GGCyOcmtP1UKMHgSmA=="
},
"Serilog.AspNetCore": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "JslDajPlBsn3Pww1554flJFTqROvK9zz9jONNQgn0D8Lx2Trw8L0A8/n6zEQK1DAZWXrJwiVLw8cnTR3YFuYsg==",
"resolved": "8.0.0",
"contentHash": "FAjtKPZ4IzqFQBqZKPv6evcXK/F0ls7RoXI/62Pnx2igkDZ6nZ/jn/C/FxVATqQbEQvtqP+KViWYIe4NZIHa2w==",
"dependencies": {
"Serilog": "4.2.0",
"Serilog.Extensions.Hosting": "9.0.0",
"Serilog.Formatting.Compact": "3.0.0",
"Serilog.Settings.Configuration": "9.0.0",
"Serilog.Sinks.Console": "6.0.0",
"Serilog.Sinks.Debug": "3.0.0",
"Serilog.Sinks.File": "6.0.0"
"Microsoft.Extensions.DependencyInjection": "8.0.0",
"Microsoft.Extensions.Logging": "8.0.0",
"Serilog": "3.1.1",
"Serilog.Extensions.Hosting": "8.0.0",
"Serilog.Extensions.Logging": "8.0.0",
"Serilog.Formatting.Compact": "2.0.0",
"Serilog.Settings.Configuration": "8.0.0",
"Serilog.Sinks.Console": "5.0.0",
"Serilog.Sinks.Debug": "2.0.0",
"Serilog.Sinks.File": "5.0.0"
}
},
"Serilog.Extensions.Hosting": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "u2TRxuxbjvTAldQn7uaAwePkWxTHIqlgjelekBtilAGL5sYyF3+65NWctN4UrwwGLsDC7c3Vz3HnOlu+PcoxXg==",
"resolved": "8.0.0",
"contentHash": "db0OcbWeSCvYQkHWu6n0v40N4kKaTAXNjlM3BKvcbwvNzYphQFcBR+36eQ/7hMMwOkJvAyLC2a9/jNdUL5NjtQ==",
"dependencies": {
"Microsoft.Extensions.DependencyInjection.Abstractions": "9.0.0",
"Microsoft.Extensions.Hosting.Abstractions": "9.0.0",
"Microsoft.Extensions.Logging.Abstractions": "9.0.0",
"Serilog": "4.2.0",
"Serilog.Extensions.Logging": "9.0.0"
"Microsoft.Extensions.DependencyInjection.Abstractions": "8.0.0",
"Microsoft.Extensions.Hosting.Abstractions": "8.0.0",
"Microsoft.Extensions.Logging.Abstractions": "8.0.0",
"Serilog": "3.1.1",
"Serilog.Extensions.Logging": "8.0.0"
}
},
"Serilog.Extensions.Logging": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "NwSSYqPJeKNzl5AuXVHpGbr6PkZJFlNa14CdIebVjK3k/76kYj/mz5kiTRNVSsSaxM8kAIa1kpy/qyT9E4npRQ==",
"resolved": "8.0.0",
"contentHash": "YEAMWu1UnWgf1c1KP85l1SgXGfiVo0Rz6x08pCiPOIBt2Qe18tcZLvdBUuV5o1QHvrs8FAry9wTIhgBRtjIlEg==",
"dependencies": {
"Microsoft.Extensions.Logging": "9.0.0",
"Serilog": "4.2.0"
"Microsoft.Extensions.Logging": "8.0.0",
"Serilog": "3.1.1"
}
},
"Serilog.Formatting.Compact": {
@ -582,12 +595,12 @@
},
"Serilog.Settings.Configuration": {
"type": "Transitive",
"resolved": "9.0.0",
"contentHash": "4/Et4Cqwa+F88l5SeFeNZ4c4Z6dEAIKbu3MaQb2Zz9F/g27T5a3wvfMcmCOaAiACjfUb4A6wrlTVfyYUZk3RRQ==",
"resolved": "8.0.0",
"contentHash": "nR0iL5HwKj5v6ULo3/zpP8NMcq9E2pxYA6XKTSWCbugVs4YqPyvaqaKOY+OMpPivKp7zMEpax2UKHnDodbRB0Q==",
"dependencies": {
"Microsoft.Extensions.Configuration.Binder": "9.0.0",
"Microsoft.Extensions.DependencyModel": "9.0.0",
"Serilog": "4.2.0"
"Microsoft.Extensions.Configuration.Binder": "8.0.0",
"Microsoft.Extensions.DependencyModel": "8.0.0",
"Serilog": "3.1.1"
}
},
"Serilog.Sinks.Async": {
@ -608,10 +621,10 @@
},
"Serilog.Sinks.Debug": {
"type": "Transitive",
"resolved": "3.0.0",
"contentHash": "4BzXcdrgRX7wde9PmHuYd9U6YqycCC28hhpKonK7hx0wb19eiuRj16fPcPSVp0o/Y1ipJuNLYQ00R3q2Zs8FDA==",
"resolved": "2.0.0",
"contentHash": "Y6g3OBJ4JzTyyw16fDqtFcQ41qQAydnEvEqmXjhwhgjsnG/FaJ8GUqF5ldsC/bVkK8KYmqrPhDO+tm4dF6xx4A==",
"dependencies": {
"Serilog": "4.0.0"
"Serilog": "2.10.0"
}
},
"Serilog.Sinks.Elasticsearch": {
@ -887,6 +900,37 @@
"System.Runtime": "4.3.0"
}
},
"Timestamps": {
"type": "Transitive",
"resolved": "1.0.11",
"contentHash": "SnWhXm3FkEStQGgUTfWMh9mKItNW032o/v8eAtFrOGqG0/ejvPPA1LdLZx0N/qqoY0TH3x11+dO00jeVcM8xNQ=="
},
"UrlMatcher": {
"type": "Transitive",
"resolved": "3.0.1",
"contentHash": "hHBZVzFSfikrx4XsRsnCIwmGLgbNKtntnlqf4z+ygcNA6Y/L/J0x5GiZZWfXdTfpxhy5v7mlt2zrZs/L9SvbOA=="
},
"Watson.Core": {
"type": "Transitive",
"resolved": "6.3.5",
"contentHash": "Y5YxKOCSLe2KDmfwvI/J0qApgmmZR77LwyoufRVfKH7GLdHiE7fY0IfoNxWTG7nNv8knBfgwyOxdehRm+4HaCg==",
"dependencies": {
"IpMatcher": "1.0.5",
"RegexMatcher": "1.0.9",
"System.Text.Json": "8.0.5",
"Timestamps": "1.0.11",
"UrlMatcher": "3.0.1"
}
},
"Watson.Lite": {
"type": "Transitive",
"resolved": "6.3.5",
"contentHash": "YF8+se3IVenn8YlyNeb4wSJK6QMnVD0QHIOEiZ22wS4K2wkwoSDzWS+ZAjk1MaPeB+XO5gRoENUN//pOc+wI2g==",
"dependencies": {
"CavemanTcp": "2.0.5",
"Watson.Core": "6.3.5"
}
},
"xunit.abstractions": {
"type": "Transitive",
"resolved": "2.0.3",
@ -930,9 +974,11 @@
"myriad": {
"type": "Project",
"dependencies": {
"NodaTime": "[3.2.0, )",
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
"Polly": "[8.5.0, )",
"Polly.Contrib.WaitAndRetry": "[1.1.1, )",
"Serilog": "[4.2.0, )",
"Serilog": "[4.1.0, )",
"StackExchange.Redis": "[2.8.22, )",
"System.Linq.Async": "[6.0.1, )"
}
@ -945,7 +991,7 @@
"Microsoft.AspNetCore.Mvc.Versioning.ApiExplorer": "[5.1.0, )",
"PluralKit.Core": "[1.0.0, )",
"Sentry": "[4.13.0, )",
"Serilog.AspNetCore": "[9.0.0, )"
"Serilog.AspNetCore": "[8.0.0, )"
}
},
"pluralkit.bot": {
@ -954,7 +1000,8 @@
"Humanizer.Core": "[2.14.1, )",
"Myriad": "[1.0.0, )",
"PluralKit.Core": "[1.0.0, )",
"Sentry": "[4.13.0, )"
"Sentry": "[4.13.0, )",
"Watson.Lite": "[6.3.5, )"
}
},
"pluralkit.core": {
@ -980,8 +1027,8 @@
"NodaTime.Serialization.JsonNet": "[3.1.0, )",
"Npgsql": "[9.0.2, )",
"Npgsql.NodaTime": "[9.0.2, )",
"Serilog": "[4.2.0, )",
"Serilog.Extensions.Logging": "[9.0.0, )",
"Serilog": "[4.1.0, )",
"Serilog.Extensions.Logging": "[8.0.0, )",
"Serilog.Formatting.Compact": "[3.0.0, )",
"Serilog.NodaTime": "[3.0.0, )",
"Serilog.Sinks.Async": "[2.1.0, )",
@ -995,6 +1042,9 @@
"System.Interactive.Async": "[6.0.1, )",
"ipnetwork2": "[3.0.667, )"
}
},
"serilog": {
"type": "Project"
}
}
}

View file

@ -5,75 +5,35 @@ PluralKit is a Discord bot meant for plural communities. It has features like me
PluralKit has a Discord server for support, feedback, and discussion: https://discord.gg/PczBt78
# Requirements
Running the bot requires [.NET 5](https://dotnet.microsoft.com/download), a PostgreSQL database and a Redis database. It should function on any system where the prerequisites are set up (including Windows).
Optionally, it can integrate with [Sentry](https://sentry.io/welcome/) for error reporting and [InfluxDB](https://www.influxdata.com/products/influxdb-overview/) for aggregate statistics.
# Configuration
Configuring the bot is done through a JSON configuration file. An example of the configuration format can be seen in [`pluralkit.conf.example`](https://github.com/PluralKit/PluralKit/blob/master/pluralkit.conf.example).
The configuration file needs to be placed in the bot's working directory (usually the repository root) and must be called `pluralkit.conf`.
The configuration file is in JSON format (albeit with a `.conf` extension). The following keys are available (using `.` to indicate a nested object level), bolded key names are required:
* **`PluralKit.Bot.Token`**: the Discord bot token to connect with
* **`PluralKit.Database`**: the URI of the database to connect to (in [ADO.NET Npgsql format](https://www.connectionstrings.com/npgsql/))
* **`PluralKit.RedisAddr`**: the `host:port` of a Redis database to connect to
* `PluralKit.Bot.Prefixes`: an array of command prefixes to use (default `["pk;", "pk!"]`).
* **`PluralKit.Bot.ClientId`**: the ID of the bot's user account, used for calculating the bot's own permissions and for the link in `pk;invite`.
* `PluralKit.SentryUrl` *(optional)*: the [Sentry](https://sentry.io/welcome/) client key/DSN to report runtime errors to. If absent, disables Sentry integration.
* `PluralKit.InfluxUrl` *(optional)*: the URL to an [InfluxDB](https://www.influxdata.com/products/influxdb-overview/) server to report aggregate statistics to. An example of these stats can be seen on [the public stats page](https://stats.pluralkit.me).
* `PluralKit.InfluxDb` *(optional)*: the name of an InfluxDB database to report statistics to. If either this field or `PluralKit.InfluxUrl` are absent, InfluxDB reporting will be disabled.
* `PluralKit.LogDir` *(optional)*: the directory to save information and error logs to. If left blank, will default to `logs/` in the current working directory.
The bot can also take configuration from environment variables, which will override the values read from the file. Here, use `:` (colon) or `__` (double underscore) as a level separator (eg. `export PluralKit__Bot__Token=foobar123`) as per [ASP.NET config](https://docs.microsoft.com/en-us/aspnet/core/fundamentals/configuration/?view=aspnetcore-3.1#environment-variables).
# Running
In production, we run PluralKit using Kubernetes. The configuration can be found in the infra repo.
## Docker
The easiest way to get the bot running is with Docker. The repository contains a `docker-compose.yml` file ready to use.
For self-hosting, it's simpler to use Docker, with the provided [docker-compose](./docker-compose.yml) file.
* Clone this repository: `git clone https://github.com/PluralKit/PluralKit`
* Create a `pluralkit.conf` file in the same directory as `docker-compose.yml` containing at least `PluralKit.Bot.Token` and `PluralKit.Bot.ClientId` fields
* (`PluralKit.Database` is overridden in `docker-compose.yml` to point to the Postgres container)
* Build the bot: `docker-compose build`
* Run the bot: `docker-compose up`
In other words:
Create a `.env` file with the Discord client ID and bot token:
```
$ git clone https://github.com/PluralKit/PluralKit
$ cd PluralKit
$ cp pluralkit.conf.example pluralkit.conf
$ nano pluralkit.conf # (or vim, or whatever)
$ docker-compose up -d
CLIENT_ID=198622483471925248
BOT_TOKEN=MTk4NjIyNDgzNDcxOTI1MjQ4.Cl2FMQ.ZnCjm1XVW7vRze4b7Cq4se7kKWs
```
## Manually
* Install the .NET 6 SDK (see https://dotnet.microsoft.com/download)
* Clone this repository: `git clone https://github.com/PluralKit/PluralKit`
* Create and fill in a `pluralkit.conf` file in the same directory as `docker-compose.yml`
* Run the bot: `dotnet run --project PluralKit.Bot`
* Alternatively, `dotnet build -c Release -o build/`, then `dotnet build/PluralKit.Bot.dll`
If you want to use `pk;admin` commands (to raise member limits and such), set `ADMIN_ROLE` to a Discord role ID:
(tip: use `scripts/run-test-db.sh` to run a temporary PostgreSQL database on your local system. Requires Docker.)
```
ADMIN_ROLE=682632767057428509
```
## Scheduled Tasks worker
*If you didn't clone the repository with submodules, run `git submodule update --init` first to pull the required submodules.*
Run `docker compose build`, then `docker compose up -d`.
There is a scheduled tasks worker that needs to be ran separately from the bot. This handles cleaning up the database, and updating statistics (system/member/etc counts, shown in the `pk;stats` embed).
To view logs, use `docker compose logs`.
Note: This worker is *not required*, and the bot will function correctly without it.
Postgres data is stored in a `pluralkit_data` [Docker volume](https://docs.docker.com/engine/storage/volumes/).
If you are running the bot via docker-compose, this is set up automatically.
If you run the bot manually you can run the worker as such:
* `dotnet run --project PluralKit.ScheduledTasks`
* or if you used `dotnet build` rather than `dotnet run` to run the bot: `dotnet build/PluralKit.ScheduledTasks.dll`
# Upgrading database from legacy version
If you have an instance of the Python version of the bot (from the `legacy` branch), you may need to take extra database migration steps.
For more information, see [LEGACYMIGRATE.md](./LEGACYMIGRATE.md).
# Development
See [the dev-docs/ directory](./dev-docs/README.md)
# User documentation
See [the docs/ directory](./docs/README.md)
# License
This project is under the GNU Affero General Public License, Version 3. It is available at the following link: https://www.gnu.org/licenses/agpl-3.0.en.html
This project is under the GNU Affero General Public License, Version 3. It is available at the following link: https://www.gnu.org/licenses/agpl-3.0.en.html

1
Serilog Submodule

@ -0,0 +1 @@
Subproject commit f5eb991cb4c4a0c1e2407de7504c543536786598

View file

@ -10,6 +10,7 @@ COPY PluralKit.Bot/PluralKit.Bot.csproj /app/PluralKit.Bot/
COPY PluralKit.Core/PluralKit.Core.csproj /app/PluralKit.Core/
COPY PluralKit.Tests/PluralKit.Tests.csproj /app/PluralKit.Tests/
COPY .git/ /app/.git
COPY Serilog/ /app/Serilog/
RUN dotnet restore PluralKit.sln
# Copy the rest of the code and build

View file

@ -25,18 +25,22 @@ COPY Cargo.lock /build/
COPY crates/ /build/crates
RUN cargo build --bin migrate --release --target x86_64-unknown-linux-musl
RUN cargo build --bin api --release --target x86_64-unknown-linux-musl
RUN cargo build --bin dispatch --release --target x86_64-unknown-linux-musl
RUN cargo build --bin gateway --release --target x86_64-unknown-linux-musl
RUN cargo build --bin avatars --release --target x86_64-unknown-linux-musl
RUN cargo build --bin avatar_cleanup --release --target x86_64-unknown-linux-musl
RUN cargo build --bin scheduled_tasks --release --target x86_64-unknown-linux-musl
RUN cargo build --bin gdpr_worker --release --target x86_64-unknown-linux-musl
FROM scratch
FROM alpine:latest
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/migrate /migrate
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/api /api
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/dispatch /dispatch
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/gateway /gateway
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/avatars /avatars
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/avatar_cleanup /avatar_cleanup
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/scheduled_tasks /scheduled_tasks
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/gdpr_worker /gdpr_worker

View file

@ -37,8 +37,10 @@ EOF
}
# add rust binaries here to build
build migrate
build api
build dispatch
build gateway
build avatars "COPY .docker-bin/avatar_cleanup /bin/avatar_cleanup"
build scheduled_tasks
build gdpr_worker

48
crates/api/src/auth.rs Normal file
View file

@ -0,0 +1,48 @@
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
#[derive(Clone)]
pub struct AuthState {
system_id: Option<i32>,
app_id: Option<i32>,
}
impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<i32>) -> Self {
Self { system_id, app_id }
}
pub fn system_id(&self) -> Option<i32> {
self.system_id
}
pub fn app_id(&self) -> Option<i32> {
self.app_id
}
pub fn access_level_for(&self, a: &impl Authable) -> PrivacyLevel {
if self
.system_id
.map(|id| id == a.authable_system_id())
.unwrap_or(false)
{
PrivacyLevel::Private
} else {
PrivacyLevel::Public
}
}
}
// authable trait/impls
pub trait Authable {
fn authable_system_id(&self) -> SystemId;
}
impl Authable for PKSystem {
fn authable_system_id(&self) -> SystemId {
self.id
}
}

View file

@ -1 +1,2 @@
pub mod private;
pub mod system;

View file

@ -52,7 +52,7 @@ use axum::{
};
use hyper::StatusCode;
use libpk::config;
use pluralkit_models::{PKSystem, PKSystemConfig};
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use reqwest::ClientBuilder;
#[derive(serde::Deserialize, Debug)]
@ -151,14 +151,12 @@ pub async fn discord_callback(
.await
.expect("failed to query");
if system.is_none() {
let Some(system) = system else {
return json_err(
StatusCode::BAD_REQUEST,
"user does not have a system registered".to_string(),
);
}
let system = system.unwrap();
};
let system_config: Option<PKSystemConfig> = sqlx::query_as(
r#"
@ -179,7 +177,7 @@ pub async fn discord_callback(
(
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"system": system.to_json(),
"system": system.to_json(PrivacyLevel::Private),
"config": system_config.to_json(),
"user": user,
"token": token,

View file

@ -0,0 +1,69 @@
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Response},
Extension, Json,
};
use serde_json::json;
use sqlx::Postgres;
use tracing::error;
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use crate::{auth::AuthState, util::json_err, ApiContext};
pub async fn get_system_settings(
Extension(auth): Extension<AuthState>,
Extension(system): Extension<PKSystem>,
State(ctx): State<ApiContext>,
) -> Response {
let access_level = auth.access_level_for(&system);
let mut config = match sqlx::query_as::<Postgres, PKSystemConfig>(
"select * from system_config where system = $1",
)
.bind(system.id)
.fetch_optional(&ctx.db)
.await
{
Ok(Some(config)) => config,
Ok(None) => {
error!(
system = system.id,
"failed to find system config for existing system"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
Err(err) => {
error!(?err, "failed to query system config");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
};
// fix this
if config.name_format.is_none() {
config.name_format = Some("{name} {tag}".to_string());
}
Json(&match access_level {
PrivacyLevel::Private => config.to_json(),
PrivacyLevel::Public => json!({
"pings_enabled": config.pings_enabled,
"latch_timeout": config.latch_timeout,
"case_sensitive_proxy_tags": config.case_sensitive_proxy_tags,
"proxy_error_message_enabled": config.proxy_error_message_enabled,
"hid_display_split": config.hid_display_split,
"hid_display_caps": config.hid_display_caps,
"hid_list_padding": config.hid_list_padding,
"proxy_switch": config.proxy_switch,
"name_format": config.name_format,
}),
})
.into_response()
}

View file

@ -1,10 +1,13 @@
#![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::{Response, StatusCode, Uri},
response::IntoResponse,
routing::{delete, get, patch, post},
Router,
Extension, Router,
};
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
@ -12,6 +15,7 @@ use hyper_util::{
};
use tracing::{error, info};
mod auth;
mod endpoints;
mod error;
mod middleware;
@ -27,6 +31,7 @@ pub struct ApiContext {
}
async fn rproxy(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Result<Response<Body>, StatusCode> {
@ -41,12 +46,25 @@ async fn rproxy(
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(ctx
.rproxy_client
.request(req)
.await
.map_err(|err| {
error!("failed to serve reverse proxy to dotnet-api: {:?}", err);
.map_err(|error| {
error!(?error, "failed to serve reverse proxy to dotnet-api");
StatusCode::BAD_GATEWAY
})?
.into_response())
@ -57,52 +75,52 @@ async fn rproxy(
fn router(ctx: ApiContext) -> Router {
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/v2/systems/:system_id", get(rproxy))
.route("/v2/systems/:system_id", patch(rproxy))
.route("/v2/systems/:system_id/settings", get(rproxy))
.route("/v2/systems/:system_id/settings", patch(rproxy))
.route("/v2/systems/{system_id}", get(rproxy))
.route("/v2/systems/{system_id}", patch(rproxy))
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
.route("/v2/systems/{system_id}/settings", patch(rproxy))
.route("/v2/systems/:system_id/members", get(rproxy))
.route("/v2/systems/{system_id}/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/:member_id", get(rproxy))
.route("/v2/members/:member_id", patch(rproxy))
.route("/v2/members/:member_id", delete(rproxy))
.route("/v2/members/{member_id}", get(rproxy))
.route("/v2/members/{member_id}", patch(rproxy))
.route("/v2/members/{member_id}", delete(rproxy))
.route("/v2/systems/:system_id/groups", get(rproxy))
.route("/v2/systems/{system_id}/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/:group_id", get(rproxy))
.route("/v2/groups/:group_id", patch(rproxy))
.route("/v2/groups/:group_id", delete(rproxy))
.route("/v2/groups/{group_id}", get(rproxy))
.route("/v2/groups/{group_id}", patch(rproxy))
.route("/v2/groups/{group_id}", delete(rproxy))
.route("/v2/groups/:group_id/members", get(rproxy))
.route("/v2/groups/:group_id/members/add", post(rproxy))
.route("/v2/groups/:group_id/members/remove", post(rproxy))
.route("/v2/groups/:group_id/members/overwrite", post(rproxy))
.route("/v2/groups/{group_id}/members", get(rproxy))
.route("/v2/groups/{group_id}/members/add", post(rproxy))
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
.route("/v2/members/:member_id/groups", get(rproxy))
.route("/v2/members/:member_id/groups/add", post(rproxy))
.route("/v2/members/:member_id/groups/remove", post(rproxy))
.route("/v2/members/:member_id/groups/overwrite", post(rproxy))
.route("/v2/members/{member_id}/groups", get(rproxy))
.route("/v2/members/{member_id}/groups/add", post(rproxy))
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
.route("/v2/systems/:system_id/switches", get(rproxy))
.route("/v2/systems/:system_id/switches", post(rproxy))
.route("/v2/systems/:system_id/fronters", get(rproxy))
.route("/v2/systems/{system_id}/switches", get(rproxy))
.route("/v2/systems/{system_id}/switches", post(rproxy))
.route("/v2/systems/{system_id}/fronters", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", get(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id/members", patch(rproxy))
.route("/v2/systems/:system_id/switches/:switch_id", delete(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", get(rproxy))
.route("/v2/systems/:system_id/guilds/:guild_id", patch(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", get(rproxy))
.route("/v2/members/:member_id/guilds/:guild_id", patch(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/systems/:system_id/autoproxy", get(rproxy))
.route("/v2/systems/:system_id/autoproxy", patch(rproxy))
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
.route("/v2/messages/:message_id", get(rproxy))
.route("/v2/messages/{message_id}", get(rproxy))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
@ -111,16 +129,19 @@ fn router(ctx: ApiContext) -> Router {
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::authnz))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::cors))
.layer(axum::middleware::from_fn(middleware::logger))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::logger::logger))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::params::params))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), middleware::auth::auth))
.layer(axum::middleware::from_fn(middleware::cors::cors))
.layer(tower_http::catch_panic::CatchPanicLayer::custom(util::handle_panic))
.with_state(ctx)
@ -128,8 +149,8 @@ fn router(ctx: ApiContext) -> Router {
.route("/", get(|| async { axum::response::Redirect::to("https://pluralkit.me/api") }))
}
libpk::main!("api");
async fn real_main() -> anyhow::Result<()> {
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;

View file

@ -0,0 +1,62 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<i32> = None;
// fetch user authorization
if let Some(system_auth_header) = req
.headers()
.get("authorization")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
}
{
authed_system_id = Some(system_id);
}
// fetch app authorization
// todo: actually fetch it from db
if let Some(app_auth_header) = req
.headers()
.get("x-pluralkit-app")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(config_token2) = libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
.as_ref()
// this is NOT how you validate tokens
// but this is low abuse risk so we're keeping it for now
&& app_auth_header == config_token2
{
authed_app_id = Some(1);
}
req.extensions_mut()
.insert(AuthState::new(authed_system_id, authed_app_id));
next.run(req).await
}

View file

@ -1,45 +0,0 @@
use axum::{
extract::{Request, State},
http::HeaderValue,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::ApiContext;
use super::logger::DID_AUTHENTICATE_HEADER;
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
let headers = request.headers_mut();
headers.remove("x-pluralkit-systemid");
let auth_header = headers
.get("authorization")
.map(|h| h.to_str().ok())
.flatten();
let mut authenticated = false;
if let Some(auth_header) = auth_header {
if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
None
}
}
{
headers.append(
"x-pluralkit-systemid",
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
);
authenticated = true;
}
}
let mut response = next.run(request).await;
if authenticated {
response
.headers_mut()
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
}
response
}

View file

@ -4,27 +4,30 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use crate::util::header_or_unknown;
use crate::{auth::AuthState, util::header_or_unknown};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let remote_ip = header_or_unknown(request.headers().get("X-PluralKit-Client-IP"));
let user_agent = header_or_unknown(request.headers().get("User-Agent"));
let endpoint = request
.extensions()
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
let uri = request.uri().clone();
let request_span = span!(
@ -37,25 +40,26 @@ pub async fn logger(request: Request, next: Next) -> Response {
);
let start = Instant::now();
let mut response = next.run(request).instrument(request_span).await;
let response = next.run(request).instrument(request_span).await;
let elapsed = start.elapsed().as_millis();
let authenticated = {
let headers = response.headers_mut();
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
headers.remove(DID_AUTHENTICATE_HEADER);
true
} else {
false
}
};
let system_id = auth
.system_id()
.map(|v| v.to_string())
.unwrap_or("none".to_string());
let app_id = auth
.app_id()
.map(|v| v.to_string())
.unwrap_or("none".to_string());
counter!(
"pluralkit_api_requests",
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.increment(1);
histogram!(
@ -63,7 +67,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.record(elapsed as f64 / 1_000_f64);
@ -81,7 +86,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.increment(1);

View file

@ -1,13 +1,6 @@
mod cors;
pub use cors::cors;
mod logger;
pub use logger::logger;
mod ignore_invalid_routes;
pub use ignore_invalid_routes::ignore_invalid_routes;
pub mod auth;
pub mod cors;
pub mod ignore_invalid_routes;
pub mod logger;
pub mod params;
pub mod ratelimit;
mod authnz;
pub use authnz::authnz;

View file

@ -0,0 +1,139 @@
use axum::{
extract::{Request, State},
http::StatusCode,
middleware::Next,
response::Response,
routing::url_params::UrlParams,
};
use sqlx::{types::Uuid, Postgres};
use tracing::error;
use crate::auth::AuthState;
use crate::{util::json_err, ApiContext};
use pluralkit_models::PKSystem;
// move this somewhere else
fn parse_hid(hid: &str) -> String {
if hid.len() > 7 || hid.len() < 5 {
hid.to_string()
} else {
hid.to_lowercase().replace("-", "")
}
}
pub async fn params(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let pms = match req.extensions().get::<UrlParams>() {
None => Vec::new(),
Some(UrlParams::Params(pms)) => pms.clone(),
_ => {
return json_err(
StatusCode::BAD_REQUEST,
r#"{"message":"400: Bad Request","code": 0}"#.to_string(),
)
.into()
}
};
for (key, value) in pms {
match key.as_ref() {
"system_id" => match value.as_str() {
"@me" => {
let Some(system_id) = req
.extensions()
.get::<AuthState>()
.expect("missing auth state")
.system_id()
else {
return json_err(
StatusCode::UNAUTHORIZED,
r#"{"message":"401: Missing or invalid Authorization header","code": 0}"#.to_string(),
)
.into();
};
match sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where id = $1",
)
.bind(system_id)
.fetch_optional(&ctx.db)
.await
{
Ok(Some(system)) => {
req.extensions_mut().insert(system);
}
Ok(None) => {
error!(
?system_id,
"could not find previously authenticated system in db"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#
.to_string(),
);
}
Err(err) => {
error!(
?err,
"failed to query previously authenticated system in db"
);
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#
.to_string(),
);
}
}
}
id => {
println!("a {id}");
match match Uuid::parse_str(id) {
Ok(uuid) => sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where uuid = $1",
)
.bind(uuid),
Err(_) => match id.parse::<i64>() {
Ok(parsed) => sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where id = (select system from accounts where uid = $1)"
)
.bind(parsed),
Err(_) => sqlx::query_as::<Postgres, PKSystem>(
"select * from systems where hid = $1",
)
.bind(parse_hid(id))
},
}
.fetch_optional(&ctx.db)
.await
{
Ok(Some(system)) => {
req.extensions_mut().insert(system);
}
Ok(None) => {
return json_err(
StatusCode::NOT_FOUND,
r#"{"message":"System not found.","code":20001}"#.to_string(),
)
}
Err(err) => {
error!(?err, ?id, "failed to query system from path in db");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#
.to_string(),
);
}
}
}
},
"member_id" => {}
"group_id" => {}
"switch_id" => {}
"guild_id" => {}
_ => {}
}
}
next.run(req).await
}

View file

@ -10,7 +10,10 @@ use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, ut
use metrics::counter;
use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
use crate::{
auth::AuthState,
util::{header_or_unknown, json_err},
};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
@ -50,7 +53,7 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
.await
{
Ok(_) => info!("connected to redis for request rate limiting"),
Err(err) => error!("could not load redis script: {}", err),
Err(error) => error!(?error, "could not load redis script"),
}
} else {
error!("could not wait for connection to load redis script!");
@ -103,37 +106,28 @@ pub async fn do_request_ratelimited(
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
// https://github.com/rust-lang/rust/issues/53667
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if let Some(token2) = &libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
{
if header.to_str().unwrap_or("invalid") == token2 {
true
} else {
false
}
} else {
false
}
} else {
false
};
let extensions = request.extensions().clone();
let endpoint = request
.extensions()
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
let auth = extensions
.get::<AuthState>()
.expect("should always have AuthState");
// looks like this chooses the tokens/sec by app_id or endpoint
// then chooses the key by system_id or source_ip
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if let Some(app_id) = auth.app_id()
&& app_id == 1
{
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
@ -145,12 +139,12 @@ pub async fn do_request_ratelimited(
let rl_key = format!(
"{}:{}",
if authenticated_system_id != "unknown"
if let Some(system_id) = auth.system_id()
&& matches!(rlimit, RatelimitType::GenericUpdate)
{
authenticated_system_id
system_id.to_string()
} else {
source_ip
source_ip.to_string()
},
rlimit.key()
);
@ -224,8 +218,8 @@ pub async fn do_request_ratelimited(
return response;
}
Err(err) => {
tracing::error!("error getting ratelimit info: {}", err);
Err(error) => {
tracing::error!(?error, "error getting ratelimit info");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: internal server error", "code": 0}"#.to_string(),

View file

@ -11,7 +11,7 @@ pub fn header_or_unknown(header: Option<&HeaderValue>) -> &str {
match value.to_str() {
Ok(v) => v,
Err(err) => {
error!("failed to parse header value {:#?}: {:#?}", value, err);
error!(?err, ?value, "failed to parse header value");
"failed to parse"
}
}
@ -34,11 +34,7 @@ where
.unwrap(),
),
None => {
error!(
"error in handler {}: {:#?}",
std::any::type_name::<F>(),
error
);
error!(?error, "error in handler {}", std::any::type_name::<F>(),);
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
@ -48,14 +44,15 @@ where
}
}
pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!("caught panic from handler: {:#?}", err);
pub fn handle_panic(error: Box<dyn std::any::Any + Send + 'static>) -> axum::response::Response {
error!(?error, "caught panic from handler");
json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
)
}
// todo: make 500 not duplicated
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();

View file

@ -4,8 +4,8 @@ use sqlx::prelude::FromRow;
use std::{sync::Arc, time::Duration};
use tracing::{error, info};
libpk::main!("avatar_cleanup");
async fn real_main() -> anyhow::Result<()> {
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let config = libpk::config
.avatars
.as_ref()
@ -13,7 +13,7 @@ async fn real_main() -> anyhow::Result<()> {
let bucket = {
let region = s3::Region::Custom {
region: "s3".to_string(),
region: "auto".to_string(),
endpoint: config.s3.endpoint.to_string(),
};
@ -38,8 +38,8 @@ async fn real_main() -> anyhow::Result<()> {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
match cleanup_job(pool.clone(), bucket.clone()).await {
Ok(()) => {}
Err(err) => {
error!("failed to run avatar cleanup job: {}", err);
Err(error) => {
error!(?error, "failed to run avatar cleanup job");
// sentry
}
}
@ -55,9 +55,10 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
let mut tx = pool.begin().await?;
let image_id: Option<CleanupJobEntry> = sqlx::query_as(
// no timestamp checking here
// images are only added to the table after 24h
r#"
select id from image_cleanup_jobs
where ts < now() - interval '1 day'
for update skip locked limit 1;"#,
)
.fetch_optional(&mut *tx)
@ -72,6 +73,7 @@ async fn cleanup_job(pool: sqlx::PgPool, bucket: Arc<s3::Bucket>) -> anyhow::Res
let image_data = libpk::db::repository::avatars::get_by_id(&pool, image_id.clone()).await?;
if image_data.is_none() {
// unsure how this can happen? there is a FK reference
info!("image {image_id} was already deleted, skipping");
sqlx::query("delete from image_cleanup_jobs where id = $1")
.bind(image_id)

View file

@ -93,7 +93,7 @@ async fn pull(
) -> Result<Json<PullResponse>, PKAvatarError> {
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
if !req.force {
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
// remove any pending image cleanup
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
@ -170,8 +170,8 @@ pub struct AppState {
config: Arc<AvatarsConfig>,
}
libpk::main!("avatars");
async fn real_main() -> anyhow::Result<()> {
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let config = libpk::config
.avatars
.as_ref()
@ -179,7 +179,7 @@ async fn real_main() -> anyhow::Result<()> {
let bucket = {
let region = s3::Region::Custom {
region: "s3".to_string(),
region: "auto".to_string(),
endpoint: config.s3.endpoint.to_string(),
};
@ -232,26 +232,11 @@ async fn real_main() -> anyhow::Result<()> {
Ok(())
}
struct AppError(anyhow::Error);
#[derive(Serialize)]
struct ErrorResponse {
error: String,
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
error!("error handling request: {}", self.0);
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: self.0.to_string(),
}),
)
.into_response()
}
}
impl IntoResponse for PKAvatarError {
fn into_response(self) -> Response {
let status_code = match self {
@ -278,12 +263,3 @@ impl IntoResponse for PKAvatarError {
.into_response()
}
}
impl<E> From<E> for AppError
where
E: Into<anyhow::Error>,
{
fn from(err: E) -> Self {
Self(err.into())
}
}

View file

@ -129,9 +129,9 @@ pub async fn worker(worker_id: u32, state: Arc<AppState>) {
Ok(()) => {}
Err(e) => {
error!(
"error in migrate worker {}: {}",
worker_id,
e.source().unwrap_or(&e)
error = e.source().unwrap_or(&e)
?worker_id,
"error in migrate worker",
);
tokio::time::sleep(Duration::from_secs(5)).await;
}

View file

@ -84,7 +84,7 @@ pub fn process(data: &[u8], kind: ImageKind) -> Result<ProcessOutput, PKAvatarEr
} else {
reader.decode().map_err(|e| {
// print the ugly error, return the nice error
error!("error decoding image: {}", e);
error!(error = format!("{e:#?}"), "error decoding image");
PKAvatarError::ImageFormatError(e)
})?
};

View file

@ -41,7 +41,11 @@ pub async fn pull(
}
}
error!("network error for {}: {}", parsed_url.full_url, s);
error!(
url = parsed_url.full_url,
error = s,
"network error pulling image"
);
PKAvatarError::NetworkErrorString(s)
})?;
let time_after_headers = Instant::now();
@ -82,7 +86,22 @@ pub async fn pull(
.map(|x| x.to_string());
let body = response.bytes().await.map_err(|e| {
error!("network error for {}: {}", parsed_url.full_url, e);
// terrible
let mut s = format!("{}", e);
if let Some(src) = e.source() {
let _ = write!(s, ": {}", src);
let mut err = src;
while let Some(src) = err.source() {
let _ = write!(s, ": {}", src);
err = src;
}
}
error!(
url = parsed_url.full_url,
error = s,
"network error pulling image"
);
PKAvatarError::NetworkError(e)
})?;
if body.len() != size as usize {
@ -137,6 +156,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
match (url.scheme(), url.domain()) {
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
("https", Some("serve.apparyllis.com")) => {
return Ok(ParsedUrl {
channel_id: 0,
attachment_id: 0,
filename: "".to_string(),
full_url: url.to_string(),
})
}
_ => anyhow::bail!("not a discord cdn url"),
}

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
libpk = { path = "../libpk" }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View file

@ -19,17 +19,8 @@ use axum::{extract::State, http::Uri, routing::post, Json, Router};
mod logger;
// this package does not currently use libpk
#[tokio::main]
#[libpk::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.json()
.with_env_filter(EnvFilter::from_default_env())
.init();
info!("hello world");
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
let (client, bg) = AsyncClient::connect(stream).await?;
@ -86,11 +77,11 @@ async fn dispatch(
let uri = match req.url.parse::<Uri>() {
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
Err(error) => {
error!(?error, "failed to parse uri {}", req.url);
error!(?error, uri = req.url, "failed to parse uri");
return DispatchResponse::BadData.to_string();
}
_ => {
error!("uri {} is invalid", req.url);
error!(uri = req.url, "uri is invalid");
return DispatchResponse::BadData.to_string();
}
};

View file

@ -13,8 +13,9 @@ futures = { workspace = true }
lazy_static = { workspace = true }
libpk = { path = "../libpk" }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }

View file

@ -1,19 +1,24 @@
use axum::{
extract::{Path, State},
extract::{ConnectInfo, Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
routing::{delete, get, post},
Router,
};
use libpk::runtime_config::RuntimeConfig;
use serde_json::{json, to_string};
use tracing::{error, info};
use twilight_model::id::Id;
use twilight_model::id::{marker::ChannelMarker, Id};
use crate::discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
use crate::{
discord::{
cache::{dm_channel, DiscordCache, DM_PERMISSIONS},
gateway::cluster_config,
shard_state::ShardStateManager,
},
event_awaiter::{AwaitEventRequest, EventAwaiter},
};
use std::sync::Arc;
use std::{net::SocketAddr, sync::Arc};
fn status_code(code: StatusCode, body: String) -> Response {
(code, body).into_response()
@ -21,10 +26,15 @@ fn status_code(code: StatusCode, body: String) -> Response {
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
pub async fn run_server(cache: Arc<DiscordCache>, shard_state: Arc<ShardStateManager>, runtime_config: Arc<RuntimeConfig>, awaiter: Arc<EventAwaiter>) -> anyhow::Result<()> {
// hacky fix for `move`
let runtime_config_for_post = runtime_config.clone();
let runtime_config_for_delete = runtime_config.clone();
let awaiter_for_clear = awaiter.clone();
let app = Router::new()
.route(
"/guilds/:guild_id",
"/guilds/{guild_id}",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild(Id::new(guild_id)) {
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
@ -33,7 +43,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/members/@me",
"/guilds/{guild_id}/members/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.0.member(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id) {
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
@ -42,7 +52,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/permissions/@me",
"/guilds/{guild_id}/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.as_ref().expect("missing discord config").client_id).await {
Ok(val) => {
@ -56,7 +66,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/permissions/:user_id",
"/guilds/{guild_id}/permissions/{user_id}",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
@ -69,7 +79,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
)
.route(
"/guilds/:guild_id/channels",
"/guilds/{guild_id}/channels",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
Some(channels) => channels.to_owned(),
@ -95,7 +105,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
})
)
.route(
"/guilds/:guild_id/channels/:channel_id",
"/guilds/{guild_id}/channels/{channel_id}",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
@ -107,7 +117,7 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
})
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
"/guilds/{guild_id}/channels/{channel_id}/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
@ -122,16 +132,19 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
}),
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
"/guilds/{guild_id}/channels/{channel_id}/permissions/{user_id}",
get(|| async { "todo" }),
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
"/guilds/{guild_id}/channels/{channel_id}/last_message",
get(|State(cache): State<Arc<DiscordCache>>, Path((_guild_id, channel_id)): Path<(u64, Id<ChannelMarker>)>| async move {
let lm = cache.get_last_message(channel_id).await;
status_code(StatusCode::FOUND, to_string(&lm).unwrap())
}),
)
.route(
"/guilds/:guild_id/roles",
"/guilds/{guild_id}/roles",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
Some(roles) => roles.to_owned(),
@ -171,13 +184,45 @@ pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
status_code(StatusCode::FOUND, to_string(&stats).unwrap())
}))
.route("/runtime_config", get(|| async move {
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
}))
.route("/runtime_config/{key}", post(|Path(key): Path<String>, body: String| async move {
let runtime_config = runtime_config_for_post;
runtime_config.set(key, body).await.expect("failed to update runtime config");
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
}))
.route("/runtime_config/{key}", delete(|Path(key): Path<String>| async move {
let runtime_config = runtime_config_for_delete;
runtime_config.delete(key).await.expect("failed to update runtime config");
status_code(StatusCode::FOUND, to_string(&runtime_config.get_all().await).unwrap())
}))
.route("/await_event", post(|ConnectInfo(addr): ConnectInfo<SocketAddr>, body: String| async move {
info!("got request: {body} from: {addr}");
let Ok(req) = serde_json::from_str::<AwaitEventRequest>(&body) else {
return status_code(StatusCode::BAD_REQUEST, "".to_string());
};
awaiter.handle_request(req, addr).await;
status_code(StatusCode::NO_CONTENT, "".to_string())
}))
.route("/clear_awaiter", post(|| async move {
awaiter_for_clear.clear().await;
status_code(StatusCode::NO_CONTENT, "".to_string())
}))
.route("/shard_status", get(|| async move {
status_code(StatusCode::FOUND, to_string(&shard_state.get().await).unwrap())
}))
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);
let addr: &str = libpk::config.discord.as_ref().expect("missing discord config").cache_api_addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
axum::serve(listener, app.into_make_service_with_connect_info::<SocketAddr>()).await?;
Ok(())
}

View file

@ -1,6 +1,7 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use serde::Serialize;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::RwLock;
use twilight_cache_inmemory::{
model::CachedMember,
@ -8,11 +9,12 @@ use twilight_cache_inmemory::{
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_gateway::Event;
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
marker::{ChannelMarker, GuildMarker, MessageMarker, UserMarker},
Id,
},
};
@ -123,16 +125,134 @@ pub fn new() -> DiscordCache {
.build(),
);
DiscordCache(cache, client, RwLock::new(Vec::new()))
DiscordCache(
cache,
client,
RwLock::new(Vec::new()),
RwLock::new(HashMap::new()),
)
}
#[derive(Clone, Serialize)]
pub struct CachedMessage {
id: Id<MessageMarker>,
referenced_message: Option<Id<MessageMarker>>,
author_username: String,
}
#[derive(Clone, Serialize)]
pub struct LastMessageCacheEntry {
pub current: CachedMessage,
pub previous: Option<CachedMessage>,
}
pub struct DiscordCache(
pub Arc<InMemoryCache>,
pub Arc<twilight_http::Client>,
pub RwLock<Vec<u32>>,
pub RwLock<HashMap<Id<ChannelMarker>, LastMessageCacheEntry>>,
);
impl DiscordCache {
pub async fn get_last_message(
&self,
channel: Id<ChannelMarker>,
) -> Option<LastMessageCacheEntry> {
self.3.read().await.get(&channel).cloned()
}
pub async fn update(&self, event: &twilight_gateway::Event) {
self.0.update(event);
match event {
Event::MessageCreate(m) => match self.3.write().await.entry(m.channel_id) {
std::collections::hash_map::Entry::Occupied(mut e) => {
let cur = e.get();
e.insert(LastMessageCacheEntry {
current: CachedMessage {
id: m.id,
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
author_username: m.author.name.clone(),
},
previous: Some(cur.current.clone()),
});
}
std::collections::hash_map::Entry::Vacant(e) => {
e.insert(LastMessageCacheEntry {
current: CachedMessage {
id: m.id,
referenced_message: m.referenced_message.as_ref().map(|v| v.id),
author_username: m.author.name.clone(),
},
previous: None,
});
}
},
Event::MessageDelete(m) => {
self.handle_message_deletion(m.channel_id, vec![m.id]).await;
}
Event::MessageDeleteBulk(m) => {
self.handle_message_deletion(m.channel_id, m.ids.clone())
.await;
}
_ => {}
};
}
async fn handle_message_deletion(
&self,
channel_id: Id<ChannelMarker>,
mids: Vec<Id<MessageMarker>>,
) {
let mut lm = self.3.write().await;
let Some(entry) = lm.get(&channel_id) else {
return;
};
let mut entry = entry.clone();
// if none of the deleted messages are relevant, just return
if !mids.contains(&entry.current.id)
&& entry
.previous
.clone()
.map(|v| !mids.contains(&v.id))
.unwrap_or(false)
{
return;
}
// remove "previous" entry if it was deleted
if let Some(prev) = entry.previous.clone()
&& mids.contains(&prev.id)
{
entry.previous = None;
}
// set "current" entry to "previous" if current entry was deleted
// (if the "previous" entry still exists, it was not deleted)
if let Some(prev) = entry.previous.clone()
&& mids.contains(&entry.current.id)
{
entry.current = prev;
entry.previous = None;
}
// if the current entry was already deleted, but previous wasn't,
// we would've set current to previous
// so if current is deleted this means both current and previous have
// been deleted
// so just drop the cache entry here
if mids.contains(&entry.current.id) && entry.previous.is_none() {
lm.remove(&channel_id);
return;
}
// ok, update the entry
lm.insert(channel_id, entry.clone());
}
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,
@ -356,12 +476,14 @@ impl DiscordCache {
system_channel_flags: guild.system_channel_flags(),
system_channel_id: guild.system_channel_id(),
threads: vec![],
unavailable: false,
unavailable: Some(false),
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
verification_level: guild.verification_level(),
voice_states: vec![],
widget_channel_id: guild.widget_channel_id(),
widget_enabled: guild.widget_enabled(),
guild_scheduled_events: guild.guild_scheduled_events().to_vec(),
max_stage_video_channel_users: guild.max_stage_video_channel_users(),
}
})
}

View file

@ -1,7 +1,9 @@
use anyhow::anyhow;
use futures::StreamExt;
use libpk::_config::ClusterSettings;
use libpk::{_config::ClusterSettings, runtime_config::RuntimeConfig, state::ShardStateEvent};
use metrics::counter;
use std::sync::{mpsc::Sender, Arc};
use std::sync::Arc;
use tokio::sync::mpsc::Sender;
use tracing::{error, info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Message, Shard, ShardId,
@ -12,9 +14,12 @@ use twilight_model::gateway::{
Intents,
};
use crate::discord::identify_queue::{self, RedisQueue};
use crate::{
discord::identify_queue::{self, RedisQueue},
RUNTIME_CONFIG_KEY_EVENT_TARGET,
};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
use super::cache::DiscordCache;
pub fn cluster_config() -> ClusterSettings {
libpk::config
@ -44,6 +49,12 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
warn!("we have less than 16 shards, assuming single gateway process");
if cluster_settings.node_id != 0 {
return Err(anyhow!(
"expecting to be node 0 in single-process mode, but we are node {}",
cluster_settings.node_id
));
}
(0, (cluster_settings.total_shards - 1).into())
} else {
(
@ -52,6 +63,13 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
)
};
let prefix = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_prefix_for_gateway
.clone();
let shards = create_iterator(
start_shard..end_shard + 1,
cluster_settings.total_shards,
@ -64,7 +82,7 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
.to_owned(),
intents,
)
.presence(presence("pk;help", false))
.presence(presence(format!("{prefix}help").as_str(), false))
.queue(queue.clone())
.build(),
|_, builder| builder.build(),
@ -76,15 +94,23 @@ pub fn create_shards(redis: fred::clients::RedisPool) -> anyhow::Result<Vec<Shar
Ok(shards_vec)
}
#[tracing::instrument(fields(shard = %shard.id()), skip_all)]
pub async fn runner(
mut shard: Shard<RedisQueue>,
_tx: Sender<(ShardId, String)>,
shard_state: ShardStateManager,
tx: Sender<(ShardId, Event, String)>,
tx_state: Sender<(ShardId, ShardStateEvent, Option<Event>, Option<i32>)>,
cache: Arc<DiscordCache>,
runtime_config: Arc<RuntimeConfig>,
) {
// let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
let shard_id = shard.id().number();
let our_user_id = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.client_id;
info!("waiting for events");
while let Some(item) = shard.next().await {
let raw_event = match item {
@ -105,7 +131,9 @@ pub async fn runner(
)
.increment(1);
if let Err(error) = shard_state.socket_closed(shard_id).await {
if let Err(error) =
tx_state.try_send((shard.id(), ShardStateEvent::Closed, None, None))
{
error!("failed to update shard state for socket closure: {error}");
}
@ -127,7 +155,7 @@ pub async fn runner(
continue;
}
Err(error) => {
error!("shard {shard_id} failed to parse gateway event: {error}");
error!(?error, ?shard_id, "failed to parse gateway event");
continue;
}
};
@ -147,14 +175,31 @@ pub async fn runner(
.increment(1);
// update shard state and discord cache
if let Err(error) = shard_state.handle_event(shard_id, event.clone()).await {
tracing::error!(?error, "error updating redis state");
if matches!(event, Event::Ready(_)) || matches!(event, Event::Resumed) {
if let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Other,
Some(event.clone()),
None,
)) {
tracing::error!(?error, "error updating shard state");
}
}
// need to do heartbeat separately, to get the latency
let latency_num = shard
.latency()
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
if let Event::GatewayHeartbeatAck = event
&& let Err(error) = shard_state.heartbeated(shard_id, shard.latency()).await
&& let Err(error) = tx_state.try_send((
shard.id(),
ShardStateEvent::Heartbeat,
Some(event.clone()),
Some(latency_num),
))
{
tracing::error!(?error, "error updating redis state for latency");
tracing::error!(?error, "error updating shard state for latency");
}
if let Event::Ready(_) = event {
@ -162,10 +207,28 @@ pub async fn runner(
cache.2.write().await.push(shard_id);
}
}
cache.0.update(&event);
cache.update(&event).await;
// okay, we've handled the event internally, let's send it to consumers
// tx.send((shard.id(), raw_event)).unwrap();
// some basic filtering here is useful
// we can't use if matching using the | operator, so anything matched does nothing
// and the default match skips the next block (continues to the next event)
match event {
Event::InteractionCreate(_) => {}
Event::MessageCreate(ref m) if m.author.id != our_user_id => {}
Event::MessageUpdate(ref m) if m.author.id != our_user_id && !m.author.bot => {}
Event::MessageDelete(_) => {}
Event::MessageDeleteBulk(_) => {}
Event::ReactionAdd(ref r) if r.user_id != our_user_id => {}
_ => {
continue;
}
}
if runtime_config.exists(RUNTIME_CONFIG_KEY_EVENT_TARGET).await {
tx.send((shard.id(), event, raw_event)).await.unwrap();
}
}
}

View file

@ -78,8 +78,8 @@ async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: on
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
Err(error) => {
error!(?error, ?shard_id, ?bucket, "error getting shard allowance")
}
}

View file

@ -1,49 +1,63 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use metrics::{counter, gauge};
use tokio::sync::RwLock;
use tracing::info;
use twilight_gateway::{Event, Latency};
use twilight_gateway::Event;
use std::collections::HashMap;
use libpk::state::ShardState;
#[derive(Clone)]
use super::gateway::cluster_config;
pub struct ShardStateManager {
redis: RedisPool,
shards: RwLock<HashMap<u32, ShardState>>,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
ShardStateManager {
redis: redis,
shards: RwLock::new(HashMap::new()),
}
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
// also update gateway.rs with event types
Event::Ready(_) => self.ready_or_resumed(shard_id, false).await,
Event::Resumed => self.ready_or_resumed(shard_id, true).await,
_ => Ok(()),
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<String> = self.redis.hget("pluralkit:shardstatus", shard_id).await?;
match data {
Some(buf) => Ok(serde_json::from_str(&buf).expect("could not decode shard data!")),
None => Ok(ShardState::default()),
async fn save_shard(&self, id: u32, state: ShardState) -> anyhow::Result<()> {
{
let mut shards = self.shards.write().await;
shards.insert(id, state.clone());
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
self.redis
.hset::<(), &str, (String, String)>(
"pluralkit:shardstatus",
(
shard_id.to_string(),
serde_json::to_string(&info).expect("could not serialize shard"),
id.to_string(),
serde_json::to_string(&state).expect("could not serialize shard"),
),
)
.await?;
Ok(())
}
async fn get_shard(&self, id: u32) -> Option<ShardState> {
let shards = self.shards.read().await;
shards.get(&id).cloned()
}
pub async fn get(&self) -> Vec<ShardState> {
self.shards.read().await.values().cloned().collect()
}
async fn ready_or_resumed(&self, shard_id: u32, resumed: bool) -> anyhow::Result<()> {
info!(
"shard {} {}",
@ -57,32 +71,52 @@ impl ShardStateManager {
)
.increment(1);
gauge!("pluralkit_gateway_shard_up").increment(1);
let mut info = self.get_shard(shard_id).await?;
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_up").decrement(1);
let mut info = self.get_shard(shard_id).await?;
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
pub async fn heartbeated(&self, shard_id: u32, latency: &Latency) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
pub async fn heartbeated(&self, shard_id: u32, latency: i32) -> anyhow::Result<()> {
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string()).set(latency);
let mut info = self
.get_shard(shard_id)
.await
.unwrap_or(ShardState::default());
info.shard_id = shard_id as i32;
info.cluster_id = Some(cluster_config().node_id as i32);
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
info.latency = latency
.recent()
.first()
.map_or_else(|| 0, |d| d.as_millis()) as i32;
gauge!("pluralkit_gateway_shard_latency", "shard_id" => shard_id.to_string())
.set(info.latency);
info.latency = latency;
self.save_shard(shard_id, info).await?;
Ok(())
}

View file

@ -0,0 +1,242 @@
// - reaction: (message_id, user_id)
// - message: (author_id, channel_id, ?options)
// - interaction: (custom_id where not_includes "help-menu")
use std::{
collections::{hash_map::Entry, HashMap},
net::{IpAddr, SocketAddr},
time::Duration,
};
use serde::Deserialize;
use tokio::{sync::RwLock, time::Instant};
use tracing::info;
use twilight_gateway::Event;
use twilight_model::{
application::interaction::InteractionData,
id::{
marker::{ChannelMarker, MessageMarker, UserMarker},
Id,
},
};
static DEFAULT_TIMEOUT: Duration = Duration::from_mins(15);
#[derive(Deserialize)]
#[serde(untagged)]
pub enum AwaitEventRequest {
Reaction {
message_id: Id<MessageMarker>,
user_id: Id<UserMarker>,
target: String,
timeout: Option<u64>,
},
Message {
channel_id: Id<ChannelMarker>,
author_id: Id<UserMarker>,
target: String,
timeout: Option<u64>,
options: Option<Vec<String>>,
},
Interaction {
id: String,
target: String,
timeout: Option<u64>,
},
}
pub struct EventAwaiter {
reactions: RwLock<HashMap<(Id<MessageMarker>, Id<UserMarker>), (Instant, String)>>,
messages: RwLock<
HashMap<(Id<ChannelMarker>, Id<UserMarker>), (Instant, String, Option<Vec<String>>)>,
>,
interactions: RwLock<HashMap<String, (Instant, String)>>,
}
impl EventAwaiter {
pub fn new() -> Self {
let v = Self {
reactions: RwLock::new(HashMap::new()),
messages: RwLock::new(HashMap::new()),
interactions: RwLock::new(HashMap::new()),
};
v
}
pub async fn cleanup_loop(&self) {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
info!("running event_awaiter cleanup loop");
let mut counts = (0, 0, 0);
let now = Instant::now();
{
let mut reactions = self.reactions.write().await;
for key in reactions.clone().keys() {
if let Entry::Occupied(entry) = reactions.entry(key.clone())
&& entry.get().0 < now
{
counts.0 += 1;
entry.remove();
}
}
}
{
let mut messages = self.messages.write().await;
for key in messages.clone().keys() {
if let Entry::Occupied(entry) = messages.entry(key.clone())
&& entry.get().0 < now
{
counts.1 += 1;
entry.remove();
}
}
}
{
let mut interactions = self.interactions.write().await;
for key in interactions.clone().keys() {
if let Entry::Occupied(entry) = interactions.entry(key.clone())
&& entry.get().0 < now
{
counts.2 += 1;
entry.remove();
}
}
}
info!("ran event_awaiter cleanup loop, took {}us, {} reactions, {} messages, {} interactions", Instant::now().duration_since(now).as_micros(), counts.0, counts.1, counts.2);
}
}
pub async fn target_for_event(&self, event: Event) -> Option<String> {
match event {
Event::MessageCreate(message) => {
let mut messages = self.messages.write().await;
messages
.remove(&(message.channel_id, message.author.id))
.map(|(timeout, target, options)| {
if let Some(options) = options
&& !options.contains(&message.content.to_lowercase())
{
messages.insert(
(message.channel_id, message.author.id),
(timeout, target, Some(options)),
);
return None;
}
Some((*target).to_string())
})?
}
Event::ReactionAdd(reaction)
if let Some((_, target)) = self
.reactions
.write()
.await
.remove(&(reaction.message_id, reaction.user_id)) =>
{
Some((*target).to_string())
}
Event::InteractionCreate(interaction)
if let Some(data) = interaction.data.clone()
&& let InteractionData::MessageComponent(component) = data
&& !component.custom_id.contains("help-menu")
&& let Some((_, target)) =
self.interactions.write().await.remove(&component.custom_id) =>
{
Some((*target).to_string())
}
_ => None,
}
}
pub async fn handle_request(&self, req: AwaitEventRequest, addr: SocketAddr) {
match req {
AwaitEventRequest::Reaction {
message_id,
user_id,
target,
timeout,
} => {
self.reactions.write().await.insert(
(message_id, user_id),
(
Instant::now()
.checked_add(
timeout
.map(|i| Duration::from_secs(i))
.unwrap_or(DEFAULT_TIMEOUT),
)
.expect("invalid time"),
target_or_addr(target, addr),
),
);
}
AwaitEventRequest::Message {
channel_id,
author_id,
target,
timeout,
options,
} => {
self.messages.write().await.insert(
(channel_id, author_id),
(
Instant::now()
.checked_add(
timeout
.map(|i| Duration::from_secs(i))
.unwrap_or(DEFAULT_TIMEOUT),
)
.expect("invalid time"),
target_or_addr(target, addr),
options,
),
);
}
AwaitEventRequest::Interaction {
id,
target,
timeout,
} => {
self.interactions.write().await.insert(
id,
(
Instant::now()
.checked_add(
timeout
.map(|i| Duration::from_secs(i))
.unwrap_or(DEFAULT_TIMEOUT),
)
.expect("invalid time"),
target_or_addr(target, addr),
),
);
}
}
}
pub async fn clear(&self) {
self.reactions.write().await.clear();
self.messages.write().await.clear();
self.interactions.write().await.clear();
}
}
fn target_or_addr(target: String, addr: SocketAddr) -> String {
if target == "source-addr" {
let ip_str = match addr.ip() {
IpAddr::V4(v4) => v4.to_string(),
IpAddr::V6(v6) => {
if let Some(v4) = v6.to_ipv4_mapped() {
v4.to_string()
} else {
format!("[{v6}]")
}
}
};
format!("http://{ip_str}:5002/events")
} else {
target
}
}

View file

@ -1,39 +1,79 @@
#![feature(let_chains)]
#![feature(if_let_guard)]
#![feature(duration_constructors)]
use chrono::Timelike;
use discord::gateway::cluster_config;
use event_awaiter::EventAwaiter;
use fred::{clients::RedisPool, interfaces::*};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
use libpk::{runtime_config::RuntimeConfig, state::ShardStateEvent};
use reqwest::{ClientBuilder, StatusCode};
use std::{sync::Arc, time::Duration, vec::Vec};
use tokio::{
signal::unix::{signal, SignalKind},
sync::mpsc::channel,
task::JoinSet,
};
use std::{
sync::{mpsc::channel, Arc},
time::Duration,
vec::Vec,
};
use tokio::task::JoinSet;
use tracing::{info, warn};
use tracing::{error, info, warn};
use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence;
mod cache_api;
mod api;
mod discord;
mod event_awaiter;
mod logger;
libpk::main!("gateway");
async fn real_main() -> anyhow::Result<()> {
let (shutdown_tx, shutdown_rx) = channel::<()>();
let shutdown_tx = Arc::new(shutdown_tx);
const RUNTIME_CONFIG_KEY_EVENT_TARGET: &'static str = "event_target";
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let redis = libpk::db::init_redis().await?;
let shard_state = discord::shard_state::new(redis.clone());
let runtime_config = Arc::new(
RuntimeConfig::new(
redis.clone(),
format!(
"{}:{}",
libpk::config.runtime_config_key.as_ref().unwrap(),
cluster_config().node_id
),
)
.await?,
);
// hacky, but needed for selfhost for now
if let Some(target) = libpk::config
.discord
.as_ref()
.unwrap()
.gateway_target
.clone()
{
runtime_config
.set(RUNTIME_CONFIG_KEY_EVENT_TARGET.to_string(), target)
.await?;
}
let cache = Arc::new(discord::cache::new());
let awaiter = Arc::new(EventAwaiter::new());
tokio::spawn({
let awaiter = awaiter.clone();
async move { awaiter.cleanup_loop().await }
});
let shards = discord::gateway::create_shards(redis.clone())?;
let (event_tx, _event_rx) = channel();
// arbitrary
// todo: make sure this doesn't fill up
let (event_tx, mut event_rx) = channel::<(ShardId, twilight_gateway::Event, String)>(1000);
// todo: make sure this doesn't fill up
let (state_tx, mut state_rx) = channel::<(
ShardId,
ShardStateEvent,
Option<twilight_gateway::Event>,
Option<i32>,
)>(1000);
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
@ -45,61 +85,160 @@ async fn real_main() -> anyhow::Result<()> {
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
state_tx.clone(),
cache.clone(),
runtime_config.clone(),
)));
}
let shard_state = Arc::new(discord::shard_state::new(redis.clone()));
set.spawn(tokio::spawn({
let shard_state = shard_state.clone();
async move {
while let Some((shard_id, state_event, parsed_event, latency)) = state_rx.recv().await {
match state_event {
ShardStateEvent::Heartbeat => {
if !latency.is_none()
&& let Err(error) = shard_state
.heartbeated(shard_id.number(), latency.unwrap())
.await
{
error!("failed to update shard state for heartbeat: {error}")
};
}
ShardStateEvent::Closed => {
if let Err(error) = shard_state.socket_closed(shard_id.number()).await {
error!("failed to update shard state for heartbeat: {error}")
};
}
ShardStateEvent::Other => {
if let Err(error) = shard_state
.handle_event(
shard_id.number(),
parsed_event.expect("shard state event not provided!"),
)
.await
{
error!("failed to update shard state for heartbeat: {error}")
};
}
}
}
}
}));
set.spawn(tokio::spawn({
let runtime_config = runtime_config.clone();
let awaiter = awaiter.clone();
async move {
let client = Arc::new(
ClientBuilder::new()
.connect_timeout(Duration::from_secs(1))
.timeout(Duration::from_secs(1))
.build()
.expect("error making client"),
);
while let Some((shard_id, parsed_event, raw_event)) = event_rx.recv().await {
let target = if let Some(target) = awaiter.target_for_event(parsed_event).await {
info!(target = ?target, "sending event to awaiter");
Some(target)
} else if let Some(target) =
runtime_config.get(RUNTIME_CONFIG_KEY_EVENT_TARGET).await
{
Some(target)
} else {
None
};
if let Some(target) = target {
tokio::spawn({
let client = client.clone();
async move {
match client
.post(format!("{target}/{}", shard_id.number()))
.body(raw_event)
.send()
.await
{
Ok(res) => {
if res.status() != StatusCode::OK {
error!(
status = ?res.status(),
target = ?target,
"got non-200 from bot while sending event",
);
}
}
Err(error) => {
error!(?error, "failed to request event target");
}
}
}
});
}
}
}
}));
set.spawn(tokio::spawn(
async move { scheduled_task(redis, senders).await },
));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache).await {
match api::run_server(cache, shard_state, runtime_config, awaiter.clone()).await {
Err(error) => {
tracing::error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
error!(?error, "failed to serve cache api");
}
_ => unreachable!(),
}
}));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
set.spawn(tokio::spawn(async move {
for sig in signals.forever() {
info!("received signal {:?}", sig);
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(());
break;
}
signal(SignalKind::interrupt()).unwrap().recv().await;
info!("got SIGINT");
}));
let _ = shutdown_rx.recv();
set.spawn(tokio::spawn(async move {
signal(SignalKind::terminate()).unwrap().recv().await;
info!("got SIGTERM");
}));
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
set.join_next().await;
info!("gateway exiting, have a nice day!");
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
set.abort_all();
info!("gateway exiting, have a nice day!");
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
Ok(())
}
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
let prefix = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_prefix_for_gateway
.clone();
println!("{prefix}");
loop {
tokio::time::sleep(Duration::from_secs(
(60 - chrono::offset::Utc::now().second()).into(),
@ -119,9 +258,9 @@ async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence(
if let Some(status) = status {
format!("pk;help | {}", status)
format!("{prefix}help | {status}")
} else {
"pk;help".to_string()
format!("{prefix}help")
}
.as_str(),
false,

View file

@ -0,0 +1,15 @@
[package]
name = "gdpr_worker"
version = "0.1.0"
edition = "2021"
[dependencies]
libpk = { path = "../libpk" }
anyhow = { workspace = true }
axum = { workspace = true }
futures = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
twilight-model = { workspace = true }

View file

@ -0,0 +1,149 @@
#![feature(let_chains)]
use sqlx::prelude::FromRow;
use std::{sync::Arc, time::Duration};
use tracing::{error, info, warn};
use twilight_http::api_error::{ApiError, GeneralApiError};
use twilight_model::id::{
marker::{ChannelMarker, MessageMarker},
Id,
};
// create table messages_gdpr_jobs (mid bigint not null references messages(mid) on delete cascade, channel bigint not null);
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_messages_db().await?;
let mut client_builder = twilight_http::Client::builder()
.token(
libpk::config
.discord
.as_ref()
.expect("missing discord config")
.bot_token
.clone(),
)
.timeout(Duration::from_secs(30));
if let Some(base_url) = libpk::config
.discord
.as_ref()
.expect("missing discord config")
.api_base_url
.clone()
{
client_builder = client_builder.proxy(base_url, true).ratelimiter(None);
}
let client = Arc::new(client_builder.build());
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
match run_job(db.clone(), client.clone()).await {
Ok(()) => {}
Err(error) => {
error!(?error, "failed to run messages gdpr job");
}
}
}
}
#[derive(FromRow)]
struct GdprJobEntry {
mid: i64,
channel_id: i64,
}
async fn run_job(pool: sqlx::PgPool, discord: Arc<twilight_http::Client>) -> anyhow::Result<()> {
let mut tx = pool.begin().await?;
let message: Option<GdprJobEntry> = sqlx::query_as(
"select mid, channel_id from messages_gdpr_jobs for update skip locked limit 1;",
)
.fetch_optional(&mut *tx)
.await?;
let Some(message) = message else {
info!("no job to run, sleeping for 1 minute");
tokio::time::sleep(tokio::time::Duration::from_secs(60)).await;
return Ok(());
};
info!("got mid={}, cleaning up...", message.mid);
// naively delete message on discord's end
let res = discord
.delete_message(
Id::<ChannelMarker>::new(message.channel_id as u64),
Id::<MessageMarker>::new(message.mid as u64),
)
.await;
if res.is_ok() {
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
.bind(message.mid)
.execute(&mut *tx)
.await?;
}
if let Err(err) = res {
if let twilight_http::error::ErrorType::Response { error, status, .. } = err.kind()
&& let ApiError::General(GeneralApiError { code, .. }) = error
{
match (status.get(), code) {
(403, _) => {
warn!(
"got 403 while deleting message in channel {}, failing fast",
message.channel_id
);
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
.bind(message.channel_id)
.execute(&mut *tx)
.await?;
}
(_, 10003) => {
warn!(
"deleting message in channel {}: channel not found, failing fast",
message.channel_id
);
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
.bind(message.channel_id)
.execute(&mut *tx)
.await?;
}
(_, 10008) => {
warn!("deleting message {}: message not found", message.mid);
sqlx::query("delete from messages_gdpr_jobs where mid = $1")
.bind(message.mid)
.execute(&mut *tx)
.await?;
}
(_, 50083) => {
warn!(
"could not delete message in thread {}: thread is archived, failing fast",
message.channel_id
);
sqlx::query("delete from messages_gdpr_jobs where channel_id = $1")
.bind(message.channel_id)
.execute(&mut *tx)
.await?;
}
_ => {
error!(
?status,
?code,
message_id = message.mid,
"got unknown error deleting message",
);
}
}
} else {
return Err(err.into());
}
}
tx.commit().await?;
return Ok(());
}

0
crates/h Normal file
View file

View file

@ -8,6 +8,7 @@ anyhow = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
pk_macros = { path = "../macros" }
sentry = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }

View file

@ -12,10 +12,16 @@ pub struct ClusterSettings {
pub total_nodes: u32,
}
fn _default_bot_prefix() -> String {
"pk;".to_string()
}
#[derive(Deserialize, Debug)]
pub struct DiscordConfig {
pub client_id: Id<UserMarker>,
pub bot_token: String,
#[serde(default = "_default_bot_prefix")]
pub bot_prefix_for_gateway: String,
pub client_secret: String,
pub max_concurrency: u32,
#[serde(default)]
@ -24,6 +30,9 @@ pub struct DiscordConfig {
#[serde(default = "_default_api_addr")]
pub cache_api_addr: String,
#[serde(default)]
pub gateway_target: Option<String>,
}
#[derive(Deserialize, Debug)]
@ -85,6 +94,7 @@ pub struct ScheduledTasksConfig {
pub set_guild_count: bool,
pub expected_gateway_count: usize,
pub gateway_url: String,
pub prometheus_url: String,
}
fn _metrics_default() -> bool {
@ -113,6 +123,9 @@ pub struct PKConfig {
#[serde(default = "_json_log_default")]
pub(crate) json_log: bool,
#[serde(default)]
pub runtime_config_key: Option<String>,
#[serde(default)]
pub sentry_url: Option<String>,
}
@ -132,10 +145,15 @@ impl PKConfig {
lazy_static! {
#[derive(Debug)]
pub static ref CONFIG: Arc<PKConfig> = {
// hacks
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
}
if let Ok(var) = std::env::var("STATEFULSET_NAME_FOR_INDEX")
&& std::env::var("pluralkit__discord__cluster__total_nodes").is_ok() {
std::env::set_var("pluralkit__discord__cluster__node_id", var.split("-").last().unwrap());
}
Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))

View file

@ -8,12 +8,15 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt, EnvFilte
use sentry_tracing::event_from_event;
pub mod db;
pub mod runtime_config;
pub mod state;
pub mod _config;
pub use crate::_config::CONFIG as config;
// functions in this file are only used by the main function below
// functions in this file are only used by the main function in macros/entrypoint.rs
pub use pk_macros::main;
pub fn init_logging(component: &str) {
let sentry_layer =
@ -42,6 +45,7 @@ pub fn init_logging(component: &str) {
tracing_subscriber::registry()
.with(sentry_layer)
.with(tracing_subscriber::fmt::layer())
.with(EnvFilter::from_default_env())
.init();
}
}
@ -66,28 +70,3 @@ pub fn init_sentry() -> sentry::ClientInitGuard {
..Default::default()
})
}
#[macro_export]
macro_rules! main {
($component:expr) => {
fn main() -> anyhow::Result<()> {
let _sentry_guard = libpk::init_sentry();
// we might also be able to use env!("CARGO_CRATE_NAME") here
libpk::init_logging($component);
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
if let Err(err) = libpk::init_metrics() {
tracing::error!("failed to init metrics collector: {err}");
};
tracing::info!("hello world");
if let Err(err) = real_main().await {
tracing::error!("failed to run service: {err}");
};
});
Ok(())
}
};
}

View file

@ -0,0 +1,72 @@
use fred::{clients::RedisPool, interfaces::HashesInterface};
use std::collections::HashMap;
use tokio::sync::RwLock;
use tracing::info;
pub struct RuntimeConfig {
redis: RedisPool,
settings: RwLock<HashMap<String, String>>,
redis_key: String,
}
impl RuntimeConfig {
pub async fn new(redis: RedisPool, component_key: String) -> anyhow::Result<Self> {
let redis_key = format!("remote_config:{component_key}");
let mut c = RuntimeConfig {
redis,
settings: RwLock::new(HashMap::new()),
redis_key,
};
c.load().await?;
Ok(c)
}
pub async fn load(&mut self) -> anyhow::Result<()> {
let redis_config: HashMap<String, String> = self.redis.hgetall(&self.redis_key).await?;
let mut settings = self.settings.write().await;
for (key, value) in redis_config {
settings.insert(key, value);
}
info!("starting with runtime config: {:?}", settings);
Ok(())
}
pub async fn set(&self, key: String, value: String) -> anyhow::Result<()> {
self.redis
.hset::<(), &str, (String, String)>(&self.redis_key, (key.clone(), value.clone()))
.await?;
self.settings
.write()
.await
.insert(key.clone(), value.clone());
info!("updated runtime config: {key}={value}");
Ok(())
}
pub async fn delete(&self, key: String) -> anyhow::Result<()> {
self.redis
.hdel::<(), &str, String>(&self.redis_key, key.clone())
.await?;
self.settings.write().await.remove(&key.clone());
info!("updated runtime config: {key} removed");
Ok(())
}
pub async fn get(&self, key: &str) -> Option<String> {
self.settings.read().await.get(key).cloned()
}
pub async fn exists(&self, key: &str) -> bool {
self.settings.read().await.contains_key(key)
}
pub async fn get_all(&self) -> HashMap<String, String> {
self.settings.read().await.clone()
}
}

View file

@ -1,4 +1,4 @@
#[derive(serde::Serialize, serde::Deserialize, Clone, Default)]
#[derive(serde::Serialize, serde::Deserialize, Clone, Default, Debug)]
pub struct ShardState {
pub shard_id: i32,
pub up: bool,
@ -10,3 +10,9 @@ pub struct ShardState {
pub last_connection: i32,
pub cluster_id: Option<i32>,
}
pub enum ShardStateEvent {
Closed,
Heartbeat,
Other,
}

View file

@ -1,5 +1,5 @@
[package]
name = "model_macros"
name = "pk_macros"
version = "0.1.0"
edition = "2021"

View file

@ -0,0 +1,41 @@
use proc_macro::{Delimiter, TokenTree};
use quote::quote;
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
// yes, this ignores everything except the codeblock
// it's fine.
let body = match input.into_iter().last().expect("empty") {
TokenTree::Group(group) if group.delimiter() == Delimiter::Brace => group.stream(),
_ => panic!("invalid function"),
};
let body = proc_macro2::TokenStream::from(body);
return quote! {
fn main() {
let _sentry_guard = libpk::init_sentry();
libpk::init_logging(env!("CARGO_CRATE_NAME"));
tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()
.unwrap()
.block_on(async {
if let Err(error) = libpk::init_metrics() {
tracing::error!(?error, "failed to init metrics collector");
};
tracing::info!("hello world");
let result: anyhow::Result<()> = async { #body }.await;
if let Err(error) = result {
tracing::error!(?error, "failed to run service");
};
});
}
}
.into();
}

14
crates/macros/src/lib.rs Normal file
View file

@ -0,0 +1,14 @@
use proc_macro::TokenStream;
mod entrypoint;
mod model;
#[proc_macro_attribute]
pub fn main(args: TokenStream, input: TokenStream) -> TokenStream {
entrypoint::macro_impl(args, input)
}
#[proc_macro_attribute]
pub fn pk_model(args: TokenStream, input: TokenStream) -> TokenStream {
model::macro_impl(args, input)
}

View file

@ -16,6 +16,7 @@ struct ModelField {
patch: ElemPatchability,
json: Option<Expr>,
is_privacy: bool,
privacy: Option<Expr>,
default: Option<Expr>,
}
@ -26,6 +27,7 @@ fn parse_field(field: syn::Field) -> ModelField {
patch: ElemPatchability::None,
json: None,
is_privacy: false,
privacy: None,
default: None,
};
@ -61,6 +63,12 @@ fn parse_field(field: syn::Field) -> ModelField {
}
f.json = Some(nv.value.clone());
}
"privacy" => {
if f.privacy.is_some() {
panic!("cannot set privacy multiple times for same field");
}
f.privacy = Some(nv.value.clone());
}
"default" => {
if f.default.is_some() {
panic!("cannot set default multiple times for same field");
@ -84,8 +92,7 @@ fn parse_field(field: syn::Field) -> ModelField {
f
}
#[proc_macro_attribute]
pub fn pk_model(
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
) -> proc_macro::TokenStream {
@ -108,8 +115,6 @@ pub fn pk_model(
panic!("fields of a struct must be named");
};
// println!("{}: {:#?}", tname, fields);
let tfields = mk_tfields(fields.clone());
let from_json = mk_tfrom_json(fields.clone());
let _from_sql = mk_tfrom_sql(fields.clone());
@ -138,9 +143,7 @@ pub fn pk_model(
#from_json
}
pub fn to_json(self) -> serde_json::Value {
#to_json
}
#to_json
}
#[derive(Debug, Clone)]
@ -189,19 +192,28 @@ fn mk_tfrom_sql(_fields: Vec<ModelField>) -> TokenStream {
quote! { unimplemented!(); }
}
fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
// todo: check privacy access
let has_privacy = fields.iter().any(|f| f.privacy.is_some());
let fielddefs: TokenStream = fields
.iter()
.filter_map(|f| {
f.json.as_ref().map(|v| {
let tname = f.name.clone();
if let Some(default) = f.default.as_ref() {
let maybepriv = if let Some(privacy) = f.privacy.as_ref() {
quote! {
#v: self.#tname.unwrap_or(#default),
#v: crate::_util::privacy_lookup!(self.#tname, self.#privacy, lookup_level)
}
} else {
quote! {
#v: self.#tname,
#v: self.#tname
}
};
if let Some(default) = f.default.as_ref() {
quote! {
#maybepriv.unwrap_or(#default),
}
} else {
quote! {
#maybepriv,
}
}
})
@ -223,13 +235,35 @@ fn mk_tto_json(fields: Vec<ModelField>) -> TokenStream {
})
.collect();
quote! {
serde_json::json!({
#fielddefs
"privacy": {
#privacyfielddefs
let privdef = if has_privacy {
quote! {
, lookup_level: crate::PrivacyLevel
}
} else {
quote! {}
};
let privacy_fielddefs = if has_privacy {
quote! {
"privacy": if matches!(lookup_level, crate::PrivacyLevel::Private) {
Some(serde_json::json!({
#privacyfielddefs
}))
} else {
None
}
})
}
} else {
quote! {}
};
quote! {
pub fn to_json(self #privdef) -> serde_json::Value {
serde_json::json!({
#fielddefs
#privacy_fielddefs
})
}
}
}

Some files were not shown because too many files have changed in this diff Show more