mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-09 23:37:54 +00:00
[WIP] feat: scoped api keys
This commit is contained in:
parent
e7ee593a85
commit
06cb160f95
45 changed files with 1264 additions and 154 deletions
|
|
@ -10,8 +10,12 @@ libpk = { path = "../libpk" }
|
|||
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
chrono = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
jsonwebtoken = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
uuid = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
|
|
|
|||
|
|
@ -1,20 +1,45 @@
|
|||
use uuid::Uuid;
|
||||
|
||||
use pluralkit_models::{PKSystem, PrivacyLevel, SystemId};
|
||||
|
||||
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||
pub const INTERNAL_TOKENID_HEADER: &'static str = "x-pluralkit-tid";
|
||||
pub const INTERNAL_PRIVACYLEVEL_HEADER: &'static str = "x-pluralkit-privacylevel";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, PartialOrd)]
|
||||
pub enum AccessLevel {
|
||||
None = 0,
|
||||
PublicRead,
|
||||
PrivateRead,
|
||||
Full,
|
||||
}
|
||||
|
||||
impl AccessLevel {
|
||||
pub fn privacy_level(&self) -> PrivacyLevel {
|
||||
match self {
|
||||
Self::None | Self::PublicRead => PrivacyLevel::Public,
|
||||
Self::PrivateRead | Self::Full => PrivacyLevel::Private,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AuthState {
|
||||
system_id: Option<i32>,
|
||||
app_id: Option<i32>,
|
||||
app_id: Option<Uuid>,
|
||||
api_key_id: Option<Uuid>,
|
||||
access_level: AccessLevel,
|
||||
internal: bool,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
pub fn new(system_id: Option<i32>, app_id: Option<i32>, internal: bool) -> Self {
|
||||
pub fn new(system_id: Option<i32>, app_id: Option<Uuid>, api_key_id: Option<Uuid>, access_level: AccessLevel, internal: bool) -> Self {
|
||||
Self {
|
||||
system_id,
|
||||
app_id,
|
||||
api_key_id,
|
||||
access_level,
|
||||
internal,
|
||||
}
|
||||
}
|
||||
|
|
@ -23,10 +48,18 @@ impl AuthState {
|
|||
self.system_id
|
||||
}
|
||||
|
||||
pub fn app_id(&self) -> Option<i32> {
|
||||
pub fn app_id(&self) -> Option<Uuid> {
|
||||
self.app_id
|
||||
}
|
||||
|
||||
pub fn api_key_id(&self) -> Option<Uuid> {
|
||||
self.api_key_id
|
||||
}
|
||||
|
||||
pub fn access_level(&self) -> AccessLevel {
|
||||
self.access_level.clone()
|
||||
}
|
||||
|
||||
pub fn internal(&self) -> bool {
|
||||
self.internal
|
||||
}
|
||||
|
|
@ -37,7 +70,7 @@ impl AuthState {
|
|||
.map(|id| id == a.authable_system_id())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
PrivacyLevel::Private
|
||||
self.access_level.privacy_level()
|
||||
} else {
|
||||
PrivacyLevel::Public
|
||||
}
|
||||
|
|
|
|||
114
crates/api/src/endpoints/internal.rs
Normal file
114
crates/api/src/endpoints/internal.rs
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
use crate::{util::json_err, AuthState, ApiContext};
|
||||
use pluralkit_models::{ApiKeyType, PKApiKey, PKSystem, SystemId};
|
||||
use pk_macros::api_internal_endpoint;
|
||||
|
||||
use axum::{
|
||||
extract::State,
|
||||
http::StatusCode,
|
||||
response::{IntoResponse, Json, Response},
|
||||
Extension,
|
||||
};
|
||||
|
||||
use sqlx::Postgres;
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
pub struct NewApiKeyRequestData {
|
||||
#[serde(default)]
|
||||
check: bool,
|
||||
system: SystemId,
|
||||
name: Option<String>,
|
||||
scopes: Vec<String>,
|
||||
}
|
||||
|
||||
#[api_internal_endpoint]
|
||||
pub async fn create_api_key_user(
|
||||
State(ctx): State<ApiContext>,
|
||||
Extension(auth): Extension<AuthState>,
|
||||
Json(req): Json<NewApiKeyRequestData>,
|
||||
) -> Response {
|
||||
let system: Option<PKSystem> = sqlx::query_as("select * from systems where id = $1")
|
||||
.bind(req.system)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
.expect("failed to query system");
|
||||
|
||||
if system.is_none() {
|
||||
return Ok(json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message": "no system found!?", "internal": true}"#.to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
let system = system.unwrap();
|
||||
|
||||
// sanity check requested scopes
|
||||
if req.scopes.len() < 1 {
|
||||
return Ok(json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
r#"{"message": "no scopes provided", "internal": true}"#.to_string(),
|
||||
));
|
||||
}
|
||||
for scope in req.scopes.iter() {
|
||||
let parts = scope.split(":").collect::<Vec<&str>>();
|
||||
let ok = match &parts[..] {
|
||||
["identify"] => true,
|
||||
["publicread", n] | ["read", n] | ["write", n] => match *n {
|
||||
"all" => true,
|
||||
"system" => true,
|
||||
"members" => true,
|
||||
"groups" => true,
|
||||
"fronters" => true,
|
||||
"switches" => true,
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
};
|
||||
|
||||
if !ok {
|
||||
return Err(crate::error::GENERIC_BAD_REQUEST);
|
||||
}
|
||||
}
|
||||
|
||||
if req.check {
|
||||
return Ok((
|
||||
StatusCode::OK,
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"valid": true,
|
||||
}))
|
||||
.expect("should not error"),
|
||||
).into_response());
|
||||
}
|
||||
|
||||
let token: PKApiKey = sqlx::query_as(
|
||||
r#"
|
||||
insert into api_keys
|
||||
(
|
||||
system,
|
||||
kind,
|
||||
scopes,
|
||||
name
|
||||
)
|
||||
values
|
||||
($1, $2::api_key_type, $3::text[], $4)
|
||||
returning *
|
||||
"#,
|
||||
)
|
||||
.bind(system.id)
|
||||
.bind(ApiKeyType::UserCreated)
|
||||
.bind(req.scopes)
|
||||
.bind(req.name)
|
||||
.fetch_one(&ctx.db)
|
||||
.await
|
||||
.expect("failed to create token");
|
||||
|
||||
let token = token.to_header_str(system.clone().uuid, &ctx.token_privatekey);
|
||||
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"valid": true,
|
||||
"token": token,
|
||||
}))
|
||||
.expect("should not error"),
|
||||
).into_response())
|
||||
}
|
||||
|
|
@ -1,2 +1,3 @@
|
|||
pub mod internal;
|
||||
pub mod private;
|
||||
pub mod system;
|
||||
|
|
|
|||
|
|
@ -1,18 +1,20 @@
|
|||
use crate::ApiContext;
|
||||
use axum::{extract::State, response::Json};
|
||||
use crate::{util::json_err, ApiContext};
|
||||
use libpk::config;
|
||||
use pluralkit_models::{PrivacyLevel, PKApiKey, PKSystem, PKSystemConfig};
|
||||
|
||||
use axum::{
|
||||
extract::{self, State},
|
||||
response::{IntoResponse, Json, Response},
|
||||
};
|
||||
use fred::interfaces::*;
|
||||
use hyper::StatusCode;
|
||||
use libpk::state::ShardState;
|
||||
use pk_macros::api_endpoint;
|
||||
use reqwest::ClientBuilder;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{json, Value};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "PascalCase")]
|
||||
struct ClusterStats {
|
||||
pub guild_count: i32,
|
||||
pub channel_count: i32,
|
||||
}
|
||||
use std::time::Duration;
|
||||
|
||||
#[api_endpoint]
|
||||
pub async fn discord_state(State(ctx): State<ApiContext>) -> Json<Value> {
|
||||
|
|
@ -43,18 +45,6 @@ pub async fn meta(State(ctx): State<ApiContext>) -> Json<Value> {
|
|||
Ok(Json(stats))
|
||||
}
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::util::json_err;
|
||||
use axum::{
|
||||
extract,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use hyper::StatusCode;
|
||||
use libpk::config;
|
||||
use pluralkit_models::{PKSystem, PKSystemConfig, PrivacyLevel};
|
||||
use reqwest::ClientBuilder;
|
||||
|
||||
#[derive(serde::Deserialize, Debug)]
|
||||
pub struct CallbackRequestData {
|
||||
redirect_domain: String,
|
||||
|
|
@ -71,6 +61,7 @@ struct CallbackDiscordData {
|
|||
code: String,
|
||||
}
|
||||
|
||||
#[api_endpoint]
|
||||
pub async fn discord_callback(
|
||||
State(ctx): State<ApiContext>,
|
||||
extract::Json(request_data): extract::Json<CallbackRequestData>,
|
||||
|
|
@ -107,7 +98,7 @@ pub async fn discord_callback(
|
|||
};
|
||||
|
||||
if !discord_data.contains_key("access_token") {
|
||||
return json_err(
|
||||
return Ok(json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
format!(
|
||||
"{{\"error\":\"{}\"\"}}",
|
||||
|
|
@ -116,7 +107,7 @@ pub async fn discord_callback(
|
|||
.expect("missing error_description from discord")
|
||||
.to_string()
|
||||
),
|
||||
);
|
||||
));
|
||||
};
|
||||
|
||||
let token = format!(
|
||||
|
|
@ -152,10 +143,10 @@ pub async fn discord_callback(
|
|||
.expect("failed to query");
|
||||
|
||||
let Some(system) = system else {
|
||||
return json_err(
|
||||
return Ok(json_err(
|
||||
StatusCode::BAD_REQUEST,
|
||||
"user does not have a system registered".to_string(),
|
||||
);
|
||||
r#"{"message": "user does not have a system registered", "code": 0}"#.to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
let system_config: Option<PKSystemConfig> = sqlx::query_as(
|
||||
|
|
@ -170,11 +161,38 @@ pub async fn discord_callback(
|
|||
|
||||
let system_config = system_config.unwrap();
|
||||
|
||||
// create dashboard token for system
|
||||
let token: PKApiKey = sqlx::query_as(
|
||||
r#"
|
||||
insert into api_keys
|
||||
(
|
||||
system,
|
||||
kind,
|
||||
discord_id,
|
||||
discord_access_token,
|
||||
discord_refresh_token,
|
||||
discord_expires_at
|
||||
)
|
||||
values
|
||||
($1, $2::api_key_type, $3, $4, $5, $6)
|
||||
returning *
|
||||
"#,
|
||||
)
|
||||
.bind(system.id)
|
||||
.bind("dashboard")
|
||||
.bind(user.id.get() as i64)
|
||||
.bind(discord_data.get("access_token").unwrap().as_str())
|
||||
.bind(discord_data.get("refresh_token").unwrap().as_str())
|
||||
.bind(
|
||||
chrono::Utc::now()
|
||||
+ chrono::Duration::seconds(discord_data.get("expires_in").unwrap().as_i64().unwrap()),
|
||||
)
|
||||
.fetch_one(&ctx.db)
|
||||
.await
|
||||
.expect("failed to create token");
|
||||
|
||||
let token = system.clone().token;
|
||||
let token = token.to_header_str(system.clone().uuid, &ctx.token_privatekey);
|
||||
|
||||
(
|
||||
Ok((
|
||||
StatusCode::OK,
|
||||
serde_json::to_string(&serde_json::json!({
|
||||
"system": system.to_json(PrivacyLevel::Private),
|
||||
|
|
@ -183,6 +201,5 @@ pub async fn discord_callback(
|
|||
"token": token,
|
||||
}))
|
||||
.expect("should not error"),
|
||||
)
|
||||
.into_response()
|
||||
).into_response())
|
||||
}
|
||||
|
|
|
|||
|
|
@ -83,4 +83,6 @@ macro_rules! define_error {
|
|||
}
|
||||
|
||||
define_error! { GENERIC_BAD_REQUEST, StatusCode::BAD_REQUEST, 0, "400: Bad Request" }
|
||||
// define_error! { GENERIC_UNAUTHORIZED, StatusCode::UNAUTHORIZED, 0, "401: Missing or invalid Authorization header" }
|
||||
define_error! { FORBIDDEN_INTERNAL_ROUTE, StatusCode::FORBIDDEN, 0, "403: Forbidden to access this endpoint" }
|
||||
define_error! { GENERIC_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR, 0, "500: Internal Server Error" }
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER, INTERNAL_TOKENID_HEADER, INTERNAL_PRIVACYLEVEL_HEADER};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
http::Uri,
|
||||
http::{HeaderValue, Uri},
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, patch, post},
|
||||
Extension, Router,
|
||||
|
|
@ -13,8 +13,8 @@ use hyper_util::{
|
|||
client::legacy::{connect::HttpConnector, Client},
|
||||
rt::TokioExecutor,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use jsonwebtoken::{DecodingKey, EncodingKey};
|
||||
use tracing::{error, info};
|
||||
use pk_macros::api_endpoint;
|
||||
|
||||
mod auth;
|
||||
|
|
@ -30,6 +30,9 @@ pub struct ApiContext {
|
|||
|
||||
rproxy_uri: String,
|
||||
rproxy_client: Client<HttpConnector, Body>,
|
||||
|
||||
token_privatekey: EncodingKey,
|
||||
token_publickey: DecodingKey,
|
||||
}
|
||||
|
||||
#[api_endpoint]
|
||||
|
|
@ -53,14 +56,21 @@ async fn rproxy(
|
|||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
headers.remove(INTERNAL_TOKENID_HEADER);
|
||||
headers.remove(INTERNAL_PRIVACYLEVEL_HEADER);
|
||||
|
||||
if let Some(sid) = auth.system_id() {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
headers.append(INTERNAL_PRIVACYLEVEL_HEADER, HeaderValue::from_str(&auth.access_level().privacy_level().to_string())?);
|
||||
}
|
||||
|
||||
if let Some(aid) = auth.app_id() {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
headers.append(INTERNAL_APPID_HEADER, HeaderValue::from_str(&format!("{}", aid))?);
|
||||
}
|
||||
|
||||
if let Some(tid) = auth.api_key_id() {
|
||||
headers.append(INTERNAL_TOKENID_HEADER, HeaderValue::from_str(&format!("{}", tid))?);
|
||||
}
|
||||
|
||||
Ok(ctx.rproxy_client.request(req).await?.into_response())
|
||||
}
|
||||
|
|
@ -124,11 +134,13 @@ fn router(ctx: ApiContext) -> Router {
|
|||
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
|
||||
.route("/private/stats", get(endpoints::private::meta))
|
||||
|
||||
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
|
||||
.route("/internal/apikey/user", post(endpoints::internal::create_api_key_user))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(middleware::ratelimit::do_request_ratelimited)) // this sucks
|
||||
.route("/v2/systems/:system_id/oembed.json", get(rproxy))
|
||||
.route("/v2/members/:member_id/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/:group_id/oembed.json", get(rproxy))
|
||||
|
||||
.layer(middleware::ratelimit::ratelimiter(ctx.clone(), middleware::ratelimit::do_request_ratelimited))
|
||||
|
||||
.layer(axum::middleware::from_fn(middleware::ignore_invalid_routes::ignore_invalid_routes))
|
||||
.layer(axum::middleware::from_fn(middleware::logger::logger))
|
||||
|
|
@ -149,14 +161,9 @@ async fn main() -> anyhow::Result<()> {
|
|||
let db = libpk::db::init_data_db().await?;
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let rproxy_uri = Uri::from_static(
|
||||
&libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.remote_url,
|
||||
)
|
||||
.to_string();
|
||||
let cfg = libpk::config.api.as_ref().expect("missing api config");
|
||||
|
||||
let rproxy_uri = Uri::from_static(cfg.remote_url.as_str()).to_string();
|
||||
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
|
||||
|
|
@ -166,16 +173,16 @@ async fn main() -> anyhow::Result<()> {
|
|||
|
||||
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
|
||||
rproxy_client,
|
||||
|
||||
token_privatekey: EncodingKey::from_ec_pem(cfg.token_privatekey.as_bytes())
|
||||
.expect("failed to load private key"),
|
||||
token_publickey: DecodingKey::from_ec_pem(cfg.token_publickey.as_bytes())
|
||||
.expect("failed to load public key"),
|
||||
};
|
||||
|
||||
let app = router(ctx);
|
||||
|
||||
let addr: &str = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.addr
|
||||
.as_ref();
|
||||
let addr: &str = cfg.addr.as_ref();
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
|
|
|
|||
|
|
@ -1,20 +1,129 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
extract::{Request, State, MatchedPath},
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
|
||||
use uuid::Uuid;
|
||||
use subtle::ConstantTimeEq;
|
||||
|
||||
use tracing::error;
|
||||
use sqlx::Postgres;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
use pluralkit_models::{ApiKeyType, PKApiKey};
|
||||
use crate::auth::{AccessLevel, AuthState};
|
||||
use crate::{util::json_err, ApiContext};
|
||||
|
||||
pub fn is_part_path<'a, 'b>(part: &'a str, endpoint: &'b str) -> bool {
|
||||
if !endpoint.starts_with("/v2/") {
|
||||
return false;
|
||||
}
|
||||
|
||||
let path_frags = endpoint[4..].split("/").collect::<Vec<&str>>();
|
||||
match part {
|
||||
"system" => match &path_frags[..] {
|
||||
["systems", _] => true,
|
||||
["systems", _, "settings"] => true,
|
||||
["systems", _, "autoproxy"] => true,
|
||||
["systems", _, "guilds", ..] => true,
|
||||
_ => false,
|
||||
},
|
||||
"members" => match &path_frags[..] {
|
||||
["systems", _, "members"] => true,
|
||||
["members"] => true,
|
||||
["members", _, "groups"] => false,
|
||||
["members", _, "groups", ..] => false,
|
||||
["members", ..] => true,
|
||||
_ => false,
|
||||
},
|
||||
"groups" => match &path_frags[..] {
|
||||
["systems", _, "groups"] => true,
|
||||
["groups"] => true,
|
||||
["groups", ..] => true,
|
||||
["members", _, "groups"] => true,
|
||||
["members", _, "groups", ..] => true,
|
||||
_ => false,
|
||||
},
|
||||
"fronters" => match &path_frags[..] {
|
||||
["systems", _, "fronters"] => true,
|
||||
_ => false,
|
||||
},
|
||||
"switches" => match &path_frags[..] {
|
||||
// switches implies fronters
|
||||
["systems", _, "fronters"] => true,
|
||||
["systems", _, "switches"] => true,
|
||||
["systems", _, "switches", ..] => true,
|
||||
_ => false,
|
||||
},
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn apikey_can_access(token: &PKApiKey, method: String, endpoint: String) -> AccessLevel {
|
||||
if token.kind == ApiKeyType::Dashboard {
|
||||
return AccessLevel::Full;
|
||||
}
|
||||
|
||||
let mut access = AccessLevel::None;
|
||||
for rscope in token.scopes.iter() {
|
||||
let scope = rscope.split(":").collect::<Vec<&str>>();
|
||||
let na = match (method.as_str(), &scope[..]) {
|
||||
("GET", ["identify"]) => {
|
||||
if &endpoint == "/v2/systems/:system_id" {
|
||||
AccessLevel::PublicRead
|
||||
} else {
|
||||
AccessLevel::None
|
||||
}
|
||||
}
|
||||
|
||||
("GET", ["publicread", part]) => {
|
||||
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
|
||||
AccessLevel::PublicRead
|
||||
} else {
|
||||
AccessLevel::None
|
||||
}
|
||||
}
|
||||
|
||||
("GET", ["read", part]) => {
|
||||
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
|
||||
AccessLevel::PrivateRead
|
||||
} else {
|
||||
AccessLevel::None
|
||||
}
|
||||
}
|
||||
|
||||
(_, ["write", part]) => {
|
||||
if *part == "all" || is_part_path(part.as_ref(), endpoint.as_ref()) {
|
||||
AccessLevel::Full
|
||||
} else {
|
||||
AccessLevel::None
|
||||
}
|
||||
}
|
||||
|
||||
_ => AccessLevel::None,
|
||||
};
|
||||
|
||||
if na > access {
|
||||
access = na;
|
||||
}
|
||||
}
|
||||
|
||||
access
|
||||
}
|
||||
|
||||
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
|
||||
let endpoint = req
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let mut authed_system_id: Option<i32> = None;
|
||||
let mut authed_app_id: Option<i32> = None;
|
||||
let mut authed_app_id: Option<Uuid> = None;
|
||||
let mut authed_api_key_id: Option<Uuid> = None;
|
||||
let mut access_level = AccessLevel::None;
|
||||
|
||||
// fetch user authorization
|
||||
if let Some(system_auth_header) = req
|
||||
|
|
@ -22,7 +131,24 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
|
|||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(system_id) =
|
||||
{
|
||||
if system_auth_header.starts_with("Bearer ")
|
||||
&& let Some(tid) =
|
||||
PKApiKey::parse_header_str(system_auth_header[7..].to_string(), &ctx.token_publickey)
|
||||
&& let Some(token) =
|
||||
sqlx::query_as::<Postgres, PKApiKey>("select * from api_keys where id = $1")
|
||||
.bind(&tid)
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
.expect("failed to query apitoken in postgres")
|
||||
{
|
||||
authed_api_key_id = Some(tid);
|
||||
access_level = apikey_can_access(&token, req.method().to_string(), endpoint.clone());
|
||||
if access_level != AccessLevel::None {
|
||||
authed_system_id = Some(token.system);
|
||||
}
|
||||
}
|
||||
else if let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
|
|
@ -33,29 +159,31 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
|
|||
);
|
||||
}
|
||||
}
|
||||
{
|
||||
authed_system_id = Some(system_id);
|
||||
}
|
||||
{
|
||||
authed_system_id = Some(system_id);
|
||||
access_level = AccessLevel::Full;
|
||||
}
|
||||
}
|
||||
|
||||
// fetch app authorization
|
||||
// todo: actually fetch it from db
|
||||
if let Some(app_auth_header) = req
|
||||
.headers()
|
||||
.get("x-pluralkit-app")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(config_token2) = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
.as_ref()
|
||||
&& app_auth_header
|
||||
.as_bytes()
|
||||
.ct_eq(config_token2.as_bytes())
|
||||
.into()
|
||||
&& let Some(app_id) =
|
||||
match libpk::db::repository::app_token_auth(&ctx.db, app_auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
{
|
||||
authed_app_id = Some(1);
|
||||
authed_app_id = Some(app_id);
|
||||
}
|
||||
|
||||
// todo: fix syntax
|
||||
|
|
@ -74,7 +202,7 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
|
|||
};
|
||||
|
||||
req.extensions_mut()
|
||||
.insert(AuthState::new(authed_system_id, authed_app_id, internal));
|
||||
.insert(AuthState::new(authed_system_id, authed_app_id, authed_api_key_id, access_level, internal));
|
||||
|
||||
next.run(req).await
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ fn add_cors_headers(headers: &mut HeaderMap) {
|
|||
headers.append("Access-Control-Allow-Methods", HeaderValue::from_static("*"));
|
||||
headers.append("Access-Control-Allow-Credentials", HeaderValue::from_static("true"));
|
||||
headers.append("Access-Control-Allow-Headers", HeaderValue::from_static("Content-Type, Authorization, sentry-trace, User-Agent"));
|
||||
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
|
||||
headers.append("Access-Control-Expose-Headers", HeaderValue::from_static("X-PluralKit-Version, X-PluralKit-Authentication, X-RateLimit-Limit, X-RateLimit-Remaining, X-RateLimit-Reset, X-RateLimit-Scope"));
|
||||
headers.append("Access-Control-Max-Age", HeaderValue::from_static("86400"));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -42,6 +42,7 @@ pub async fn ignore_invalid_routes(request: Request, next: Next) -> Response {
|
|||
// we ignored v1 routes earlier, now let's ignore all non-v2 routes
|
||||
else if !request.uri().clone().path().starts_with("/v2")
|
||||
&& !request.uri().clone().path().starts_with("/private")
|
||||
&& !request.uri().clone().path().starts_with("/internal")
|
||||
{
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
|
|
|
|||
|
|
@ -8,12 +8,15 @@ use axum::{
|
|||
};
|
||||
use fred::{clients::RedisPool, interfaces::ClientLike, prelude::LuaInterface, util::sha1_hash};
|
||||
use metrics::counter;
|
||||
use sqlx::Postgres;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
ApiContext,
|
||||
auth::AuthState,
|
||||
util::{header_or_unknown, json_err},
|
||||
};
|
||||
use pluralkit_models::PKExternalApp;
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
|
|
@ -22,7 +25,10 @@ lazy_static::lazy_static! {
|
|||
}
|
||||
|
||||
// this is awful but it works
|
||||
pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
||||
pub fn ratelimiter<F, T>(
|
||||
ctx: ApiContext,
|
||||
f: F,
|
||||
) -> FromFnLayer<F, (ApiContext, Option<RedisPool>), T> {
|
||||
let redis = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
|
|
@ -52,14 +58,14 @@ pub fn ratelimiter<F, T>(f: F) -> FromFnLayer<F, Option<RedisPool>, T> {
|
|||
warn!("running without request rate limiting!");
|
||||
}
|
||||
|
||||
axum::middleware::from_fn_with_state(redis, f)
|
||||
axum::middleware::from_fn_with_state((ctx, redis), f)
|
||||
}
|
||||
|
||||
enum RatelimitType {
|
||||
GenericGet,
|
||||
GenericUpdate,
|
||||
Message,
|
||||
TempCustom,
|
||||
AppCustom(i32),
|
||||
}
|
||||
|
||||
impl RatelimitType {
|
||||
|
|
@ -68,7 +74,7 @@ impl RatelimitType {
|
|||
RatelimitType::GenericGet => "generic_get",
|
||||
RatelimitType::GenericUpdate => "generic_update",
|
||||
RatelimitType::Message => "message",
|
||||
RatelimitType::TempCustom => "token2", // this should be "app_custom" or something
|
||||
RatelimitType::AppCustom(_) => "app_custom",
|
||||
}
|
||||
.to_string()
|
||||
}
|
||||
|
|
@ -78,21 +84,41 @@ impl RatelimitType {
|
|||
RatelimitType::GenericGet => 10,
|
||||
RatelimitType::GenericUpdate => 3,
|
||||
RatelimitType::Message => 10,
|
||||
RatelimitType::TempCustom => 20,
|
||||
RatelimitType::AppCustom(n) => *n,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn do_request_ratelimited(
|
||||
State(redis): State<Option<RedisPool>>,
|
||||
State((ctx, redis)): State<(ApiContext, Option<RedisPool>)>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
if headers.get("x-pluralkit-internal").is_some() {
|
||||
// bypass ratelimiting entirely for internal requests
|
||||
return next.run(request).await;
|
||||
}
|
||||
|
||||
let extensions = request.extensions().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
|
||||
let mut app_rate: Option<i32> = None;
|
||||
if let Some(app_header) = request.headers().clone().get("x-pluralkit-app") {
|
||||
let app_token = app_header.to_str().unwrap_or("invalid");
|
||||
if app_token.starts_with("pkap2:")
|
||||
&& let Some(app) = sqlx::query_as::<Postgres, PKExternalApp>(
|
||||
"select * from external_apps where api_rl_token = $1",
|
||||
)
|
||||
.bind(&app_token[6..])
|
||||
.fetch_optional(&ctx.db)
|
||||
.await
|
||||
.expect("failed to query external app in postgres")
|
||||
{
|
||||
app_rate = Some(app.api_rl_rate.expect("external app has no api_rl_rate"));
|
||||
}
|
||||
};
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
|
|
@ -109,11 +135,8 @@ pub async fn do_request_ratelimited(
|
|||
// todo: key should probably be chosen by app_id when it's present
|
||||
// todo: make x-ratelimit-scope actually meaningful
|
||||
|
||||
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||
let rlimit = if let Some(app_id) = auth.app_id()
|
||||
&& app_id == 1
|
||||
{
|
||||
RatelimitType::TempCustom
|
||||
let rlimit = if let Some(r) = app_rate {
|
||||
RatelimitType::AppCustom(r)
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
} else if request.method() == Method::GET {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue