mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-13 09:10:14 +00:00
feat(api): pull SP avatars
This commit is contained in:
parent
63d9b411ae
commit
6c0c7a5c99
10 changed files with 153 additions and 58 deletions
|
|
@ -6,4 +6,5 @@ public class ApiConfig
|
||||||
public string? ClientId { get; set; }
|
public string? ClientId { get; set; }
|
||||||
public string? ClientSecret { get; set; }
|
public string? ClientSecret { get; set; }
|
||||||
public bool TrustAuth { get; set; } = false;
|
public bool TrustAuth { get; set; } = false;
|
||||||
|
public string? AvatarServiceUrl { get; set; }
|
||||||
}
|
}
|
||||||
|
|
@ -21,6 +21,12 @@ public class AuthorizationTokenHandlerMiddleware
|
||||||
&& int.TryParse(sidHeaders[0], out var systemId))
|
&& int.TryParse(sidHeaders[0], out var systemId))
|
||||||
ctx.Items.Add("SystemId", new SystemId(systemId));
|
ctx.Items.Add("SystemId", new SystemId(systemId));
|
||||||
|
|
||||||
|
if (cfg.TrustAuth
|
||||||
|
&& ctx.Request.Headers.TryGetValue("X-PluralKit-AppId", out var aidHeaders)
|
||||||
|
&& aidHeaders.Count > 0
|
||||||
|
&& int.TryParse(aidHeaders[0], out var appId))
|
||||||
|
ctx.Items.Add("AppId", appId);
|
||||||
|
|
||||||
await _next.Invoke(ctx);
|
await _next.Invoke(ctx);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1,3 +1,7 @@
|
||||||
|
using System.Net;
|
||||||
|
using System.Net.Http;
|
||||||
|
using System.Net.Http.Json;
|
||||||
|
|
||||||
using Microsoft.AspNetCore.Mvc;
|
using Microsoft.AspNetCore.Mvc;
|
||||||
|
|
||||||
using Newtonsoft.Json.Linq;
|
using Newtonsoft.Json.Linq;
|
||||||
|
|
@ -50,6 +54,9 @@ public class MemberControllerV2: PKControllerBase
|
||||||
if (patch.Errors.Count > 0)
|
if (patch.Errors.Count > 0)
|
||||||
throw new ModelParseError(patch.Errors);
|
throw new ModelParseError(patch.Errors);
|
||||||
|
|
||||||
|
if (patch.AvatarUrl.Value != null)
|
||||||
|
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
|
||||||
|
|
||||||
using var conn = await _db.Obtain();
|
using var conn = await _db.Obtain();
|
||||||
using var tx = await conn.BeginTransactionAsync();
|
using var tx = await conn.BeginTransactionAsync();
|
||||||
|
|
||||||
|
|
@ -110,6 +117,9 @@ public class MemberControllerV2: PKControllerBase
|
||||||
if (patch.Errors.Count > 0)
|
if (patch.Errors.Count > 0)
|
||||||
throw new ModelParseError(patch.Errors);
|
throw new ModelParseError(patch.Errors);
|
||||||
|
|
||||||
|
if (patch.AvatarUrl.Value != null)
|
||||||
|
patch.AvatarUrl = await TryUploadAvatar(patch.AvatarUrl.Value, system);
|
||||||
|
|
||||||
var newMember = await _repo.UpdateMember(member.Id, patch);
|
var newMember = await _repo.UpdateMember(member.Id, patch);
|
||||||
return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid));
|
return Ok(newMember.ToJson(LookupContext.ByOwner, systemStr: system.Hid));
|
||||||
}
|
}
|
||||||
|
|
@ -129,4 +139,28 @@ public class MemberControllerV2: PKControllerBase
|
||||||
|
|
||||||
return NoContent();
|
return NoContent();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async Task<string> TryUploadAvatar(string avatarUrl, PKSystem system)
|
||||||
|
{
|
||||||
|
if (!avatarUrl.StartsWith("https://serve.apparyllis.com/")) return avatarUrl;
|
||||||
|
if (_config.AvatarServiceUrl == null) return avatarUrl;
|
||||||
|
if (!HttpContext.Items.TryGetValue("AppId", out var appId) || (int)appId != 1) return avatarUrl;
|
||||||
|
|
||||||
|
using var client = new HttpClient();
|
||||||
|
var response = await client.PostAsJsonAsync(_config.AvatarServiceUrl + "/pull",
|
||||||
|
new { url = avatarUrl, kind = "avatar", uploaded_by = (string)null, system_id = system.Uuid.ToString() });
|
||||||
|
|
||||||
|
if (response.StatusCode != HttpStatusCode.OK)
|
||||||
|
{
|
||||||
|
var error = await response.Content.ReadFromJsonAsync<ErrorResponse>();
|
||||||
|
throw new PKError(500, 0, $"Error uploading image to CDN: {error.Error}");
|
||||||
|
}
|
||||||
|
|
||||||
|
var success = await response.Content.ReadFromJsonAsync<SuccessResponse>();
|
||||||
|
return success.Url;
|
||||||
|
}
|
||||||
|
|
||||||
|
public record ErrorResponse(string Error);
|
||||||
|
|
||||||
|
public record SuccessResponse(string Url, bool New);
|
||||||
}
|
}
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
#![feature(let_chains)]
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
extract::{Request as ExtractRequest, State},
|
extract::{Request as ExtractRequest, State},
|
||||||
|
|
|
||||||
|
|
@ -1,45 +1,90 @@
|
||||||
use axum::{
|
use axum::{
|
||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
http::HeaderValue,
|
http::StatusCode,
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::Response,
|
response::Response,
|
||||||
};
|
};
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::ApiContext;
|
use crate::{util::json_err, ApiContext};
|
||||||
|
|
||||||
use super::logger::DID_AUTHENTICATE_HEADER;
|
pub const INTERNAL_SYSTEMID_HEADER: &'static str = "x-pluralkit-systemid";
|
||||||
|
pub const INTERNAL_APPID_HEADER: &'static str = "x-pluralkit-appid";
|
||||||
|
|
||||||
|
// todo: auth should pass down models in request context
|
||||||
|
// not numerical ids in headers
|
||||||
|
|
||||||
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
pub async fn authnz(State(ctx): State<ApiContext>, mut request: Request, next: Next) -> Response {
|
||||||
let headers = request.headers_mut();
|
let headers = request.headers_mut();
|
||||||
headers.remove("x-pluralkit-systemid");
|
|
||||||
let auth_header = headers
|
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||||
|
headers.remove(INTERNAL_APPID_HEADER);
|
||||||
|
|
||||||
|
let mut authed_system_id: Option<i32> = None;
|
||||||
|
let mut authed_app_id: Option<i32> = None;
|
||||||
|
|
||||||
|
// fetch user authorization
|
||||||
|
if let Some(system_auth_header) = headers
|
||||||
.get("authorization")
|
.get("authorization")
|
||||||
.map(|h| h.to_str().ok())
|
.map(|h| h.to_str().ok())
|
||||||
.flatten();
|
.flatten()
|
||||||
let mut authenticated = false;
|
&& let Some(system_id) =
|
||||||
if let Some(auth_header) = auth_header {
|
match libpk::db::repository::legacy_token_auth(&ctx.db, system_auth_header).await {
|
||||||
if let Some(system_id) =
|
|
||||||
match libpk::db::repository::legacy_token_auth(&ctx.db, auth_header).await {
|
|
||||||
Ok(val) => val,
|
Ok(val) => val,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
error!(?err, "failed to query authorization token in postgres");
|
error!(?err, "failed to query authorization token in postgres");
|
||||||
None
|
return json_err(
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
r#"{"message": "500: Internal Server Error", "code": 0}"#.to_string(),
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
headers.append(
|
authed_system_id = Some(system_id);
|
||||||
"x-pluralkit-systemid",
|
}
|
||||||
HeaderValue::from_str(format!("{system_id}").as_str()).unwrap(),
|
|
||||||
);
|
// fetch app authorization
|
||||||
authenticated = true;
|
// todo: actually fetch it from db
|
||||||
|
if let Some(app_auth_header) = headers
|
||||||
|
.get("x-pluralkit-app")
|
||||||
|
.map(|h| h.to_str().ok())
|
||||||
|
.flatten()
|
||||||
|
&& let Some(config_token2) = libpk::config
|
||||||
|
.api
|
||||||
|
.as_ref()
|
||||||
|
.expect("missing api config")
|
||||||
|
.temp_token2
|
||||||
|
.as_ref()
|
||||||
|
// this is NOT how you validate tokens
|
||||||
|
// but this is low abuse risk so we're keeping it for now
|
||||||
|
&& app_auth_header == config_token2
|
||||||
|
{
|
||||||
|
authed_app_id = Some(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// add headers for ratelimiter / dotnet-api
|
||||||
|
{
|
||||||
|
let headers = request.headers_mut();
|
||||||
|
if let Some(sid) = authed_system_id {
|
||||||
|
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||||
|
}
|
||||||
|
if let Some(aid) = authed_app_id {
|
||||||
|
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut response = next.run(request).await;
|
let mut response = next.run(request).await;
|
||||||
if authenticated {
|
|
||||||
response
|
// add headers for logger module (ugh)
|
||||||
.headers_mut()
|
{
|
||||||
.insert(DID_AUTHENTICATE_HEADER, HeaderValue::from_static("1"));
|
let headers = response.headers_mut();
|
||||||
|
if let Some(sid) = authed_system_id {
|
||||||
|
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||||
|
}
|
||||||
|
if let Some(aid) = authed_app_id {
|
||||||
|
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
response
|
response
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -4,14 +4,15 @@ use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::R
|
||||||
use metrics::{counter, histogram};
|
use metrics::{counter, histogram};
|
||||||
use tracing::{info, span, warn, Instrument, Level};
|
use tracing::{info, span, warn, Instrument, Level};
|
||||||
|
|
||||||
use crate::util::header_or_unknown;
|
use crate::{
|
||||||
|
middleware::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
|
||||||
|
util::header_or_unknown,
|
||||||
|
};
|
||||||
|
|
||||||
// log any requests that take longer than 2 seconds
|
// log any requests that take longer than 2 seconds
|
||||||
// todo: change as necessary
|
// todo: change as necessary
|
||||||
const MIN_LOG_TIME: u128 = 2_000;
|
const MIN_LOG_TIME: u128 = 2_000;
|
||||||
|
|
||||||
pub const DID_AUTHENTICATE_HEADER: &'static str = "x-pluralkit-didauthenticate";
|
|
||||||
|
|
||||||
pub async fn logger(request: Request, next: Next) -> Response {
|
pub async fn logger(request: Request, next: Next) -> Response {
|
||||||
let method = request.method().clone();
|
let method = request.method().clone();
|
||||||
|
|
||||||
|
|
@ -40,14 +41,20 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
||||||
let mut response = next.run(request).instrument(request_span).await;
|
let mut response = next.run(request).instrument(request_span).await;
|
||||||
let elapsed = start.elapsed().as_millis();
|
let elapsed = start.elapsed().as_millis();
|
||||||
|
|
||||||
let authenticated = {
|
let (system_id, app_id) = {
|
||||||
let headers = response.headers_mut();
|
let headers = response.headers_mut();
|
||||||
if headers.contains_key(DID_AUTHENTICATE_HEADER) {
|
(
|
||||||
headers.remove(DID_AUTHENTICATE_HEADER);
|
headers
|
||||||
true
|
.remove(INTERNAL_SYSTEMID_HEADER)
|
||||||
} else {
|
.map(|h| h.to_str().ok().map(|v| v.to_string()))
|
||||||
false
|
.flatten()
|
||||||
}
|
.unwrap_or("none".to_string()),
|
||||||
|
headers
|
||||||
|
.remove(INTERNAL_APPID_HEADER)
|
||||||
|
.map(|h| h.to_str().ok().map(|v| v.to_string()))
|
||||||
|
.flatten()
|
||||||
|
.unwrap_or("none".to_string()),
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
counter!(
|
counter!(
|
||||||
|
|
@ -55,7 +62,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
||||||
"method" => method.to_string(),
|
"method" => method.to_string(),
|
||||||
"endpoint" => endpoint.clone(),
|
"endpoint" => endpoint.clone(),
|
||||||
"status" => response.status().to_string(),
|
"status" => response.status().to_string(),
|
||||||
"authenticated" => authenticated.to_string(),
|
"system_id" => system_id.to_string(),
|
||||||
|
"app_id" => app_id.to_string(),
|
||||||
)
|
)
|
||||||
.increment(1);
|
.increment(1);
|
||||||
histogram!(
|
histogram!(
|
||||||
|
|
@ -63,7 +71,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
||||||
"method" => method.to_string(),
|
"method" => method.to_string(),
|
||||||
"endpoint" => endpoint.clone(),
|
"endpoint" => endpoint.clone(),
|
||||||
"status" => response.status().to_string(),
|
"status" => response.status().to_string(),
|
||||||
"authenticated" => authenticated.to_string(),
|
"system_id" => system_id.to_string(),
|
||||||
|
"app_id" => app_id.to_string(),
|
||||||
)
|
)
|
||||||
.record(elapsed as f64 / 1_000_f64);
|
.record(elapsed as f64 / 1_000_f64);
|
||||||
|
|
||||||
|
|
@ -81,7 +90,8 @@ pub async fn logger(request: Request, next: Next) -> Response {
|
||||||
"method" => method.to_string(),
|
"method" => method.to_string(),
|
||||||
"endpoint" => endpoint.clone(),
|
"endpoint" => endpoint.clone(),
|
||||||
"status" => response.status().to_string(),
|
"status" => response.status().to_string(),
|
||||||
"authenticated" => authenticated.to_string(),
|
"system_id" => system_id.to_string(),
|
||||||
|
"app_id" => app_id.to_string(),
|
||||||
)
|
)
|
||||||
.increment(1);
|
.increment(1);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
use crate::util::{header_or_unknown, json_err};
|
use crate::util::{header_or_unknown, json_err};
|
||||||
|
|
||||||
|
use super::authnz::{INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||||
|
|
||||||
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
const LUA_SCRIPT: &str = include_str!("ratelimit.lua");
|
||||||
|
|
||||||
lazy_static::lazy_static! {
|
lazy_static::lazy_static! {
|
||||||
|
|
@ -103,28 +105,8 @@ pub async fn do_request_ratelimited(
|
||||||
if let Some(redis) = redis {
|
if let Some(redis) = redis {
|
||||||
let headers = request.headers().clone();
|
let headers = request.headers().clone();
|
||||||
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
let source_ip = header_or_unknown(headers.get("X-PluralKit-Client-IP"));
|
||||||
let authenticated_system_id = header_or_unknown(headers.get("x-pluralkit-systemid"));
|
let authenticated_system_id = header_or_unknown(headers.get(INTERNAL_SYSTEMID_HEADER));
|
||||||
|
let authenticated_app_id = header_or_unknown(headers.get(INTERNAL_APPID_HEADER));
|
||||||
// https://github.com/rust-lang/rust/issues/53667
|
|
||||||
let is_temp_token2 = if let Some(header) = request.headers().clone().get("X-PluralKit-App")
|
|
||||||
{
|
|
||||||
if let Some(token2) = &libpk::config
|
|
||||||
.api
|
|
||||||
.as_ref()
|
|
||||||
.expect("missing api config")
|
|
||||||
.temp_token2
|
|
||||||
{
|
|
||||||
if header.to_str().unwrap_or("invalid") == token2 {
|
|
||||||
true
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
false
|
|
||||||
};
|
|
||||||
|
|
||||||
let endpoint = request
|
let endpoint = request
|
||||||
.extensions()
|
.extensions()
|
||||||
|
|
@ -133,7 +115,13 @@ pub async fn do_request_ratelimited(
|
||||||
.map(|v| v.as_str().to_string())
|
.map(|v| v.as_str().to_string())
|
||||||
.unwrap_or("unknown".to_string());
|
.unwrap_or("unknown".to_string());
|
||||||
|
|
||||||
let rlimit = if is_temp_token2 {
|
// looks like this chooses the tokens/sec by app_id or endpoint
|
||||||
|
// then chooses the key by system_id or source_ip
|
||||||
|
// todo: key should probably be chosen by app_id when it's present
|
||||||
|
// todo: make x-ratelimit-scope actually meaningful
|
||||||
|
|
||||||
|
// hack: for now, we only have one "registered app", so we hardcode the app id
|
||||||
|
let rlimit = if authenticated_app_id == "1" {
|
||||||
RatelimitType::TempCustom
|
RatelimitType::TempCustom
|
||||||
} else if endpoint == "/v2/messages/:message_id" {
|
} else if endpoint == "/v2/messages/:message_id" {
|
||||||
RatelimitType::Message
|
RatelimitType::Message
|
||||||
|
|
|
||||||
|
|
@ -56,6 +56,7 @@ pub fn handle_panic(err: Box<dyn std::any::Any + Send + 'static>) -> axum::respo
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// todo: make 500 not duplicated
|
||||||
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
pub fn json_err(code: StatusCode, text: String) -> axum::response::Response {
|
||||||
let mut response = (code, text).into_response();
|
let mut response = (code, text).into_response();
|
||||||
let headers = response.headers_mut();
|
let headers = response.headers_mut();
|
||||||
|
|
|
||||||
|
|
@ -93,7 +93,7 @@ async fn pull(
|
||||||
) -> Result<Json<PullResponse>, PKAvatarError> {
|
) -> Result<Json<PullResponse>, PKAvatarError> {
|
||||||
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
let parsed = pull::parse_url(&req.url) // parsing beforehand to "normalize"
|
||||||
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
.map_err(|_| PKAvatarError::InvalidCdnUrl)?;
|
||||||
if !req.force {
|
if !(req.force || req.url.contains("https://serve.apparyllis.com/")) {
|
||||||
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
if let Some(existing) = db::get_by_attachment_id(&state.pool, parsed.attachment_id).await? {
|
||||||
// remove any pending image cleanup
|
// remove any pending image cleanup
|
||||||
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
|
db::remove_deletion_queue(&state.pool, parsed.attachment_id).await?;
|
||||||
|
|
|
||||||
|
|
@ -137,6 +137,14 @@ pub fn parse_url(url: &str) -> anyhow::Result<ParsedUrl> {
|
||||||
|
|
||||||
match (url.scheme(), url.domain()) {
|
match (url.scheme(), url.domain()) {
|
||||||
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
|
("https", Some("media.discordapp.net" | "cdn.discordapp.com")) => {}
|
||||||
|
("https", Some("serve.apparyllis.com")) => {
|
||||||
|
return Ok(ParsedUrl {
|
||||||
|
channel_id: 0,
|
||||||
|
attachment_id: 0,
|
||||||
|
filename: "".to_string(),
|
||||||
|
full_url: url.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
_ => anyhow::bail!("not a discord cdn url"),
|
_ => anyhow::bail!("not a discord cdn url"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue