[WIP] feat: scoped api keys

This commit is contained in:
Iris System 2025-08-17 02:47:01 -07:00
parent e7ee593a85
commit 06cb160f95
45 changed files with 1264 additions and 154 deletions

View file

@ -10,8 +10,12 @@ libpk = { path = "../libpk" }
anyhow = { workspace = true }
axum = { workspace = true }
base64 = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
jsonwebtoken = { workspace = true }
lazy_static = { workspace = true }
uuid = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }

View file

@ -1,20 +1,45 @@
use uuid::Uuid;
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
pub const INTERNAL_TOKENID_HEADER: &'static str = "x-pluralkit-tid";
pub const INTERNAL_PRIVACYLEVEL_HEADER: &'static str = "x-pluralkit-privacylevel";
#[derive(Debug, Clone, PartialEq, PartialOrd)]
pub enum AccessLevel {
None = 0,
PublicRead,
PrivateRead,
Full,
}
impl AccessLevel {
pub fn privacy_level(&self) -> PrivacyLevel {
match self {
Self::None | Self::PublicRead => PrivacyLevel::Public,
Self::PrivateRead | Self::Full => PrivacyLevel::Private,
}
}
}
#[derive(Clone)]
pub struct AuthState {
system_id: Option<i32>,
app_id: Option<i32>,
app_id: Option<Uuid>,
api_key_id: Option<Uuid>,
access_level: AccessLevel,
internal: bool,
}
impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<i32>, internal: bool) -> Self {
pub fn new(system_id: Option<i32>, app_id: Option<Uuid>, api_key_id: Option<Uuid>, access_level: AccessLevel, internal: bool) -> Self {
Self {
system_id,
app_id,
api_key_id,
access_level,
internal,
}
}
@ -23,10 +48,18 @@ impl AuthState {
self.system_id
}
pub fn app_id(&self) -> Option<i32> {
pub fn app_id(&self) -> Option<Uuid> {
self.app_id
}
pub fn api_key_id(&self) -> Option<Uuid> {
self.api_key_id
}
pub fn access_level(&self) -> AccessLevel {
self.access_level.clone()
}
pub fn internal(&self) -> bool {
self.internal
}
@ -37,7 +70,7 @@ impl AuthState {
.map(|id| id == a.authable_system_id())
.unwrap_or(false)
{
PrivacyLevel::Private
self.access_level.privacy_level()
} else {
PrivacyLevel::Public
}

View file

@ -0,0 +1,114 @@
use crate::{util::json_err, AuthState, ApiContext};
use pluralkit_models::{ApiKeyType, PKApiKey, PKSystem, SystemId};
use pk_macros::api_internal_endpoint;
use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Response},
Extension,
};
use sqlx::Postgres;
#[derive(serde::Deserialize)]
pub struct NewApiKeyRequestData {
#[serde(default)]
check: bool,
system: SystemId,
name: Option<String>,
scopes: Vec<String>,
}
#[api_internal_endpoint]
pub async fn create_api_key_user(
State(ctx): State<ApiContext>,
Extension(auth): Extension<AuthState>,
Json(req): Json<NewApiKeyRequestData>,
) -> Response {
let system: Option<PKSystem> = sqlx::query_as("select * from systems where id = $1")
.bind(req.system)
.fetch_optional(&ctx.db)
.await
.expect("failed to query system");
if system.is_none() {
return Ok(json_err(
StatusCode::BAD_REQUEST,
r#"{"message": "no system found!?", "internal": true}"#.to_string(),
));
}
let system = system.unwrap();
// sanity check requested scopes
if req.scopes.len() < 1 {
return Ok(json_err(
StatusCode::BAD_REQUEST,
r#"{"message": "no scopes provided", "internal": true}"#.to_string(),
));
}
for scope in req.scopes.iter() {
let parts = scope.split(":").collect::<Vec<&str>>();
let ok = match &parts[..] {
["identify"] => true,
["publicread", n] | ["read", n] | ["write", n] => match *n {
"all" => true,
"system" => true,
"members" => true,
"groups" => true,
"fronters" => true,
"switches" => true,
_ => false,
},
_ => false,
};
if !ok {
return Err(crate::error::GENERIC_BAD_REQUEST);
}
}
if req.check {
return Ok((
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"valid": true,
}))
.expect("should not error"),
).into_response());
}
let token: PKApiKey = sqlx::query_as(
r#"
insert into api_keys
(
system,
kind,
scopes,
name
)
values
($1, $2::api_key_type, $3::text[], $4)
returning *
"#,
)
.bind(system.id)
.bind(ApiKeyType::UserCreated)
.bind(req.scopes)
.bind(req.name)
.fetch_one(&ctx.db)
.await
.expect("failed to create token");
let token = token.to_header_str(system.clone().uuid, &ctx.token_privatekey);
Ok((
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"valid": true,
"token": token,
}))
.expect("should not error"),
).into_response())
}

