diff --git a/PluralKit.API/ApiConfig.cs b/PluralKit.API/ApiConfig.cs index 46556a79..8da70196 100644 --- a/PluralKit.API/ApiConfig.cs +++ b/PluralKit.API/ApiConfig.cs @@ -7,4 +7,5 @@ public class ApiConfig public string? ClientSecret { get; set; } public bool TrustAuth { get; set; } = false; public string? AvatarServiceUrl { get; set; } + public bool SearchGuildSettings = false; } \ No newline at end of file diff --git a/PluralKit.API/Controllers/v2/DiscordControllerV2.cs b/PluralKit.API/Controllers/v2/DiscordControllerV2.cs index 340c8698..7547a751 100644 --- a/PluralKit.API/Controllers/v2/DiscordControllerV2.cs +++ b/PluralKit.API/Controllers/v2/DiscordControllerV2.cs @@ -20,7 +20,7 @@ public class DiscordControllerV2: PKControllerBase if (ContextFor(system) != LookupContext.ByOwner) throw Errors.GenericMissingPermissions; - var settings = await _repo.GetSystemGuild(guild_id, system.Id, false); + var settings = await _repo.GetSystemGuild(guild_id, system.Id, false, _config.SearchGuildSettings); if (settings == null) throw Errors.SystemGuildNotFound; @@ -34,7 +34,7 @@ public class DiscordControllerV2: PKControllerBase if (ContextFor(system) != LookupContext.ByOwner) throw Errors.GenericMissingPermissions; - var settings = await _repo.GetSystemGuild(guild_id, system.Id, false); + var settings = await _repo.GetSystemGuild(guild_id, system.Id, false, _config.SearchGuildSettings); if (settings == null) throw Errors.SystemGuildNotFound; @@ -58,7 +58,7 @@ public class DiscordControllerV2: PKControllerBase if (member.System != system.Id) throw Errors.NotOwnMemberError; - var settings = await _repo.GetMemberGuild(guild_id, member.Id, false); + var settings = await _repo.GetMemberGuild(guild_id, member.Id, false, _config.SearchGuildSettings ? system.Id : null); if (settings == null) throw Errors.MemberGuildNotFound; @@ -75,7 +75,7 @@ public class DiscordControllerV2: PKControllerBase if (member.System != system.Id) throw Errors.NotOwnMemberError; - var settings = await _repo.GetMemberGuild(guild_id, member.Id, false); + var settings = await _repo.GetMemberGuild(guild_id, member.Id, false, _config.SearchGuildSettings ? system.Id : null); if (settings == null) throw Errors.MemberGuildNotFound; diff --git a/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs b/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs index f0d70900..361a2bf3 100644 --- a/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs +++ b/PluralKit.Core/Database/Repository/ModelRepository.Guild.cs @@ -19,16 +19,40 @@ public partial class ModelRepository } - public Task GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true) + public async Task GetSystemGuild(ulong guild, SystemId system, bool defaultInsert = true, bool search = false) { if (!defaultInsert) - return _db.QueryFirst(new Query("system_guild") + { + var simpleRes = await _db.QueryFirst(new Query("system_guild") .Where("guild", guild) .Where("system", system) ); + if (simpleRes != null || !search) + return simpleRes; + + var accounts = await GetSystemAccounts(system); + + var searchRes = await _db.QueryFirst( + "select exists(select 1 from command_messages where guild = @guild and sender = any(@accounts))", + new { guild = guild, accounts = accounts.Select(u => (long)u).ToArray() }, + queryName: "find_system_from_commands", + messages: true + ); + + if (!searchRes) + searchRes = await _db.QueryFirst( + "select exists(select 1 from command_messages where guild = @guild and sender = any(@accounts))", + new { guild = guild, accounts = accounts.Select(u => (long)u).ToArray() }, + queryName: "find_system_from_messages", + messages: true + ); + + if (!searchRes) + return null; + } var query = new Query("system_guild").AsInsert(new { guild, system }); - return _db.QueryFirst(query, + return await _db.QueryFirst(query, "on conflict (guild, system) do update set guild = $1, system = $2 returning *" ); } @@ -42,16 +66,25 @@ public partial class ModelRepository return settings; } - public Task GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true) + public async Task GetMemberGuild(ulong guild, MemberId member, bool defaultInsert = true, SystemId? search = null) { if (!defaultInsert) - return _db.QueryFirst(new Query("member_guild") + { + var simpleRes = await _db.QueryFirst(new Query("member_guild") .Where("guild", guild) .Where("member", member) ); + if (simpleRes != null || !search.HasValue) + return simpleRes; + + var systemConfig = await GetSystemGuild(guild, search.Value, defaultInsert: false, search: true); + + if (systemConfig == null) + return null; + } var query = new Query("member_guild").AsInsert(new { guild, member }); - return _db.QueryFirst(query, + return await _db.QueryFirst(query, "on conflict (guild, member) do update set guild = $1, member = $2 returning *" ); } diff --git a/PluralKit.Core/Database/Utils/DatabaseMigrator.cs b/PluralKit.Core/Database/Utils/DatabaseMigrator.cs index 11c6e2a9..d4d58093 100644 --- a/PluralKit.Core/Database/Utils/DatabaseMigrator.cs +++ b/PluralKit.Core/Database/Utils/DatabaseMigrator.cs @@ -9,7 +9,7 @@ namespace PluralKit.Core; internal class DatabaseMigrator { private const string RootPath = "PluralKit.Core.Database"; // "resource path" root for SQL files - private const int TargetSchemaVersion = 51; + private const int TargetSchemaVersion = 52; private readonly ILogger _logger; public DatabaseMigrator(ILogger logger)