mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 13:06:50 +00:00
feat(api): pull SP avatars
This commit is contained in:
parent
63d9b411ae
commit
6c0c7a5c99
10 changed files with 153 additions and 58 deletions
|
|
@ -6,4 +6,5 @@ public class ApiConfig
|
|||
public string? ClientId { get; set; }
|
||||
public string? ClientSecret { get; set; }
|
||||
public bool TrustAuth { get; set; } = false;
|
||||
public string? AvatarServiceUrl { get; set; }
|
||||
}
|
||||
|
|
@ -21,6 +21,12 @@ public class AuthorizationTokenHandlerMiddleware
|
|||
&& int.TryParse(sidHeaders[0], out var systemId))
|
||||
ctx.Items.Add("SystemId", new SystemId(systemId));
|
||||
|
||||
if (cfg.TrustAuth
|
||||
&& ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders)
|
||||
&& aidHeaders.Count > 0
|
||||
&& int.TryParse(aidHeaders[0], out var appId))
|
||||
ctx.Items.Add("AppId", appId);
|
||||
|
||||
await _next.Invoke(ctx);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,3 +1,7 @@
|
|||
using System.Net;
|
||||
using System.Net.Http;
|
||||
using System.Net.Http.Json;
|
||||
|
||||
using Microsoft.AspNetCore.Mvc;
|
||||
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
|
@ -50,6 +54,9 @@ public class MemberControllerV2: PKControllerBase
|
|||
if (patch.Errors.Count > 0)
|
||||
throw new ModelParseError(patch.Errors);
|
||||
|
||||
if (patch.AvatarUrl.Value != null)
|
||||
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
|
||||
|
||||
using var conn = await _db.Obtain();
|
||||
using var tx = await conn.BeginTransactionAsync();
|
||||
|
||||
|
|
@ -110,6 +117,9 @@ public class MemberControllerV2: PKControllerBase
|
|||
if (patch.Errors.Count > 0)
|
||||
throw new ModelParseError(patch.Errors);
|
||||
|
||||
if (patch.AvatarUrl.Value != null)
|
||||
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
|
||||
|
||||
var newMember = await _repo.UpdateMember(member.Id, patch);
|
||||
return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid));
|
||||
}
|
||||
|
|
@ -129,4 +139,28 @@ public class MemberControllerV2: PKControllerBase
|
|||
|
||||
return NoContent();
|
||||
}
|
||||
|
||||
private async Task<string> TryUploadAvatar(string avatarUrl, PKSystem system)
|
||||
{
|
||||
if (!avatarUrl.StartsWith("https://serve.apparyllis.com/")) return avatarUrl;
|
||||
if (_config.AvatarServiceUrl == null) return avatarUrl;
|
||||
if (!HttpContext.Items.TryGetValue("AppId", out var appId) || (int)appId != 1) return avatarUrl;
|
||||
|
||||
using var client = new HttpClient();
|
||||
var response = await client.PostAsJsonAsync(_config.AvatarServiceUrl + "/pull",
|
||||
new { url = avatarUrl, kind = "avatar", uploaded_by = (string)null, system_id = system.Uuid.ToString() });
|
||||
|
||||
if (response.StatusCode != HttpStatusCode.OK)
|
||||
{
|
||||
var error = await response.Content.ReadFromJsonAsync<ErrorResponse>();
|
||||
throw new PKError(500, 0, $"Error uploading image to CDN: {error.Error}");
|
||||
}
|
||||
|
||||
var success = await response.Content.ReadFromJsonAsync<SuccessResponse>();
|
||||
return success.Url;
|
||||
}
|
||||
|
||||
public record ErrorResponse(string Error);
|
||||
|
||||
public record SuccessResponse(string Url, bool New);
|
||||
}
|
||||
|
|
@ -1,3 +1,5 @@
|
|||
#![feature(let_chains)]
|
||||
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
|
|
|
|||
|
|
@ -1,45 +1,90 @@
|
|||
use axum::{
|
||||
extract::{Request, State},
|
||||
http::HeaderValue,
|
||||
http::StatusCode,
|
||||
middleware::Next,
|
||||
response::Response,
|
||||
};
|
||||
use tracing::error;
|
||||
|
||||
use crate::ApiContext;
|
||||
use crate::{util::json_err, ApiContext};
|
||||
|
||||
use super::logger::DID_AUTHENTICATE_HEADER;
|
||||
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||
|
||||
// todo: auth should pass down models in request context
|
||||
// not numerical ids in headers
|
||||
|
||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||
let headers = request.headers_mut();
|
||||
headers.remove("x-pluralkit-systemid");
|
||||
let auth_header = headers
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
let mut authed_system_id: Option<i32> = None;
|
||||
let mut authed_app_id: Option<i32> = None;
|
||||
|
||||
// fetch user authorization
|
||||
if let Some(system_auth_header) = headers
|
||||
.get("authorization")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten();
|
||||
let mut authenticated = false;
|
||||
if let Some(auth_header) = auth_header {
|
||||
if let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
|
||||
.flatten()
|
||||
&& let Some(system_id) =
|
||||
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
|
||||
Ok(val) => val,
|
||||
Err(err) => {
|
||||
error!(?err, "failed to query authorization token in postgres");
|
||||
None
|
||||
return json_err(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
{
|
||||
headers.append(
|
||||
"x-pluralkit-systemid",
|
||||
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
|
||||
);
|
||||
authenticated = true;
|
||||
{
|
||||
authed_system_id = Some(system_id);
|
||||
}
|
||||
|
||||
// fetch app authorization
|
||||
// todo: actually fetch it from db
|
||||
if let Some(app_auth_header) = headers
|
||||
.get("x-pluralkit-app")
|
||||
.map(|h| h.to_str().ok())
|
||||
.flatten()
|
||||
&& let Some(config_token2) = libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
.as_ref()
|
||||
// this is NOT how you validate tokens
|
||||
// but this is low abuse risk so we're keeping it for now
|
||||
&& app_auth_header == config_token2
|
||||
{
|
||||
authed_app_id = Some(1);
|
||||
}
|
||||
|
||||
// add headers for ratelimiter / dotnet-api
|
||||
{
|
||||
let headers = request.headers_mut();
|
||||
if let Some(sid) = authed_system_id {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
if let Some(aid) = authed_app_id {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
}
|
||||
|
||||
let mut response = next.run(request).await;
|
||||
if authenticated {
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
|
||||
|
||||
// add headers for logger module (ugh)
|
||||
{
|
||||
let headers = response.headers_mut();
|
||||
if let Some(sid) = authed_system_id {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
if let Some(aid) = authed_app_id {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
|
|
|
|||
|
|
@ -4,14 +4,15 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
|
|||
use metrics::{counter, histogram};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
use crate::util::header_or_unknown;
|
||||
use crate::{
|
||||
middleware::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
|
||||
util::header_or_unknown,
|
||||
};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
|
|
@ -40,14 +41,20 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
let mut response = next.run(request).instrument(request_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
let authenticated = {
|
||||
let (system_id, app_id) = {
|
||||
let headers = response.headers_mut();
|
||||
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
|
||||
headers.remove(DID_AUTHENTICATE_HEADER);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
(
|
||||
headers
|
||||
.remove(INTERNAL_SYSTEMID_HEADER)
|
||||
.map(|h| h.to_str().ok().map(|v| v.to_string()))
|
||||
.flatten()
|
||||
.unwrap_or("none".to_string()),
|
||||
headers
|
||||
.remove(INTERNAL_APPID_HEADER)
|
||||
.map(|h| h.to_str().ok().map(|v| v.to_string()))
|
||||
.flatten()
|
||||
.unwrap_or("none".to_string()),
|
||||
)
|
||||
};
|
||||
|
||||
counter!(
|
||||
|
|
@ -55,7 +62,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
histogram!(
|
||||
|
|
@ -63,7 +71,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.record(elapsed as f64 / 1_000_f64);
|
||||
|
||||
|
|
@ -81,7 +90,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
|||
"method" => method.to_string(),
|
||||
"endpoint" => endpoint.clone(),
|
||||
"status" => response.status().to_string(),
|
||||
"authenticated" => authenticated.to_string(),
|
||||
"system_id" => system_id.to_string(),
|
||||
"app_id" => app_id.to_string(),
|
||||
)
|
||||
.increment(1);
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ use tracing::{debug, error, info, warn};
|
|||
|
||||
use crate::util::{header_or_unknown, json_err};
|
||||
|
||||
use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
|
||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
|
|
@ -103,28 +105,8 @@ pub async fn do_request_ratelimited(
|
|||
if let Some(redis) = redis {
|
||||
let headers = request.headers().clone();
|
||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
|
||||
|
||||
// https://github.com/rust-lang/rust/issues/53667
|
||||
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
||||
{
|
||||
if let Some(token2) = &libpk::config
|
||||
.api
|
||||
.as_ref()
|
||||
.expect("missing api config")
|
||||
.temp_token2
|
||||
{
|
||||
if header.to_str().unwrap_or("invalid") == token2 {
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
false
|
||||
};
|
||||
let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER));
|
||||
let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER));
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
|
|
@ -133,7 +115,13 @@ pub async fn do_request_ratelimited(
|
|||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let rlimit = if is_temp_token2 {
|
||||
// looks like this chooses the tokens/sec by app_id or endpoint
|
||||
// then chooses the key by system_id or source_ip
|
||||
// todo: key should probably be chosen by app_id when it's present
|
||||
// todo: make x-ratelimit-scope actually meaningful
|
||||
|
||||
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||
let rlimit = if authenticated_app_id == "1" {
|
||||
RatelimitType::TempCustom
|
||||
} else if endpoint == "/v2/messages/:message_id" {
|
||||
RatelimitType::Message
|
||||
|
|
|
|||
|
|
@ -56,6 +56,7 @@ pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::respo
|
|||
)
|
||||
}
|
||||
|
||||
// todo: make 500 not duplicated
|
||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||
let mut response = (code, text).into_response();
|
||||
let headers = response.headers_mut();
|
||||
|
|
|
|||
|
|
@ -93,7 +93,7 @@ async fn pull(
|
|||
) -> Result<Json<PullResponse>, PKAvatarError> {
|
||||
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
||||
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||
if !req.force {
|
||||
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
|
||||
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||
// remove any pending image cleanup
|
||||
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
|
||||
|
|
|
|||
|
|
@ -137,6 +137,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
|
|||
|
||||
match (url.scheme(), url.domain()) {
|
||||
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
|
||||
("https", Some("serve.apparyllis.com")) => {
|
||||
return Ok(ParsedUrl {
|
||||
channel_id: 0,
|
||||
attachment_id: 0,
|
||||
filename: "".to_string(),
|
||||
full_url: url.to_string(),
|
||||
})
|
||||
}
|
||||
_ => anyhow::bail!("not a discord cdn url"),
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue