mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
feat: premium service boilerplate
This commit is contained in:
parent
c4f820e114
commit
f1471088d2
15 changed files with 912 additions and 104 deletions
|
|
@ -19,6 +19,7 @@ serde = { workspace = true }
|
|||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
|
||||
|
|
@ -27,6 +28,5 @@ hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"
|
|||
reverse-proxy-service = { version = "0.2.1", features = ["axum"] }
|
||||
serde_urlencoded = "0.7.1"
|
||||
tower = "0.4.13"
|
||||
tower-http = { version = "0.5.2", features = ["catch-panic"] }
|
||||
subtle = "2.6.1"
|
||||
sea-query-sqlx = { version = "0.8.0-rc.8", features = ["sqlx-postgres", "with-chrono"] }
|
||||
|
|
|
|||
|
|
@ -93,6 +93,14 @@ macro_rules! fail {
|
|||
|
||||
pub(crate) use fail;
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! fail_html {
|
||||
($($stuff:tt)+) => {{
|
||||
tracing::error!($($stuff)+);
|
||||
return (axum::http::StatusCode::INTERNAL_SERVER_ERROR, "internal server error").into_response();
|
||||
}};
|
||||
}
|
||||
|
||||
macro_rules! define_error {
|
||||
( $name:ident, $response_code:expr, $json_code:expr, $message:expr ) => {
|
||||
#[allow(dead_code)]
|
||||
|
|
|
|||
10
crates/api/src/lib.rs
Normal file
10
crates/api/src/lib.rs
Normal file
|
|
@ -0,0 +1,10 @@
|
|||
mod auth;
|
||||
pub mod error;
|
||||
pub mod middleware;
|
||||
pub mod util;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiContext {
|
||||
pub db: sqlx::postgres::PgPool,
|
||||
pub redis: fred::clients::RedisPool,
|
||||
}
|
||||
|
|
@ -1,135 +1,95 @@
|
|||
use auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER};
|
||||
use api::ApiContext;
|
||||
use auth::AuthState;
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
body::Body,
|
||||
extract::{Request as ExtractRequest, State},
|
||||
extract::Request as ExtractRequest,
|
||||
http::Uri,
|
||||
response::{IntoResponse, Response},
|
||||
routing::{delete, get, patch, post},
|
||||
};
|
||||
use hyper_util::{
|
||||
client::legacy::{Client, connect::HttpConnector},
|
||||
rt::TokioExecutor,
|
||||
};
|
||||
use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
|
||||
use libpk::config;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use pk_macros::api_endpoint;
|
||||
use crate::proxyer::Proxyer;
|
||||
|
||||
mod auth;
|
||||
mod endpoints;
|
||||
mod error;
|
||||
mod middleware;
|
||||
mod proxyer;
|
||||
mod util;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct ApiContext {
|
||||
pub db: sqlx::postgres::PgPool,
|
||||
pub redis: fred::clients::RedisPool,
|
||||
|
||||
rproxy_uri: String,
|
||||
rproxy_client: Client<HttpConnector, Body>,
|
||||
}
|
||||
|
||||
#[api_endpoint]
|
||||
async fn rproxy(
|
||||
Extension(auth): Extension<AuthState>,
|
||||
State(ctx): State<ApiContext>,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Response {
|
||||
let path = req.uri().path();
|
||||
let path_query = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or(path);
|
||||
|
||||
let uri = format!("{}{}", ctx.rproxy_uri, path_query);
|
||||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
let headers = req.headers_mut();
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
if let Some(sid) = auth.system_id() {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
|
||||
if let Some(aid) = auth.app_id() {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
|
||||
Ok(ctx.rproxy_client.request(req).await?.into_response())
|
||||
}
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
fn router(ctx: ApiContext) -> Router {
|
||||
fn router(ctx: ApiContext, proxyer: Proxyer) -> Router {
|
||||
let rproxy = |Extension(auth): Extension<AuthState>, req: ExtractRequest<Body>| {
|
||||
proxyer.rproxy(auth, req)
|
||||
};
|
||||
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
Router::new()
|
||||
.route("/v2/systems/{system_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}", get(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}", patch(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/settings", get(endpoints::system::get_system_settings))
|
||||
.route("/v2/systems/{system_id}/settings", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/settings", patch(rproxy.clone()))
|
||||
|
||||
.route("/v2/systems/{system_id}/members", get(rproxy))
|
||||
.route("/v2/members", post(rproxy))
|
||||
.route("/v2/members/{member_id}", get(rproxy))
|
||||
.route("/v2/members/{member_id}", patch(rproxy))
|
||||
.route("/v2/members/{member_id}", delete(rproxy))
|
||||
.route("/v2/systems/{system_id}/members", get(rproxy.clone()))
|
||||
.route("/v2/members", post(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}", get(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}", patch(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}", delete(rproxy.clone()))
|
||||
|
||||
.route("/v2/systems/{system_id}/groups", get(rproxy))
|
||||
.route("/v2/groups", post(rproxy))
|
||||
.route("/v2/groups/{group_id}", get(rproxy))
|
||||
.route("/v2/groups/{group_id}", patch(rproxy))
|
||||
.route("/v2/groups/{group_id}", delete(rproxy))
|
||||
.route("/v2/systems/{system_id}/groups", get(rproxy.clone()))
|
||||
.route("/v2/groups", post(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}", get(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}", patch(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}", delete(rproxy.clone()))
|
||||
|
||||
.route("/v2/groups/{group_id}/members", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/add", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/remove", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy))
|
||||
.route("/v2/groups/{group_id}/members", get(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}/members/add", post(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}/members/remove", post(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}/members/overwrite", post(rproxy.clone()))
|
||||
|
||||
.route("/v2/members/{member_id}/groups", get(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/add", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/remove", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy))
|
||||
.route("/v2/members/{member_id}/groups", get(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}/groups/add", post(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}/groups/remove", post(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}/groups/overwrite", post(rproxy.clone()))
|
||||
|
||||
.route("/v2/systems/{system_id}/switches", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches", post(rproxy))
|
||||
.route("/v2/systems/{system_id}/fronters", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches", get(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/switches", post(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/fronters", get(rproxy.clone()))
|
||||
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", get(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", patch(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}/members", patch(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/switches/{switch_id}", delete(rproxy.clone()))
|
||||
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", get(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/guilds/{guild_id}", patch(rproxy.clone()))
|
||||
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", get(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}/guilds/{guild_id}", patch(rproxy.clone()))
|
||||
|
||||
.route("/v2/systems/{system_id}/autoproxy", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy))
|
||||
.route("/v2/systems/{system_id}/autoproxy", get(rproxy.clone()))
|
||||
.route("/v2/systems/{system_id}/autoproxy", patch(rproxy.clone()))
|
||||
|
||||
.route("/v2/messages/{message_id}", get(rproxy))
|
||||
.route("/v2/messages/{message_id}", get(rproxy.clone()))
|
||||
|
||||
.route("/v2/bulk", post(endpoints::bulk::bulk))
|
||||
|
||||
.route("/private/bulk_privacy/member", post(rproxy))
|
||||
.route("/private/bulk_privacy/group", post(rproxy))
|
||||
.route("/private/discord/callback", post(rproxy))
|
||||
.route("/private/bulk_privacy/member", post(rproxy.clone()))
|
||||
.route("/private/bulk_privacy/group", post(rproxy.clone()))
|
||||
.route("/private/discord/callback", post(rproxy.clone()))
|
||||
.route("/private/discord/callback2", post(endpoints::private::discord_callback))
|
||||
.route("/private/discord/shard_state", get(endpoints::private::discord_state))
|
||||
.route("/private/dash_views", post(endpoints::private::dash_views))
|
||||
.route("/private/dash_view/{id}", get(endpoints::private::dash_view))
|
||||
.route("/private/stats", get(endpoints::private::meta))
|
||||
|
||||
.route("/v2/systems/{system_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/members/{member_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/groups/{group_id}/oembed.json", get(rproxy))
|
||||
.route("/v2/systems/{system_id}/oembed.json", get(rproxy.clone()))
|
||||
.route("/v2/members/{member_id}/oembed.json", get(rproxy.clone()))
|
||||
.route("/v2/groups/{group_id}/oembed.json", get(rproxy.clone()))
|
||||
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
if config.api().use_ratelimiter {
|
||||
|
|
@ -161,15 +121,14 @@ async fn main() -> anyhow::Result<()> {
|
|||
let rproxy_client = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
|
||||
.build(HttpConnector::new());
|
||||
|
||||
let ctx = ApiContext {
|
||||
db,
|
||||
redis,
|
||||
|
||||
let proxyer = Proxyer {
|
||||
rproxy_uri: rproxy_uri[..rproxy_uri.len() - 1].to_string(),
|
||||
rproxy_client,
|
||||
};
|
||||
|
||||
let app = router(ctx);
|
||||
let ctx = ApiContext { db, redis };
|
||||
|
||||
let app = router(ctx, proxyer);
|
||||
|
||||
let addr: &str = libpk::config.api().addr.as_ref();
|
||||
|
||||
|
|
|
|||
51
crates/api/src/proxyer.rs
Normal file
51
crates/api/src/proxyer.rs
Normal file
|
|
@ -0,0 +1,51 @@
|
|||
use crate::{
|
||||
auth::{AuthState, INTERNAL_APPID_HEADER, INTERNAL_SYSTEMID_HEADER},
|
||||
error::PKError,
|
||||
};
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request as ExtractRequest,
|
||||
http::Uri,
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use hyper_util::client::legacy::{Client, connect::HttpConnector};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct Proxyer {
|
||||
pub rproxy_uri: String,
|
||||
pub rproxy_client: Client<HttpConnector, Body>,
|
||||
}
|
||||
|
||||
impl Proxyer {
|
||||
pub async fn rproxy(
|
||||
self,
|
||||
auth: AuthState,
|
||||
mut req: ExtractRequest<Body>,
|
||||
) -> Result<Response, PKError> {
|
||||
let path = req.uri().path();
|
||||
let path_query = req
|
||||
.uri()
|
||||
.path_and_query()
|
||||
.map(|v| v.as_str())
|
||||
.unwrap_or(path);
|
||||
|
||||
let uri = format!("{}{}", self.rproxy_uri, path_query);
|
||||
|
||||
*req.uri_mut() = Uri::try_from(uri).unwrap();
|
||||
|
||||
let headers = req.headers_mut();
|
||||
|
||||
headers.remove(INTERNAL_SYSTEMID_HEADER);
|
||||
headers.remove(INTERNAL_APPID_HEADER);
|
||||
|
||||
if let Some(sid) = auth.system_id() {
|
||||
headers.append(INTERNAL_SYSTEMID_HEADER, sid.into());
|
||||
}
|
||||
|
||||
if let Some(aid) = auth.app_id() {
|
||||
headers.append(INTERNAL_APPID_HEADER, aid.into());
|
||||
}
|
||||
|
||||
Ok(self.rproxy_client.request(req).await?.into_response())
|
||||
}
|
||||
}
|
||||
|
|
@ -97,6 +97,13 @@ pub struct ScheduledTasksConfig {
|
|||
pub prometheus_url: String,
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Clone, Debug)]
|
||||
pub struct PremiumConfig {
|
||||
pub postmark_token: String,
|
||||
pub from_email: String,
|
||||
pub base_url: String,
|
||||
}
|
||||
|
||||
fn _metrics_default() -> bool {
|
||||
false
|
||||
}
|
||||
|
|
@ -116,6 +123,8 @@ pub struct PKConfig {
|
|||
avatars: Option<AvatarsConfig>,
|
||||
#[serde(default)]
|
||||
pub scheduled_tasks: Option<ScheduledTasksConfig>,
|
||||
#[serde(default)]
|
||||
premium: Option<PremiumConfig>,
|
||||
|
||||
#[serde(default = "_metrics_default")]
|
||||
pub run_metrics_server: bool,
|
||||
|
|
@ -147,6 +156,10 @@ impl PKConfig {
|
|||
.as_ref()
|
||||
.expect("missing avatar service config")
|
||||
}
|
||||
|
||||
pub fn premium(&self) -> &PremiumConfig {
|
||||
self.premium.as_ref().expect("missing premium config")
|
||||
}
|
||||
}
|
||||
|
||||
// todo: consider passing this down instead of making it global
|
||||
|
|
|
|||
35
crates/premium/Cargo.toml
Normal file
35
crates/premium/Cargo.toml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
[package]
|
||||
name = "premium"
|
||||
version = "0.1.0"
|
||||
edition = "2024"
|
||||
|
||||
[dependencies]
|
||||
pluralkit_models = { path = "../models" }
|
||||
pk_macros = { path = "../macros" }
|
||||
libpk = { path = "../libpk" }
|
||||
api = { path = "../api" }
|
||||
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
axum-extra = { workspace = true }
|
||||
fred = { workspace = true }
|
||||
lazy_static = { workspace = true }
|
||||
metrics = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
sea-query = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
twilight-http = { workspace = true }
|
||||
|
||||
askama = "0.14.0"
|
||||
postmark = { version = "0.11", features = ["reqwest"] }
|
||||
rand = "0.8"
|
||||
thiserror = "1.0"
|
||||
hex = "0.4"
|
||||
chrono = { workspace = true }
|
||||
serde_urlencoded = "0.7"
|
||||
time = "0.3"
|
||||
318
crates/premium/src/auth.rs
Normal file
318
crates/premium/src/auth.rs
Normal file
|
|
@ -0,0 +1,318 @@
|
|||
use api::{ApiContext, fail_html};
|
||||
use askama::Template;
|
||||
use axum::{
|
||||
extract::{MatchedPath, Request, State},
|
||||
http::header::SET_COOKIE,
|
||||
middleware::Next,
|
||||
response::{AppendHeaders, IntoResponse, Redirect, Response},
|
||||
};
|
||||
use axum_extra::extract::cookie::CookieJar;
|
||||
use fred::{
|
||||
prelude::{KeysInterface, LuaInterface},
|
||||
util::sha1_hash,
|
||||
};
|
||||
use rand::{Rng, distributions::Alphanumeric};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::web::{render, message};
|
||||
|
||||
const LOGIN_TOKEN_TTL_SECS: i64 = 60 * 10;
|
||||
|
||||
const SESSION_LUA_SCRIPT: &str = r#"
|
||||
local session_key = KEYS[1]
|
||||
local ttl = ARGV[1]
|
||||
|
||||
local session_data = redis.call('GET', session_key)
|
||||
if session_data then
|
||||
redis.call('EXPIRE', session_key, ttl)
|
||||
end
|
||||
return session_data
|
||||
"#;
|
||||
|
||||
const SESSION_TTL_SECS: i64 = 60 * 60 * 4;
|
||||
|
||||
lazy_static::lazy_static! {
|
||||
static ref SESSION_LUA_SCRIPT_SHA: String = sha1_hash(SESSION_LUA_SCRIPT);
|
||||
}
|
||||
|
||||
fn rand_token() -> String {
|
||||
rand::thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(64)
|
||||
.map(char::from)
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[derive(Clone, Serialize, Deserialize)]
|
||||
pub struct AuthState {
|
||||
pub email: String,
|
||||
|
||||
pub csrf_token: String,
|
||||
pub session_id: String,
|
||||
}
|
||||
|
||||
impl AuthState {
|
||||
fn new(email: String) -> Self {
|
||||
Self {
|
||||
email,
|
||||
csrf_token: rand_token(),
|
||||
session_id: rand_token(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn from_request(
|
||||
headers: axum::http::HeaderMap,
|
||||
ctx: &ApiContext,
|
||||
) -> anyhow::Result<Option<Self>> {
|
||||
let jar = CookieJar::from_headers(&headers);
|
||||
let Some(session_cookie) = jar.get("pk-session") else {
|
||||
return Ok(None);
|
||||
};
|
||||
let session_id = session_cookie.value();
|
||||
|
||||
let session_key = format!("premium:session:{}", session_id);
|
||||
|
||||
let script_exists: Vec<usize> = ctx
|
||||
.redis
|
||||
.script_exists(vec![SESSION_LUA_SCRIPT_SHA.to_string()])
|
||||
.await?;
|
||||
|
||||
if script_exists[0] != 1 {
|
||||
ctx.redis
|
||||
.script_load::<String, String>(SESSION_LUA_SCRIPT.to_string())
|
||||
.await?;
|
||||
}
|
||||
|
||||
let session_data: Option<String> = ctx
|
||||
.redis
|
||||
.evalsha(
|
||||
SESSION_LUA_SCRIPT_SHA.to_string(),
|
||||
vec![session_key],
|
||||
vec![SESSION_TTL_SECS],
|
||||
)
|
||||
.await?;
|
||||
|
||||
let Some(session_data) = session_data else {
|
||||
return Ok(None);
|
||||
};
|
||||
|
||||
let session: AuthState = serde_json::from_str(&session_data)?;
|
||||
Ok(Some(session))
|
||||
}
|
||||
|
||||
async fn save(&self, ctx: &ApiContext) -> anyhow::Result<()> {
|
||||
let session_key = format!("premium:session:{}", self.session_id);
|
||||
let session_data = serde_json::to_string(&self)?;
|
||||
ctx.redis
|
||||
.set::<(), _, _>(
|
||||
session_key,
|
||||
session_data,
|
||||
Some(fred::types::Expiration::EX(SESSION_TTL_SECS)),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete(&self, ctx: &ApiContext) -> anyhow::Result<()> {
|
||||
let session_key = format!("premium:session:{}", self.session_id);
|
||||
ctx.redis.del::<(), _>(session_key).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
fn refresh_session_cookie(session: &AuthState, mut response: Response) -> Response {
|
||||
let cookie_value = format!(
|
||||
"pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
|
||||
session.session_id, SESSION_TTL_SECS
|
||||
);
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(SET_COOKIE, cookie_value.parse().unwrap());
|
||||
response
|
||||
}
|
||||
|
||||
pub async fn middleware(
|
||||
State(ctx): State<ApiContext>,
|
||||
mut request: Request,
|
||||
next: Next,
|
||||
) -> Response {
|
||||
let extensions = request.extensions().clone();
|
||||
|
||||
let endpoint = extensions
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let session = match AuthState::from_request(request.headers().clone(), &ctx).await {
|
||||
Ok(s) => s,
|
||||
Err(err) => fail_html!(?err, "failed to fetch auth state from redis"),
|
||||
};
|
||||
|
||||
if let Some(session) = session.clone() {
|
||||
request.extensions_mut().insert(session);
|
||||
}
|
||||
|
||||
match endpoint.as_str() {
|
||||
"/" => {
|
||||
if let Some(ref session) = session {
|
||||
let response = next.run(request).await;
|
||||
refresh_session_cookie(session, response)
|
||||
} else {
|
||||
return render!(crate::web::Index {
|
||||
session: None,
|
||||
show_login_form: true,
|
||||
message: None,
|
||||
});
|
||||
}
|
||||
}
|
||||
"/login" => {
|
||||
if let Some(ref session) = session {
|
||||
// no session here because that shows the "you're logged in as" component
|
||||
let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None));
|
||||
return refresh_session_cookie(session, response);
|
||||
} else {
|
||||
let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await {
|
||||
Ok(b) => b,
|
||||
Err(err) => fail_html!(?err, "failed to read request body"),
|
||||
};
|
||||
let form: std::collections::HashMap<String, String> =
|
||||
match serde_urlencoded::from_bytes(&body) {
|
||||
Ok(f) => f,
|
||||
Err(err) => fail_html!(?err, "failed to parse form data"),
|
||||
};
|
||||
let Some(email) = form.get("email") else {
|
||||
return render!(crate::web::Index {
|
||||
session: None,
|
||||
show_login_form: true,
|
||||
message: Some("email field is required".to_string()),
|
||||
});
|
||||
};
|
||||
let email = email.trim().to_lowercase();
|
||||
if email.is_empty() {
|
||||
return render!(crate::web::Index {
|
||||
session: None,
|
||||
show_login_form: true,
|
||||
message: Some("email field is required".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let token = rand_token();
|
||||
|
||||
let token_key = format!("premium:login_token:{}", token);
|
||||
if let Err(err) = ctx
|
||||
.redis
|
||||
.set::<(), _, _>(
|
||||
token_key,
|
||||
&email,
|
||||
Some(fred::types::Expiration::EX(LOGIN_TOKEN_TTL_SECS)),
|
||||
None,
|
||||
false,
|
||||
)
|
||||
.await
|
||||
{
|
||||
fail_html!(?err, "failed to store login token in redis");
|
||||
}
|
||||
|
||||
if let Err(err) = crate::mailer::login_token(email, token).await {
|
||||
fail_html!(?err, "failed to send login email");
|
||||
}
|
||||
|
||||
return render!(message(
|
||||
"check your email for a login link! it will expire in 10 minutes.".to_string(),
|
||||
None
|
||||
));
|
||||
}
|
||||
}
|
||||
"/login/{token}" => {
|
||||
if let Some(ref session) = session {
|
||||
// no session here because that shows the "you're logged in as" component
|
||||
let response = render!(message("you are already logged in! go back home and log out if you need to log in to a different account.".to_string(), None));
|
||||
return refresh_session_cookie(session, response);
|
||||
}
|
||||
|
||||
let path = request.uri().path();
|
||||
let token = path.strip_prefix("/login/").unwrap_or("");
|
||||
if token.is_empty() {
|
||||
return render!(crate::web::Index {
|
||||
session: None,
|
||||
show_login_form: true,
|
||||
message: Some("invalid login link".to_string()),
|
||||
});
|
||||
}
|
||||
|
||||
let token_key = format!("premium:login_token:{}", token);
|
||||
let email: Option<String> = match ctx.redis.get(&token_key).await {
|
||||
Ok(e) => e,
|
||||
Err(err) => fail_html!(?err, "failed to fetch login token from redis"),
|
||||
};
|
||||
|
||||
let Some(email) = email else {
|
||||
return render!(crate::web::Index {
|
||||
session: None,
|
||||
show_login_form: true,
|
||||
message: Some(
|
||||
"invalid or expired login link. please request a new one.".to_string()
|
||||
),
|
||||
});
|
||||
};
|
||||
|
||||
if let Err(err) = ctx.redis.del::<(), _>(&token_key).await {
|
||||
fail_html!(?err, "failed to delete login token from redis");
|
||||
}
|
||||
|
||||
let session = AuthState::new(email);
|
||||
if let Err(err) = session.save(&ctx).await {
|
||||
fail_html!(?err, "failed to save session to redis");
|
||||
}
|
||||
|
||||
let cookie_value = format!(
|
||||
"pk-session={}; Path=/; HttpOnly; Secure; SameSite=Lax; Max-Age={}",
|
||||
session.session_id, SESSION_TTL_SECS
|
||||
);
|
||||
(
|
||||
AppendHeaders([(SET_COOKIE, cookie_value)]),
|
||||
Redirect::to("/"),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
"/logout" => {
|
||||
let Some(session) = session else {
|
||||
return Redirect::to("/").into_response();
|
||||
};
|
||||
|
||||
let body = match axum::body::to_bytes(request.into_body(), 1024 * 16).await {
|
||||
Ok(b) => b,
|
||||
Err(err) => fail_html!(?err, "failed to read request body"),
|
||||
};
|
||||
let form: std::collections::HashMap<String, String> =
|
||||
match serde_urlencoded::from_bytes(&body) {
|
||||
Ok(f) => f,
|
||||
Err(err) => fail_html!(?err, "failed to parse form data"),
|
||||
};
|
||||
|
||||
let csrf_valid = form
|
||||
.get("csrf_token")
|
||||
.map(|t| t == &session.csrf_token)
|
||||
.unwrap_or(false);
|
||||
|
||||
if !csrf_valid {
|
||||
return (axum::http::StatusCode::FORBIDDEN, "invalid csrf token").into_response();
|
||||
}
|
||||
|
||||
if let Err(err) = session.delete(&ctx).await {
|
||||
fail_html!(?err, "failed to delete session from redis");
|
||||
}
|
||||
|
||||
let cookie_value = "pk-session=; Path=/; HttpOnly; Max-Age=0";
|
||||
(
|
||||
AppendHeaders([(SET_COOKIE, cookie_value)]),
|
||||
Redirect::to("/"),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
_ => (axum::http::StatusCode::NOT_FOUND, "404 not found").into_response(),
|
||||
}
|
||||
}
|
||||
44
crates/premium/src/mailer.rs
Normal file
44
crates/premium/src/mailer.rs
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
use lazy_static::lazy_static;
|
||||
use postmark::{
|
||||
Query,
|
||||
api::{Body, email::SendEmailRequest},
|
||||
reqwest::PostmarkClient,
|
||||
};
|
||||
|
||||
lazy_static! {
|
||||
pub static ref CLIENT: PostmarkClient = {
|
||||
PostmarkClient::builder()
|
||||
.server_token(&libpk::config.premium().postmark_token)
|
||||
.build()
|
||||
};
|
||||
}
|
||||
|
||||
const LOGIN_TEXT: &'static str = r#"Hello,
|
||||
|
||||
Someone (hopefully you) has requested a link to log in to the PluralKit Premium website.
|
||||
|
||||
Click here to log in: {link}
|
||||
|
||||
This link will expire in 10 minutes.
|
||||
|
||||
If you did not request this link, please ignore this message.
|
||||
|
||||
Thanks,
|
||||
- PluralKit Team
|
||||
"#;
|
||||
|
||||
pub async fn login_token(rcpt: String, token: String) -> anyhow::Result<()> {
|
||||
SendEmailRequest::builder()
|
||||
.from(&libpk::config.premium().from_email)
|
||||
.to(rcpt)
|
||||
.subject("[PluralKit Premium] Your login link")
|
||||
.body(Body::text(LOGIN_TEXT.replace(
|
||||
"{link}",
|
||||
format!("{}/login/{token}", libpk::config.premium().base_url).as_str(),
|
||||
)))
|
||||
.build()
|
||||
.execute(&(CLIENT.to_owned()))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
63
crates/premium/src/main.rs
Normal file
63
crates/premium/src/main.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
use askama::Template;
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
response::Html,
|
||||
routing::{get, post},
|
||||
};
|
||||
use tower_http::{catch_panic::CatchPanicLayer, services::ServeDir};
|
||||
use tracing::info;
|
||||
|
||||
use api::{ApiContext, middleware};
|
||||
|
||||
mod auth;
|
||||
mod mailer;
|
||||
mod web;
|
||||
|
||||
// this function is manually formatted for easier legibility of route_services
|
||||
#[rustfmt::skip]
|
||||
fn router(ctx: ApiContext) -> Router {
|
||||
// processed upside down (???) so we have to put middleware at the end
|
||||
Router::new()
|
||||
.route("/", get(|Extension(session): Extension<auth::AuthState>| async move {
|
||||
Html(web::Index {
|
||||
session: Some(session),
|
||||
show_login_form: false,
|
||||
message: None,
|
||||
}.render().unwrap())
|
||||
}))
|
||||
|
||||
.route("/login/{token}", get(|| async {
|
||||
"handled in auth middleware"
|
||||
}))
|
||||
.route("/login", post(|| async {
|
||||
"handled in auth middleware"
|
||||
}))
|
||||
.route("/logout", post(|| async {
|
||||
"handled in auth middleware"
|
||||
}))
|
||||
|
||||
.layer(axum::middleware::from_fn_with_state(ctx.clone(), auth::middleware))
|
||||
.layer(axum::middleware::from_fn(middleware::logger::logger))
|
||||
.nest_service("/static", ServeDir::new("static"))
|
||||
.layer(CatchPanicLayer::custom(api::util::handle_panic))
|
||||
|
||||
.with_state(ctx)
|
||||
}
|
||||
|
||||
#[libpk::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
let db = libpk::db::init_data_db().await?;
|
||||
let redis = libpk::db::init_redis().await?;
|
||||
|
||||
let ctx = ApiContext { db, redis };
|
||||
|
||||
let app = router(ctx);
|
||||
|
||||
let addr: &str = libpk::config.api().addr.as_ref();
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("listening on {}", addr);
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
33
crates/premium/src/web.rs
Normal file
33
crates/premium/src/web.rs
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
use askama::Template;
|
||||
|
||||
use crate::auth::AuthState;
|
||||
|
||||
macro_rules! render {
|
||||
($stuff:expr) => {{
|
||||
let mut response = $stuff.render().unwrap().into_response();
|
||||
let headers = response.headers_mut();
|
||||
headers.insert(
|
||||
"content-type",
|
||||
axum::http::HeaderValue::from_static("text/html"),
|
||||
);
|
||||
response
|
||||
}};
|
||||
}
|
||||
|
||||
pub(crate) use render;
|
||||
|
||||
pub fn message(message: String, session: Option<AuthState>) -> Index {
|
||||
Index {
|
||||
session: session,
|
||||
show_login_form: false,
|
||||
message: Some(message)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Template)]
|
||||
#[template(path = "index.html")]
|
||||
pub struct Index {
|
||||
pub session: Option<AuthState>,
|
||||
pub show_login_form: bool,
|
||||
pub message: Option<String>,
|
||||
}
|
||||
0
crates/premium/static/stylesheet.css
Normal file
0
crates/premium/static/stylesheet.css
Normal file
29
crates/premium/templates/index.html
Normal file
29
crates/premium/templates/index.html
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
<!DOCTYPE html>
|
||||
<head>
|
||||
<title>PluralKit Premium</title>
|
||||
<link rel="stylesheet" href="/static/stylesheet.css" />
|
||||
</head>
|
||||
<body>
|
||||
<h2>PluralKit Premium</h2>
|
||||
|
||||
{% if let Some(session) = session %}
|
||||
<form action="/logout" method="post">
|
||||
<input type="hidden" name="csrf_token" value="{{ session.csrf_token }}" />
|
||||
<p>logged in as <strong>{{ session.email }}.</strong></p>
|
||||
<button type="submit">log out</button>
|
||||
</form>
|
||||
{% endif %}
|
||||
|
||||
{% if show_login_form %}
|
||||
<p>Enter your email address to log in.</p>
|
||||
|
||||
<form method="POST" action="/login">
|
||||
<input type="email" name="email" placeholder="you@example.com" required />
|
||||
<button type="submit">Send</button>
|
||||
</form>
|
||||
{% endif %}
|
||||
|
||||
{% if let Some(msg) = message %}
|
||||
<div>{{ msg }}</div>
|
||||
{% endif %}
|
||||
</body>
|
||||
Loading…
Add table
Add a link
Reference in a new issue