feat: premium service boilerplate

This commit is contained in:
alyssa 2025-12-23 00:45:45 -05:00
parent c4f820e114
commit f1471088d2
15 changed files with 912 additions and 104 deletions

250
Cargo.lock generated
View file

@ -138,6 +138,48 @@ version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50"
[[package]]
name = "askama"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f75363874b771be265f4ffe307ca705ef6f3baa19011c149da8674a87f1b75c4"
dependencies = [
"askama_derive",
"itoa",
"percent-encoding",
"serde",
"serde_json",
]
[[package]]
name = "askama_derive"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "129397200fe83088e8a68407a8e2b1f826cf0086b21ccdb866a722c8bcd3a94f"
dependencies = [
"askama_parser",
"basic-toml",
"memchr",
"proc-macro2",
"quote",
"rustc-hash 2.1.1",
"serde",
"serde_derive",
"syn",
]
[[package]]
name = "askama_parser"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d6ab5630b3d5eaf232620167977f95eb51f3432fc76852328774afbd242d4358"
dependencies = [
"memchr",
"serde",
"serde_derive",
"winnow",
]
[[package]]
name = "async-trait"
version = "0.1.88"
@ -324,6 +366,32 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum"
version = "0.8.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [
"axum-core 0.5.5",
"bytes",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"itoa",
"matchit 0.8.4",
"memchr",
"mime",
"percent-encoding",
"pin-project-lite",
"serde_core",
"sync_wrapper 1.0.2",
"tower 0.5.2",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-core"
version = "0.3.4"
@ -360,6 +428,48 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-core"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22"
dependencies = [
"bytes",
"futures-core",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"sync_wrapper 1.0.2",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "axum-extra"
version = "0.10.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9963ff19f40c6102c76756ef0a46004c0d58957d87259fc9208ff8441c12ab96"
dependencies = [
"axum 0.8.8",
"axum-core 0.5.5",
"bytes",
"cookie",
"futures-util",
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
"rustversion",
"serde_core",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "backtrace"
version = "0.3.74"
@ -399,6 +509,15 @@ version = "1.7.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3"
[[package]]
name = "basic-toml"
version = "0.1.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a"
dependencies = [
"serde",
]
[[package]]
name = "bindgen"
version = "0.69.5"
@ -638,6 +757,17 @@ dependencies = [
"unicode-segmentation",
]
[[package]]
name = "cookie"
version = "0.18.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747"
dependencies = [
"percent-encoding",
"time",
"version_check",
]
[[package]]
name = "cookie-factory"
version = "0.3.2"
@ -1574,6 +1704,12 @@ dependencies = [
"pin-project-lite",
]
[[package]]
name = "http-range-header"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c"
[[package]]
name = "httparse"
version = "1.10.1"
@ -2249,6 +2385,16 @@ version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a"
[[package]]
name = "mime_guess"
version = "2.0.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e"
dependencies = [
"mime",
"unicase",
]
[[package]]
name = "minimal-lexical"
version = "0.2.1"
@ -2634,6 +2780,24 @@ version = "1.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
[[package]]
name = "postmark"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "846751b682939565add1f69358a595fa6f3f7d4f1eb15d920b16478e0f981fe2"
dependencies = [
"async-trait",
"bytes",
"http 1.3.1",
"reqwest 0.12.15",
"serde",
"serde_json",
"thiserror 2.0.12",
"time",
"typed-builder",
"url",
]
[[package]]
name = "powerfmt"
version = "0.2.0"
@ -2649,6 +2813,39 @@ dependencies = [
"zerocopy 0.8.24",
]
[[package]]
name = "premium"
version = "0.1.0"
dependencies = [
"anyhow",
"api",
"askama",
"axum 0.8.4",
"axum-extra",
"chrono",
"fred",
"hex",
"lazy_static",
"libpk",
"metrics",
"pk_macros",
"pluralkit_models",
"postmark",
"rand 0.8.5",
"reqwest 0.12.15",
"sea-query",
"serde",
"serde_json",
"serde_urlencoded",
"sqlx",
"thiserror 1.0.69",
"time",
"tokio",
"tower-http",
"tracing",
"twilight-http",
]
[[package]]
name = "prettyplease"
version = "0.2.36"
@ -3548,10 +3745,11 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.219"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e"
dependencies = [
"serde_core",
"serde_derive",
]
@ -3566,10 +3764,19 @@ dependencies = [
]
[[package]]
name = "serde_derive"
version = "1.0.219"
name = "serde_core"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.228"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79"
dependencies = [
"proc-macro2",
"quote",
@ -4389,7 +4596,14 @@ dependencies = [
"http 1.3.1",
"http-body 1.0.1",
"http-body-util",
"http-range-header",
"httpdate",
"mime",
"mime_guess",
"percent-encoding",
"pin-project-lite",
"tokio",
"tokio-util",
"tower-layer",
"tower-service",
"tracing",
@ -4599,6 +4813,26 @@ dependencies = [
"twilight-model",
]
[[package]]
name = "typed-builder"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fef81aec2ca29576f9f6ae8755108640d0a86dd3161b2e8bca6cfa554e98f77d"
dependencies = [
"typed-builder-macro",
]
[[package]]
name = "typed-builder-macro"
version = "0.21.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ecb9ecf7799210407c14a8cfdfe0173365780968dc57973ed082211958e0b18"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "typenum"
version = "1.18.0"
@ -4620,6 +4854,12 @@ dependencies = [
"libc",
]
[[package]]
name = "unicase"
version = "2.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
[[package]]
name = "unicode-bidi"
version = "0.3.18"

View file

@ -6,7 +6,6 @@ resolver = "2"
[workspace.dependencies]
anyhow = "1"
axum-macros = "0.4.1"
bytes = "1.6.0"
chrono = "0.4"
fred = { version = "9.3.0", default-features = false, features = ["tracing", "i-keys", "i-hashes", "i-scripts", "sha-1"] }
@ -25,6 +24,9 @@ tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] }
uuid = { version = "1.7.0", features = ["serde"] }
axum = { git = "https://github.com/pluralkit/axum", branch = "v0.8.4-pluralkit" }
axum-macros = "0.4.1"
axum-extra = { version = "0.10", features = ["cookie"] }
tower-http = { version = "0.5.2", features = ["catch-panic", "fs"] }
twilight-gateway = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-7f08d95" }
twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-7f08d95", features = ["permission-calculator"] }
@ -37,3 +39,6 @@ twilight-http = { git = "https://github.com/pluralkit/twilight", branch = "plura
# twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] }
# twilight-model = { path = "../twilight/twilight-model" }
# twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-aws_lc_rs", "rustls-native-roots"] }
[patch.crates-io]
axum = { git = "https://github.com/pluralkit/axum", branch = "v0.8.4-pluralkit" }

View file

@ -19,6 +19,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
@ -27,6 +28,5 @@ hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
serde_urlencoded = "0.7.1"
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["catch-panic"] }
subtle = "2.6.1"
sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] }

