[WIP] feat: scoped api keys

This commit is contained in:
Iris System 2025-08-17 02:47:01 -07:00
parent e7ee593a85
commit 06cb160f95
45 changed files with 1264 additions and 154 deletions

54
Cargo.lock generated
View file

@ -82,9 +82,12 @@ version = "0.1.0"
dependencies = [
"anyhow",
"axum 0.8.4",
"base64 0.22.1",
"chrono",
"fred",
"hyper 1.6.0",
"hyper-util",
"jsonwebtoken",
"lazy_static",
"libpk",
"metrics",
@ -102,6 +105,7 @@ dependencies = [
"tower-http",
"tracing",
"twilight-http",
"uuid",
]
[[package]]
@ -1963,6 +1967,21 @@ dependencies = [
"serde",
]
[[package]]
name = "jsonwebtoken"
version = "9.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde"
dependencies = [
"base64 0.22.1",
"js-sys",
"pem",
"ring 0.17.14",
"serde",
"serde_json",
"simple_asn1",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
@ -2008,6 +2027,7 @@ dependencies = [
"config",
"fred",
"json-subscriber",
"jsonwebtoken",
"lazy_static",
"metrics",
"metrics-exporter-prometheus",
@ -2260,6 +2280,16 @@ dependencies = [
"winapi",
]
[[package]]
name = "num-bigint"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9"
dependencies = [
"num-integer",
"num-traits",
]
[[package]]
name = "num-bigint-dig"
version = "0.8.4"
@ -2435,6 +2465,16 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3"
[[package]]
name = "pem"
version = "3.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3"
dependencies = [
"base64 0.22.1",
"serde",
]
[[package]]
name = "pem-rfc7468"
version = "0.7.0"
@ -2568,7 +2608,9 @@ checksum = "7edddbd0b52d732b21ad9a5fab5c704c14cd949e5e9a1ec5929a24fded1b904c"
name = "pluralkit_models"
version = "0.1.0"
dependencies = [
"base64 0.22.1",
"chrono",
"jsonwebtoken",
"pk_macros",
"sea-query",
"serde",
@ -3684,6 +3726,18 @@ version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e"
[[package]]
name = "simple_asn1"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "297f631f50729c8c99b84667867963997ec0b50f32b2a7dbcab828ef0541e8bb"
dependencies = [
"num-bigint",
"num-traits",
"thiserror 2.0.12",
"time",
]
[[package]]
name = "sketches-ddsketch"
version = "0.2.2"

View file

@ -7,10 +7,12 @@ resolver = "2"
[workspace.dependencies]
anyhow = "1"
axum-macros = "0.4.1"
base64 = "0.22.1"
bytes = "1.6.0"
chrono = "0.4"
fred = { version = "9.3.0", default-features = false, features = ["tracing", "i-keys", "i-hashes", "i-scripts", "sha-1"] }
futures = "0.3.30"
jsonwebtoken = { version = "9.3.0", features = ["pem"] }
lazy_static = "1.4.0"
metrics = "0.23.0"
reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-tls", "trust-dns"]}

View file

@ -15,11 +15,18 @@ public class AuthorizationTokenHandlerMiddleware
public async Task Invoke(HttpContext ctx, IDatabase db, ApiConfig cfg)
{
if (cfg.TrustAuth
&& ctx.Request.Headers.TryGetValue("X-PluralKit-SystemId", out var sidHeaders)
&& sidHeaders.Count > 0
&& int.TryParse(sidHeaders[0], out var systemId))
ctx.Items.Add("SystemId", new SystemId(systemId));
if (cfg.TrustAuth)
{
if (ctx.Request.Headers.TryGetValue("X-PluralKit-SystemId", out var sidHeaders)
&& sidHeaders.Count > 0
&& int.TryParse(sidHeaders[0], out var systemId))
ctx.Items.Add("SystemId", new SystemId(systemId));
if (ctx.Request.Headers.TryGetValue("X-PluralKit-PrivacyLevel", out var levelHeaders)
&& levelHeaders.Count > 0)
ctx.Items.Add("LookupContext",
levelHeaders[0].ToLower().Trim() == "private" ? LookupContext.ByOwner : LookupContext.ByNonOwner);
}
if (cfg.TrustAuth
&& ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders)

View file

@ -101,24 +101,48 @@ public class PKControllerBase: ControllerBase
return null;
}
protected bool IsAuthenticatedAs(SystemId system)
{
HttpContext.Items.TryGetValue("SystemId", out var systemId);
return systemId != null && (SystemId)systemId == system;
}
protected LookupContext ContextFor(PKSystem system)
{
HttpContext.Items.TryGetValue("SystemId", out var systemId);
if (systemId == null) return LookupContext.ByNonOwner;
return (SystemId)systemId == system.Id ? LookupContext.ByOwner : LookupContext.ByNonOwner;
HttpContext.Items.TryGetValue("LookupContext", out var lookupCtx);
if (systemId != null && (SystemId)systemId == system.Id)
{
if (lookupCtx != null) return (LookupContext)lookupCtx;
return LookupContext.ByOwner;
}
return LookupContext.ByNonOwner;
}
protected LookupContext ContextFor(PKMember member)
{
HttpContext.Items.TryGetValue("SystemId", out var systemId);
if (systemId == null) return LookupContext.ByNonOwner;
return (SystemId)systemId == member.System ? LookupContext.ByOwner : LookupContext.ByNonOwner;
HttpContext.Items.TryGetValue("LookupContext", out var lookupCtx);
if (systemId != null && (SystemId)systemId == member.System)
{
if (lookupCtx != null) return (LookupContext)lookupCtx;
return LookupContext.ByOwner;
}
return LookupContext.ByNonOwner;
}
protected LookupContext ContextFor(PKGroup group)
{
HttpContext.Items.TryGetValue("SystemId", out var systemId);
if (systemId == null) return LookupContext.ByNonOwner;
return (SystemId)systemId == group.System ? LookupContext.ByOwner : LookupContext.ByNonOwner;
HttpContext.Items.TryGetValue("LookupContext", out var lookupCtx);
if (systemId != null && (SystemId)systemId == group.System)
{
if (lookupCtx != null) return (LookupContext)lookupCtx;
return LookupContext.ByOwner;
}
return LookupContext.ByNonOwner;
}
}

View file

@ -21,7 +21,7 @@ public class GroupControllerV2: PKControllerBase
var ctx = ContextFor(system);
if (!system.GroupListPrivacy.CanAccess(ContextFor(system)))
if (!IsAuthenticatedAs(system.Id) && !system.GroupListPrivacy.CanAccess(ContextFor(system)))
throw Errors.UnauthorizedGroupList;
var groups = _repo.GetSystemGroups(system.Id);

View file

@ -19,10 +19,9 @@ public class GroupMemberControllerV2: PKControllerBase
if (group == null)
throw Errors.GroupNotFound;
var ctx = ContextFor(group);
if (!group.ListPrivacy.CanAccess(ctx))
if (!IsAuthenticatedAs(group.System) && !group.ListPrivacy.CanAccess(ctx))
throw Errors.UnauthorizedGroupMemberList;
var system = await _repo.GetSystem(group.System);
@ -154,7 +153,7 @@ public class GroupMemberControllerV2: PKControllerBase
var ctx = ContextFor(member);
var system = await _repo.GetSystem(member.System);
if (!system.GroupListPrivacy.CanAccess(ctx))
if (!IsAuthenticatedAs(member.System) && !system.GroupListPrivacy.CanAccess(ctx))
throw Errors.UnauthorizedGroupList;
var groups = _repo.GetMemberGroups(member.Id).Where(g => g.Visibility.CanAccess(ctx));

View file

@ -26,7 +26,7 @@ public class MemberControllerV2: PKControllerBase
var ctx = ContextFor(system);
if (!system.MemberListPrivacy.CanAccess(ContextFor(system)))
if (!IsAuthenticatedAs(system.Id) && !system.MemberListPrivacy.CanAccess(ContextFor(system)))
throw Errors.UnauthorizedMemberList;
var members = _repo.GetSystemMembers(system.Id);

View file

@ -28,7 +28,7 @@ public class SwitchControllerV2: PKControllerBase
var ctx = ContextFor(system);
if (!system.FrontHistoryPrivacy.CanAccess(ctx))
if (!IsAuthenticatedAs(system.Id) && !system.FrontHistoryPrivacy.CanAccess(ctx))
throw Errors.UnauthorizedFrontHistory;
if (before == null)
@ -59,7 +59,7 @@ public class SwitchControllerV2: PKControllerBase
var ctx = ContextFor(system);
if (!system.FrontPrivacy.CanAccess(ctx))
if (!IsAuthenticatedAs(system.Id) && !system.FrontPrivacy.CanAccess(ctx))
throw Errors.UnauthorizedCurrentFronters;
var sw = await _repo.GetLatestSwitch(system.Id);
@ -145,7 +145,7 @@ public class SwitchControllerV2: PKControllerBase
var ctx = ContextFor(system);
if (!system.FrontHistoryPrivacy.CanAccess(ctx))
if (!IsAuthenticatedAs(system.Id) && !system.FrontHistoryPrivacy.CanAccess(ctx))
throw Errors.SwitchNotFoundPublic;
var members = _db.Execute(conn => _repo.GetSwitchMembers(conn, sw.Id));

View file

@ -33,6 +33,11 @@ public partial class CommandTree
public static Command ConfigProxySwitch = new Command("config proxyswitch", "config proxyswitch [new|add|off]", "Switching behavior when proxy tags are used");
public static Command ConfigNameFormat = new Command("config nameformat", "config nameformat [format]", "Changes your system's username formatting");
public static Command ConfigServerNameFormat = new Command("config servernameformat", "config servernameformat [format]", "Changes your system's username formatting in the current server");
public static Command ApiKeyCreate = new Command("system apikey new", "system apikey new <name> <type>", "Create a new API key");
public static Command ApiKeyList = new Command("system apikey list", "system apikey list", "Show current API keys");
public static Command ApiKeyRename = new Command("system apikey <key> rename", "system apikey <key> rename <name>", "Rename an existing API key");
public static Command ApiKeyDelete = new Command("system apikey <key> delete", "system apikey <key> delete", "Delete an existing API key");
public static Command ApiKeyDeleteAll = new Command("system apikey deleteall", "system apikey deleteall", "Delete all existing API keys");
public static Command AutoproxySet = new Command("autoproxy", "autoproxy [off|front|latch|member]", "Sets your system's autoproxy mode for the current server");
public static Command AutoproxyOff = new Command("autoproxy off", "autoproxy off", "Disables autoproxying for your system in the current server");
public static Command AutoproxyFront = new Command("autoproxy front", "autoproxy front", "Sets your system's autoproxy in this server to proxy the first member currently registered as front");

View file

@ -218,6 +218,8 @@ public partial class CommandTree
// todo: these aren't deprecated but also shouldn't be here
else if (ctx.Match("webhook", "hook"))
await ctx.Execute<Api>(null, m => m.SystemWebhook(ctx));
else if (ctx.Match("apikey", "apikeys", "apitoken", "apitokens"))
await HandleSystemApiKeyCommand(ctx);
else if (ctx.Match("proxy"))
await ctx.Execute<SystemEdit>(SystemProxy, m => m.SystemProxy(ctx));
@ -322,6 +324,42 @@ public partial class CommandTree
await ctx.CheckSystem(target).Execute<Random>(MemberRandom, m => m.Member(ctx, target));
}
private async Task HandleSystemApiKeyCommand(Context ctx)
{
ctx.CheckSystem();
if (ctx.Match("new", "n", "add", "create", "register"))
await ctx.Execute<Api>(ApiKeyCreate, c => c.ApiKeyCreate(ctx));
else if (ctx.Match("list", "ls", "l"))
await ctx.Execute<Api>(ApiKeyList, c => c.ApiKeyList(ctx));
else if (ctx.Match("deleteall", "removeall", "destroyall", "eraseall", "revokeall", "yeetall"))
await ctx.Execute<Api>(ApiKeyDeleteAll, c => c.ApiKeyDeleteAll(ctx));
else if (!ctx.HasNext())
await PrintCommandExpectedError(ctx, ApiKeyCreate, ApiKeyList, ApiKeyRename, ApiKeyDelete, ApiKeyDeleteAll);
else
{
PKApiKey? key = null!;
var input = ctx.PeekArgument();
if (Guid.TryParse(input, out var keyId))
key = await ctx.Repository.GetApiKey(keyId);
else if (await ctx.Repository.GetApiKeyByName(ctx.System.Id, input) is PKApiKey keyByName)
key = keyByName;
if (key == null || key.System != ctx.System.Id)
{
await ctx.Reply($"{Emojis.Error} API key with name \"{ctx.PopArgument()}\" not found.");
return;
}
ctx.PopArgument();
if (ctx.Match("rename", "name", "changename", "setname", "rn"))
await ctx.Execute<Api>(ApiKeyRename, c => c.ApiKeyRename(ctx, key));
else if (ctx.Match("delete", "remove", "destroy", "erase", "yeet"))
await ctx.Execute<Api>(ApiKeyDelete, c => c.ApiKeyDelete(ctx, key));
else
await PrintCommandNotFoundError(ctx, ApiKeyRename, ApiKeyDelete);
}
}
private async Task HandleMemberCommand(Context ctx)
{
if (ctx.Match("new", "n", "add", "create", "register"))

View file

@ -1,28 +1,41 @@
using System.Text;
using System.Text.RegularExpressions;
using Myriad.Builders;
using Myriad.Extensions;
using Myriad.Rest.Exceptions;
using Myriad.Rest.Types;
using Myriad.Rest.Types.Requests;
using Myriad.Types;
using NodaTime;
using SqlKata;
using PluralKit.Core;
namespace PluralKit.Bot;
public class Api
{
private record PaginatedApiKey(Guid Id, string Name, string[] Scopes, string? AppName, Instant Created);
private static readonly Regex _webhookRegex =
new("https://(?:\\w+.)?discord(?:app)?.com/api(?:/v.*)?/webhooks/(.*)");
private readonly BotConfig _botConfig;
private readonly DispatchService _dispatch;
private readonly InteractionDispatchService _interactions;
private readonly PrivateChannelService _dmCache;
private readonly ApiKeyService _apiKey;
public Api(BotConfig botConfig, DispatchService dispatch, PrivateChannelService dmCache)
public Api(BotConfig botConfig, DispatchService dispatch, InteractionDispatchService interactions, PrivateChannelService dmCache, ApiKeyService apiKey)
{
_botConfig = botConfig;
_dispatch = dispatch;
_interactions = interactions;
_dmCache = dmCache;
_apiKey = apiKey;
}
public async Task GetToken(Context ctx)
@ -172,4 +185,167 @@ public class Api
await ctx.Reply($"{Emojis.Success} Successfully the new webhook URL for your system.");
}
public async Task ApiKeyCreate(Context ctx)
{
if (!ctx.HasNext())
throw new PKSyntaxError($"An API key name must be provided.");
var rawScopes = ctx.MatchFlag("scopes", "scope");
var keyName = ctx.PopArgument();
List<string> keyScopes = new();
if (!ctx.HasNext())
throw new PKSyntaxError($"A list of API key scopes must be provided.");
var scopestr = ctx.RemainderOrNull()!.NormalizeLineEndSpacing().Trim();
if (rawScopes)
keyScopes = scopestr.Split(" ").Distinct().ToList();
else
keyScopes.Add(scopestr switch
{
"full" => "write:all",
"read private" => "read:all",
"read public" => "readpublic:all",
"identify" => "identify",
_ => throw new PKError(
$"Couldn't find a scope preset named {scopestr}."),
});
string? check = null!;
try
{
check = await _apiKey.CreateUserApiKey(ctx.System.Id, keyName, keyScopes.ToArray(), check: true);
if (check != null)
throw new PKError("API key validation failed: unknown error");
}
catch (Exception ex)
{
if (ex.Message.StartsWith("API key"))
throw new PKError(ex.Message);
throw;
}
async Task cb(InteractionContext ictx)
{
if (ictx.User.Id != ctx.Author.Id)
{
await ictx.Ignore();
return;
}
var newKey = await _apiKey.CreateUserApiKey(ctx.System.Id, keyName, keyScopes.ToArray());
await ictx.Reply($"Your new API key is below. You will only be shown this once, so please save it!\n\n||`{newKey}`||");
await ctx.Rest.EditMessage(ictx.ChannelId, ictx.MessageId!.Value, new MessageEditRequest
{
Components = new MessageComponent[] { },
});
}
var content =
$"Ready to create a new API key named **{keyName}**, "
+ $"with these scopes: {(String.Join(", ", keyScopes.Select(x => x.AsCode())))}\n"
+ "To create this API key, press the button below.";
await ctx.Rest.CreateMessage(ctx.Channel.Id, new MessageRequest
{
Content = content,
AllowedMentions = new() { Parse = new AllowedMentions.ParseType[] { }, RepliedUser = false },
Components = new[] {
new MessageComponent
{
Type = ComponentType.ActionRow,
Components = new[]
{
new MessageComponent
{
Type = ComponentType.Button,
Style = ButtonStyle.Primary,
Label = "Create API key",
CustomId = _interactions.Register(cb),
},
}
}
},
});
}
public async Task ApiKeyList(Context ctx)
{
var keys = await ctx.Repository.GetSystemApiKeys(ctx.System.Id)
.Select(k => new PaginatedApiKey(k.Id, k.Name, k.Scopes, null, k.Created))
.ToListAsync();
await ctx.Paginate<PaginatedApiKey>(
keys.ToAsyncEnumerable(),
keys.Count,
10,
"Current API keys for your system",
ctx.System.Color,
(eb, l) =>
{
var description = new StringBuilder();
foreach (var item in l)
{
description.Append($"**{item.Name}** (`{item.Id}`)");
description.AppendLine();
description.Append("- Scopes: ");
description.Append(String.Join(", ", item.Scopes.Select(sc => $"`{sc}`")));
description.AppendLine();
description.Append("- Created: ");
description.Append(item.Created.FormatZoned(ctx.Zone));
description.AppendLine();
description.AppendLine();
}
eb.Description(description.ToString());
return Task.CompletedTask;
}
);
}
public async Task ApiKeyRename(Context ctx, PKApiKey key)
{
if (!ctx.HasNext())
throw new PKError("You must provide a new name for this API key.");
var name = ctx.RemainderOrNull(false).NormalizeLineEndSpacing();
await ctx.Repository.UpdateApiKey(key.Id, new ApiKeyPatch { Name = name });
await ctx.Reply($"{Emojis.Success} API key renamed.");
}
public async Task ApiKeyDelete(Context ctx, PKApiKey key)
{
if (!await ctx.PromptYesNo($"Really delete API key **{key.Name}** `{key.Id}`?", "Delete", matchFlag: false))
{
await ctx.Reply($"{Emojis.Error} Deletion cancelled.");
return;
}
await ctx.Repository.DeleteApiKey(key.Id);
await ctx.Reply($"{Emojis.Success} Successfully deleted API key.");
}
public async Task ApiKeyDeleteAll(Context ctx)
{
if (!await ctx.PromptYesNo($"Really delete *all manually-created* API keys for your system?", "Delete", matchFlag: false))
{
await ctx.Reply($"{Emojis.Error} Deletion cancelled.");
return;
}
await ctx.BusyIndicator(async () =>
{
var query = new Query("api_keys")
.AsDelete()
.WhereRaw("[kind]::text not in ( 'dashboard', 'external_app' )")
.Where("system", ctx.System.Id);
await ctx.Database.ExecuteQuery(query);
});
await ctx.Reply($"{Emojis.Success} Successfully deleted all manually-created API keys.");
}
}

View file

@ -162,6 +162,7 @@ public class BotModule: Module
builder.RegisterType<AvatarHostingService>().AsSelf().SingleInstance();
builder.RegisterType<HttpListenerService>().AsSelf().SingleInstance();
builder.RegisterType<RuntimeConfigService>().AsSelf().SingleInstance();
builder.RegisterType<ApiKeyService>().AsSelf().SingleInstance();
// Sentry stuff
builder.Register(_ => new Scope(null)).AsSelf().InstancePerLifetimeScope();

View file

@ -16,6 +16,8 @@ public class CoreConfig
public string? SeqLogUrl { get; set; }
public string? DispatchProxyUrl { get; set; }
public string? DispatchProxyToken { get; set; }
public string? InternalApiBaseUrl { get; set; }
public string? InternalApiToken { get; set; }
public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Debug;
public LogEventLevel ElasticLogLevel { get; set; } = LogEventLevel.Information;

View file

@ -0,0 +1,55 @@
using Dapper;
using SqlKata;
namespace PluralKit.Core;
public partial class ModelRepository
{
public async Task<PKApiKey?> GetApiKey(Guid id)
{
var query = new Query("api_keys")
.Select("id", "system", "scopes", "app", "name", "created")
.SelectRaw("[kind]::text")
.Where("id", id);
return await _db.QueryFirst<PKApiKey?>(query);
}
public async Task<PKApiKey?> GetApiKeyByName(SystemId system, string name)
{
var query = new Query("api_keys")
.Select("id", "system", "scopes", "app", "name", "created")
.SelectRaw("[kind]::text")
.Where("system", system)
.WhereRaw("lower(name) = lower(?)", name.ToLower());
return await _db.QueryFirst<PKApiKey?>(query);
}
public IAsyncEnumerable<PKApiKey> GetSystemApiKeys(SystemId system)
{
var query = new Query("api_keys")
.Select("id", "system", "scopes", "app", "name", "created")
.SelectRaw("[kind]::text")
.Where("system", system)
.WhereRaw("[kind]::text not in ( 'dashboard' )")
.OrderByDesc("created");
return _db.QueryStream<PKApiKey>(query);
}
public async Task UpdateApiKey(Guid id, ApiKeyPatch patch)
{
_logger.Information("Updated API key {keyId}: {@ApiKeyPatch}", id, patch);
var query = patch.Apply(new Query("api_keys").Where("id", id));
await _db.ExecuteQuery(query, "returning *");
}
public async Task DeleteApiKey(Guid id)
{
var query = new Query("api_keys").AsDelete().Where("id", id);
await _db.ExecuteQuery(query);
_logger.Information("Deleted ApiKey {keyId}", id);
}
}

View file

@ -0,0 +1,15 @@
using NodaTime;
namespace PluralKit.Core;
public class PKApiKey
{
public Guid Id { get; private set; }
public SystemId System { get; private set; }
public string Kind { get; private set; }
public string[] Scopes { get; private set; }
public Guid? App { get; private set; }
public string Name { get; private set; }
public Instant Created { get; private set; }
}

View file

@ -0,0 +1,24 @@
using Newtonsoft.Json.Linq;
using SqlKata;
namespace PluralKit.Core;
public class ApiKeyPatch: PatchObject
{
public Partial<string> Name { get; set; }
public override Query Apply(Query q) => q.ApplyPatch(wrapper => wrapper
.With("name", Name)
);
public JObject ToJson()
{
var o = new JObject();
if (Name.IsPresent)
o.Add("name", Name.Value);
return o;
}
}

View file

@ -0,0 +1,72 @@
using Autofac;
using System.Text;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using NodaTime;
using Serilog;
namespace PluralKit.Core;
public class ApiKeyService
{
private readonly HttpClient _client;
private readonly ILogger _logger;
private readonly CoreConfig _cfg;
private readonly ILifetimeScope _provider;
public ApiKeyService(ILogger logger, ILifetimeScope provider, CoreConfig cfg)
{
_logger = logger;
_cfg = cfg;
_provider = provider;
_client = new HttpClient();
_client.DefaultRequestHeaders.Add("User-Agent", "PluralKitInternal");
}
public async Task<string?> CreateUserApiKey(SystemId systemId, string keyName, string[] keyScopes, bool check = false)
{
if (_cfg.InternalApiBaseUrl == null || _cfg.InternalApiToken == null)
throw new Exception("internal API config not set!");
if (!Uri.TryCreate(new Uri(_cfg.InternalApiBaseUrl), "/internal/apikey/user", out var uri))
throw new Exception("internal API base invalid!?");
var repo = _provider.Resolve<ModelRepository>();
var system = await repo.GetSystem(systemId);
if (system == null)
return null;
var reqData = new JObject();
reqData.Add("check", check);
reqData.Add("system", system.Id.Value);
reqData.Add("name", keyName);
reqData.Add("scopes", new JArray(keyScopes));
var req = new HttpRequestMessage()
{
RequestUri = uri,
Method = HttpMethod.Post,
Content = new StringContent(JsonConvert.SerializeObject(reqData), Encoding.UTF8, "application/json"),
};
req.Headers.Add("X-Pluralkit-InternalAuth", _cfg.InternalApiToken);
var res = await _client.SendAsync(req);
var data = JsonConvert.DeserializeObject<JObject>(await res.Content.ReadAsStringAsync());
if (data.ContainsKey("error"))
throw new Exception($"API key validation failed: {(data.Value<string>("error"))}");
if (data.Value<bool>("valid") != true)
throw new Exception("API key validation failed: unknown error");
if (!data.ContainsKey("token"))
return null;
return data.Value<string>("token");
}
}

View file

@ -10,8 +10,12 @@ libpk = { path = "../libpk" }
anyhow = { workspace = true }
axum = { workspace = true }
base64 = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
jsonwebtoken = { workspace = true }
lazy_static = { workspace = true }
uuid = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }

View file

@ -1,20 +1,45 @@
use uuid::Uuid;
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
pub const INTERNAL_TOKENID_HEADER: &'static str = "x-pluralkit-tid";
pub const INTERNAL_PRIVACYLEVEL_HEADER: &'static str = "x-pluralkit-privacylevel";
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum AccessLevel {
None = 0,
PublicRead,
PrivateRead,
Full,
}
impl AccessLevel {
pub fn privacy_level(&self) -> PrivacyLevel {
match self {
Self::None | Self::PublicRead => PrivacyLevel::Public,
Self::PrivateRead | Self::Full => PrivacyLevel::Private,
}
}
}
#[derive(Clone)]
pub struct AuthState {
system_id: Option<i32>,
app_id: Option<i32>,
app_id: Option<Uuid>,
api_key_id: Option<Uuid>,
access_level: AccessLevel,
internal: bool,
}
impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<i32>, internal: bool) -> Self {
pub fn new(system_id: Option<i32>, app_id: Option<Uuid>, api_key_id: Option<Uuid>, access_level: AccessLevel, internal: bool) -> Self {
Self {
system_id,
app_id,
api_key_id,
access_level,
internal,
}
}
@ -23,10 +48,18 @@ impl AuthState {
self.system_id
}
pub fn app_id(&self) -> Option<i32> {
pub fn app_id(&self) -> Option<Uuid> {
self.app_id
}
pub fn api_key_id(&self) -> Option<Uuid> {
self.api_key_id
}
pub fn access_level(&self) -> AccessLevel {
self.access_level.clone()
}
pub fn internal(&self) -> bool {
self.internal
}
@ -37,7 +70,7 @@ impl AuthState {
.map(|id| id == a.authable_system_id())
.unwrap_or(false)
{
PrivacyLevel::Private
self.access_level.privacy_level()
} else {
PrivacyLevel::Public
}

View file

@ -0,0 +1,114 @@
use crate::{util::json_err, AuthState, ApiContext};
use pluralkit_models::{ApiKeyType, PKApiKey, PKSystem, SystemId};
use pk_macros::api_internal_endpoint;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Response},
Extension,
};
use sqlx::Postgres;
#[derive(serde::Deserialize)]
pub struct NewApiKeyRequestData {
#[serde(default)]
check: bool,
system: SystemId,
name: Option<String>,
scopes: Vec<String>,
}
#[api_internal_endpoint]
pub async fn create_api_key_user(
State(ctx): State<ApiContext>,
Extension(auth): Extension<AuthState>,
Json(req): Json<NewApiKeyRequestData>,
) -> Response {
let system: Option<PKSystem> = sqlx::query_as("select * from systems where id = $1")
.bind(req.system)
.fetch_optional(&ctx.db)
.await
.expect("failed to query system");
if system.is_none() {
return Ok(json_err(
StatusCode::BAD_REQUEST,
r#"{"message": "no system found!?", "internal": true}"#.to_string(),
));
}
let system = system.unwrap();
// sanity check requested scopes
if req.scopes.len() < 1 {
return Ok(json_err(
StatusCode::BAD_REQUEST,
r#"{"message": "no scopes provided", "internal": true}"#.to_string(),
));
}
for scope in req.scopes.iter() {
let parts = scope.split(":").collect::<Vec<&str>>();
let ok = match &parts[..] {
["identify"] => true,
["publicread", n] | ["read", n] | ["write", n] => match *n {
"all" => true,
"system" => true,
"members" => true,
"groups" => true,
"fronters" => true,
"switches" => true,
_ => false,
},
_ => false,
};
if !ok {
return Err(crate::error::GENERIC_BAD_REQUEST);
}
}
if req.check {
return Ok((
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"valid": true,
}))
.expect("should not error"),
).into_response());
}
let token: PKApiKey = sqlx::query_as(
r#"
insert into api_keys
(
system,
kind,
scopes,
name
)
values
($1, $2::api_key_type, $3::text[], $4)
returning *
"#,
)
.bind(system.id)
.bind(ApiKeyType::UserCreated)
.bind(req.scopes)
.bind(req.name)
.fetch_one(&ctx.db)
.await
.expect("failed to create token");
let token = token.to_header_str(system.clone().uuid, &ctx.token_privatekey);
Ok((
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"valid": true,
"token": token,
}))
.expect("should not error"),
).into_response())
}

