feat: premium service boilerplate

This commit is contained in:
alyssa 2025-12-23 00:45:45 -05:00
parent c4f820e114
commit f1471088d2
15 changed files with 912 additions and 104 deletions

View file

@ -19,6 +19,7 @@ serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
@ -27,6 +28,5 @@ hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
serde_urlencoded = "0.7.1"
tower = "0.4.13"
tower-http = { version = "0.5.2", features = ["catch-panic"] }
subtle = "2.6.1"
sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] }

View file

@ -93,6 +93,14 @@ macro_rules! fail {
pub(crate) use fail;
#[macro_export]
macro_rules! fail_html {
($($stuff:tt)+) => {{
tracing::error!($($stuff)+);
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response();
}};
}
macro_rules! define_error {
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
#[allow(dead_code)]

10
crates/api/src/lib.rs Normal file
View file

@ -0,0 +1,10 @@
mod auth;
pub mod error;
pub mod middleware;
pub mod util;
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::clients::RedisPool,
}

View file

@ -1,135 +1,95 @@
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
use api::ApiContext;
use auth::AuthState;
use axum::{
Extension, Router,
body::Body,
extract::{Request as ExtractRequest, State},
extract::Request as ExtractRequest,
http::Uri,
response::{IntoResponse, Response},
routing::{delete, get, patch, post},
};
use hyper_util::{
client::legacy::{Client, connect::HttpConnector},
rt::TokioExecutor,
};
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
use libpk::config;
use tracing::{info, warn};
use pk_macros::api_endpoint;
use crate::proxyer::Proxyer;
mod auth;
mod endpoints;
mod error;
mod middleware;
mod proxyer;
mod util;
#[derive(Clone)]
pub struct ApiContext {
pub db: sqlx::postgres::PgPool,
pub redis: fred::clients::RedisPool,
rproxy_uri: String,
rproxy_client: Client<HttpConnector, Body>,
}
#[api_endpoint]
async fn rproxy(
Extension(auth): Extension<AuthState>,
State(ctx): State<ApiContext>,
mut req: ExtractRequest<Body>,
) -> Response {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(ctx.rproxy_client.request(req).await?.into_response())
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
fn router(ctx: ApiContext) -> Router {
fn router(ctx: ApiContext, proxyer: Proxyer) -> Router {
let rproxy = |Extension(auth): Extension<AuthState>, req: ExtractRequest<Body>| {
proxyer.rproxy(auth, req)
};
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/v2/systems/{system_id}", get(rproxy))
.route("/v2/systems/{system_id}", patch(rproxy))
.route("/v2/systems/{system_id}", get(rproxy.clone()))
.route("/v2/systems/{system_id}", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
.route("/v2/systems/{system_id}/settings", patch(rproxy))
.route("/v2/systems/{system_id}/settings", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/members", get(rproxy))
.route("/v2/members", post(rproxy))
.route("/v2/members/{member_id}", get(rproxy))
.route("/v2/members/{member_id}", patch(rproxy))
.route("/v2/members/{member_id}", delete(rproxy))
.route("/v2/systems/{system_id}/members", get(rproxy.clone()))
.route("/v2/members", post(rproxy.clone()))
.route("/v2/members/{member_id}", get(rproxy.clone()))
.route("/v2/members/{member_id}", patch(rproxy.clone()))
.route("/v2/members/{member_id}", delete(rproxy.clone()))
.route("/v2/systems/{system_id}/groups", get(rproxy))
.route("/v2/groups", post(rproxy))
.route("/v2/groups/{group_id}", get(rproxy))
.route("/v2/groups/{group_id}", patch(rproxy))
.route("/v2/groups/{group_id}", delete(rproxy))
.route("/v2/systems/{system_id}/groups", get(rproxy.clone()))
.route("/v2/groups", post(rproxy.clone()))
.route("/v2/groups/{group_id}", get(rproxy.clone()))
.route("/v2/groups/{group_id}", patch(rproxy.clone()))
.route("/v2/groups/{group_id}", delete(rproxy.clone()))
.route("/v2/groups/{group_id}/members", get(rproxy))
.route("/v2/groups/{group_id}/members/add", post(rproxy))
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
.route("/v2/groups/{group_id}/members", get(rproxy.clone()))
.route("/v2/groups/{group_id}/members/add", post(rproxy.clone()))
.route("/v2/groups/{group_id}/members/remove", post(rproxy.clone()))
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy.clone()))
.route("/v2/members/{member_id}/groups", get(rproxy))
.route("/v2/members/{member_id}/groups/add", post(rproxy))
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
.route("/v2/members/{member_id}/groups", get(rproxy.clone()))
.route("/v2/members/{member_id}/groups/add", post(rproxy.clone()))
.route("/v2/members/{member_id}/groups/remove", post(rproxy.clone()))
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy.clone()))
.route("/v2/systems/{system_id}/switches", get(rproxy))
.route("/v2/systems/{system_id}/switches", post(rproxy))
.route("/v2/systems/{system_id}/fronters", get(rproxy))
.route("/v2/systems/{system_id}/switches", get(rproxy.clone()))
.route("/v2/systems/{system_id}/switches", post(rproxy.clone()))
.route("/v2/systems/{system_id}/fronters", get(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy.clone()))
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy.clone()))
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy.clone()))
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy.clone()))
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy.clone()))
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
.route("/v2/systems/{system_id}/autoproxy", get(rproxy.clone()))
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy.clone()))
.route("/v2/messages/{message_id}", get(rproxy))
.route("/v2/messages/{message_id}", get(rproxy.clone()))
.route("/v2/bulk", post(endpoints::bulk::bulk))
.route("/private/bulk_privacy/member", post(rproxy))
.route("/private/bulk_privacy/group", post(rproxy))
.route("/private/discord/callback", post(rproxy))
.route("/private/bulk_privacy/member", post(rproxy.clone()))
.route("/private/bulk_privacy/group", post(rproxy.clone()))
.route("/private/discord/callback", post(rproxy.clone()))
.route("/private/discord/callback2", post(endpoints::private::discord_callback))
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
.route("/private/dash_views", post(endpoints::private::dash_views))
.route("/private/dash_view/{id}", get(endpoints::private::dash_view))
.route("/private/stats", get(endpoints::private::meta))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
.route("/v2/systems/{system_id}/oembed.json", get(rproxy.clone()))
.route("/v2/members/{member_id}/oembed.json", get(rproxy.clone()))
.route("/v2/groups/{group_id}/oembed.json", get(rproxy.clone()))
.layer(axum::middleware::from_fn_with_state(
if config.api().use_ratelimiter {
@ -161,15 +121,14 @@ async fn main() -> anyhow::Result<()> {
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
.build(HttpConnector::new());
let ctx = ApiContext {
db,
redis,
let proxyer = Proxyer {
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
rproxy_client,
};
let app = router(ctx);
let ctx = ApiContext { db, redis };
let app = router(ctx, proxyer);
let addr: &str = libpk::config.api().addr.as_ref();

51
crates/api/src/proxyer.rs Normal file
View file

@ -0,0 +1,51 @@
use crate::{
auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
error::PKError,
};
use axum::{
body::Body,
extract::Request as ExtractRequest,
http::Uri,
response::{IntoResponse, Response},
};
use hyper_util::client::legacy::{Client, connect::HttpConnector};
#[derive(Clone)]
pub struct Proxyer {
pub rproxy_uri: String,
pub rproxy_client: Client<HttpConnector, Body>,
}
impl Proxyer {
pub async fn rproxy(
self,
auth: AuthState,
mut req: ExtractRequest<Body>,
) -> Result<Response, PKError> {
let path = req.uri().path();
let path_query = req
.uri()
.path_and_query()
.map(|v| v.as_str())
.unwrap_or(path);
let uri = format!("{}{}", self.rproxy_uri, path_query);
*req.uri_mut() = Uri::try_from(uri).unwrap();
let headers = req.headers_mut();
headers.remove(INTERNAL_SYSTEMID_HEADER);
headers.remove(INTERNAL_APPID_HEADER);
if let Some(sid) = auth.system_id() {
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
}
if let Some(aid) = auth.app_id() {
headers.append(INTERNAL_APPID_HEADER, aid.into());
}
Ok(self.rproxy_client.request(req).await?.into_response())
}
}

View file

@ -97,6 +97,13 @@ pub struct ScheduledTasksConfig {
pub prometheus_url: String,
}
#[derive(Deserialize, Clone, Debug)]
pub struct PremiumConfig {
pub postmark_token: String,
pub from_email: String,
pub base_url: String,
}
fn _metrics_default() -> bool {
false
}
@ -116,6 +123,8 @@ pub struct PKConfig {
avatars: Option<AvatarsConfig>,
#[serde(default)]
pub scheduled_tasks: Option<ScheduledTasksConfig>,
#[serde(default)]
premium: Option<PremiumConfig>,
#[serde(default = "_metrics_default")]
pub run_metrics_server: bool,
@ -147,6 +156,10 @@ impl PKConfig {
.as_ref()
.expect("missing avatar service config")
}
pub fn premium(&self) -> &PremiumConfig {
self.premium.as_ref().expect("missing premium config")
}
}
// todo: consider passing this down instead of making it global

35
crates/premium/Cargo.toml Normal file
View file

@ -0,0 +1,35 @@
[package]
name = "premium"
version = "0.1.0"
edition = "2024"
[dependencies]
pluralkit_models = { path = "../models" }
pk_macros = { path = "../macros" }
libpk = { path = "../libpk" }
api = { path = "../api" }
anyhow = { workspace = true }
axum = { workspace = true }
axum-extra = { workspace = true }
fred = { workspace = true }
lazy_static = { workspace = true }
metrics = { workspace = true }
reqwest = { workspace = true }
sea-query = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tower-http = { workspace = true }
tracing = { workspace = true }
twilight-http = { workspace = true }
askama = "0.14.0"
postmark = { version = "0.11", features = ["reqwest"] }
rand = "0.8"
thiserror = "1.0"
hex = "0.4"
chrono = { workspace = true }
serde_urlencoded = "0.7"
time = "0.3"

318
crates/premium/src/auth.rs Normal file
View file

@ -0,0 +1,318 @@
use api::{ApiContext, fail_html};
use askama::Template;
use axum::{
extract::{MatchedPath, Request, State},
http::header::SET_COOKIE,
middleware::Next,
response::{AppendHeaders, IntoResponse, Redirect, Response},
};
use axum_extra::extract::cookie::CookieJar;
use fred::{
prelude::{KeysInterface, LuaInterface},
util::sha1_hash,
};
use rand::{Rng, distributions::Alphanumeric};
use serde::{Deserialize, Serialize};
use crate::web::{render, message};
const LOGIN_TOKEN_TTL_SECS: i64 = 60 * 10;
const SESSION_LUA_SCRIPT: &str = r#"
local session_key = KEYS[1]
local ttl = ARGV[1]
local session_data = redis.call('GET', session_key)
if session_data then
redis.call('EXPIRE', session_key, ttl)
end
return session_data
"#;
const SESSION_TTL_SECS: i64 = 60 * 60 * 4;
lazy_static::lazy_static! {
static ref SESSION_LUA_SCRIPT_SHA: String = sha1_hash(SESSION_LUA_SCRIPT);
}
fn rand_token() -> String {
rand::thread_rng()
.sample_iter(&Alphanumeric)
.take(64)
.map(char::from)
.collect()
}
#[derive(Clone, Serialize, Deserialize)]
pub struct AuthState {
pub email: String,
pub csrf_token: String,
pub session_id: String,
}
impl AuthState {
fn new(email: String) -> Self {
Self {
email,
csrf_token: rand_token(),
session_id: rand_token(),
}
}
async fn from_request(
headers: axum::http::HeaderMap,
ctx: &ApiContext,
) -> anyhow::Result<Option<Self>> {
let jar = CookieJar::from_headers(&headers);
let Some(session_cookie) = jar.get("pk-session") else {
return Ok(None);
};
let session_id = session_cookie.value();
let session_key = format!("premium:session:{}", session_id);
let script_exists: Vec<usize> = ctx
.redis
.script_exists(vec![SESSION_LUA_SCRIPT_SHA.to_string()])
.await?;
if script_exists[0] != 1 {
ctx.redis
.script_load::<String, String>(SESSION_LUA_SCRIPT.to_string())
.await?;
}
let session_data: Option<String> = ctx
.redis
.evalsha(
SESSION_LUA_SCRIPT_SHA.to_string(),
vec![session_key],
vec![SESSION_TTL_SECS],
)
.await?;
let Some(session_data) = session_data else {
return Ok(None);
};
let session: AuthState = serde_json::from_str(&session_data)?;
Ok(Some(session))
}
async fn save(&self, ctx: &ApiContext) -> anyhow::Result<()> {
let session_key = format!("premium:session:{}", self.session_id);
let session_data = serde_json::to_string(&self)?;
ctx.redis
.set::<(), _, _>(
session_key,
session_data,
Some(fred::types::Expiration::EX(SESSION_TTL_SECS)),
None,
false,
)
.await?;
Ok(())
}
async fn delete(&self, ctx: &ApiContext) -> anyhow::Result<()> {
let session_key = format!("premium:session:{}", self.session_id);
ctx.redis.del::<(), _>(session_key).await?;
Ok(())
}
}
fn refresh_session_cookie(session: &AuthState, mut response: Response) -> Response {
let cookie_value = format!(
"pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
session.session_id, SESSION_TTL_SECS
);
response
.headers_mut()
.insert(SET_COOKIE, cookie_value.parse().unwrap());
response
}
pub async fn middleware(
State(ctx): State<ApiContext>,
mut request: Request,
next: Next,
) -> Response {
let extensions = request.extensions().clone();
let endpoint = extensions
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let session = match AuthState::from_request(request.headers().clone(), &ctx).await {
Ok(s) => s,
Err(err) => fail_html!(?err, "failed to fetch auth state from redis"),
};
if let Some(session) = session.clone() {
request.extensions_mut().insert(session);
}
match endpoint.as_str() {
"/" => {
if let Some(ref session) = session {
let response = next.run(request).await;
refresh_session_cookie(session, response)
} else {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: None,
});
}
}
"/login" => {
if let Some(ref session) = session {
// no session here because that shows the "you're logged in as" component
let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None));
return refresh_session_cookie(session, response);
} else {
let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await {
Ok(b) => b,
Err(err) => fail_html!(?err, "failed to read request body"),
};
let form: std::collections::HashMap<String, String> =
match serde_urlencoded::from_bytes(&body) {
Ok(f) => f,
Err(err) => fail_html!(?err, "failed to parse form data"),
};
let Some(email) = form.get("email") else {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some("email field is required".to_string()),
});
};
let email = email.trim().to_lowercase();
if email.is_empty() {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some("email field is required".to_string()),
});
}
let token = rand_token();
let token_key = format!("premium:login_token:{}", token);
if let Err(err) = ctx
.redis
.set::<(), _, _>(
token_key,
&email,
Some(fred::types::Expiration::EX(LOGIN_TOKEN_TTL_SECS)),
None,
false,
)
.await
{
fail_html!(?err, "failed to store login token in redis");
}
if let Err(err) = crate::mailer::login_token(email, token).await {
fail_html!(?err, "failed to send login email");
}
return render!(message(
"check your email for a login link! it will expire in 10 minutes.".to_string(),
None
));
}
}
"/login/{token}" => {
if let Some(ref session) = session {
// no session here because that shows the "you're logged in as" component
let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None));
return refresh_session_cookie(session, response);
}
let path = request.uri().path();
let token = path.strip_prefix("/login/").unwrap_or("");
if token.is_empty() {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some("invalid login link".to_string()),
});
}
let token_key = format!("premium:login_token:{}", token);
let email: Option<String> = match ctx.redis.get(&token_key).await {
Ok(e) => e,
Err(err) => fail_html!(?err, "failed to fetch login token from redis"),
};
let Some(email) = email else {
return render!(crate::web::Index {
session: None,
show_login_form: true,
message: Some(
"invalid or expired login link. please request a new one.".to_string()
),
});
};
if let Err(err) = ctx.redis.del::<(), _>(&token_key).await {
fail_html!(?err, "failed to delete login token from redis");
}
let session = AuthState::new(email);
if let Err(err) = session.save(&ctx).await {
fail_html!(?err, "failed to save session to redis");
}
let cookie_value = format!(
"pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
session.session_id, SESSION_TTL_SECS
);
(
AppendHeaders([(SET_COOKIE, cookie_value)]),
Redirect::to("/"),
)
.into_response()
}
"/logout" => {
let Some(session) = session else {
return Redirect::to("/").into_response();
};
let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await {
Ok(b) => b,
Err(err) => fail_html!(?err, "failed to read request body"),
};
let form: std::collections::HashMap<String, String> =
match serde_urlencoded::from_bytes(&body) {
Ok(f) => f,
Err(err) => fail_html!(?err, "failed to parse form data"),
};
let csrf_valid = form
.get("csrf_token")
.map(|t| t == &session.csrf_token)
.unwrap_or(false);
if !csrf_valid {
return (axum::http::StatusCode::FORBIDDEN, "invalid csrf token").into_response();
}
if let Err(err) = session.delete(&ctx).await {
fail_html!(?err, "failed to delete session from redis");
}
let cookie_value = "pk-session=; Path=/; HttpOnly; Max-Age=0";
(
AppendHeaders([(SET_COOKIE, cookie_value)]),
Redirect::to("/"),
)
.into_response()
}
_ => (axum::http::StatusCode::NOT_FOUND, "404 not found").into_response(),
}
}