View file

@ -93,6 +93,14 @@ macro_rules! fail {
pub(crate) use fail;
#[macro_export]
macro_rules! fail_html {
($($stuff:tt)+) => {{
tracing::error!($($stuff)+);
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response();
}};
}
macro_rules! define_error {
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
#[allow(dead_code)]

10
crates/api/src/lib.rs Normal file
View file

@ -0,0 +1,10 @@
mod auth;
pub mod error;
pub mod middleware;
pub mod util;
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::clients::RedisPool,
}

View file

@ -1,135 +1,95 @@
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use api::ApiContext;
use auth::AuthState;
use axum::{
Extension, Router,
body::Body,
extract::{Request as ExtractRequest, State},
extract::Request as ExtractRequest,
http::Uri,
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
};
use hyper_util::{
client::legacy::{Client, connect::HttpConnector},
rt::TokioExecutor,
};
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use libpk::config;
use tracing::{info, warn};
use pk_macros::api_endpoint;
use crate::proxyer::Proxyer;
mod auth;
mod endpoints;
mod error;
mod middleware;
mod proxyer;
mod util;
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::clients::RedisPool,
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
}
#[api_endpoint]
async fn rproxy(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Response {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(ctx.rproxy_client.request(req).await?.into_response())
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
fn router(ctx: ApiContext) -> Router {
fn router(ctx: ApiContext, proxyer: Proxyer) -> Router {
let rproxy = |Extension(auth): Extension<AuthState>, req: ExtractRequest<Body>| {
proxyer.rproxy(auth, req)
};
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/v2/systems/{system_id}", get(rproxy))
.route("/v2/systems/{system_id}", patch(rproxy))
.route("/v2/systems/{system_id}", get(rproxy.clone()))
.route("/v2/systems/{system_id}", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
.route("/v2/systems/{system_id}/settings", patch(rproxy))
.route("/v2/systems/{system_id}/settings", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/{member_id}", get(rproxy))
.route("/v2/members/{member_id}", patch(rproxy))
.route("/v2/members/{member_id}", delete(rproxy))
.route("/v2/systems/{system_id}/members", get(rproxy.clone()))
.route("/v2/members", post(rproxy.clone()))
.route("/v2/members/{member_id}", get(rproxy.clone()))
.route("/v2/members/{member_id}", patch(rproxy.clone()))
.route("/v2/members/{member_id}", delete(rproxy.clone()))
.route("/v2/systems/{system_id}/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/{group_id}", get(rproxy))
.route("/v2/groups/{group_id}", patch(rproxy))
.route("/v2/groups/{group_id}", delete(rproxy))
.route("/v2/systems/{system_id}/groups", get(rproxy.clone()))
.route("/v2/groups", post(rproxy.clone()))
.route("/v2/groups/{group_id}", get(rproxy.clone()))
.route("/v2/groups/{group_id}", patch(rproxy.clone()))
.route("/v2/groups/{group_id}", delete(rproxy.clone()))
.route("/v2/groups/{group_id}/members", get(rproxy))
.route("/v2/groups/{group_id}/members/add", post(rproxy))
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
.route("/v2/groups/{group_id}/members", get(rproxy.clone()))
.route("/v2/groups/{group_id}/members/add", post(rproxy.clone()))
.route("/v2/groups/{group_id}/members/remove", post(rproxy.clone()))
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy.clone()))
.route("/v2/members/{member_id}/groups", get(rproxy))
.route("/v2/members/{member_id}/groups/add", post(rproxy))
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
.route("/v2/members/{member_id}/groups", get(rproxy.clone()))
.route("/v2/members/{member_id}/groups/add", post(rproxy.clone()))
.route("/v2/members/{member_id}/groups/remove", post(rproxy.clone()))
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy.clone()))
.route("/v2/systems/{system_id}/switches", get(rproxy))
.route("/v2/systems/{system_id}/switches", post(rproxy))
.route("/v2/systems/{system_id}/fronters", get(rproxy))
.route("/v2/systems/{system_id}/switches", get(rproxy.clone()))
.route("/v2/systems/{system_id}/switches", post(rproxy.clone()))
.route("/v2/systems/{system_id}/fronters", get(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy.clone()))
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy.clone()))
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy.clone()))
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy.clone()))
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
.route("/v2/systems/{system_id}/autoproxy", get(rproxy.clone()))
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy.clone()))
.route("/v2/messages/{message_id}", get(rproxy))
.route("/v2/messages/{message_id}", get(rproxy.clone()))
.route("/v2/bulk", post(endpoints::bulk::bulk))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
.route("/private/discord/callback", post(rproxy))
.route("/private/bulk_privacy/member", post(rproxy.clone()))
.route("/private/bulk_privacy/group", post(rproxy.clone()))
.route("/private/discord/callback", post(rproxy.clone()))
.route("/private/discord/callback2", post(endpoints::private::discord_callback))
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/dash_views", post(endpoints::private::dash_views))
.route("/private/dash_view/{id}", get(endpoints::private::dash_view))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy.clone()))
.route("/v2/members/{member_id}/oembed.json", get(rproxy.clone()))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy.clone()))
.layer(axum::middleware::from_fn_with_state(
if config.api().use_ratelimiter {
@ -161,15 +121,14 @@ async fn main() -> anyhow::Result<()> {
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let ctx = ApiContext {
db,
redis,
let proxyer = Proxyer {
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
};
let app = router(ctx);
let ctx = ApiContext { db, redis };
let app = router(ctx, proxyer);
let addr: &str = libpk::config.api().addr.as_ref();

51
crates/api/src/proxyer.rs Normal file
View file

@ -0,0 +1,51 @@
use crate::{
auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
error::PKError,
};
use axum::{
body::Body,
extract::Request as ExtractRequest,
http::Uri,
response::{IntoResponse, Response},
};
use hyper_util::client::legacy::{Client, connect::HttpConnector};
#[derive(Clone)]
pub struct Proxyer {
pub rproxy_uri: String,
pub rproxy_client: Client<HttpConnector, Body>,
}
impl Proxyer {
pub async fn rproxy(
self,
auth: AuthState,
mut req: ExtractRequest<Body>,
) -> Result<Response, PKError> {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", self.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(self.rproxy_client.request(req).await?.into_response())
}
}

View file

@ -97,6 +97,13 @@ pub struct ScheduledTasksConfig {
pub prometheus_url: String,
}
#[derive(Deserialize, Clone, Debug)]
pub struct PremiumConfig {
pub postmark_token: String,
pub from_email: String,
pub base_url: String,
}
fn _metrics_default() -> bool {
false
}
@ -116,6 +123,8 @@ pub struct PKConfig {
avatars: Option<AvatarsConfig>,
#[serde(default)]
pub scheduled_tasks: Option<ScheduledTasksConfig>,
#[serde(default)]
premium: Option<PremiumConfig>,
#[serde(default = "_metrics_default")]
pub run_metrics_server: bool,
@ -147,6 +156,10 @@ impl PKConfig {
.as_ref()
.expect("missing avatar service config")
}
pub fn premium(&self) -> &PremiumConfig {
self.premium.as_ref().expect("missing premium config")
}
}
// todo: consider passing this down instead of making it global

35
crates/premium/Cargo.toml Normal file
View file

@ -0,0 +1,35 @@
[package]
name = "premium"
version = "0.1.0"
edition = "2024"
[dependencies]
pluralkit_models = { path = "../models" }
pk_macros = { path = "../macros" }
libpk = { path = "../libpk" }
api = { path = "../api" }
anyhow = { workspace = true }
axum = { workspace = true }
axum-extra = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
sea-query = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
askama = "0.14.0"
postmark = { version = "0.11", features = ["reqwest"] }
rand = "0.8"
thiserror = "1.0"
hex = "0.4"
chrono = { workspace = true }
serde_urlencoded = "0.7"
time = "0.3"

318
crates/premium/src/auth.rs Normal file
View file

@ -0,0 +1,318 @@
use api::{ApiContext, fail_html};
use askama::Template;
use axum::{
extract::{MatchedPath, Request, State},
http::header::SET_COOKIE,
middleware::Next,
response::{AppendHeaders, IntoResponse, Redirect, Response},
};
use axum_extra::extract::cookie::CookieJar;
use fred::{
prelude::{KeysInterface, LuaInterface},
util::sha1_hash,
};
use rand::{Rng, distributions::Alphanumeric};
use serde::{Deserialize, Serialize};
use crate::web::{render, message};
const LOGIN_TOKEN_TTL_SECS: i64 = 60 * 10;
const SESSION_LUA_SCRIPT: &str = r#"
local session_key = KEYS[1]
local ttl = ARGV[1]
local session_data = redis.call('GET', session_key)
if session_data then
redis.call('EXPIRE', session_key, ttl)
end
return session_data
"#;
const SESSION_TTL_SECS: i64 = 60 * 60 * 4;
lazy_static::lazy_static! {
static ref SESSION_LUA_SCRIPT_SHA: String = sha1_hash(SESSION_LUA_SCRIPT);
}
fn rand_token() -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect()
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AuthState {
pub email: String,
pub csrf_token: String,
pub session_id: String,
}
impl AuthState {
fn new(email: String) -> Self {
Self {
email,
csrf_token: rand_token(),
session_id: rand_token(),
}
}
async fn from_request(
headers: axum::http::HeaderMap,
ctx: &ApiContext,
) -> anyhow::Result<Option<Self>> {
let jar = CookieJar::from_headers(&headers);
let Some(session_cookie) = jar.get("pk-session") else {
return Ok(None);
};
let session_id = session_cookie.value();
let session_key = format!("premium:session:{}", session_id);
let script_exists: Vec<usize> = ctx
.redis
.script_exists(vec![SESSION_LUA_SCRIPT_SHA.to_string()])
.await?;
if script_exists[0] != 1 {
ctx.redis
.script_load::<String, String>(SESSION_LUA_SCRIPT.to_string())
.await?;
}
let session_data: Option<String> = ctx
.redis
.evalsha(
SESSION_LUA_SCRIPT_SHA.to_string(),
vec![session_key],
vec![SESSION_TTL_SECS],
)
.await?;
let Some(session_data) = session_data else {
return Ok(None);
};
let session: AuthState = serde_json::from_str(&session_data)?;
Ok(Some(session))
}
async fn save(&self, ctx: &ApiContext) -> anyhow::Result<()> {
let session_key = format!("premium:session:{}", self.session_id);
let session_data = serde_json::to_string(&self)?;
ctx.redis
.set::<(), _, _>(
session_key,
session_data,
Some(fred::types::Expiration::EX(SESSION_TTL_SECS)),
None,
false,
)
.await?;
Ok(())
}
async fn delete(&self, ctx: &ApiContext) -> anyhow::Result<()> {
let session_key = format!("premium:session:{}", self.session_id);
ctx.redis.del::<(), _>(session_key).await?;
Ok(())
}
}
fn refresh_session_cookie(session: &AuthState, mut response: Response) -> Response {
let cookie_value = format!(
"pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
session.session_id, SESSION_TTL_SECS
);
response
.headers_mut()
.insert(SET_COOKIE, cookie_value.parse().unwrap());
response
}
pub async fn middleware(
State(ctx): State<ApiContext>,
mut request: Request,
next: Next,
) -> Response {
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let session = match AuthState::from_request(request.headers().clone(), &ctx).await {
Ok(s) => s,
Err(err) => fail_html!(?err, "failed to fetch auth state from redis"),
};
if let Some(session) = session.clone() {
request.extensions_mut().insert(session);
}
match endpoint.as_str() {
"/" => {
if let Some(ref session) = session {
let response = next.run(request).await;
refresh_session_cookie(session, response)
} else {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: None,
});
}
}
"/login" => {
if let Some(ref session) = session {
// no session here because that shows the "you're logged in as" component
let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None));
return refresh_session_cookie(session, response);
} else {
let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await {
Ok(b) => b,
Err(err) => fail_html!(?err, "failed to read request body"),
};
let form: std::collections::HashMap<String, String> =
match serde_urlencoded::from_bytes(&body) {
Ok(f) => f,
Err(err) => fail_html!(?err, "failed to parse form data"),
};
let Some(email) = form.get("email") else {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some("email field is required".to_string()),
});
};
let email = email.trim().to_lowercase();
if email.is_empty() {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some("email field is required".to_string()),
});
}
let token = rand_token();
let token_key = format!("premium:login_token:{}", token);
if let Err(err) = ctx
.redis
.set::<(), _, _>(
token_key,
&email,
Some(fred::types::Expiration::EX(LOGIN_TOKEN_TTL_SECS)),
None,
false,
)
.await
{
fail_html!(?err, "failed to store login token in redis");
}
if let Err(err) = crate::mailer::login_token(email, token).await {
fail_html!(?err, "failed to send login email");
}
return render!(message(
"check your email for a login link! it will expire in 10 minutes.".to_string(),
None
));
}
}
"/login/{token}" => {
if let Some(ref session) = session {
// no session here because that shows the "you're logged in as" component
let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None));
return refresh_session_cookie(session, response);
}
let path = request.uri().path();
let token = path.strip_prefix("/login/").unwrap_or("");
if token.is_empty() {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some("invalid login link".to_string()),
});
}
let token_key = format!("premium:login_token:{}", token);
let email: Option<String> = match ctx.redis.get(&token_key).await {
Ok(e) => e,
Err(err) => fail_html!(?err, "failed to fetch login token from redis"),
};
let Some(email) = email else {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some(
"invalid or expired login link. please request a new one.".to_string()
),
});
};
if let Err(err) = ctx.redis.del::<(), _>(&token_key).await {
fail_html!(?err, "failed to delete login token from redis");
}
let session = AuthState::new(email);
if let Err(err) = session.save(&ctx).await {
fail_html!(?err, "failed to save session to redis");
}
let cookie_value = format!(
"pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
session.session_id, SESSION_TTL_SECS
);
(
AppendHeaders([(SET_COOKIE, cookie_value)]),
Redirect::to("/"),
)
.into_response()
}
"/logout" => {
let Some(session) = session else {
return Redirect::to("/").into_response();
};
let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await {
Ok(b) => b,
Err(err) => fail_html!(?err, "failed to read request body"),
};
let form: std::collections::HashMap<String, String> =
match serde_urlencoded::from_bytes(&body) {
Ok(f) => f,
Err(err) => fail_html!(?err, "failed to parse form data"),
};
let csrf_valid = form
.get("csrf_token")
.map(|t| t == &session.csrf_token)
.unwrap_or(false);
if !csrf_valid {
return (axum::http::StatusCode::FORBIDDEN, "invalid csrf token").into_response();
}
if let Err(err) = session.delete(&ctx).await {
fail_html!(?err, "failed to delete session from redis");
}
let cookie_value = "pk-session=; Path=/; HttpOnly; Max-Age=0";
(
AppendHeaders([(SET_COOKIE, cookie_value)]),
Redirect::to("/"),
)
.into_response()
}
_ => (axum::http::StatusCode::NOT_FOUND, "404 not found").into_response(),
}
}

