From 8401c464c124661260a4b93c6242722c7cd61d92 Mon Sep 17 00:00:00 2001 From: alyssa Date: Tue, 23 Dec 2025 00:45:45 -0500 Subject: [PATCH] feat: premium service boilerplate --- Cargo.lock | 250 ++++++++++++++++++++- Cargo.toml | 7 +- crates/api/Cargo.toml | 2 +- crates/api/src/error.rs | 8 + crates/api/src/lib.rs | 10 + crates/api/src/main.rs | 153 +++++-------- crates/api/src/proxyer.rs | 51 +++++ crates/libpk/src/_config.rs | 13 ++ crates/premium/Cargo.toml | 35 +++ crates/premium/src/auth.rs | 318 +++++++++++++++++++++++++++ crates/premium/src/mailer.rs | 44 ++++ crates/premium/src/main.rs | 63 ++++++ crates/premium/src/web.rs | 33 +++ crates/premium/static/stylesheet.css | 0 crates/premium/templates/index.html | 29 +++ 15 files changed, 912 insertions(+), 104 deletions(-) create mode 100644 crates/api/src/lib.rs create mode 100644 crates/api/src/proxyer.rs create mode 100644 crates/premium/Cargo.toml create mode 100644 crates/premium/src/auth.rs create mode 100644 crates/premium/src/mailer.rs create mode 100644 crates/premium/src/main.rs create mode 100644 crates/premium/src/web.rs create mode 100644 crates/premium/static/stylesheet.css create mode 100644 crates/premium/templates/index.html diff --git a/Cargo.lock b/Cargo.lock index 79f65d9f..851a8c8c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -138,6 +138,48 @@ version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +[[package]] +name = "askama" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f75363874b771be265f4ffe307ca705ef6f3baa19011c149da8674a87f1b75c4" +dependencies = [ + "askama_derive", + "itoa", + "percent-encoding", + "serde", + "serde_json", +] + +[[package]] +name = "askama_derive" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "129397200fe83088e8a68407a8e2b1f826cf0086b21ccdb866a722c8bcd3a94f" +dependencies = [ + "askama_parser", + "basic-toml", + "memchr", + "proc-macro2", + "quote", + "rustc-hash 2.1.1", + "serde", + "serde_derive", + "syn", +] + +[[package]] +name = "askama_parser" +version = "0.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d6ab5630b3d5eaf232620167977f95eb51f3432fc76852328774afbd242d4358" +dependencies = [ + "memchr", + "serde", + "serde_derive", + "winnow", +] + [[package]] name = "async-trait" version = "0.1.88" @@ -324,6 +366,32 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum" +version = "0.8.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" +dependencies = [ + "axum-core 0.5.5", + "bytes", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "itoa", + "matchit 0.8.4", + "memchr", + "mime", + "percent-encoding", + "pin-project-lite", + "serde_core", + "sync_wrapper 1.0.2", + "tower 0.5.2", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "axum-core" version = "0.3.4" @@ -360,6 +428,48 @@ dependencies = [ "tracing", ] +[[package]] +name = "axum-core" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59446ce19cd142f8833f856eb31f3eb097812d1479ab224f54d72428ca21ea22" +dependencies = [ + "bytes", + "futures-core", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "sync_wrapper 1.0.2", + "tower-layer", + "tower-service", + "tracing", +] + +[[package]] +name = "axum-extra" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9963ff19f40c6102c76756ef0a46004c0d58957d87259fc9208ff8441c12ab96" +dependencies = [ + "axum 0.8.8", + "axum-core 0.5.5", + "bytes", + "cookie", + "futures-util", + "http 1.3.1", + "http-body 1.0.1", + "http-body-util", + "mime", + "pin-project-lite", + "rustversion", + "serde_core", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "backtrace" version = "0.3.74" @@ -399,6 +509,15 @@ version = "1.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "89e25b6adfb930f02d1981565a6e5d9c547ac15a96606256d3b59040e5cd4ca3" +[[package]] +name = "basic-toml" +version = "0.1.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba62675e8242a4c4e806d12f11d136e626e6c8361d6b829310732241652a178a" +dependencies = [ + "serde", +] + [[package]] name = "bindgen" version = "0.69.5" @@ -638,6 +757,17 @@ dependencies = [ "unicode-segmentation", ] +[[package]] +name = "cookie" +version = "0.18.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" +dependencies = [ + "percent-encoding", + "time", + "version_check", +] + [[package]] name = "cookie-factory" version = "0.3.2" @@ -1574,6 +1704,12 @@ dependencies = [ "pin-project-lite", ] +[[package]] +name = "http-range-header" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9171a2ea8a68358193d15dd5d70c1c10a2afc3e7e4c5bc92bc9f025cebd7359c" + [[package]] name = "httparse" version = "1.10.1" @@ -2249,6 +2385,16 @@ version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6877bb514081ee2a7ff5ef9de3281f14a4dd4bceac4c09388074a6b5df8a139a" +[[package]] +name = "mime_guess" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f7c44f8e672c00fe5308fa235f821cb4198414e1c77935c1ab6948d3fd78550e" +dependencies = [ + "mime", + "unicase", +] + [[package]] name = "minimal-lexical" version = "0.2.1" @@ -2634,6 +2780,24 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e" +[[package]] +name = "postmark" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "846751b682939565add1f69358a595fa6f3f7d4f1eb15d920b16478e0f981fe2" +dependencies = [ + "async-trait", + "bytes", + "http 1.3.1", + "reqwest 0.12.15", + "serde", + "serde_json", + "thiserror 2.0.12", + "time", + "typed-builder", + "url", +] + [[package]] name = "powerfmt" version = "0.2.0" @@ -2649,6 +2813,39 @@ dependencies = [ "zerocopy 0.8.24", ] +[[package]] +name = "premium" +version = "0.1.0" +dependencies = [ + "anyhow", + "api", + "askama", + "axum 0.8.4", + "axum-extra", + "chrono", + "fred", + "hex", + "lazy_static", + "libpk", + "metrics", + "pk_macros", + "pluralkit_models", + "postmark", + "rand 0.8.5", + "reqwest 0.12.15", + "sea-query", + "serde", + "serde_json", + "serde_urlencoded", + "sqlx", + "thiserror 1.0.69", + "time", + "tokio", + "tower-http", + "tracing", + "twilight-http", +] + [[package]] name = "prettyplease" version = "0.2.36" @@ -3548,10 +3745,11 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.219" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6" +checksum = "9a8e94ea7f378bd32cbbd37198a4a91436180c5bb472411e48b5ec2e2124ae9e" dependencies = [ + "serde_core", "serde_derive", ] @@ -3566,10 +3764,19 @@ dependencies = [ ] [[package]] -name = "serde_derive" -version = "1.0.219" +name = "serde_core" +version = "1.0.228" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00" +checksum = "41d385c7d4ca58e59fc732af25c3983b67ac852c1a25000afe1175de458b67ad" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.228" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d540f220d3187173da220f885ab66608367b6574e925011a9353e4badda91d79" dependencies = [ "proc-macro2", "quote", @@ -4389,7 +4596,14 @@ dependencies = [ "http 1.3.1", "http-body 1.0.1", "http-body-util", + "http-range-header", + "httpdate", + "mime", + "mime_guess", + "percent-encoding", "pin-project-lite", + "tokio", + "tokio-util", "tower-layer", "tower-service", "tracing", @@ -4599,6 +4813,26 @@ dependencies = [ "twilight-model", ] +[[package]] +name = "typed-builder" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fef81aec2ca29576f9f6ae8755108640d0a86dd3161b2e8bca6cfa554e98f77d" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ecb9ecf7799210407c14a8cfdfe0173365780968dc57973ed082211958e0b18" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "typenum" version = "1.18.0" @@ -4620,6 +4854,12 @@ dependencies = [ "libc", ] +[[package]] +name = "unicase" +version = "2.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539" + [[package]] name = "unicode-bidi" version = "0.3.18" diff --git a/Cargo.toml b/Cargo.toml index 3707ce77..aa5ec504 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,7 +6,6 @@ resolver = "2" [workspace.dependencies] anyhow = "1" -axum-macros = "0.4.1" bytes = "1.6.0" chrono = "0.4" fred = { version = "9.3.0", default-features = false, features = ["tracing", "i-keys", "i-hashes", "i-scripts", "sha-1"] } @@ -25,6 +24,9 @@ tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } uuid = { version = "1.7.0", features = ["serde"] } axum = { git = "https://github.com/pluralkit/axum", branch = "v0.8.4-pluralkit" } +axum-macros = "0.4.1" +axum-extra = { version = "0.10", features = ["cookie"] } +tower-http = { version = "0.5.2", features = ["catch-panic", "fs"] } twilight-gateway = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-7f08d95" } twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", branch = "pluralkit-7f08d95", features = ["permission-calculator"] } @@ -37,3 +39,6 @@ twilight-http = { git = "https://github.com/pluralkit/twilight", branch = "plura # twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] } # twilight-model = { path = "../twilight/twilight-model" } # twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-aws_lc_rs", "rustls-native-roots"] } + +[patch.crates-io] +axum = { git = "https://github.com/pluralkit/axum", branch = "v0.8.4-pluralkit" } diff --git a/crates/api/Cargo.toml b/crates/api/Cargo.toml index 16c54a7f..3d37fa2b 100644 --- a/crates/api/Cargo.toml +++ b/crates/api/Cargo.toml @@ -19,6 +19,7 @@ serde = { workspace = true } serde_json = { workspace = true } sqlx = { workspace = true } tokio = { workspace = true } +tower-http = { workspace = true } tracing = { workspace = true } twilight-http = { workspace = true } @@ -27,6 +28,5 @@ hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1" reverse-proxy-service = { version = "0.2.1", features = ["axum"] } serde_urlencoded = "0.7.1" tower = "0.4.13" -tower-http = { version = "0.5.2", features = ["catch-panic"] } subtle = "2.6.1" sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] } diff --git a/crates/api/src/error.rs b/crates/api/src/error.rs index ae7e5a99..f13dbff2 100644 --- a/crates/api/src/error.rs +++ b/crates/api/src/error.rs @@ -93,6 +93,14 @@ macro_rules! fail { pub(crate) use fail; +#[macro_export] +macro_rules! fail_html { + ($($stuff:tt)+) => {{ + tracing::error!($($stuff)+); + return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response(); + }}; +} + macro_rules! define_error { ( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => { #[allow(dead_code)] diff --git a/crates/api/src/lib.rs b/crates/api/src/lib.rs new file mode 100644 index 00000000..5b01ba19 --- /dev/null +++ b/crates/api/src/lib.rs @@ -0,0 +1,10 @@ +mod auth; +pub mod error; +pub mod middleware; +pub mod util; + +#[derive(Clone)] +pub struct ApiContext { + pub db: sqlx::postgres::PgPool, + pub redis: fred::clients::RedisPool, +} diff --git a/crates/api/src/main.rs b/crates/api/src/main.rs index b69a304c..0f234d69 100644 --- a/crates/api/src/main.rs +++ b/crates/api/src/main.rs @@ -1,135 +1,95 @@ -use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}; +use api::ApiContext; +use auth::AuthState; use axum::{ Extension, Router, body::Body, - extract::{Request as ExtractRequest, State}, + extract::Request as ExtractRequest, http::Uri, - response::{IntoResponse, Response}, routing::{delete, get, patch, post}, }; -use hyper_util::{ - client::legacy::{Client, connect::HttpConnector}, - rt::TokioExecutor, -}; +use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use libpk::config; use tracing::{info, warn}; -use pk_macros::api_endpoint; +use crate::proxyer::Proxyer; mod auth; mod endpoints; mod error; mod middleware; +mod proxyer; mod util; -#[derive(Clone)] -pub struct ApiContext { - pub db: sqlx::postgres::PgPool, - pub redis: fred::clients::RedisPool, - - rproxy_uri: String, - rproxy_client: Client, -} - -#[api_endpoint] -async fn rproxy( - Extension(auth): Extension, - State(ctx): State, - mut req: ExtractRequest, -) -> Response { - let path = req.uri().path(); - let path_query = req - .uri() - .path_and_query() - .map(|v| v.as_str()) - .unwrap_or(path); - - let uri = format!("{}{}", ctx.rproxy_uri, path_query); - - *req.uri_mut() = Uri::try_from(uri).unwrap(); - - let headers = req.headers_mut(); - - headers.remove(INTERNAL_SYSTEMID_HEADER); - headers.remove(INTERNAL_APPID_HEADER); - - if let Some(sid) = auth.system_id() { - headers.append(INTERNAL_SYSTEMID_HEADER, sid.into()); - } - - if let Some(aid) = auth.app_id() { - headers.append(INTERNAL_APPID_HEADER, aid.into()); - } - - Ok(ctx.rproxy_client.request(req).await?.into_response()) -} - // this function is manually formatted for easier legibility of route_services #[rustfmt::skip] -fn router(ctx: ApiContext) -> Router { +fn router(ctx: ApiContext, proxyer: Proxyer) -> Router { + let rproxy = |Extension(auth): Extension, req: ExtractRequest| { + proxyer.rproxy(auth, req) + }; + // processed upside down (???) so we have to put middleware at the end Router::new() - .route("/v2/systems/{system_id}", get(rproxy)) - .route("/v2/systems/{system_id}", patch(rproxy)) + .route("/v2/systems/{system_id}", get(rproxy.clone())) + .route("/v2/systems/{system_id}", patch(rproxy.clone())) .route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings)) - .route("/v2/systems/{system_id}/settings", patch(rproxy)) + .route("/v2/systems/{system_id}/settings", patch(rproxy.clone())) - .route("/v2/systems/{system_id}/members", get(rproxy)) - .route("/v2/members", post(rproxy)) - .route("/v2/members/{member_id}", get(rproxy)) - .route("/v2/members/{member_id}", patch(rproxy)) - .route("/v2/members/{member_id}", delete(rproxy)) + .route("/v2/systems/{system_id}/members", get(rproxy.clone())) + .route("/v2/members", post(rproxy.clone())) + .route("/v2/members/{member_id}", get(rproxy.clone())) + .route("/v2/members/{member_id}", patch(rproxy.clone())) + .route("/v2/members/{member_id}", delete(rproxy.clone())) - .route("/v2/systems/{system_id}/groups", get(rproxy)) - .route("/v2/groups", post(rproxy)) - .route("/v2/groups/{group_id}", get(rproxy)) - .route("/v2/groups/{group_id}", patch(rproxy)) - .route("/v2/groups/{group_id}", delete(rproxy)) + .route("/v2/systems/{system_id}/groups", get(rproxy.clone())) + .route("/v2/groups", post(rproxy.clone())) + .route("/v2/groups/{group_id}", get(rproxy.clone())) + .route("/v2/groups/{group_id}", patch(rproxy.clone())) + .route("/v2/groups/{group_id}", delete(rproxy.clone())) - .route("/v2/groups/{group_id}/members", get(rproxy)) - .route("/v2/groups/{group_id}/members/add", post(rproxy)) - .route("/v2/groups/{group_id}/members/remove", post(rproxy)) - .route("/v2/groups/{group_id}/members/overwrite", post(rproxy)) + .route("/v2/groups/{group_id}/members", get(rproxy.clone())) + .route("/v2/groups/{group_id}/members/add", post(rproxy.clone())) + .route("/v2/groups/{group_id}/members/remove", post(rproxy.clone())) + .route("/v2/groups/{group_id}/members/overwrite", post(rproxy.clone())) - .route("/v2/members/{member_id}/groups", get(rproxy)) - .route("/v2/members/{member_id}/groups/add", post(rproxy)) - .route("/v2/members/{member_id}/groups/remove", post(rproxy)) - .route("/v2/members/{member_id}/groups/overwrite", post(rproxy)) + .route("/v2/members/{member_id}/groups", get(rproxy.clone())) + .route("/v2/members/{member_id}/groups/add", post(rproxy.clone())) + .route("/v2/members/{member_id}/groups/remove", post(rproxy.clone())) + .route("/v2/members/{member_id}/groups/overwrite", post(rproxy.clone())) - .route("/v2/systems/{system_id}/switches", get(rproxy)) - .route("/v2/systems/{system_id}/switches", post(rproxy)) - .route("/v2/systems/{system_id}/fronters", get(rproxy)) + .route("/v2/systems/{system_id}/switches", get(rproxy.clone())) + .route("/v2/systems/{system_id}/switches", post(rproxy.clone())) + .route("/v2/systems/{system_id}/fronters", get(rproxy.clone())) - .route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy)) - .route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy)) - .route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy)) - .route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy)) + .route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy.clone())) + .route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy.clone())) + .route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy.clone())) + .route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy.clone())) - .route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy)) - .route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy)) + .route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy.clone())) + .route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy.clone())) - .route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy)) - .route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy)) + .route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy.clone())) + .route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy.clone())) - .route("/v2/systems/{system_id}/autoproxy", get(rproxy)) - .route("/v2/systems/{system_id}/autoproxy", patch(rproxy)) + .route("/v2/systems/{system_id}/autoproxy", get(rproxy.clone())) + .route("/v2/systems/{system_id}/autoproxy", patch(rproxy.clone())) - .route("/v2/messages/{message_id}", get(rproxy)) + .route("/v2/messages/{message_id}", get(rproxy.clone())) .route("/v2/bulk", post(endpoints::bulk::bulk)) - .route("/private/bulk_privacy/member", post(rproxy)) - .route("/private/bulk_privacy/group", post(rproxy)) - .route("/private/discord/callback", post(rproxy)) + .route("/private/bulk_privacy/member", post(rproxy.clone())) + .route("/private/bulk_privacy/group", post(rproxy.clone())) + .route("/private/discord/callback", post(rproxy.clone())) .route("/private/discord/callback2", post(endpoints::private::discord_callback)) .route("/private/discord/shard_state", get(endpoints::private::discord_state)) .route("/private/dash_views", post(endpoints::private::dash_views)) .route("/private/dash_view/{id}", get(endpoints::private::dash_view)) .route("/private/stats", get(endpoints::private::meta)) - .route("/v2/systems/{system_id}/oembed.json", get(rproxy)) - .route("/v2/members/{member_id}/oembed.json", get(rproxy)) - .route("/v2/groups/{group_id}/oembed.json", get(rproxy)) + .route("/v2/systems/{system_id}/oembed.json", get(rproxy.clone())) + .route("/v2/members/{member_id}/oembed.json", get(rproxy.clone())) + .route("/v2/groups/{group_id}/oembed.json", get(rproxy.clone())) .layer(axum::middleware::from_fn_with_state( if config.api().use_ratelimiter { @@ -161,15 +121,14 @@ async fn main() -> anyhow::Result<()> { let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) .build(HttpConnector::new()); - let ctx = ApiContext { - db, - redis, - + let proxyer = Proxyer { rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(), rproxy_client, }; - let app = router(ctx); + let ctx = ApiContext { db, redis }; + + let app = router(ctx, proxyer); let addr: &str = libpk::config.api().addr.as_ref(); diff --git a/crates/api/src/proxyer.rs b/crates/api/src/proxyer.rs new file mode 100644 index 00000000..5d770235 --- /dev/null +++ b/crates/api/src/proxyer.rs @@ -0,0 +1,51 @@ +use crate::{ + auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER}, + error::PKError, +}; +use axum::{ + body::Body, + extract::Request as ExtractRequest, + http::Uri, + response::{IntoResponse, Response}, +}; +use hyper_util::client::legacy::{Client, connect::HttpConnector}; + +#[derive(Clone)] +pub struct Proxyer { + pub rproxy_uri: String, + pub rproxy_client: Client, +} + +impl Proxyer { + pub async fn rproxy( + self, + auth: AuthState, + mut req: ExtractRequest, + ) -> Result { + let path = req.uri().path(); + let path_query = req + .uri() + .path_and_query() + .map(|v| v.as_str()) + .unwrap_or(path); + + let uri = format!("{}{}", self.rproxy_uri, path_query); + + *req.uri_mut() = Uri::try_from(uri).unwrap(); + + let headers = req.headers_mut(); + + headers.remove(INTERNAL_SYSTEMID_HEADER); + headers.remove(INTERNAL_APPID_HEADER); + + if let Some(sid) = auth.system_id() { + headers.append(INTERNAL_SYSTEMID_HEADER, sid.into()); + } + + if let Some(aid) = auth.app_id() { + headers.append(INTERNAL_APPID_HEADER, aid.into()); + } + + Ok(self.rproxy_client.request(req).await?.into_response()) + } +} diff --git a/crates/libpk/src/_config.rs b/crates/libpk/src/_config.rs index ec76e27a..d51479ac 100644 --- a/crates/libpk/src/_config.rs +++ b/crates/libpk/src/_config.rs @@ -97,6 +97,13 @@ pub struct ScheduledTasksConfig { pub prometheus_url: String, } +#[derive(Deserialize, Clone, Debug)] +pub struct PremiumConfig { + pub postmark_token: String, + pub from_email: String, + pub base_url: String, +} + fn _metrics_default() -> bool { false } @@ -116,6 +123,8 @@ pub struct PKConfig { avatars: Option, #[serde(default)] pub scheduled_tasks: Option, + #[serde(default)] + premium: Option, #[serde(default = "_metrics_default")] pub run_metrics_server: bool, @@ -147,6 +156,10 @@ impl PKConfig { .as_ref() .expect("missing avatar service config") } + + pub fn premium(&self) -> &PremiumConfig { + self.premium.as_ref().expect("missing premium config") + } } // todo: consider passing this down instead of making it global diff --git a/crates/premium/Cargo.toml b/crates/premium/Cargo.toml new file mode 100644 index 00000000..bfc62ade --- /dev/null +++ b/crates/premium/Cargo.toml @@ -0,0 +1,35 @@ +[package] +name = "premium" +version = "0.1.0" +edition = "2024" + +[dependencies] +pluralkit_models = { path = "../models" } +pk_macros = { path = "../macros" } +libpk = { path = "../libpk" } +api = { path = "../api" } + +anyhow = { workspace = true } +axum = { workspace = true } +axum-extra = { workspace = true } +fred = { workspace = true } +lazy_static = { workspace = true } +metrics = { workspace = true } +reqwest = { workspace = true } +sea-query = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +sqlx = { workspace = true } +tokio = { workspace = true } +tower-http = { workspace = true } +tracing = { workspace = true } +twilight-http = { workspace = true } + +askama = "0.14.0" +postmark = { version = "0.11", features = ["reqwest"] } +rand = "0.8" +thiserror = "1.0" +hex = "0.4" +chrono = { workspace = true } +serde_urlencoded = "0.7" +time = "0.3" \ No newline at end of file diff --git a/crates/premium/src/auth.rs b/crates/premium/src/auth.rs new file mode 100644 index 00000000..2d3092d7 --- /dev/null +++ b/crates/premium/src/auth.rs @@ -0,0 +1,318 @@ +use api::{ApiContext, fail_html}; +use askama::Template; +use axum::{ + extract::{MatchedPath, Request, State}, + http::header::SET_COOKIE, + middleware::Next, + response::{AppendHeaders, IntoResponse, Redirect, Response}, +}; +use axum_extra::extract::cookie::CookieJar; +use fred::{ + prelude::{KeysInterface, LuaInterface}, + util::sha1_hash, +}; +use rand::{Rng, distributions::Alphanumeric}; +use serde::{Deserialize, Serialize}; + +use crate::web::{render, message}; + +const LOGIN_TOKEN_TTL_SECS: i64 = 60 * 10; + +const SESSION_LUA_SCRIPT: &str = r#" +local session_key = KEYS[1] +local ttl = ARGV[1] + +local session_data = redis.call('GET', session_key) +if session_data then + redis.call('EXPIRE', session_key, ttl) +end +return session_data +"#; + +const SESSION_TTL_SECS: i64 = 60 * 60 * 4; + +lazy_static::lazy_static! { + static ref SESSION_LUA_SCRIPT_SHA: String = sha1_hash(SESSION_LUA_SCRIPT); +} + +fn rand_token() -> String { + rand::thread_rng() + .sample_iter(&Alphanumeric) + .take(64) + .map(char::from) + .collect() +} + +#[derive(Clone, Serialize, Deserialize)] +pub struct AuthState { + pub email: String, + + pub csrf_token: String, + pub session_id: String, +} + +impl AuthState { + fn new(email: String) -> Self { + Self { + email, + csrf_token: rand_token(), + session_id: rand_token(), + } + } + + async fn from_request( + headers: axum::http::HeaderMap, + ctx: &ApiContext, + ) -> anyhow::Result> { + let jar = CookieJar::from_headers(&headers); + let Some(session_cookie) = jar.get("pk-session") else { + return Ok(None); + }; + let session_id = session_cookie.value(); + + let session_key = format!("premium:session:{}", session_id); + + let script_exists: Vec = ctx + .redis + .script_exists(vec![SESSION_LUA_SCRIPT_SHA.to_string()]) + .await?; + + if script_exists[0] != 1 { + ctx.redis + .script_load::(SESSION_LUA_SCRIPT.to_string()) + .await?; + } + + let session_data: Option = ctx + .redis + .evalsha( + SESSION_LUA_SCRIPT_SHA.to_string(), + vec![session_key], + vec![SESSION_TTL_SECS], + ) + .await?; + + let Some(session_data) = session_data else { + return Ok(None); + }; + + let session: AuthState = serde_json::from_str(&session_data)?; + Ok(Some(session)) + } + + async fn save(&self, ctx: &ApiContext) -> anyhow::Result<()> { + let session_key = format!("premium:session:{}", self.session_id); + let session_data = serde_json::to_string(&self)?; + ctx.redis + .set::<(), _, _>( + session_key, + session_data, + Some(fred::types::Expiration::EX(SESSION_TTL_SECS)), + None, + false, + ) + .await?; + Ok(()) + } + + async fn delete(&self, ctx: &ApiContext) -> anyhow::Result<()> { + let session_key = format!("premium:session:{}", self.session_id); + ctx.redis.del::<(), _>(session_key).await?; + Ok(()) + } +} + +fn refresh_session_cookie(session: &AuthState, mut response: Response) -> Response { + let cookie_value = format!( + "pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}", + session.session_id, SESSION_TTL_SECS + ); + response + .headers_mut() + .insert(SET_COOKIE, cookie_value.parse().unwrap()); + response +} + +pub async fn middleware( + State(ctx): State, + mut request: Request, + next: Next, +) -> Response { + let extensions = request.extensions().clone(); + + let endpoint = extensions + .get::() + .cloned() + .map(|v| v.as_str().to_string()) + .unwrap_or("unknown".to_string()); + + let session = match AuthState::from_request(request.headers().clone(), &ctx).await { + Ok(s) => s, + Err(err) => fail_html!(?err, "failed to fetch auth state from redis"), + }; + + if let Some(session) = session.clone() { + request.extensions_mut().insert(session); + } + + match endpoint.as_str() { + "/" => { + if let Some(ref session) = session { + let response = next.run(request).await; + refresh_session_cookie(session, response) + } else { + return render!(crate::web::Index { + session: None, + show_login_form: true, + message: None, + }); + } + } + "/login" => { + if let Some(ref session) = session { + // no session here because that shows the "you're logged in as" component + let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None)); + return refresh_session_cookie(session, response); + } else { + let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await { + Ok(b) => b, + Err(err) => fail_html!(?err, "failed to read request body"), + }; + let form: std::collections::HashMap = + match serde_urlencoded::from_bytes(&body) { + Ok(f) => f, + Err(err) => fail_html!(?err, "failed to parse form data"), + }; + let Some(email) = form.get("email") else { + return render!(crate::web::Index { + session: None, + show_login_form: true, + message: Some("email field is required".to_string()), + }); + }; + let email = email.trim().to_lowercase(); + if email.is_empty() { + return render!(crate::web::Index { + session: None, + show_login_form: true, + message: Some("email field is required".to_string()), + }); + } + + let token = rand_token(); + + let token_key = format!("premium:login_token:{}", token); + if let Err(err) = ctx + .redis + .set::<(), _, _>( + token_key, + &email, + Some(fred::types::Expiration::EX(LOGIN_TOKEN_TTL_SECS)), + None, + false, + ) + .await + { + fail_html!(?err, "failed to store login token in redis"); + } + + if let Err(err) = crate::mailer::login_token(email, token).await { + fail_html!(?err, "failed to send login email"); + } + + return render!(message( + "check your email for a login link! it will expire in 10 minutes.".to_string(), + None + )); + } + } + "/login/{token}" => { + if let Some(ref session) = session { + // no session here because that shows the "you're logged in as" component + let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None)); + return refresh_session_cookie(session, response); + } + + let path = request.uri().path(); + let token = path.strip_prefix("/login/").unwrap_or(""); + if token.is_empty() { + return render!(crate::web::Index { + session: None, + show_login_form: true, + message: Some("invalid login link".to_string()), + }); + } + + let token_key = format!("premium:login_token:{}", token); + let email: Option = match ctx.redis.get(&token_key).await { + Ok(e) => e, + Err(err) => fail_html!(?err, "failed to fetch login token from redis"), + }; + + let Some(email) = email else { + return render!(crate::web::Index { + session: None, + show_login_form: true, + message: Some( + "invalid or expired login link. please request a new one.".to_string() + ), + }); + }; + + if let Err(err) = ctx.redis.del::<(), _>(&token_key).await { + fail_html!(?err, "failed to delete login token from redis"); + } + + let session = AuthState::new(email); + if let Err(err) = session.save(&ctx).await { + fail_html!(?err, "failed to save session to redis"); + } + + let cookie_value = format!( + "pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}", + session.session_id, SESSION_TTL_SECS + ); + ( + AppendHeaders([(SET_COOKIE, cookie_value)]), + Redirect::to("/"), + ) + .into_response() + } + "/logout" => { + let Some(session) = session else { + return Redirect::to("/").into_response(); + }; + + let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await { + Ok(b) => b, + Err(err) => fail_html!(?err, "failed to read request body"), + }; + let form: std::collections::HashMap = + match serde_urlencoded::from_bytes(&body) { + Ok(f) => f, + Err(err) => fail_html!(?err, "failed to parse form data"), + }; + + let csrf_valid = form + .get("csrf_token") + .map(|t| t == &session.csrf_token) + .unwrap_or(false); + + if !csrf_valid { + return (axum::http::StatusCode::FORBIDDEN, "invalid csrf token").into_response(); + } + + if let Err(err) = session.delete(&ctx).await { + fail_html!(?err, "failed to delete session from redis"); + } + + let cookie_value = "pk-session=; Path=/; HttpOnly; Max-Age=0"; + ( + AppendHeaders([(SET_COOKIE, cookie_value)]), + Redirect::to("/"), + ) + .into_response() + } + _ => (axum::http::StatusCode::NOT_FOUND, "404 not found").into_response(), + } +} diff --git a/crates/premium/src/mailer.rs b/crates/premium/src/mailer.rs new file mode 100644 index 00000000..e1ea8715 --- /dev/null +++ b/crates/premium/src/mailer.rs @@ -0,0 +1,44 @@ +use lazy_static::lazy_static; +use postmark::{ + Query, + api::{Body, email::SendEmailRequest}, + reqwest::PostmarkClient, +}; + +lazy_static! { + pub static ref CLIENT: PostmarkClient = { + PostmarkClient::builder() + .server_token(&libpk::config.premium().postmark_token) + .build() + }; +} + +const LOGIN_TEXT: &'static str = r#"Hello, + +Someone (hopefully you) has requested a link to log in to the PluralKit Premium website. + +Click here to log in: {link} + +This link will expire in 10 minutes. + +If you did not request this link, please ignore this message. + +Thanks, +- PluralKit Team +"#; + +pub async fn login_token(rcpt: String, token: String) -> anyhow::Result<()> { + SendEmailRequest::builder() + .from(&libpk::config.premium().from_email) + .to(rcpt) + .subject("[PluralKit Premium] Your login link") + .body(Body::text(LOGIN_TEXT.replace( + "{link}", + format!("{}/login/{token}", libpk::config.premium().base_url).as_str(), + ))) + .build() + .execute(&(CLIENT.to_owned())) + .await?; + + Ok(()) +} diff --git a/crates/premium/src/main.rs b/crates/premium/src/main.rs new file mode 100644 index 00000000..c573a976 --- /dev/null +++ b/crates/premium/src/main.rs @@ -0,0 +1,63 @@ +use askama::Template; +use axum::{ + Extension, Router, + response::Html, + routing::{get, post}, +}; +use tower_http::{catch_panic::CatchPanicLayer, services::ServeDir}; +use tracing::info; + +use api::{ApiContext, middleware}; + +mod auth; +mod mailer; +mod web; + +// this function is manually formatted for easier legibility of route_services +#[rustfmt::skip] +fn router(ctx: ApiContext) -> Router { + // processed upside down (???) so we have to put middleware at the end + Router::new() + .route("/", get(|Extension(session): Extension| async move { + Html(web::Index { + session: Some(session), + show_login_form: false, + message: None, + }.render().unwrap()) + })) + + .route("/login/{token}", get(|| async { + "handled in auth middleware" + })) + .route("/login", post(|| async { + "handled in auth middleware" + })) + .route("/logout", post(|| async { + "handled in auth middleware" + })) + + .layer(axum::middleware::from_fn_with_state(ctx.clone(), auth::middleware)) + .layer(axum::middleware::from_fn(middleware::logger::logger)) + .nest_service("/static", ServeDir::new("static")) + .layer(CatchPanicLayer::custom(api::util::handle_panic)) + + .with_state(ctx) +} + +#[libpk::main] +async fn main() -> anyhow::Result<()> { + let db = libpk::db::init_data_db().await?; + let redis = libpk::db::init_redis().await?; + + let ctx = ApiContext { db, redis }; + + let app = router(ctx); + + let addr: &str = libpk::config.api().addr.as_ref(); + + let listener = tokio::net::TcpListener::bind(addr).await?; + info!("listening on {}", addr); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/crates/premium/src/web.rs b/crates/premium/src/web.rs new file mode 100644 index 00000000..ba610037 --- /dev/null +++ b/crates/premium/src/web.rs @@ -0,0 +1,33 @@ +use askama::Template; + +use crate::auth::AuthState; + +macro_rules! render { + ($stuff:expr) => {{ + let mut response = $stuff.render().unwrap().into_response(); + let headers = response.headers_mut(); + headers.insert( + "content-type", + axum::http::HeaderValue::from_static("text/html"), + ); + response + }}; +} + +pub(crate) use render; + +pub fn message(message: String, session: Option) -> Index { + Index { + session: session, + show_login_form: false, + message: Some(message) + } +} + +#[derive(Template)] +#[template(path = "index.html")] +pub struct Index { + pub session: Option, + pub show_login_form: bool, + pub message: Option, +} diff --git a/crates/premium/static/stylesheet.css b/crates/premium/static/stylesheet.css new file mode 100644 index 00000000..e69de29b diff --git a/crates/premium/templates/index.html b/crates/premium/templates/index.html new file mode 100644 index 00000000..df99e357 --- /dev/null +++ b/crates/premium/templates/index.html @@ -0,0 +1,29 @@ + + + PluralKit Premium + + + +

PluralKit Premium

+ +{% if let Some(session) = session %} +
+ +

logged in as {{ session.email }}.

+ +
+{% endif %} + +{% if show_login_form %} +

Enter your email address to log in.

+ +
+ + +
+{% endif %} + +{% if let Some(msg) = message %} +
{{ msg }}
+{% endif %} +