View file

@ -1,2 +1,3 @@
pub mod internal;
pub mod private;
pub mod system;

View file

@ -1,18 +1,20 @@
use crate::ApiContext;
use axum::{extract::State, response::Json};
use crate::{util::json_err, ApiContext};
use libpk::config;
use pluralkit_models::{PrivacyLevel, PKApiKey, PKSystem, PKSystemConfig};
use axum::{
extract::{self, State},
response::{IntoResponse, Json, Response},
};
use fred::interfaces::*;
use hyper::StatusCode;
use libpk::state::ShardState;
use pk_macros::api_endpoint;
use reqwest::ClientBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {
pub guild_count: i32,
pub channel_count: i32,
}
use std::time::Duration;
#[api_endpoint]
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
@ -43,18 +45,6 @@ pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
Ok(Json(stats))
}
use std::time::Duration;
use crate::util::json_err;
use axum::{
extract,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use libpk::config;
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use reqwest::ClientBuilder;
#[derive(serde::Deserialize, Debug)]
pub struct CallbackRequestData {
redirect_domain: String,
@ -71,6 +61,7 @@ struct CallbackDiscordData {
code: String,
}
#[api_endpoint]
pub async fn discord_callback(
State(ctx): State<ApiContext>,
extract::Json(request_data): extract::Json<CallbackRequestData>,
@ -107,7 +98,7 @@ pub async fn discord_callback(
};
if !discord_data.contains_key("access_token") {
return json_err(
return Ok(json_err(
StatusCode::BAD_REQUEST,
format!(
"{{\"error\":\"{}\"\"}}",
@ -116,7 +107,7 @@ pub async fn discord_callback(
.expect("missing error_description from discord")
.to_string()
),
);
));
};
let token = format!(
@ -152,10 +143,10 @@ pub async fn discord_callback(
.expect("failed to query");
let Some(system) = system else {
return json_err(
return Ok(json_err(
StatusCode::BAD_REQUEST,
"user does not have a system registered".to_string(),
);
r#"{"message": "user does not have a system registered", "code": 0}"#.to_string(),
));
};
let system_config: Option<PKSystemConfig> = sqlx::query_as(
@ -170,11 +161,38 @@ pub async fn discord_callback(
let system_config = system_config.unwrap();
// create dashboard token for system
let token: PKApiKey = sqlx::query_as(
r#"
insert into api_keys
(
system,
kind,
discord_id,
discord_access_token,
discord_refresh_token,
discord_expires_at
)
values
($1, $2::api_key_type, $3, $4, $5, $6)
returning *
"#,
)
.bind(system.id)
.bind("dashboard")
.bind(user.id.get() as i64)
.bind(discord_data.get("access_token").unwrap().as_str())
.bind(discord_data.get("refresh_token").unwrap().as_str())
.bind(
chrono::Utc::now()
+ chrono::Duration::seconds(discord_data.get("expires_in").unwrap().as_i64().unwrap()),
)
.fetch_one(&ctx.db)
.await
.expect("failed to create token");
let token = system.clone().token;
let token = token.to_header_str(system.clone().uuid, &ctx.token_privatekey);
(
Ok((
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"system": system.to_json(PrivacyLevel::Private),
@ -183,6 +201,5 @@ pub async fn discord_callback(
"token": token,
}))
.expect("should not error"),
)
.into_response()
).into_response())
}

View file

@ -83,4 +83,6 @@ macro_rules! define_error {
}
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
// define_error! { GENERIC_UNAUTHORIZED, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" }
define_error! { FORBIDDEN_INTERNAL_ROUTE, StatusCode::FORBIDDEN, 0, "403: Forbidden to access this endpoint" }
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }

View file

@ -1,10 +1,10 @@
#![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER, INTERNAL_TOKENID_HEADER, INTERNAL_PRIVACYLEVEL_HEADER};
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::Uri,
http::{HeaderValue, Uri},
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
Extension, Router,
@ -13,8 +13,8 @@ use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use tracing::info;
use jsonwebtoken::{DecodingKey, EncodingKey};
use tracing::{error, info};
use pk_macros::api_endpoint;
mod auth;
@ -30,6 +30,9 @@ pub struct ApiContext {
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
token_privatekey: EncodingKey,
token_publickey: DecodingKey,
}
#[api_endpoint]
@ -53,14 +56,21 @@ async fn rproxy(
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
headers.remove(INTERNAL_TOKENID_HEADER);
headers.remove(INTERNAL_PRIVACYLEVEL_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
headers.append(INTERNAL_PRIVACYLEVEL_HEADER, HeaderValue::from_str(&auth.access_level().privacy_level().to_string())?);
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
headers.append(INTERNAL_APPID_HEADER, HeaderValue::from_str(&format!("{}", aid))?);
}
if let Some(tid) = auth.api_key_id() {
headers.append(INTERNAL_TOKENID_HEADER, HeaderValue::from_str(&format!("{}", tid))?);
}
Ok(ctx.rproxy_client.request(req).await?.into_response())
}
@ -124,11 +134,13 @@ fn router(ctx: ApiContext) -> Router {
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.route("/internal/apikey/user", post(endpoints::internal::create_api_key_user))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(ctx.clone(), middleware::ratelimit::do_request_ratelimited))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::logger::logger))
@ -149,14 +161,9 @@ async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let rproxy_uri = Uri::from_static(
&libpk::config
.api
.as_ref()
.expect("missing api config")
.remote_url,
)
.to_string();
let cfg = libpk::config.api.as_ref().expect("missing api config");
let rproxy_uri = Uri::from_static(cfg.remote_url.as_str()).to_string();
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
@ -166,16 +173,16 @@ async fn main() -> anyhow::Result<()> {
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
token_privatekey: EncodingKey::from_ec_pem(cfg.token_privatekey.as_bytes())
.expect("failed to load private key"),
token_publickey: DecodingKey::from_ec_pem(cfg.token_publickey.as_bytes())
.expect("failed to load public key"),
};
let app = router(ctx);
let addr: &str = libpk::config
.api
.as_ref()
.expect("missing api config")
.addr
.as_ref();
let addr: &str = cfg.addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);

View file

@ -1,20 +1,129 @@
use axum::{
extract::{Request, State},
extract::{Request, State, MatchedPath},
http::StatusCode,
middleware::Next,
response::Response,
};
use uuid::Uuid;
use subtle::ConstantTimeEq;
use tracing::error;
use sqlx::Postgres;
use crate::auth::AuthState;
use pluralkit_models::{ApiKeyType, PKApiKey};
use crate::auth::{AccessLevel, AuthState};
use crate::{util::json_err, ApiContext};
pub fn is_part_path<'a, 'b>(part: &'a str, endpoint: &'b str) -> bool {
if !endpoint.starts_with("/v2/") {
return false;
}
let path_frags = endpoint[4..].split("/").collect::<Vec<&str>>();
match part {
"system" => match &path_frags[..] {
["systems", _] => true,
["systems", _, "settings"] => true,
["systems", _, "autoproxy"] => true,
["systems", _, "guilds", ..] => true,
_ => false,
},
"members" => match &path_frags[..] {
["systems", _, "members"] => true,
["members"] => true,
["members", _, "groups"] => false,
["members", _, "groups", ..] => false,
["members", ..] => true,
_ => false,
},
"groups" => match &path_frags[..] {
["systems", _, "groups"] => true,
["groups"] => true,
["groups", ..] => true,
["members", _, "groups"] => true,
["members", _, "groups", ..] => true,
_ => false,
},
"fronters" => match &path_frags[..] {
["systems", _, "fronters"] => true,
_ => false,
},
"switches" => match &path_frags[..] {
// switches implies fronters
["systems", _, "fronters"] => true,
["systems", _, "switches"] => true,
["systems", _, "switches", ..] => true,
_ => false,
},
_ => false,
}
}
pub fn apikey_can_access(token: &PKApiKey, method: String, endpoint: String) -> AccessLevel {
if token.kind == ApiKeyType::Dashboard {
return AccessLevel::Full;
}
let mut access = AccessLevel::None;
for rscope in token.scopes.iter() {
let scope = rscope.split(":").collect::<Vec<&str>>();
let na = match (method.as_str(), &scope[..]) {
("GET", ["identify"]) => {
if &endpoint == "/v2/systems/:system_id" {
AccessLevel::PublicRead
} else {
AccessLevel::None
}
}
("GET", ["publicread", part]) => {
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
AccessLevel::PublicRead
} else {
AccessLevel::None
}
}
("GET", ["read", part]) => {
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
AccessLevel::PrivateRead
} else {
AccessLevel::None
}
}
(_, ["write", part]) => {
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
AccessLevel::Full
} else {
AccessLevel::None
}
}
_ => AccessLevel::None,
};
if na > access {
access = na;
}
}
access
}
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let endpoint = req
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<i32> = None;
let mut authed_app_id: Option<Uuid> = None;
let mut authed_api_key_id: Option<Uuid> = None;
let mut access_level = AccessLevel::None;
// fetch user authorization
if let Some(system_auth_header) = req
@ -22,7 +131,24 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
.get("authorization")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(system_id) =
{
if system_auth_header.starts_with("Bearer ")
&& let Some(tid) =
PKApiKey::parse_header_str(system_auth_header[7..].to_string(), &ctx.token_publickey)
&& let Some(token) =
sqlx::query_as::<Postgres, PKApiKey>("select * from api_keys where id = $1")
.bind(&tid)
.fetch_optional(&ctx.db)
.await
.expect("failed to query apitoken in postgres")
{
authed_api_key_id = Some(tid);
access_level = apikey_can_access(&token, req.method().to_string(), endpoint.clone());
if access_level != AccessLevel::None {
authed_system_id = Some(token.system);
}
}
else if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
Ok(val) => val,
Err(err) => {
@ -33,29 +159,31 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
);
}
}
{
authed_system_id = Some(system_id);
}
{
authed_system_id = Some(system_id);
access_level = AccessLevel::Full;
}
}
// fetch app authorization
// todo: actually fetch it from db
if let Some(app_auth_header) = req
.headers()
.get("x-pluralkit-app")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(config_token2) = libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
.as_ref()
&& app_auth_header
.as_bytes()
.ct_eq(config_token2.as_bytes())
.into()
&& let Some(app_id) =
match libpk::db::repository::app_token_auth(&ctx.db, app_auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
}
{
authed_app_id = Some(1);
authed_app_id = Some(app_id);
}
// todo: fix syntax
@ -74,7 +202,7 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
};
req.extensions_mut()
.insert(AuthState::new(authed_system_id, authed_app_id, internal));
.insert(AuthState::new(authed_system_id, authed_app_id, authed_api_key_id, access_level, internal));
next.run(req).await
}

View file

@ -11,7 +11,7 @@ fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-PluralKit-Authentication, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}

View file

@ -42,6 +42,7 @@ pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
else if !request.uri().clone().path().starts_with("/v2")
&& !request.uri().clone().path().starts_with("/private")
&& !request.uri().clone().path().starts_with("/internal")
{
return (
StatusCode::BAD_REQUEST,

View file

@ -8,12 +8,15 @@ use axum::{
};
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
use metrics::counter;
use sqlx::Postgres;
use tracing::{debug, error, info, warn};
use crate::{
ApiContext,
auth::AuthState,
util::{header_or_unknown, json_err},
};
use pluralkit_models::PKExternalApp;
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
@ -22,7 +25,10 @@ lazy_static::lazy_static! {
}
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
pub fn ratelimiter<F, T>(
ctx: ApiContext,
f: F,
) -> FromFnLayer<F, (ApiContext, Option<RedisPool>), T> {
let redis = libpk::config
.api
.as_ref()
@ -52,14 +58,14 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
axum::middleware::from_fn_with_state((ctx, redis), f)
}
enum RatelimitType {
GenericGet,
GenericUpdate,
Message,
TempCustom,
AppCustom(i32),
}
impl RatelimitType {
@ -68,7 +74,7 @@ impl RatelimitType {
RatelimitType::GenericGet => "generic_get",
RatelimitType::GenericUpdate => "generic_update",
RatelimitType::Message => "message",
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
RatelimitType::AppCustom(_) => "app_custom",
}
.to_string()
}
@ -78,21 +84,41 @@ impl RatelimitType {
RatelimitType::GenericGet => 10,
RatelimitType::GenericUpdate => 3,
RatelimitType::Message => 10,
RatelimitType::TempCustom => 20,
RatelimitType::AppCustom(n) => *n,
}
}
}
pub async fn do_request_ratelimited(
State(redis): State<Option<RedisPool>>,
State((ctx, redis)): State<(ApiContext, Option<RedisPool>)>,
request: Request,
next: Next,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
if headers.get("x-pluralkit-internal").is_some() {
// bypass ratelimiting entirely for internal requests
return next.run(request).await;
}
let extensions = request.extensions().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let mut app_rate: Option<i32> = None;
if let Some(app_header) = request.headers().clone().get("x-pluralkit-app") {
let app_token = app_header.to_str().unwrap_or("invalid");
if app_token.starts_with("pkap2:")
&& let Some(app) = sqlx::query_as::<Postgres, PKExternalApp>(
"select * from external_apps where api_rl_token = $1",
)
.bind(&app_token[6..])
.fetch_optional(&ctx.db)
.await
.expect("failed to query external app in postgres")
{
app_rate = Some(app.api_rl_rate.expect("external app has no api_rl_rate"));
}
};
let endpoint = extensions
.get::<MatchedPath>()
@ -109,11 +135,8 @@ pub async fn do_request_ratelimited(
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if let Some(app_id) = auth.app_id()
&& app_id == 1
{
RatelimitType::TempCustom
let rlimit = if let Some(r) = app_rate {
RatelimitType::AppCustom(r)
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
} else if request.method() == Method::GET {

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
anyhow = { workspace = true }
fred = { workspace = true }
jsonwebtoken = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
pk_macros = { path = "../macros" }

View file

@ -62,6 +62,11 @@ pub struct ApiConfig {
#[serde(default)]
pub temp_token2: Option<String>,
pub token_privatekey: String,
pub token_publickey: String,
pub internal_request_secret: String,
}
#[derive(Deserialize, Clone, Debug)]

View file

@ -1,3 +1,5 @@
use uuid::Uuid;
pub async fn legacy_token_auth(
pool: &sqlx::postgres::PgPool,
token: &str,
@ -18,3 +20,24 @@ pub async fn legacy_token_auth(
struct LegacyTokenDbResponse {
id: i32,
}
pub async fn app_token_auth(
pool: &sqlx::postgres::PgPool,
token: &str,
) -> anyhow::Result<Option<Uuid>> {
let mut app: Vec<AppTokenDbResponse> =
sqlx::query_as("select id from external_apps where api_rl_token = $1")
.bind(token)
.fetch_all(pool)
.await?;
Ok(if let Some(app) = app.pop() {
Some(app.id)
} else {
None
})
}
#[derive(sqlx::FromRow)]
struct AppTokenDbResponse {
id: Uuid,
}

View file

@ -9,6 +9,7 @@ fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
is_internal: bool,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as ItemFn);
@ -34,6 +35,16 @@ pub fn macro_impl(
})
.collect();
let internal_res = if is_internal {
quote! {
if !auth.internal() {
return crate::error::FORBIDDEN_INTERNAL_ROUTE.into_response();
}
}
} else {
quote!()
};
let res = quote! {
#[allow(unused_mut)]
pub async fn #fn_name(#fn_params) -> axum::response::Response {
@ -41,6 +52,7 @@ pub fn macro_impl(
#fn_body
}
#internal_res
match inner(#(#pms),*).await {
Ok(res) => res.into_response(),
Err(err) => err.into_response(),

View file

@ -6,7 +6,12 @@ mod model;
#[proc_macro_attribute]
pub fn api_endpoint(args: TokenStream, input: TokenStream) -> TokenStream {
api::macro_impl(args, input)
api::macro_impl(args, input, false)
}
#[proc_macro_attribute]
pub fn api_internal_endpoint(args: TokenStream, input: TokenStream) -> TokenStream {
api::macro_impl(args, input, true)
}
#[proc_macro_attribute]

View file

@ -18,4 +18,4 @@ create table command_messages (
create index command_messages_by_original on command_messages(original_mid);
create index command_messages_by_sender on command_messages(sender);
update info set schema_version = 52;
update info set schema_version = 52;

View file

@ -0,0 +1,38 @@
-- database version 53
--
-- scoped API keys + skeleton for oauth2 for third-party apps
create table external_apps (
id uuid primary key default gen_random_uuid(),
name text not null,
homepage_url text not null,
oauth2_secret text,
oauth2_allowed_redirects text[] not null default array[]::text[],
oauth2_scopes text[] not null default array[]::text[],
api_rl_token text,
api_rl_rate int
);
create type api_key_type as enum (
'dashboard',
'user_created',
'external_app'
);
create table api_keys (
id uuid primary key default gen_random_uuid(),
system int references systems(id) on delete cascade,
kind api_key_type not null,
scopes text[] not null default array[]::text[],
app uuid references external_apps(id) on delete cascade,
name text,
discord_id bigint,
discord_access_token text,
discord_refresh_token text,
discord_expires_at timestamp,
created timestamp with time zone not null default (current_timestamp at time zone 'utc')
);
update info set schema_version = 53;

View file

@ -4,8 +4,10 @@ version = "0.1.0"
edition = "2021"
[dependencies]
base64 = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
pk_macros = { path = "../macros" }
jsonwebtoken = { workspace = true }
sea-query = "0.32.1"
serde = { workspace = true }
serde_json = { workspace = true, features = ["preserve_order"] }

View file

@ -6,9 +6,9 @@
// note: caller needs to implement From<i32> for their type
macro_rules! fake_enum_impls {
($n:ident) => {
impl Type<Postgres> for $n {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("INT4")
impl ::sqlx::Type<::sqlx::Postgres> for $n {
fn type_info() -> ::sqlx::postgres::PgTypeInfo {
::sqlx::postgres::PgTypeInfo::with_name("INT4")
}
}
@ -18,14 +18,14 @@ macro_rules! fake_enum_impls {
}
}
impl<'r, DB: Database> Decode<'r, DB> for $n
impl<'r, DB: ::sqlx::Database> ::sqlx::Decode<'r, DB> for $n
where
i32: Decode<'r, DB>,
i32: ::sqlx::Decode<'r, DB>,
{
fn decode(
value: <DB as Database>::ValueRef<'r>,
) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
let value = <i32 as Decode<DB>>::decode(value)?;
value: <DB as ::sqlx::Database>::ValueRef<'r>,
) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
let value = <i32 as ::sqlx::Decode<DB>>::decode(value)?;
Ok(Self::from(value))
}
}

View file

@ -0,0 +1,104 @@
use pk_macros::pk_model;
use chrono::{DateTime, Utc, NaiveDateTime};
use uuid::Uuid;
use base64::{prelude::BASE64_STANDARD, Engine};
use jsonwebtoken::{
crypto::{sign, verify},
DecodingKey, EncodingKey,
};
use crate::SystemId;
#[derive(sqlx::Type, Debug, Clone, PartialEq, serde::Serialize)]
#[serde(rename_all = "snake_case")]
#[sqlx(rename_all = "snake_case")]
#[sqlx(type_name = "api_key_type")]
pub enum ApiKeyType {
Dashboard,
UserCreated,
ExternalApp,
}
#[pk_model]
struct ApiKey {
#[json = "id"]
id: Uuid,
system: SystemId,
#[json = "type"]
kind: ApiKeyType,
#[json = "scopes"]
scopes: Vec<String>,
#[json = "app"]
app: Option<Uuid>,
#[json = "name"]
#[patchable]
name: Option<String>,
#[json = "discord_id"]
discord_id: Option<i64>,
#[private_patchable]
discord_access_token: Option<String>,
#[private_patchable]
discord_refresh_token: Option<String>,
#[private_patchable]
discord_expires_at: Option<NaiveDateTime>,
#[json = "created"]
created: DateTime<Utc>,
}
const SIGNATURE_ALGORITHM: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::ES256;
impl PKApiKey {
pub fn to_header_str(self, system_uuid: Uuid, key: &EncodingKey) -> String {
let b64 = BASE64_STANDARD.encode(
serde_json::to_vec(&serde_json::json!({
"tid": self.id.to_string(),
"sid": system_uuid.to_string(),
"type": self.kind,
"scopes": self.scopes,
}))
.expect("should not fail"),
);
let signature = sign(b64.as_bytes(), key, SIGNATURE_ALGORITHM).expect("should not fail");
format!("pkapi:{b64}:{signature}")
}
/// Parse a header string into a token uuid
pub fn parse_header_str(token: String, key: &DecodingKey) -> Option<Uuid> {
let mut parts = token.split(":");
let pkapi = parts.next();
if pkapi.is_none_or(|v| v != "pkapi") {
return None;
}
let Some(jsonblob) = parts.next() else {
return None;
};
let Some(sig) = parts.next() else {
return None;
};
// verify signature before doing anything else
let valid = verify(sig, jsonblob.as_bytes(), key, SIGNATURE_ALGORITHM);
if valid.is_err() || matches!(valid, Ok(false)) {
return None;
}
let Ok(bytes) = BASE64_STANDARD.decode(jsonblob) else {
return None;
};
let Ok(obj) = serde_json::from_slice::<serde_json::Value>(bytes.as_slice()) else {
return None;
};
obj.get("tid")
.map(|v| v.as_str().map(|f| Uuid::parse_str(f).ok()))
.flatten()
.flatten()
}
}

View file

@ -1,14 +1,5 @@
mod _util;
macro_rules! model {
($n:ident) => {
mod $n;
pub use $n::*;
};
}
model!(system);
model!(system_config);
use _util::fake_enum_impls;
#[derive(serde::Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
@ -17,10 +8,7 @@ pub enum PrivacyLevel {
Private,
}
// this sucks, put it somewhere else
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
use std::error::Error;
_util::fake_enum_impls!(PrivacyLevel);
fake_enum_impls!(PrivacyLevel);
impl From<i32> for PrivacyLevel {
fn from(value: i32) -> Self {
@ -31,3 +19,24 @@ impl From<i32> for PrivacyLevel {
}
}
}
impl PrivacyLevel {
pub fn to_string(&self) -> String {
match self {
PrivacyLevel::Public => "public".into(),
PrivacyLevel::Private => "private".into(),
}
}
}
macro_rules! model {
($n:ident) => {
mod $n;
pub use $n::*;
};
}
model!(api_key);
model!(oauth2_app);
model!(system);
model!(system_config);

View file

@ -0,0 +1,28 @@
use pk_macros::pk_model;
use uuid::Uuid;
#[pk_model]
struct ExternalApp {
#[json = "id"]
id: Uuid,
#[json = "name"]
#[patchable]
name: String,
#[json = "homepage_url"]
#[patchable]
homepage_url: String,
#[private_patchable]
oauth2_secret: Option<String>,
#[json = "oauth2_allowed_redirects"]
#[patchable]
oauth2_allowed_redirects: Vec<String>,
#[json = "oauth2_scopes"]
#[patchable]
oauth2_scopes: Vec<String>,
#[private_patchable]
api_rl_token: Option<String>,
#[private_patchable]
api_rl_rate: Option<i32>,
}

View file

@ -1,10 +1,9 @@
use pk_macros::pk_model;
use crate::PrivacyLevel;
use chrono::NaiveDateTime;
use uuid::Uuid;
use crate::PrivacyLevel;
// todo: fix this
pub type SystemId = i32;

View file

@ -77,7 +77,8 @@ module.exports = {
"/api/endpoints",
"/api/models",
"/api/errors",
"/api/dispatch"
"/api/tokens",
"/api/dispatch",
]
},
["https://discord.gg/PczBt78", "Join the support server"],

View file

@ -42,7 +42,7 @@ The PluralKit Discord bot can be configured to display short IDs in uppercase, o
|---|---|---|
|id|string||
|uuid|string||
|?system|string|id of system this member is registered in (only returned in `/members/:id` endpoint)|
|system|string|id of system this member is registered in|
|name|string|100-character limit|
|display_name|?string|100-character limit|
|color|?string|6-character hex code, no `#` at the beginning|
@ -78,7 +78,7 @@ The PluralKit Discord bot can be configured to display short IDs in uppercase, o
|---|---|---|
|id|string||
|uuid|string||
|?system|string|id of system this group is registered in (only returned in `/groups/:id` endpoint)|
|system|string|id of system this group is registered in|
|name|string|100-character limit|
|display_name|?string|100-character limit|
|description|?string|1000-character limit|
@ -96,7 +96,7 @@ The PluralKit Discord bot can be configured to display short IDs in uppercase, o
|---|---|---|
|id|uuid||
|timestamp|datetime||
| members | list of id/Member | Is sometimes in plain ID list form (eg. `GET /systems/:id/switches`), sometimes includes the full Member model (eg. `GET /systems/:id/fronters`). |
|members|list of id/Member|Is sometimes in plain ID list form (eg. `GET /systems/:id/switches`), sometimes includes the full Member model (eg. `GET /systems/:id/fronters`)|
### Message model
@ -121,6 +121,13 @@ The PluralKit Discord bot can be configured to display short IDs in uppercase, o
|member_default_private*|boolean|whether members created through the bot have privacy settings set to private by default|
|group_default_private*|boolean|whether groups created through the bot have privacy settings set to private by default|
|show_private_info|boolean|whether the bot shows the system's own private information without a `-private` flag|
|case_sensitive_proxy_tags|boolean|whether the system's member proxy tags are parsed as case sensitive|
|proxy_error_message_enabled|boolean|whether to show proxying-specific error messages|
|hid_display_split|boolean|if enabled, the system prefers 6-character IDs to be displayed with a hyphen splitting each group of 3 characters|
|hid_display_caps|boolean|if enabled, the system prefers short IDs to be displayed in all-caps|
|hid_list_padding|one of "off", "left", "right"|system preference for padding short IDs in lists|
|proxy_switch|one of "off", "new", "add"||
|name_format|string|formatting template for display names for the system's proxied messages|
|member_limit|int|read-only, defaults to 1000|
|group_limit|int|read-only, defaults to 250|
|case_sensitive_proxy_tags|bool|whether the bot will match proxy tags matching only the case used in the trigger message|

View file

@ -8,11 +8,6 @@ permalink: /api
PluralKit has a basic HTTP REST API for querying and modifying your system.
The root endpoint of the API is `https://api.pluralkit.me/v2/`.
#### Authorization header token example
```
Authorization: z865MC7JNhLtZuSq1NXQYVe+FgZJHBfeBCXOPYYRwH4liDCDrsd7zdOuR45mX257
```
Endpoints will always return all fields, using `null` when a value is missing. On `PATCH` endpoints,
missing fields from the JSON request will be ignored and preserved as is, but on `POST` endpoints will
be set to `null` or cleared.
@ -29,14 +24,12 @@ If you are developing an application exposed to the public, we would appreciate
## Authentication
Authentication is done with a simple "system token". You can get your system token by running `pk;token` using the
Discord bot, either in a channel with the bot or in DMs. Then, pass this token in the `Authorization` HTTP header
on requests that require it. Failure to do so on endpoints that require authentication will return a `401 Unauthorized`.
Authentication is done with an API key provided in the `Authorization` HTTP header - [see the API key section of the documentation for details.](/api/tokens)
Some endpoints show information that a given system may have set to private. If this is a specific field
(eg. description), the field will simply contain `null` rather than the true value. If this applies to entire endpoint
responses (eg. fronter, switches, member list), the entire request will return `403 Forbidden`. Authenticating with the
system's token (as described above) will override these privacy settings and show the full information.
Some endpoints show information that a given system may have set to private. For unauthenticated requests, and for requests authenticated with an API key that does not have permission to read private data, the following rules apply:
- For fields with specific privacy settings (e.g. descriptions), the field will simply contain `null` rather than the true value
- For entire endpoints which show private data (e.g. member/group lists), a `403 Forbidden` response will be returned
## Rate Limiting
@ -46,7 +39,7 @@ To protect against abuse and manage server resources, PluralKit's API limits the
- **10/second** for requests to the [Get Proxied Message Information](/api/endpoints/#get-proxied-message-information) endpoint (`message` scope)
- **3/second** for any `POST`, `PATCH`, or `DELETE` requests (`generic_update` scope)
We may raise the limits for individual users in a case-by-case basis; please ask [in the support server](https://discord.gg/PczBt78) if you need a higher limit.
We may raise the limits for individual API clients on a case-by-case basis; please ask [in the support server](https://discord.gg/PczBt78) if you need a higher limit.
::: tip
If you are looking to query a specific resource in your system repeatedly (polling), please consider using [Dispatch Webhooks](/api/dispatch) instead.

View file

@ -0,0 +1,79 @@
---
title: API keys / tokens
permalink: /api/tokens
---
# API keys / tokens
There are currently two types of API keys / tokens used by PluralKit - "legacy" tokens from the `pk;token` command (64 characters, a system can only have one valid token at a time); and "modern" API keys (variable length, always start with `pkapi:').
## "Legacy" tokens
"Legacy" PluralKit tokens look similar to the following:
```
LvWacQm3Yu+Jbhl8B7LR97Q4kfpAasTiB8/BY5/HJCppHFggzwOai6QBxehAJ53C
```
These tokens are supplied *as-is* in the `Authorization` HTTP header when talking to the PluralKit API (e.g. `Authorization: LvWacQm3Y...`)
Each PluralKit system can only have *one* valid "legacy" token at a time, and that token holds the keys to the entire castle - it grants full read/write privileges.
**PluralKit's API will stop accepting "legacy" tokens for authentication in the near future!** We do not yet have a deprecation plan set in stone, but there will be a significant notice period before this happens.
## "Modern" API keys
A "modern" PluralKit API key is made up of three components, separated by colons:
- The string `"pkapi"`
- A Base64-encoded JSON blob containing information about the API key
- An opaque signature
As an example:
```
pkapi:eyJ0aWQiOiI3NWEzODZlNy1mMjNlLTRmM2EtYjkwNC1jYTgwMzE0OWFmNWEiLCJzaWQiOiIyMmIwYjA3Yi00ZmE3LTRmYTEtYmYyNS1lZWI4NjY1ZjMyYzEiLCJ0eXBlIjoidXNlcl9jcmVhdGVkIiwic2NvcGVzIjpbIndyaXRlOmFsbCJdfQ==:nUjJPPtBOyPb1bYFhm24bU87N2Fb_oSaNnHEZkB-6ZSCSlAJvkyb32MTfmdEv3U6wNBlBQtQb0Fkv2nSvbNsCw
```
These tokens must be supplied with a "Bearer" prefix in the `Authorization` HTTP header when talking to the PluralKit API (e.g. `Authorization: Bearer pkapi:eyJ0aW...`).
The JSON blob in the above example API key contains the following:
```js
{
// API key ID
"tid": "75a386e7-f23e-4f3a-b904-ca803149af5a",
// UUID of the PluralKit system the token belongs to
"sid": "22b0b07b-4fa7-4fa1-bf25-eeb8665f32c1",
// "user_created" for manually generated API keys,
// "external_app" for OAuth2 user API keys (coming soon!)
"type": "user_created",
// One or more scopes (see below)
"scopes": ["write:all"]
}
```
### Scopes
In the below table, `<X>` refers to a *permission level* - one of the following:
- `publicread`: read-only access to *public* information
- `read`: read-only access to all (public *and* private) information
- `write`: read-write access to all information (implies `read`)
|scope|notes|
|---|---|
|`identify`|Read-only access to `/v2/systems/@me` - for proving the user providing the token has control of the PluralKit system|
|`<X>:system`|Access to core system data, system settings (including autoproxy), and server-specific settings|
|`<X>:members`|Access to member information, *not including group membership*|
|`<X>:groups`|Access to group information|
|`<X>:fronters`|Access to current system fronters|
|`<X>:switches`|Access to full system switch history (implies `<X>:fronters`)|
|`<X>:all`|Includes all other scopes|
### Issuing new API keys
TODO