View file

@ -0,0 +1,44 @@
use lazy_static::lazy_static;
use postmark::{
Query,
api::{Body, email::SendEmailRequest},
reqwest::PostmarkClient,
};
lazy_static! {
pub static ref CLIENT: PostmarkClient = {
PostmarkClient::builder()
.server_token(&libpk::config.premium().postmark_token)
.build()
};
}
const LOGIN_TEXT: &'static str = r#"Hello,
Someone (hopefully you) has requested a link to log in to the PluralKit Premium website.
Click here to log in: {link}
This link will expire in 10 minutes.
If you did not request this link, please ignore this message.
Thanks,
- PluralKit Team
"#;
pub async fn login_token(rcpt: String, token: String) -> anyhow::Result<()> {
SendEmailRequest::builder()
.from(&libpk::config.premium().from_email)
.to(rcpt)
.subject("[PluralKit Premium] Your login link")
.body(Body::text(LOGIN_TEXT.replace(
"{link}",
format!("{}/login/{token}", libpk::config.premium().base_url).as_str(),
)))
.build()
.execute(&(CLIENT.to_owned()))
.await?;
Ok(())
}

View file

@ -0,0 +1,63 @@
use askama::Template;
use axum::{
Extension, Router,
response::Html,
routing::{get, post},
};
use tower_http::{catch_panic::CatchPanicLayer, services::ServeDir};
use tracing::info;
use api::{ApiContext, middleware};
mod auth;
mod mailer;
mod web;
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
fn router(ctx: ApiContext) -> Router {
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/", get(|Extension(session): Extension<auth::AuthState>| async move {
Html(web::Index {
session: Some(session),
show_login_form: false,
message: None,
}.render().unwrap())
}))
.route("/login/{token}", get(|| async {
"handled in auth middleware"
}))
.route("/login", post(|| async {
"handled in auth middleware"
}))
.route("/logout", post(|| async {
"handled in auth middleware"
}))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), auth::middleware))
.layer(axum::middleware::from_fn(middleware::logger::logger))
.nest_service("/static", ServeDir::new("static"))
.layer(CatchPanicLayer::custom(api::util::handle_panic))
.with_state(ctx)
}
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let ctx = ApiContext { db, redis };
let app = router(ctx);
let addr: &str = libpk::config.api().addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

33
crates/premium/src/web.rs Normal file
View file

@ -0,0 +1,33 @@
use askama::Template;
use crate::auth::AuthState;
macro_rules! render {
($stuff:expr) => {{
let mut response = $stuff.render().unwrap().into_response();
let headers = response.headers_mut();
headers.insert(
"content-type",
axum::http::HeaderValue::from_static("text/html"),
);
response
}};
}
pub(crate) use render;
pub fn message(message: String, session: Option<AuthState>) -> Index {
Index {
session: session,
show_login_form: false,
message: Some(message)
}
}
#[derive(Template)]
#[template(path = "index.html")]
pub struct Index {
pub session: Option<AuthState>,
pub show_login_form: bool,
pub message: Option<String>,
}

View file

View file

@ -0,0 +1,29 @@
<!DOCTYPE html>
<head>
<title>PluralKit Premium</title>
<link rel="stylesheet" href="/static/stylesheet.css" />
</head>
<body>
<h2>PluralKit Premium</h2>
{% if let Some(session) = session %}
<form action="/logout" method="post">
<input type="hidden" name="csrf_token" value="{{ session.csrf_token }}" />
<p>logged in as <strong>{{ session.email }}.</strong></p>
<button type="submit">log out</button>
</form>
{% endif %}
{% if show_login_form %}
<p>Enter your email address to log in.</p>
<form method="POST" action="/login">
<input type="email" name="email" placeholder="you@example.com" required />
<button type="submit">Send</button>
</form>
{% endif %}
{% if let Some(msg) = message %}
<div>{{ msg }}</div>
{% endif %}
</body>