feat(api): pull SP avatars

This commit is contained in:
alyssa 2025-04-26 12:03:00 +00:00
parent 63d9b411ae
commit 6c0c7a5c99
10 changed files with 153 additions and 58 deletions

View file

@ -6,4 +6,5 @@ public class ApiConfig
public string? ClientId { get; set; }
public string? ClientSecret { get; set; }
public bool TrustAuth { get; set; } = false;
public string? AvatarServiceUrl { get; set; }
}

View file

@ -21,6 +21,12 @@ public class AuthorizationTokenHandlerMiddleware
&& int.TryParse(sidHeaders[0], out var systemId))
ctx.Items.Add("SystemId", new SystemId(systemId));
if (cfg.TrustAuth
&& ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders)
&& aidHeaders.Count > 0
&& int.TryParse(aidHeaders[0], out var appId))
ctx.Items.Add("AppId", appId);
await _next.Invoke(ctx);
}
}

View file

@ -1,3 +1,7 @@
using System.Net;
using System.Net.Http;
using System.Net.Http.Json;
using Microsoft.AspNetCore.Mvc;
using Newtonsoft.Json.Linq;
@ -50,6 +54,9 @@ public class MemberControllerV2: PKControllerBase
if (patch.Errors.Count > 0)
throw new ModelParseError(patch.Errors);
if (patch.AvatarUrl.Value != null)
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
using var conn = await _db.Obtain();
using var tx = await conn.BeginTransactionAsync();
@ -110,6 +117,9 @@ public class MemberControllerV2: PKControllerBase
if (patch.Errors.Count > 0)
throw new ModelParseError(patch.Errors);
if (patch.AvatarUrl.Value != null)
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
var newMember = await _repo.UpdateMember(member.Id, patch);
return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid));
}
@ -129,4 +139,28 @@ public class MemberControllerV2: PKControllerBase
return NoContent();
}
private async Task<string> TryUploadAvatar(string avatarUrl, PKSystem system)
{
if (!avatarUrl.StartsWith("https://serve.apparyllis.com/")) return avatarUrl;
if (_config.AvatarServiceUrl == null) return avatarUrl;
if (!HttpContext.Items.TryGetValue("AppId", out var appId) || (int)appId != 1) return avatarUrl;
using var client = new HttpClient();
var response = await client.PostAsJsonAsync(_config.AvatarServiceUrl + "/pull",
new { url = avatarUrl, kind = "avatar", uploaded_by = (string)null, system_id = system.Uuid.ToString() });
if (response.StatusCode != HttpStatusCode.OK)
{
var error = await response.Content.ReadFromJsonAsync<ErrorResponse>();
throw new PKError(500, 0, $"Error uploading image to CDN: {error.Error}");
}
var success = await response.Content.ReadFromJsonAsync<SuccessResponse>();
return success.Url;
}
public record ErrorResponse(string Error);
public record SuccessResponse(string Url, bool New);
}

View file

