Compare commits

...

2 commits

Author SHA1 Message Date
Iris System
3b2c1332c2 fixup! cargo fmt
Some checks failed
Build and push Rust service Docker images / rust docker build (push) Has been cancelled
rust checks / cargo fmt (push) Has been cancelled
2025-08-18 22:33:03 +12:00
Iris System
be218c89cc fixup! working just like it was before :) 2025-08-18 21:57:26 +12:00
14 changed files with 107 additions and 86 deletions

View file

@ -29,7 +29,7 @@ jobs:
- uses: docker/setup-buildx-action@v1 - uses: docker/setup-buildx-action@v1
# main docker build # main docker build
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" >> $GITHUB_ENV - run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" | sed 's|/|-|g' >> $GITHUB_ENV
- uses: docker/build-push-action@v2 - uses: docker/build-push-action@v2
with: with:
# https://github.com/docker/build-push-action/issues/378 # https://github.com/docker/build-push-action/issues/378

1
Cargo.lock generated
View file

@ -2616,6 +2616,7 @@ dependencies = [
"serde", "serde",
"serde_json", "serde_json",
"sqlx", "sqlx",
"tracing",
"uuid", "uuid",
] ]

View file

@ -19,7 +19,7 @@ reqwest = { version = "0.12.7" , default-features = false, features = ["rustls-t
sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls sentry = { version = "0.36.0", default-features = false, features = ["backtrace", "contexts", "panic", "debug-images", "reqwest", "rustls"] } # replace native-tls with rustls
serde = { version = "1.0.196", features = ["derive"] } serde = { version = "1.0.196", features = ["derive"] }
serde_json = "1.0.117" serde_json = "1.0.117"
sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "macros", "uuid"] } sqlx = { version = "0.8.2", features = ["runtime-tokio", "postgres", "time", "chrono", "macros", "uuid"] }
tokio = { version = "1.36.0", features = ["full"] } tokio = { version = "1.36.0", features = ["full"] }
tracing = "0.1" tracing = "0.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] } tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }

View file

@ -28,18 +28,24 @@ impl AccessLevel {
pub struct AuthState { pub struct AuthState {
system_id: Option<i32>, system_id: Option<i32>,
app_id: Option<Uuid>, app_id: Option<Uuid>,
api_key_id: Option<Uuid>, api_key_id: Option<Uuid>,
access_level: AccessLevel, access_level: AccessLevel,
internal: bool, internal: bool,
} }
impl AuthState { impl AuthState {
pub fn new(system_id: Option<i32>, app_id: Option<Uuid>, api_key_id: Option<Uuid>, access_level: AccessLevel, internal: bool) -> Self { pub fn new(
system_id: Option<i32>,
app_id: Option<Uuid>,
api_key_id: Option<Uuid>,
access_level: AccessLevel,
internal: bool,
) -> Self {
Self { Self {
system_id, system_id,
app_id, app_id,
api_key_id, api_key_id,
access_level, access_level,
internal, internal,
} }
} }
@ -56,9 +62,9 @@ impl AuthState {
self.api_key_id self.api_key_id
} }
pub fn access_level(&self) -> AccessLevel { pub fn access_level(&self) -> AccessLevel {
self.access_level.clone() self.access_level.clone()
} }
pub fn internal(&self) -> bool { pub fn internal(&self) -> bool {
self.internal self.internal

View file

@ -1,12 +1,12 @@
use crate::{util::json_err, AuthState, ApiContext}; use crate::{util::json_err, ApiContext, AuthState};
use pluralkit_models::{ApiKeyType, PKApiKey, PKSystem, SystemId};
use pk_macros::api_internal_endpoint; use pk_macros::api_internal_endpoint;
use pluralkit_models::{ApiKeyType, PKApiKey, PKSystem, SystemId};
use axum::{ use axum::{
extract::State, extract::State,
http::StatusCode, http::StatusCode,
response::{IntoResponse, Json, Response}, response::{IntoResponse, Json, Response},
Extension, Extension,
}; };
use sqlx::Postgres; use sqlx::Postgres;
@ -23,7 +23,7 @@ pub struct NewApiKeyRequestData {
#[api_internal_endpoint] #[api_internal_endpoint]
pub async fn create_api_key_user( pub async fn create_api_key_user(
State(ctx): State<ApiContext>, State(ctx): State<ApiContext>,
Extension(auth): Extension<AuthState>, Extension(auth): Extension<AuthState>,
Json(req): Json<NewApiKeyRequestData>, Json(req): Json<NewApiKeyRequestData>,
) -> Response { ) -> Response {
let system: Option<PKSystem> = sqlx::query_as("select * from systems where id = $1") let system: Option<PKSystem> = sqlx::query_as("select * from systems where id = $1")
@ -76,7 +76,8 @@ pub async fn create_api_key_user(
"valid": true, "valid": true,
})) }))
.expect("should not error"), .expect("should not error"),
).into_response()); )
.into_response());
} }
let token: PKApiKey = sqlx::query_as( let token: PKApiKey = sqlx::query_as(
@ -110,5 +111,6 @@ pub async fn create_api_key_user(
"token": token, "token": token,
})) }))
.expect("should not error"), .expect("should not error"),
).into_response()) )
.into_response())
} }