View file

@ -1,2 +1,3 @@
pub mod internal;
pub mod private;
pub mod system;

View file

@ -1,18 +1,20 @@
use crate::ApiContext;
use axum::{extract::State, response::Json};
use crate::{util::json_err, ApiContext};
use libpk::config;
use pluralkit_models::{PrivacyLevel, PKApiKey, PKSystem, PKSystemConfig};
use axum::{
extract::{self, State},
response::{IntoResponse, Json, Response},
};
use fred::interfaces::*;
use hyper::StatusCode;
use libpk::state::ShardState;
use pk_macros::api_endpoint;
use reqwest::ClientBuilder;
use serde::Deserialize;
use serde_json::{json, Value};
use std::collections::HashMap;
#[derive(Deserialize)]
#[serde(rename_all = "PascalCase")]
struct ClusterStats {
pub guild_count: i32,
pub channel_count: i32,
}
use std::time::Duration;
#[api_endpoint]
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
@ -43,18 +45,6 @@ pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
Ok(Json(stats))
}
use std::time::Duration;
use crate::util::json_err;
use axum::{
extract,
response::{IntoResponse, Response},
};
use hyper::StatusCode;
use libpk::config;
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
use reqwest::ClientBuilder;
#[derive(serde::Deserialize, Debug)]
pub struct CallbackRequestData {
redirect_domain: String,
@ -71,6 +61,7 @@ struct CallbackDiscordData {
code: String,
}
#[api_endpoint]
pub async fn discord_callback(
State(ctx): State<ApiContext>,
extract::Json(request_data): extract::Json<CallbackRequestData>,
@ -107,7 +98,7 @@ pub async fn discord_callback(
};
if !discord_data.contains_key("access_token") {
return json_err(
return Ok(json_err(
StatusCode::BAD_REQUEST,
format!(
"{{\"error\":\"{}\"\"}}",
@ -116,7 +107,7 @@ pub async fn discord_callback(
.expect("missing error_description from discord")
.to_string()
),
);
));
};
let token = format!(
@ -152,10 +143,10 @@ pub async fn discord_callback(
.expect("failed to query");
let Some(system) = system else {
return json_err(
return Ok(json_err(
StatusCode::BAD_REQUEST,
"user does not have a system registered".to_string(),
);
r#"{"message": "user does not have a system registered", "code": 0}"#.to_string(),
));
};
let system_config: Option<PKSystemConfig> = sqlx::query_as(
@ -170,11 +161,38 @@ pub async fn discord_callback(
let system_config = system_config.unwrap();
// create dashboard token for system
let token: PKApiKey = sqlx::query_as(
r#"
insert into api_keys
(
system,
kind,
discord_id,
discord_access_token,
discord_refresh_token,
discord_expires_at
)
values
($1, $2::api_key_type, $3, $4, $5, $6)
returning *
"#,
)
.bind(system.id)
.bind("dashboard")
.bind(user.id.get() as i64)
.bind(discord_data.get("access_token").unwrap().as_str())
.bind(discord_data.get("refresh_token").unwrap().as_str())
.bind(
chrono::Utc::now()
+ chrono::Duration::seconds(discord_data.get("expires_in").unwrap().as_i64().unwrap()),
)
.fetch_one(&ctx.db)
.await
.expect("failed to create token");
let token = system.clone().token;
let token = token.to_header_str(system.clone().uuid, &ctx.token_privatekey);
(
Ok((
StatusCode::OK,
serde_json::to_string(&serde_json::json!({
"system": system.to_json(PrivacyLevel::Private),
@ -183,6 +201,5 @@ pub async fn discord_callback(
"token": token,
}))
.expect("should not error"),
)
.into_response()
).into_response())
}

View file

@ -83,4 +83,6 @@ macro_rules! define_error {
}
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
// define_error! { GENERIC_UNAUTHORIZED, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" }
define_error! { FORBIDDEN_INTERNAL_ROUTE, StatusCode::FORBIDDEN, 0, "403: Forbidden to access this endpoint" }
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }

View file

@ -1,10 +1,10 @@
#![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER, INTERNAL_TOKENID_HEADER, INTERNAL_PRIVACYLEVEL_HEADER};
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},
http::Uri,
http::{HeaderValue, Uri},
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
Extension, Router,
@ -13,8 +13,8 @@ use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor,
};
use tracing::info;
use jsonwebtoken::{DecodingKey, EncodingKey};
use tracing::{error, info};
use pk_macros::api_endpoint;
mod auth;
@ -30,6 +30,9 @@ pub struct ApiContext {
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
token_privatekey: EncodingKey,
token_publickey: DecodingKey,
}
#[api_endpoint]
@ -53,14 +56,21 @@ async fn rproxy(
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
headers.remove(INTERNAL_TOKENID_HEADER);
headers.remove(INTERNAL_PRIVACYLEVEL_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
headers.append(INTERNAL_PRIVACYLEVEL_HEADER, HeaderValue::from_str(&auth.access_level().privacy_level().to_string())?);
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
headers.append(INTERNAL_APPID_HEADER, HeaderValue::from_str(&format!("{}", aid))?);
}
if let Some(tid) = auth.api_key_id() {
headers.append(INTERNAL_TOKENID_HEADER, HeaderValue::from_str(&format!("{}", tid))?);
}
Ok(ctx.rproxy_client.request(req).await?.into_response())
}
@ -124,11 +134,13 @@ fn router(ctx: ApiContext) -> Router {
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.route("/internal/apikey/user", post(endpoints::internal::create_api_key_user))
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(ctx.clone(), middleware::ratelimit::do_request_ratelimited))
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
.layer(axum::middleware::from_fn(middleware::logger::logger))
@ -149,14 +161,9 @@ async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let rproxy_uri = Uri::from_static(
&libpk::config
.api
.as_ref()
.expect("missing api config")
.remote_url,
)
.to_string();
let cfg = libpk::config.api.as_ref().expect("missing api config");
let rproxy_uri = Uri::from_static(cfg.remote_url.as_str()).to_string();
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
@ -166,16 +173,16 @@ async fn main() -> anyhow::Result<()> {
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
token_privatekey: EncodingKey::from_ec_pem(cfg.token_privatekey.as_bytes())
.expect("failed to load private key"),
token_publickey: DecodingKey::from_ec_pem(cfg.token_publickey.as_bytes())
.expect("failed to load public key"),
};
let app = router(ctx);
let addr: &str = libpk::config
.api
.as_ref()
.expect("missing api config")
.addr
.as_ref();
let addr: &str = cfg.addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);

View file

@ -1,20 +1,129 @@
use axum::{
extract::{Request, State},
extract::{Request, State, MatchedPath},
http::StatusCode,
middleware::Next,
response::Response,
};
use uuid::Uuid;
use subtle::ConstantTimeEq;
use tracing::error;
use sqlx::Postgres;
use crate::auth::AuthState;
use pluralkit_models::{ApiKeyType, PKApiKey};
use crate::auth::{AccessLevel, AuthState};
use crate::{util::json_err, ApiContext};
pub fn is_part_path<'a, 'b>(part: &'a str, endpoint: &'b str) -> bool {
if !endpoint.starts_with("/v2/") {
return false;
}
let path_frags = endpoint[4..].split("/").collect::<Vec<&str>>();
match part {
"system" => match &path_frags[..] {
["systems", _] => true,
["systems", _, "settings"] => true,
["systems", _, "autoproxy"] => true,
["systems", _, "guilds", ..] => true,
_ => false,
},
"members" => match &path_frags[..] {
["systems", _, "members"] => true,
["members"] => true,
["members", _, "groups"] => false,
["members", _, "groups", ..] => false,
["members", ..] => true,
_ => false,
},
"groups" => match &path_frags[..] {
["systems", _, "groups"] => true,
["groups"] => true,
["groups", ..] => true,
["members", _, "groups"] => true,
["members", _, "groups", ..] => true,
_ => false,
},
"fronters" => match &path_frags[..] {
["systems", _, "fronters"] => true,
_ => false,
},
"switches" => match &path_frags[..] {
// switches implies fronters
["systems", _, "fronters"] => true,
["systems", _, "switches"] => true,
["systems", _, "switches", ..] => true,
_ => false,
},
_ => false,
}
}
pub fn apikey_can_access(token: &PKApiKey, method: String, endpoint: String) -> AccessLevel {
if token.kind == ApiKeyType::Dashboard {
return AccessLevel::Full;
}
let mut access = AccessLevel::None;
for rscope in token.scopes.iter() {
let scope = rscope.split(":").collect::<Vec<&str>>();
let na = match (method.as_str(), &scope[..]) {
("GET", ["identify"]) => {
if &endpoint == "/v2/systems/:system_id" {
AccessLevel::PublicRead
} else {
AccessLevel::None
}
}
("GET", ["publicread", part]) => {
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
AccessLevel::PublicRead
} else {
AccessLevel::None
}
}
("GET", ["read", part]) => {
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
AccessLevel::PrivateRead
} else {
AccessLevel::None
}
}
(_, ["write", part]) => {
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
AccessLevel::Full
} else {
AccessLevel::None
}
}
_ => AccessLevel::None,
};
if na > access {
access = na;
}
}
access
}
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let endpoint = req
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<i32> = None;
let mut authed_app_id: Option<Uuid> = None;
let mut authed_api_key_id: Option<Uuid> = None;
let mut access_level = AccessLevel::None;
// fetch user authorization
if let Some(system_auth_header) = req
@ -22,7 +131,24 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
.get("authorization")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(system_id) =
{
if system_auth_header.starts_with("Bearer ")
&& let Some(tid) =
PKApiKey::parse_header_str(system_auth_header[7..].to_string(), &ctx.token_publickey)
&& let Some(token) =
sqlx::query_as::<Postgres, PKApiKey>("select * from api_keys where id = $1")
.bind(&tid)
.fetch_optional(&ctx.db)
.await
.expect("failed to query apitoken in postgres")
{
authed_api_key_id = Some(tid);
access_level = apikey_can_access(&token, req.method().to_string(), endpoint.clone());
if access_level != AccessLevel::None {
authed_system_id = Some(token.system);
}
}
else if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
Ok(val) => val,
Err(err) => {
@ -33,29 +159,31 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
);
}
}
{
authed_system_id = Some(system_id);
}
{
authed_system_id = Some(system_id);
access_level = AccessLevel::Full;
}
}
// fetch app authorization
// todo: actually fetch it from db
if let Some(app_auth_header) = req
.headers()
.get("x-pluralkit-app")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(config_token2) = libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
.as_ref()
&& app_auth_header
.as_bytes()
.ct_eq(config_token2.as_bytes())
.into()
&& let Some(app_id) =
match libpk::db::repository::app_token_auth(&ctx.db, app_auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
}
{
authed_app_id = Some(1);
authed_app_id = Some(app_id);
}
// todo: fix syntax
@ -74,7 +202,7 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
};
req.extensions_mut()
.insert(AuthState::new(authed_system_id, authed_app_id, internal));
.insert(AuthState::new(authed_system_id, authed_app_id, authed_api_key_id, access_level, internal));
next.run(req).await
}

