From 6c0c7a5c9918c106bef1ccc09b4ee434ca830c1b Mon Sep 17 00:00:00 2001 From: alyssa Date: Sat, 26 Apr 2025 12:03:00 +0000 Subject: [PATCH] feat(api): pull SP avatars --- PluralKit.API/ApiConfig.cs | 1 + .../AuthorizationTokenHandlerMiddleware.cs | 6 ++ .../Controllers/v2/MemberControllerV2.cs | 34 ++++++++ crates/api/src/main.rs | 2 + crates/api/src/middleware/authnz.rs | 87 ++++++++++++++----- crates/api/src/middleware/logger.rs | 36 +++++--- crates/api/src/middleware/ratelimit.rs | 34 +++----- crates/api/src/util.rs | 1 + crates/avatars/src/main.rs | 2 +- crates/avatars/src/pull.rs | 8 ++ 10 files changed, 153 insertions(+), 58 deletions(-) diff --git a/PluralKit.API/ApiConfig.cs b/PluralKit.API/ApiConfig.cs index fc34d515..46556a79 100644 --- a/PluralKit.API/ApiConfig.cs +++ b/PluralKit.API/ApiConfig.cs @@ -6,4 +6,5 @@ public class ApiConfig public string? ClientId { get; set; } public string? ClientSecret { get; set; } public bool TrustAuth { get; set; } = false; + public string? AvatarServiceUrl { get; set; } } \ No newline at end of file diff --git a/PluralKit.API/AuthorizationTokenHandlerMiddleware.cs b/PluralKit.API/AuthorizationTokenHandlerMiddleware.cs index a09c869e..5f1c4011 100644 --- a/PluralKit.API/AuthorizationTokenHandlerMiddleware.cs +++ b/PluralKit.API/AuthorizationTokenHandlerMiddleware.cs @@ -21,6 +21,12 @@ public class AuthorizationTokenHandlerMiddleware && int.TryParse(sidHeaders[0], out var systemId)) ctx.Items.Add("SystemId", new SystemId(systemId)); + if (cfg.TrustAuth + && ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders) + && aidHeaders.Count > 0 + && int.TryParse(aidHeaders[0], out var appId)) + ctx.Items.Add("AppId", appId); + await _next.Invoke(ctx); } } \ No newline at end of file diff --git a/PluralKit.API/Controllers/v2/MemberControllerV2.cs b/PluralKit.API/Controllers/v2/MemberControllerV2.cs index 25163dff..6c37fa81 100644 --- a/PluralKit.API/Controllers/v2/MemberControllerV2.cs +++ b/PluralKit.API/Controllers/v2/MemberControllerV2.cs @@ -1,3 +1,7 @@ +using System.Net; +using System.Net.Http; +using System.Net.Http.Json; + using Microsoft.AspNetCore.Mvc; using Newtonsoft.Json.Linq; @@ -50,6 +54,9 @@ public class MemberControllerV2: PKControllerBase if (patch.Errors.Count > 0) throw new ModelParseError(patch.Errors); + if (patch.AvatarUrl.Value != null) + patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system); + using var conn = await _db.Obtain(); using var tx = await conn.BeginTransactionAsync(); @@ -110,6 +117,9 @@ public class MemberControllerV2: PKControllerBase if (patch.Errors.Count > 0) throw new ModelParseError(patch.Errors); + if (patch.AvatarUrl.Value != null) + patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system); + var newMember = await _repo.UpdateMember(member.Id, patch); return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid)); } @@ -129,4 +139,28 @@ public class MemberControllerV2: PKControllerBase return NoContent(); } + + private async Task TryUploadAvatar(string avatarUrl, PKSystem system) + { + if (!avatarUrl.StartsWith("https://serve.apparyllis.com/")) return avatarUrl; + if (_config.AvatarServiceUrl == null) return avatarUrl; + if (!HttpContext.Items.TryGetValue("AppId", out var appId) || (int)appId != 1) return avatarUrl; + + using var client = new HttpClient(); + var response = await client.PostAsJsonAsync(_config.AvatarServiceUrl + "/pull", + new { url = avatarUrl, kind = "avatar", uploaded_by = (string)null, system_id = system.Uuid.ToString() }); + + if (response.StatusCode != HttpStatusCode.OK) + { + var error = await response.Content.ReadFromJsonAsync(); + throw new PKError(500, 0, $"Error uploading image to CDN: {error.Error}"); + } + + var success = await response.Content.ReadFromJsonAsync(); + return success.Url; + } + + public record ErrorResponse(string Error); + + public record SuccessResponse(string Url, bool New); } \ No newline at end of file diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index bf07ff8f..7e23a22d 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -1,3 +1,5 @@ +#![feature(let_chains)] + use axum::{ body::Body, extract::{Request as ExtractRequest, State}, diff --git a/crates/api/src/middleware/authnz.rs b/crates/api/src/middleware/authnz.rs index 4544e6bf..47140767 100644 --- a/crates/api/src/middleware/authnz.rs +++ b/crates/api/src/middleware/authnz.rs @@ -1,45 +1,90 @@ use axum::{ extract::{Request, State}, - http::HeaderValue, + http::StatusCode, middleware::Next, response::Response, }; use tracing::error; -use crate::ApiContext; +use crate::{util::json_err, ApiContext}; -use super::logger::DID_AUTHENTICATE_HEADER; +pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid"; +pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid"; + +// todo: auth should pass down models in request context +// not numerical ids in headers pub async fn authnz(State(ctx): State, mut request: Request, next: Next) -> Response { let headers = request.headers_mut(); - headers.remove("x-pluralkit-systemid"); - let auth_header = headers + + headers.remove(INTERNAL_SYSTEMID_HEADER); + headers.remove(INTERNAL_APPID_HEADER); + + let mut authed_system_id: Option = None; + let mut authed_app_id: Option = None; + + // fetch user authorization + if let Some(system_auth_header) = headers .get("authorization") .map(|h| h.to_str().ok()) - .flatten(); - let mut authenticated = false; - if let Some(auth_header) = auth_header { - if let Some(system_id) = - match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await { + .flatten() + && let Some(system_id) = + match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await { Ok(val) => val, Err(err) => { error!(?err, "failed to query authorization token in postgres"); - None + return json_err( + StatusCode::INTERNAL_SERVER_ERROR, + r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(), + ); } } - { - headers.append( - "x-pluralkit-systemid", - HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(), - ); - authenticated = true; + { + authed_system_id = Some(system_id); + } + + // fetch app authorization + // todo: actually fetch it from db + if let Some(app_auth_header) = headers + .get("x-pluralkit-app") + .map(|h| h.to_str().ok()) + .flatten() + && let Some(config_token2) = libpk::config + .api + .as_ref() + .expect("missing api config") + .temp_token2 + .as_ref() + // this is NOT how you validate tokens + // but this is low abuse risk so we're keeping it for now + && app_auth_header == config_token2 + { + authed_app_id = Some(1); + } + + // add headers for ratelimiter / dotnet-api + { + let headers = request.headers_mut(); + if let Some(sid) = authed_system_id { + headers.append(INTERNAL_SYSTEMID_HEADER, sid.into()); + } + if let Some(aid) = authed_app_id { + headers.append(INTERNAL_APPID_HEADER, aid.into()); } } + let mut response = next.run(request).await; - if authenticated { - response - .headers_mut() - .insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1")); + + // add headers for logger module (ugh) + { + let headers = response.headers_mut(); + if let Some(sid) = authed_system_id { + headers.append(INTERNAL_SYSTEMID_HEADER, sid.into()); + } + if let Some(aid) = authed_app_id { + headers.append(INTERNAL_APPID_HEADER, aid.into()); + } } + response } diff --git a/crates/api/src/middleware/logger.rs b/crates/api/src/middleware/logger.rs index 020de2e2..8f239042 100644 --- a/crates/api/src/middleware/logger.rs +++ b/crates/api/src/middleware/logger.rs @@ -4,14 +4,15 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R use metrics::{counter, histogram}; use tracing::{info, span, warn, Instrument, Level}; -use crate::util::header_or_unknown; +use crate::{ + middleware::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}, + util::header_or_unknown, +}; // log any requests that take longer than 2 seconds // todo: change as necessary const MIN_LOG_TIME: u128 = 2_000; -pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate"; - pub async fn logger(request: Request, next: Next) -> Response { let method = request.method().clone(); @@ -40,14 +41,20 @@ pub async fn logger(request: Request, next: Next) -> Response { let mut response = next.run(request).instrument(request_span).await; let elapsed = start.elapsed().as_millis(); - let authenticated = { + let (system_id, app_id) = { let headers = response.headers_mut(); - if headers.contains_key(DID_AUTHENTICATE_HEADER) { - headers.remove(DID_AUTHENTICATE_HEADER); - true - } else { - false - } + ( + headers + .remove(INTERNAL_SYSTEMID_HEADER) + .map(|h| h.to_str().ok().map(|v| v.to_string())) + .flatten() + .unwrap_or("none".to_string()), + headers + .remove(INTERNAL_APPID_HEADER) + .map(|h| h.to_str().ok().map(|v| v.to_string())) + .flatten() + .unwrap_or("none".to_string()), + ) }; counter!( @@ -55,7 +62,8 @@ pub async fn logger(request: Request, next: Next) -> Response { "method" => method.to_string(), "endpoint" => endpoint.clone(), "status" => response.status().to_string(), - "authenticated" => authenticated.to_string(), + "system_id" => system_id.to_string(), + "app_id" => app_id.to_string(), ) .increment(1); histogram!( @@ -63,7 +71,8 @@ pub async fn logger(request: Request, next: Next) -> Response { "method" => method.to_string(), "endpoint" => endpoint.clone(), "status" => response.status().to_string(), - "authenticated" => authenticated.to_string(), + "system_id" => system_id.to_string(), + "app_id" => app_id.to_string(), ) .record(elapsed as f64 / 1_000_f64); @@ -81,7 +90,8 @@ pub async fn logger(request: Request, next: Next) -> Response { "method" => method.to_string(), "endpoint" => endpoint.clone(), "status" => response.status().to_string(), - "authenticated" => authenticated.to_string(), + "system_id" => system_id.to_string(), + "app_id" => app_id.to_string(), ) .increment(1); diff --git a/crates/api/src/middleware/ratelimit.rs b/crates/api/src/middleware/ratelimit.rs index e7bb1dd0..3c4a6be4 100644 --- a/crates/api/src/middleware/ratelimit.rs +++ b/crates/api/src/middleware/ratelimit.rs @@ -12,6 +12,8 @@ use tracing::{debug, error, info, warn}; use crate::util::{header_or_unknown, json_err}; +use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}; + const LUA_SCRIPT: &str = include_str!("ratelimit.lua"); lazy_static::lazy_static! { @@ -103,28 +105,8 @@ pub async fn do_request_ratelimited( if let Some(redis) = redis { let headers = request.headers().clone(); let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP")); - let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid")); - - // https://github.com/rust-lang/rust/issues/53667 - let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App") - { - if let Some(token2) = &libpk::config - .api - .as_ref() - .expect("missing api config") - .temp_token2 - { - if header.to_str().unwrap_or("invalid") == token2 { - true - } else { - false - } - } else { - false - } - } else { - false - }; + let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER)); + let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER)); let endpoint = request .extensions() @@ -133,7 +115,13 @@ pub async fn do_request_ratelimited( .map(|v| v.as_str().to_string()) .unwrap_or("unknown".to_string()); - let rlimit = if is_temp_token2 { + // looks like this chooses the tokens/sec by app_id or endpoint + // then chooses the key by system_id or source_ip + // todo: key should probably be chosen by app_id when it's present + // todo: make x-ratelimit-scope actually meaningful + + // hack: for now, we only have one "registered app", so we hardcode the app id + let rlimit = if authenticated_app_id == "1" { RatelimitType::TempCustom } else if endpoint == "/v2/messages/:message_id" { RatelimitType::Message diff --git a/crates/api/src/util.rs b/crates/api/src/util.rs index 03121659..0abb1337 100644 --- a/crates/api/src/util.rs +++ b/crates/api/src/util.rs @@ -56,6 +56,7 @@ pub fn handle_panic(err: Box) -> axum::respo ) } +// todo: make 500 not duplicated pub fn json_err(code: StatusCode, text: String) -> axum::response::Response { let mut response = (code, text).into_response(); let headers = response.headers_mut(); diff --git a/crates/avatars/src/main.rs b/crates/avatars/src/main.rs index 3b621f52..b6f7a2c7 100644 --- a/crates/avatars/src/main.rs +++ b/crates/avatars/src/main.rs @@ -93,7 +93,7 @@ async fn pull( ) -> Result, PKAvatarError> { let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize" .map_err(|_| PKAvatarError::InvalidCdnUrl)?; - if !req.force { + if !(req.force || req.url.contains("https://serve.apparyllis.com/")) { if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? { // remove any pending image cleanup db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?; diff --git a/crates/avatars/src/pull.rs b/crates/avatars/src/pull.rs index 8b9064d0..3ffb24dc 100644 --- a/crates/avatars/src/pull.rs +++ b/crates/avatars/src/pull.rs @@ -137,6 +137,14 @@ pub fn parse_url(url: &str) -> anyhow::Result { match (url.scheme(), url.domain()) { ("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {} + ("https", Some("serve.apparyllis.com")) => { + return Ok(ParsedUrl { + channel_id: 0, + attachment_id: 0, + filename: "".to_string(), + full_url: url.to_string(), + }) + } _ => anyhow::bail!("not a discord cdn url"), }