View file

@ -1,6 +1,6 @@
use crate::{util::json_err, ApiContext}; use crate::{util::json_err, ApiContext};
use libpk::config; use libpk::config;
use pluralkit_models::{PrivacyLevel, PKApiKey, PKSystem, PKSystemConfig}; use pluralkit_models::{PKApiKey, PKSystem, PKSystemConfig, PrivacyLevel};
use axum::{ use axum::{
extract::{self, State}, extract::{self, State},
@ -201,5 +201,6 @@ pub async fn discord_callback(
"token": token, "token": token,
})) }))
.expect("should not error"), .expect("should not error"),
).into_response()) )
.into_response())
} }

View file

@ -1,6 +1,9 @@
#![feature(let_chains)] #![feature(let_chains)]
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER, INTERNAL_TOKENID_HEADER, INTERNAL_PRIVACYLEVEL_HEADER}; use auth::{
AuthState, INTERNAL_APPID_HEADER, INTERNAL_PRIVACYLEVEL_HEADER, INTERNAL_SYSTEMID_HEADER,
INTERNAL_TOKENID_HEADER,
};
use axum::{ use axum::{
body::Body, body::Body,
extract::{Request as ExtractRequest, State}, extract::{Request as ExtractRequest, State},
@ -13,9 +16,10 @@ use hyper_util::{
client::legacy::{connect::HttpConnector, Client}, client::legacy::{connect::HttpConnector, Client},
rt::TokioExecutor, rt::TokioExecutor,
}; };
use jsonwebtoken::{DecodingKey, EncodingKey}; use jsonwebtoken::{DecodingKey, EncodingKey};
use tracing::{error, info};
use pk_macros::api_endpoint; use pk_macros::api_endpoint;
use tracing::{error, info};
mod auth; mod auth;
mod endpoints; mod endpoints;
@ -56,21 +60,30 @@ async fn rproxy(
headers.remove(INTERNAL_SYSTEMID_HEADER); headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER); headers.remove(INTERNAL_APPID_HEADER);
headers.remove(INTERNAL_TOKENID_HEADER); headers.remove(INTERNAL_TOKENID_HEADER);
headers.remove(INTERNAL_PRIVACYLEVEL_HEADER); headers.remove(INTERNAL_PRIVACYLEVEL_HEADER);
if let Some(sid) = auth.system_id() { if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into()); headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
headers.append(INTERNAL_PRIVACYLEVEL_HEADER, HeaderValue::from_str(&auth.access_level().privacy_level().to_string())?); headers.append(
INTERNAL_PRIVACYLEVEL_HEADER,
HeaderValue::from_str(&auth.access_level().privacy_level().to_string())?,
);
} }
if let Some(aid) = auth.app_id() { if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, HeaderValue::from_str(&format!("{}", aid))?); headers.append(
INTERNAL_APPID_HEADER,
HeaderValue::from_str(&format!("{}", aid))?,
);
}
if let Some(tid) = auth.api_key_id() {
headers.append(
INTERNAL_TOKENID_HEADER,
HeaderValue::from_str(&format!("{}", tid))?,
);
} }
if let Some(tid) = auth.api_key_id() {
headers.append(INTERNAL_TOKENID_HEADER, HeaderValue::from_str(&format!("{}", tid))?);
}
Ok(ctx.rproxy_client.request(req).await?.into_response()) Ok(ctx.rproxy_client.request(req).await?.into_response())
} }
@ -136,9 +149,9 @@ fn router(ctx: ApiContext) -> Router {
.route("/internal/apikey/user", post(endpoints::internal::create_api_key_user)) .route("/internal/apikey/user", post(endpoints::internal::create_api_key_user))
.route("/v2/systems/:system_id/oembed.json", get(rproxy)) .route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/:member_id/oembed.json", get(rproxy)) .route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/:group_id/oembed.json", get(rproxy)) .route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.layer(middleware::ratelimit::ratelimiter(ctx.clone(), middleware::ratelimit::do_request_ratelimited)) .layer(middleware::ratelimit::ratelimiter(ctx.clone(), middleware::ratelimit::do_request_ratelimited))