View file

@ -11,7 +11,7 @@ fn add_cors_headers(headers: &mut HeaderMap) {
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-PluralKit-Authentication, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
}

View file

@ -42,6 +42,7 @@ pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
else if !request.uri().clone().path().starts_with("/v2")
&& !request.uri().clone().path().starts_with("/private")
&& !request.uri().clone().path().starts_with("/internal")
{
return (
StatusCode::BAD_REQUEST,

View file

@ -8,12 +8,15 @@ use axum::{
};
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
use metrics::counter;
use sqlx::Postgres;
use tracing::{debug, error, info, warn};
use crate::{
ApiContext,
auth::AuthState,
util::{header_or_unknown, json_err},
};
use pluralkit_models::PKExternalApp;
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
@ -22,7 +25,10 @@ lazy_static::lazy_static! {
}
// this is awful but it works
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
pub fn ratelimiter<F, T>(
ctx: ApiContext,
f: F,
) -> FromFnLayer<F, (ApiContext, Option<RedisPool>), T> {
let redis = libpk::config
.api
.as_ref()
@ -52,14 +58,14 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
warn!("running without request rate limiting!");
}
axum::middleware::from_fn_with_state(redis, f)
axum::middleware::from_fn_with_state((ctx, redis), f)
}
enum RatelimitType {
GenericGet,
GenericUpdate,
Message,
TempCustom,
AppCustom(i32),
}
impl RatelimitType {
@ -68,7 +74,7 @@ impl RatelimitType {
RatelimitType::GenericGet => "generic_get",
RatelimitType::GenericUpdate => "generic_update",
RatelimitType::Message => "message",
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
RatelimitType::AppCustom(_) => "app_custom",
}
.to_string()
}
@ -78,21 +84,41 @@ impl RatelimitType {
RatelimitType::GenericGet => 10,
RatelimitType::GenericUpdate => 3,
RatelimitType::Message => 10,
RatelimitType::TempCustom => 20,
RatelimitType::AppCustom(n) => *n,
}
}
}
pub async fn do_request_ratelimited(
State(redis): State<Option<RedisPool>>,
State((ctx, redis)): State<(ApiContext, Option<RedisPool>)>,
request: Request,
next: Next,
) -> Response {
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
if headers.get("x-pluralkit-internal").is_some() {
// bypass ratelimiting entirely for internal requests
return next.run(request).await;
}
let extensions = request.extensions().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let mut app_rate: Option<i32> = None;
if let Some(app_header) = request.headers().clone().get("x-pluralkit-app") {
let app_token = app_header.to_str().unwrap_or("invalid");
if app_token.starts_with("pkap2:")
&& let Some(app) = sqlx::query_as::<Postgres, PKExternalApp>(
"select * from external_apps where api_rl_token = $1",
)
.bind(&app_token[6..])
.fetch_optional(&ctx.db)
.await
.expect("failed to query external app in postgres")
{
app_rate = Some(app.api_rl_rate.expect("external app has no api_rl_rate"));
}
};
let endpoint = extensions
.get::<MatchedPath>()
@ -109,11 +135,8 @@ pub async fn do_request_ratelimited(
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if let Some(app_id) = auth.app_id()
&& app_id == 1
{
RatelimitType::TempCustom
let rlimit = if let Some(r) = app_rate {
RatelimitType::AppCustom(r)
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message
} else if request.method() == Method::GET {

View file

@ -6,6 +6,7 @@ edition = "2021"
[dependencies]
anyhow = { workspace = true }
fred = { workspace = true }
jsonwebtoken = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
pk_macros = { path = "../macros" }

View file

@ -62,6 +62,11 @@ pub struct ApiConfig {
#[serde(default)]
pub temp_token2: Option<String>,
pub token_privatekey: String,
pub token_publickey: String,
pub internal_request_secret: String,
}
#[derive(Deserialize, Clone, Debug)]

View file

@ -1,3 +1,5 @@
use uuid::Uuid;
pub async fn legacy_token_auth(
pool: &sqlx::postgres::PgPool,
token: &str,
@ -18,3 +20,24 @@ pub async fn legacy_token_auth(
struct LegacyTokenDbResponse {
id: i32,
}
pub async fn app_token_auth(
pool: &sqlx::postgres::PgPool,
token: &str,
) -> anyhow::Result<Option<Uuid>> {
let mut app: Vec<AppTokenDbResponse> =
sqlx::query_as("select id from external_apps where api_rl_token = $1")
.bind(token)
.fetch_all(pool)
.await?;
Ok(if let Some(app) = app.pop() {
Some(app.id)
} else {
None
})
}
#[derive(sqlx::FromRow)]
struct AppTokenDbResponse {
id: Uuid,
}

View file

@ -9,6 +9,7 @@ fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
pub fn macro_impl(
_args: proc_macro::TokenStream,
input: proc_macro::TokenStream,
is_internal: bool,
) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as ItemFn);
@ -34,6 +35,16 @@ pub fn macro_impl(
})
.collect();
let internal_res = if is_internal {
quote! {
if !auth.internal() {
return crate::error::FORBIDDEN_INTERNAL_ROUTE.into_response();
}
}
} else {
quote!()
};
let res = quote! {
#[allow(unused_mut)]
pub async fn #fn_name(#fn_params) -> axum::response::Response {
@ -41,6 +52,7 @@ pub fn macro_impl(
#fn_body
}
#internal_res
match inner(#(#pms),*).await {
Ok(res) => res.into_response(),
Err(err) => err.into_response(),

View file

@ -6,7 +6,12 @@ mod model;
#[proc_macro_attribute]
pub fn api_endpoint(args: TokenStream, input: TokenStream) -> TokenStream {
api::macro_impl(args, input)
api::macro_impl(args, input, false)
}
#[proc_macro_attribute]
pub fn api_internal_endpoint(args: TokenStream, input: TokenStream) -> TokenStream {
api::macro_impl(args, input, true)
}
#[proc_macro_attribute]

View file

@ -18,4 +18,4 @@ create table command_messages (
create index command_messages_by_original on command_messages(original_mid);
create index command_messages_by_sender on command_messages(sender);
update info set schema_version = 52;
update info set schema_version = 52;

View file

@ -0,0 +1,38 @@
-- database version 53
--
-- scoped API keys + skeleton for oauth2 for third-party apps
create table external_apps (
id uuid primary key default gen_random_uuid(),
name text not null,
homepage_url text not null,
oauth2_secret text,
oauth2_allowed_redirects text[] not null default array[]::text[],
oauth2_scopes text[] not null default array[]::text[],
api_rl_token text,
api_rl_rate int
);
create type api_key_type as enum (
'dashboard',
'user_created',
'external_app'
);
create table api_keys (
id uuid primary key default gen_random_uuid(),
system int references systems(id) on delete cascade,
kind api_key_type not null,
scopes text[] not null default array[]::text[],
app uuid references external_apps(id) on delete cascade,
name text,
discord_id bigint,
discord_access_token text,
discord_refresh_token text,
discord_expires_at timestamp,
created timestamp with time zone not null default (current_timestamp at time zone 'utc')
);
update info set schema_version = 53;

View file

@ -4,8 +4,10 @@ version = "0.1.0"
edition = "2021"
[dependencies]
base64 = { workspace = true }
chrono = { workspace = true, features = ["serde"] }
pk_macros = { path = "../macros" }
jsonwebtoken = { workspace = true }
sea-query = "0.32.1"
serde = { workspace = true }
serde_json = { workspace = true, features = ["preserve_order"] }

View file

@ -6,9 +6,9 @@
// note: caller needs to implement From<i32> for their type
macro_rules! fake_enum_impls {
($n:ident) => {
impl Type<Postgres> for $n {
fn type_info() -> PgTypeInfo {
PgTypeInfo::with_name("INT4")
impl ::sqlx::Type<::sqlx::Postgres> for $n {
fn type_info() -> ::sqlx::postgres::PgTypeInfo {
::sqlx::postgres::PgTypeInfo::with_name("INT4")
}
}
@ -18,14 +18,14 @@ macro_rules! fake_enum_impls {
}
}
impl<'r, DB: Database> Decode<'r, DB> for $n
impl<'r, DB: ::sqlx::Database> ::sqlx::Decode<'r, DB> for $n
where
i32: Decode<'r, DB>,
i32: ::sqlx::Decode<'r, DB>,
{
fn decode(
value: <DB as Database>::ValueRef<'r>,
) -> Result<Self, Box<dyn Error + 'static + Send + Sync>> {
let value = <i32 as Decode<DB>>::decode(value)?;
value: <DB as ::sqlx::Database>::ValueRef<'r>,
) -> Result<Self, Box<dyn ::std::error::Error + 'static + Send + Sync>> {
let value = <i32 as ::sqlx::Decode<DB>>::decode(value)?;
Ok(Self::from(value))
}
}

View file

@ -0,0 +1,104 @@
use pk_macros::pk_model;
use chrono::{DateTime, Utc, NaiveDateTime};
use uuid::Uuid;
use base64::{prelude::BASE64_STANDARD, Engine};
use jsonwebtoken::{
crypto::{sign, verify},
DecodingKey, EncodingKey,
};
use crate::SystemId;
#[derive(sqlx::Type, Debug, Clone, PartialEq, serde::Serialize)]
#[serde(rename_all = "snake_case")]
#[sqlx(rename_all = "snake_case")]
#[sqlx(type_name = "api_key_type")]
pub enum ApiKeyType {
Dashboard,
UserCreated,
ExternalApp,
}
#[pk_model]
struct ApiKey {
#[json = "id"]
id: Uuid,
system: SystemId,
#[json = "type"]
kind: ApiKeyType,
#[json = "scopes"]
scopes: Vec<String>,
#[json = "app"]
app: Option<Uuid>,
#[json = "name"]
#[patchable]
name: Option<String>,
#[json = "discord_id"]
discord_id: Option<i64>,
#[private_patchable]
discord_access_token: Option<String>,
#[private_patchable]
discord_refresh_token: Option<String>,
#[private_patchable]
discord_expires_at: Option<NaiveDateTime>,
#[json = "created"]
created: DateTime<Utc>,
}
const SIGNATURE_ALGORITHM: jsonwebtoken::Algorithm = jsonwebtoken::Algorithm::ES256;
impl PKApiKey {
pub fn to_header_str(self, system_uuid: Uuid, key: &EncodingKey) -> String {
let b64 = BASE64_STANDARD.encode(
serde_json::to_vec(&serde_json::json!({
"tid": self.id.to_string(),
"sid": system_uuid.to_string(),
"type": self.kind,
"scopes": self.scopes,
}))
.expect("should not fail"),
);
let signature = sign(b64.as_bytes(), key, SIGNATURE_ALGORITHM).expect("should not fail");
format!("pkapi:{b64}:{signature}")
}
/// Parse a header string into a token uuid
pub fn parse_header_str(token: String, key: &DecodingKey) -> Option<Uuid> {
let mut parts = token.split(":");
let pkapi = parts.next();
if pkapi.is_none_or(|v| v != "pkapi") {
return None;
}
let Some(jsonblob) = parts.next() else {
return None;
};
let Some(sig) = parts.next() else {
return None;
};
// verify signature before doing anything else
let valid = verify(sig, jsonblob.as_bytes(), key, SIGNATURE_ALGORITHM);
if valid.is_err() || matches!(valid, Ok(false)) {
return None;
}
let Ok(bytes) = BASE64_STANDARD.decode(jsonblob) else {
return None;
};
let Ok(obj) = serde_json::from_slice::<serde_json::Value>(bytes.as_slice()) else {
return None;
};
obj.get("tid")
.map(|v| v.as_str().map(|f| Uuid::parse_str(f).ok()))
.flatten()
.flatten()
}
}

View file

@ -1,14 +1,5 @@
mod _util;
macro_rules! model {
($n:ident) => {
mod $n;
pub use $n::*;
};
}
model!(system);
model!(system_config);
use _util::fake_enum_impls;
#[derive(serde::Serialize, Debug, Clone)]
#[serde(rename_all = "snake_case")]
@ -17,10 +8,7 @@ pub enum PrivacyLevel {
Private,
}
// this sucks, put it somewhere else
use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type};
use std::error::Error;
_util::fake_enum_impls!(PrivacyLevel);
fake_enum_impls!(PrivacyLevel);
impl From<i32> for PrivacyLevel {
fn from(value: i32) -> Self {
@ -31,3 +19,24 @@ impl From<i32> for PrivacyLevel {
}
}
}
impl PrivacyLevel {
pub fn to_string(&self) -> String {
match self {
PrivacyLevel::Public => "public".into(),
PrivacyLevel::Private => "private".into(),
}
}
}
macro_rules! model {
($n:ident) => {
mod $n;
pub use $n::*;
};
}
model!(api_key);
model!(oauth2_app);
model!(system);
model!(system_config);

View file

@ -0,0 +1,28 @@
use pk_macros::pk_model;
use uuid::Uuid;
#[pk_model]
struct ExternalApp {
#[json = "id"]
id: Uuid,
#[json = "name"]
#[patchable]
name: String,
#[json = "homepage_url"]
#[patchable]
homepage_url: String,
#[private_patchable]
oauth2_secret: Option<String>,
#[json = "oauth2_allowed_redirects"]
#[patchable]
oauth2_allowed_redirects: Vec<String>,
#[json = "oauth2_scopes"]
#[patchable]
oauth2_scopes: Vec<String>,
#[private_patchable]
api_rl_token: Option<String>,
#[private_patchable]
api_rl_rate: Option<i32>,
}

View file

@ -1,10 +1,9 @@
use pk_macros::pk_model;
use crate::PrivacyLevel;
use chrono::NaiveDateTime;
use uuid::Uuid;
use crate::PrivacyLevel;
// todo: fix this
pub type SystemId = i32;