View file

@ -0,0 +1,44 @@
use lazy_static::lazy_static;
use postmark::{
Query,
api::{Body, email::SendEmailRequest},
reqwest::PostmarkClient,
};
lazy_static! {
pub static ref CLIENT: PostmarkClient = {
PostmarkClient::builder()
.server_token(&libpk::config.premium().postmark_token)
.build()
};
}
const LOGIN_TEXT: &'static str = r#"Hello,
Someone (hopefully you) has requested a link to log in to the PluralKit Premium website.
Click here to log in: {link}
This link will expire in 10 minutes.
If you did not request this link, please ignore this message.
Thanks,
- PluralKit Team
"#;
pub async fn login_token(rcpt: String, token: String) -> anyhow::Result<()> {
SendEmailRequest::builder()
.from(&libpk::config.premium().from_email)
.to(rcpt)
.subject("[PluralKit Premium] Your login link")
.body(Body::text(LOGIN_TEXT.replace(
"{link}",
format!("{}/login/{token}", libpk::config.premium().base_url).as_str(),
)))
.build()
.execute(&(CLIENT.to_owned()))
.await?;
Ok(())
}

View file

@ -0,0 +1,63 @@
use askama::Template;
use axum::{
Extension, Router,
response::Html,
routing::{get, post},
};
use tower_http::{catch_panic::CatchPanicLayer, services::ServeDir};
use tracing::info;
use api::{ApiContext, middleware};
mod auth;
mod mailer;
mod web;
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
fn router(ctx: ApiContext) -> Router {
// processed upside down (???) so we have to put middleware at the end
Router::new()
.route("/", get(|Extension(session): Extension<auth::AuthState>| async move {
Html(web::Index {
session: Some(session),
show_login_form: false,
message: None,
}.render().unwrap())
}))
.route("/login/{token}", get(|| async {
"handled in auth middleware"
}))
.route("/login", post(|| async {
"handled in auth middleware"
}))
.route("/logout", post(|| async {
"handled in auth middleware"
}))
.layer(axum::middleware::from_fn_with_state(ctx.clone(), auth::middleware))
.layer(axum::middleware::from_fn(middleware::logger::logger))
.nest_service("/static", ServeDir::new("static"))
.layer(CatchPanicLayer::custom(api::util::handle_panic))
.with_state(ctx)
}
#[libpk::main]
async fn main() -> anyhow::Result<()> {
let db = libpk::db::init_data_db().await?;
let redis = libpk::db::init_redis().await?;
let ctx = ApiContext { db, redis };
let app = router(ctx);
let addr: &str = libpk::config.api().addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

33
crates/premium/src/web.rs Normal file
View file

@ -0,0 +1,33 @@
use askama::Template;
use crate::auth::AuthState;
macro_rules! render {
($stuff:expr) => {{
let mut response = $stuff.render().unwrap().into_response();
let headers = response.headers_mut();
headers.insert(
"content-type",
axum::http::HeaderValue::from_static("text/html"),
);
response
}};
}
pub(crate) use render;
pub fn message(message: String, session: Option<AuthState>) -> Index {
Index {
session: session,
show_login_form: false,
message: Some(message)
}
}
#[derive(Template)]
#[template(path = "index.html")]
pub struct Index {
pub session: Option<AuthState>,
pub show_login_form: bool,
pub message: Option<String>,
}

View file

View file

@ -0,0 +1,29 @@
<!DOCTYPE html>
<head>
<title>PluralKit Premium</title>
<link rel="stylesheet" href="/static/stylesheet.css" />
</head>
<body>
<h2>PluralKit Premium</h2>
{% if let Some(session) = session %}
<form action="/logout" method="post">
<input type="hidden" name="csrf_token" value="{{ session.csrf_token }}" />
<p>logged in as <strong>{{ session.email }}.</strong></p>
<button type="submit">log out</button>
</form>
{% endif %}
{% if show_login_form %}
<p>Enter your email address to log in.</p>
<form method="POST" action="/login">
<input type="email" name="email" placeholder="you@example.com" required />
<button type="submit">Send</button>
</form>
{% endif %}
{% if let Some(msg) = message %}
<div>{{ msg }}</div>
{% endif %}
</body>