View file

@ -1,19 +1,19 @@
use axum::{ use axum::{
extract::{Request, State, MatchedPath}, extract::{MatchedPath, Request, State},
http::StatusCode, http::StatusCode,
middleware::Next, middleware::Next,
response::Response, response::Response,
}; };
use uuid::Uuid;
use subtle::ConstantTimeEq; use subtle::ConstantTimeEq;
use uuid::Uuid;
use tracing::error;
use sqlx::Postgres; use sqlx::Postgres;
use tracing::error;
use pluralkit_models::{ApiKeyType, PKApiKey};
use crate::auth::{AccessLevel, AuthState}; use crate::auth::{AccessLevel, AuthState};
use crate::{util::json_err, ApiContext}; use crate::{util::json_err, ApiContext};
use pluralkit_models::{ApiKeyType, PKApiKey};
pub fn is_part_path<'a, 'b>(part: &'a str, endpoint: &'b str) -> bool { pub fn is_part_path<'a, 'b>(part: &'a str, endpoint: &'b str) -> bool {
if !endpoint.starts_with("/v2/") { if !endpoint.starts_with("/v2/") {
@ -113,7 +113,7 @@ pub fn apikey_can_access(token: &PKApiKey, method: String, endpoint: String) ->
} }
pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response { pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -> Response {
let endpoint = req let endpoint = req
.extensions() .extensions()
.get::<MatchedPath>() .get::<MatchedPath>()
.cloned() .cloned()
@ -122,8 +122,8 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
let mut authed_system_id: Option<i32> = None; let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<Uuid> = None; let mut authed_app_id: Option<Uuid> = None;
let mut authed_api_key_id: Option<Uuid> = None; let mut authed_api_key_id: Option<Uuid> = None;
let mut access_level = AccessLevel::None; let mut access_level = AccessLevel::None;
// fetch user authorization // fetch user authorization
if let Some(system_auth_header) = req if let Some(system_auth_header) = req
@ -131,10 +131,12 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
.get("authorization") .get("authorization")
.map(|h| h.to_str().ok()) .map(|h| h.to_str().ok())
.flatten() .flatten()
{ {
if system_auth_header.starts_with("Bearer ") if system_auth_header.starts_with("Bearer ")
&& let Some(tid) = && let Some(tid) = PKApiKey::parse_header_str(
PKApiKey::parse_header_str(system_auth_header[7..].to_string(), &ctx.token_publickey) system_auth_header[7..].to_string(),
&ctx.token_publickey,
)
&& let Some(token) = && let Some(token) =
sqlx::query_as::<Postgres, PKApiKey>("select * from api_keys where id = $1") sqlx::query_as::<Postgres, PKApiKey>("select * from api_keys where id = $1")
.bind(&tid) .bind(&tid)
@ -142,13 +144,10 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
.await .await
.expect("failed to query apitoken in postgres") .expect("failed to query apitoken in postgres")
{ {
authed_api_key_id = Some(tid); authed_system_id = Some(token.system);
access_level = apikey_can_access(&token, req.method().to_string(), endpoint.clone()); authed_api_key_id = Some(tid);
if access_level != AccessLevel::None { access_level = apikey_can_access(&token, req.method().to_string(), endpoint.clone());
authed_system_id = Some(token.system); } else if let Some(system_id) =
}
}
else if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await { match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
Ok(val) => val, Ok(val) => val,
Err(err) => { Err(err) => {
@ -159,11 +158,11 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
); );
} }
} }
{ {
authed_system_id = Some(system_id); authed_system_id = Some(system_id);
access_level = AccessLevel::Full; access_level = AccessLevel::Full;
} }
} }
// fetch app authorization // fetch app authorization
if let Some(app_auth_header) = req if let Some(app_auth_header) = req
@ -172,7 +171,7 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
.map(|h| h.to_str().ok()) .map(|h| h.to_str().ok())
.flatten() .flatten()
&& let Some(app_id) = && let Some(app_id) =
match libpk::db::repository::app_token_auth(&ctx.db, app_auth_header).await { match libpk::db::repository::app_token_auth(&ctx.db, app_auth_header).await {
Ok(val) => val, Ok(val) => val,
Err(err) => { Err(err) => {
error!(?err, "failed to query authorization token in postgres"); error!(?err, "failed to query authorization token in postgres");
@ -201,8 +200,13 @@ pub async fn auth(State(ctx): State<ApiContext>, mut req: Request, next: Next) -
false false
}; };
req.extensions_mut() req.extensions_mut().insert(AuthState::new(
.insert(AuthState::new(authed_system_id, authed_app_id, authed_api_key_id, access_level, internal)); authed_system_id,
authed_app_id,
authed_api_key_id,
access_level,
internal,
));
next.run(req).await next.run(req).await
} }

View file

@ -12,9 +12,9 @@ use sqlx::Postgres;
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use crate::{ use crate::{
ApiContext,
auth::AuthState, auth::AuthState,
util::{header_or_unknown, json_err}, util::{header_or_unknown, json_err},
ApiContext,
}; };
use pluralkit_models::PKExternalApp; use pluralkit_models::PKExternalApp;

View file

@ -60,13 +60,8 @@ pub struct ApiConfig {
pub remote_url: String, pub remote_url: String,
#[serde(default)]
pub temp_token2: Option<String>,
pub token_privatekey: String, pub token_privatekey: String,
pub token_publickey: String, pub token_publickey: String,
pub internal_request_secret: String,
} }
#[derive(Deserialize, Clone, Debug)] #[derive(Deserialize, Clone, Debug)]