@ -1,3 +1,5 @@
#![feature(let_chains)]
use axum::{
body::Body,
extract::{Request as ExtractRequest, State},

View file

@ -1,45 +1,90 @@
use axum::{
extract::{Request, State},
http::HeaderValue,
http::StatusCode,
middleware::Next,
response::Response,
};
use tracing::error;
use crate::ApiContext;
use crate::{util::json_err, ApiContext};
use super::logger::DID_AUTHENTICATE_HEADER;
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
// todo: auth should pass down models in request context
// not numerical ids in headers
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
let headers = request.headers_mut();
headers.remove("x-pluralkit-systemid");
let auth_header = headers
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
let mut authed_system_id: Option<i32> = None;
let mut authed_app_id: Option<i32> = None;
// fetch user authorization
if let Some(system_auth_header) = headers
.get("authorization")
.map(|h| h.to_str().ok())
.flatten();
let mut authenticated = false;
if let Some(auth_header) = auth_header {
if let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
.flatten()
&& let Some(system_id) =
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
Ok(val) => val,
Err(err) => {
error!(?err, "failed to query authorization token in postgres");
None
return json_err(
StatusCode::INTERNAL_SERVER_ERROR,
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
);
}
}
{
headers.append(
"x-pluralkit-systemid",
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
);
authenticated = true;
{
authed_system_id = Some(system_id);
}
// fetch app authorization
// todo: actually fetch it from db
if let Some(app_auth_header) = headers
.get("x-pluralkit-app")
.map(|h| h.to_str().ok())
.flatten()
&& let Some(config_token2) = libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
.as_ref()
// this is NOT how you validate tokens
// but this is low abuse risk so we're keeping it for now
&& app_auth_header == config_token2
{
authed_app_id = Some(1);
}
// add headers for ratelimiter / dotnet-api
{
let headers = request.headers_mut();
if let Some(sid) = authed_system_id {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = authed_app_id {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
}
let mut response = next.run(request).await;
if authenticated {
response
.headers_mut()
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
// add headers for logger module (ugh)
{
let headers = response.headers_mut();
if let Some(sid) = authed_system_id {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = authed_app_id {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
}
response
}

View file

@ -4,14 +4,15 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
use metrics::{counter, histogram};
use tracing::{info, span, warn, Instrument, Level};
use crate::util::header_or_unknown;
use crate::{
middleware::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
util::header_or_unknown,
};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
@ -40,14 +41,20 @@ pub async fn logger(request: Request, next: Next) -> Response {
let mut response = next.run(request).instrument(request_span).await;
let elapsed = start.elapsed().as_millis();
let authenticated = {
let (system_id, app_id) = {
let headers = response.headers_mut();
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
headers.remove(DID_AUTHENTICATE_HEADER);
true
} else {
false
}
(
headers
.remove(INTERNAL_SYSTEMID_HEADER)
.map(|h| h.to_str().ok().map(|v| v.to_string()))
.flatten()
.unwrap_or("none".to_string()),
headers
.remove(INTERNAL_APPID_HEADER)
.map(|h| h.to_str().ok().map(|v| v.to_string()))
.flatten()
.unwrap_or("none".to_string()),
)
};
counter!(
@ -55,7 +62,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.increment(1);
histogram!(
@ -63,7 +71,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.record(elapsed as f64 / 1_000_f64);
@ -81,7 +90,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
"method" => method.to_string(),
"endpoint" => endpoint.clone(),
"status" => response.status().to_string(),
"authenticated" => authenticated.to_string(),
"system_id" => system_id.to_string(),
"app_id" => app_id.to_string(),
)
.increment(1);

View file

@ -12,6 +12,8 @@ use tracing::{debug, error, info, warn};
use crate::util::{header_or_unknown, json_err};
use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
lazy_static::lazy_static! {
@ -103,28 +105,8 @@ pub async fn do_request_ratelimited(
if let Some(redis) = redis {
let headers = request.headers().clone();
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
// https://github.com/rust-lang/rust/issues/53667
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
{
if let Some(token2) = &libpk::config
.api
.as_ref()
.expect("missing api config")
.temp_token2
{
if header.to_str().unwrap_or("invalid") == token2 {
true
} else {
false
}
} else {
false
}
} else {
false
};
let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER));
let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER));
let endpoint = request
.extensions()
@ -133,7 +115,13 @@ pub async fn do_request_ratelimited(
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let rlimit = if is_temp_token2 {
// looks like this chooses the tokens/sec by app_id or endpoint
// then chooses the key by system_id or source_ip
// todo: key should probably be chosen by app_id when it's present
// todo: make x-ratelimit-scope actually meaningful
// hack: for now, we only have one "registered app", so we hardcode the app id
let rlimit = if authenticated_app_id == "1" {
RatelimitType::TempCustom
} else if endpoint == "/v2/messages/:message_id" {
RatelimitType::Message

View file

@ -56,6 +56,7 @@ pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::respo
)
}
// todo: make 500 not duplicated
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
let mut response = (code, text).into_response();
let headers = response.headers_mut();

View file

@ -93,7 +93,7 @@ async fn pull(
) -> Result<Json<PullResponse>, PKAvatarError> {
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
if !req.force {
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
// remove any pending image cleanup
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;

View file

@ -137,6 +137,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
match (url.scheme(), url.domain()) {
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
("https", Some("serve.apparyllis.com")) => {
return Ok(ParsedUrl {
channel_id: 0,
attachment_id: 0,
filename: "".to_string(),
full_url: url.to_string(),
})
}
_ => anyhow::bail!("not a discord cdn url"),
}