View file

@ -22,10 +22,10 @@ struct LegacyTokenDbResponse {
} }
pub async fn app_token_auth( pub async fn app_token_auth(
pool: &sqlx::postgres::PgPool, pool: &sqlx::postgres::PgPool,
token: &str, token: &str,
) -> anyhow::Result<Option<Uuid>> { ) -> anyhow::Result<Option<Uuid>> {
let mut app: Vec<AppTokenDbResponse> = let mut app: Vec<AppTokenDbResponse> =
sqlx::query_as("select id from external_apps where api_rl_token = $1") sqlx::query_as("select id from external_apps where api_rl_token = $1")
.bind(token) .bind(token)
.fetch_all(pool) .fetch_all(pool)

View file

@ -9,7 +9,7 @@ fn pretty_print(ts: &proc_macro2::TokenStream) -> String {
pub fn macro_impl( pub fn macro_impl(
_args: proc_macro::TokenStream, _args: proc_macro::TokenStream,
input: proc_macro::TokenStream, input: proc_macro::TokenStream,
is_internal: bool, is_internal: bool,
) -> proc_macro::TokenStream { ) -> proc_macro::TokenStream {
let input = parse_macro_input!(input as ItemFn); let input = parse_macro_input!(input as ItemFn);
@ -35,15 +35,15 @@ pub fn macro_impl(
}) })
.collect(); .collect();
let internal_res = if is_internal { let internal_res = if is_internal {
quote! { quote! {
if !auth.internal() { if !auth.internal() {
return crate::error::FORBIDDEN_INTERNAL_ROUTE.into_response(); return crate::error::FORBIDDEN_INTERNAL_ROUTE.into_response();
} }
} }
} else { } else {
quote!() quote!()
}; };
let res = quote! { let res = quote! {
#[allow(unused_mut)] #[allow(unused_mut)]
@ -52,7 +52,7 @@ pub fn macro_impl(
#fn_body #fn_body
} }
#internal_res #internal_res
match inner(#(#pms),*).await { match inner(#(#pms),*).await {
Ok(res) => res.into_response(), Ok(res) => res.into_response(),
Err(err) => err.into_response(), Err(err) => err.into_response(),

View file

@ -11,7 +11,6 @@ jsonwebtoken = { workspace = true }
sea-query = "0.32.1" sea-query = "0.32.1"
serde = { workspace = true } serde = { workspace = true }
serde_json = { workspace = true, features = ["preserve_order"] } serde_json = { workspace = true, features = ["preserve_order"] }
# in theory we want to default-features = false for sqlx sqlx = { workspace = true }
# but cargo doesn't seem to support this
sqlx = { workspace = true, features = ["chrono"] }
uuid = { workspace = true } uuid = { workspace = true }
tracing = { workspace = true }

View file

@ -21,12 +21,12 @@ impl From<i32> for PrivacyLevel {
} }
impl PrivacyLevel { impl PrivacyLevel {
pub fn to_string(&self) -> String { pub fn to_string(&self) -> String {
match self { match self {
PrivacyLevel::Public => "public".into(), PrivacyLevel::Public => "public".into(),
PrivacyLevel::Private => "private".into(), PrivacyLevel::Private => "private".into(),
} }
} }
} }
macro_rules! model { macro_rules! model {