From 882e9b66f231939e7e9162dd0c7f47a89f2c86d3 Mon Sep 17 00:00:00 2001 From: alyssa Date: Sun, 29 Dec 2024 21:48:28 +0000 Subject: [PATCH] feat(api): port discord/callback to rust --- Cargo.lock | 4 + lib/model_macros/src/lib.rs | 36 +++++-- lib/models/src/_util.rs | 35 ++++++ lib/models/src/lib.rs | 13 ++- lib/models/src/system.rs | 39 ++----- lib/models/src/system_config.rs | 89 ++++++++++++++++ services/api/Cargo.toml | 7 +- services/api/src/endpoints/private.rs | 146 ++++++++++++++++++++++++++ services/api/src/main.rs | 1 + 9 files changed, 327 insertions(+), 43 deletions(-) create mode 100644 lib/models/src/_util.rs create mode 100644 lib/models/src/system_config.rs diff --git a/Cargo.lock b/Cargo.lock index faf1c5f4..6e18bead 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -94,14 +94,18 @@ dependencies = [ "lazy_static", "libpk", "metrics", + "pluralkit_models", + "reqwest 0.12.8", "reverse-proxy-service", "serde", "serde_json", + "serde_urlencoded", "sqlx", "tokio", "tower", "tower-http 0.5.2", "tracing", + "twilight-http", ] [[package]] diff --git a/lib/model_macros/src/lib.rs b/lib/model_macros/src/lib.rs index fe5f5193..aead66df 100644 --- a/lib/model_macros/src/lib.rs +++ b/lib/model_macros/src/lib.rs @@ -16,6 +16,7 @@ struct ModelField { patch: ElemPatchability, json: Option, is_privacy: bool, + default: Option, } fn parse_field(field: syn::Field) -> ModelField { @@ -25,6 +26,7 @@ fn parse_field(field: syn::Field) -> ModelField { patch: ElemPatchability::None, json: None, is_privacy: false, + default: None, }; for attr in field.attrs.iter() { @@ -59,6 +61,12 @@ fn parse_field(field: syn::Field) -> ModelField { } f.json = Some(nv.value.clone()); } + "default" => { + if f.default.is_some() { + panic!("cannot set default multiple times for same field"); + } + f.default = Some(nv.value.clone()); + } _ => panic!("unknown attribute"), }, Meta::List(_) => panic!("unknown attribute"), @@ -69,6 +77,10 @@ fn parse_field(field: syn::Field) -> ModelField { panic!("must have json name to be publicly patchable"); } + if f.json.is_some() && f.is_privacy { + panic!("cannot set custom json name for privacy field"); + } + f } @@ -96,7 +108,7 @@ pub fn pk_model( panic!("fields of a struct must be named"); }; - println!("{}: {:#?}", tname, fields); + // println!("{}: {:#?}", tname, fields); let tfields = mk_tfields(fields.clone()); let from_json = mk_tfrom_json(fields.clone()); @@ -126,7 +138,7 @@ pub fn pk_model( #from_json } - pub fn to_json(self) -> String { + pub fn to_json(self) -> serde_json::Value { #to_json } } @@ -150,7 +162,7 @@ pub fn pk_model( #patch_to_sql } - pub fn to_json(self) -> String { + pub fn to_json(self) -> serde_json::Value { #patch_to_json } } @@ -165,7 +177,7 @@ fn mk_tfields(fields: Vec) -> TokenStream { let name = f.name.clone(); let ty = f.ty.clone(); quote! { - #name: #ty, + pub #name: #ty, } }) .collect() @@ -183,8 +195,14 @@ fn mk_tto_json(fields: Vec) -> TokenStream { .filter_map(|f| { f.json.as_ref().map(|v| { let tname = f.name.clone(); - quote! { - #v: self.#tname, + if let Some(default) = f.default.as_ref() { + quote! { + #v: self.#tname.unwrap_or(#default), + } + } else { + quote! { + #v: self.#tname, + } } }) }) @@ -206,12 +224,12 @@ fn mk_tto_json(fields: Vec) -> TokenStream { .collect(); quote! { - serde_json::to_string(&serde_json::json!({ + serde_json::json!({ #fielddefs "privacy": { #privacyfielddefs } - })).expect("json serializing generated models should not fail") + }) } } @@ -222,7 +240,7 @@ fn mk_patch_fields(fields: Vec) -> TokenStream { let name = f.name.clone(); let ty = f.ty.clone(); quote! { - #name: Option<#ty>, + pub #name: Option<#ty>, } }) .collect() diff --git a/lib/models/src/_util.rs b/lib/models/src/_util.rs new file mode 100644 index 00000000..c13a0bf6 --- /dev/null +++ b/lib/models/src/_util.rs @@ -0,0 +1,35 @@ +// postgres enums created in c# pluralkit implementations are "fake", i.e. they +// are actually ints in the database rather than postgres enums, because dapper +// does not support postgres enums +// here, we add some impls to support this kind of enum in sqlx +// there is probably a better way to do this, but works for now. +// note: caller needs to implement From for their type +macro_rules! fake_enum_impls { + ($n:ident) => { + impl Type for $n { + fn type_info() -> PgTypeInfo { + PgTypeInfo::with_name("INT4") + } + } + + impl From<$n> for i32 { + fn from(enum_value: $n) -> Self { + enum_value as i32 + } + } + + impl<'r, DB: Database> Decode<'r, DB> for $n + where + i32: Decode<'r, DB>, + { + fn decode( + value: ::ValueRef<'r>, + ) -> Result> { + let value = >::decode(value)?; + Ok(Self::from(value)) + } + } + }; +} + +pub(crate) use fake_enum_impls; diff --git a/lib/models/src/lib.rs b/lib/models/src/lib.rs index 8ca1df21..bb7bb08d 100644 --- a/lib/models/src/lib.rs +++ b/lib/models/src/lib.rs @@ -1,2 +1,11 @@ -mod system; -pub use system::*; +mod _util; + +macro_rules! model { + ($n:ident) => { + mod $n; + pub use $n::*; + }; +} + +model!(system); +model!(system_config); diff --git a/lib/models/src/system.rs b/lib/models/src/system.rs index 56b25070..d59d5957 100644 --- a/lib/models/src/system.rs +++ b/lib/models/src/system.rs @@ -6,54 +6,31 @@ use chrono::NaiveDateTime; use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type}; use uuid::Uuid; +use crate::_util::fake_enum_impls; + // todo: fix this pub type SystemId = i32; -// // todo: move this +// todo: move this #[derive(serde::Serialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] pub enum PrivacyLevel { - #[serde(rename = "public")] - Public = 1, - #[serde(rename = "private")] - Private = 2, + Public, + Private, } -impl Type for PrivacyLevel { - fn type_info() -> PgTypeInfo { - PgTypeInfo::with_name("INT4") - } -} - -impl From for i32 { - fn from(enum_value: PrivacyLevel) -> Self { - enum_value as i32 - } -} +fake_enum_impls!(PrivacyLevel); impl From for PrivacyLevel { fn from(value: i32) -> Self { match value { 1 => PrivacyLevel::Public, 2 => PrivacyLevel::Private, - _ => unimplemented!(), + _ => unreachable!(), } } } -struct MyType; - -impl<'r, DB: Database> Decode<'r, DB> for PrivacyLevel -where - i32: Decode<'r, DB>, -{ - fn decode( - value: ::ValueRef<'r>, - ) -> Result> { - let value = >::decode(value)?; - Ok(Self::from(value)) - } -} - #[pk_model] struct System { id: SystemId, diff --git a/lib/models/src/system_config.rs b/lib/models/src/system_config.rs new file mode 100644 index 00000000..d6b58a58 --- /dev/null +++ b/lib/models/src/system_config.rs @@ -0,0 +1,89 @@ +use model_macros::pk_model; + +use sqlx::{postgres::PgTypeInfo, Database, Decode, Postgres, Type}; +use std::error::Error; + +use crate::{SystemId, _util::fake_enum_impls}; + +pub const DEFAULT_MEMBER_LIMIT: i32 = 1000; +pub const DEFAULT_GROUP_LIMIT: i32 = 250; + +#[derive(serde::Serialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +enum HidPadFormat { + #[serde(rename = "off")] + None, + Left, + Right, +} +fake_enum_impls!(HidPadFormat); + +impl From for HidPadFormat { + fn from(value: i32) -> Self { + match value { + 0 => HidPadFormat::None, + 1 => HidPadFormat::Left, + 2 => HidPadFormat::Right, + _ => unreachable!(), + } + } +} + +#[derive(serde::Serialize, Debug, Clone)] +#[serde(rename_all = "snake_case")] +enum ProxySwitchAction { + Off, + New, + Add, +} +fake_enum_impls!(ProxySwitchAction); + +impl From for ProxySwitchAction { + fn from(value: i32) -> Self { + match value { + 0 => ProxySwitchAction::Off, + 1 => ProxySwitchAction::New, + 2 => ProxySwitchAction::Add, + _ => unreachable!(), + } + } +} + +#[pk_model] +struct SystemConfig { + system: SystemId, + #[json = "timezone"] + ui_tz: String, + #[json = "pings_enabled"] + pings_enabled: bool, + #[json = "latch_timeout"] + latch_timeout: Option, + #[json = "member_default_private"] + member_default_private: bool, + #[json = "group_default_private"] + group_default_private: bool, + #[json = "show_private_info"] + show_private_info: bool, + #[json = "member_limit"] + #[default = DEFAULT_MEMBER_LIMIT] + member_limit_override: Option, + #[json = "group_limit"] + #[default = DEFAULT_GROUP_LIMIT] + group_limit_override: Option, + #[json = "case_sensitive_proxy_tags"] + case_sensitive_proxy_tags: bool, + #[json = "proxy_error_message_enabled"] + proxy_error_message_enabled: bool, + #[json = "hid_display_split"] + hid_display_split: bool, + #[json = "hid_display_caps"] + hid_display_caps: bool, + #[json = "hid_list_padding"] + hid_list_padding: HidPadFormat, + #[json = "proxy_switch"] + proxy_switch: ProxySwitchAction, + #[json = "name_format"] + name_format: String, + #[json = "description_templates"] + description_templates: Vec, +} diff --git a/services/api/Cargo.toml b/services/api/Cargo.toml index a8de117b..58f2944b 100644 --- a/services/api/Cargo.toml +++ b/services/api/Cargo.toml @@ -4,20 +4,25 @@ version = "0.1.0" edition = "2021" [dependencies] +pluralkit_models = { path = "../../lib/models" } +libpk = { path = "../../lib/libpk" } + anyhow = { workspace = true } axum = { workspace = true } fred = { workspace = true } lazy_static = { workspace = true } -libpk = { path = "../../lib/libpk" } metrics = { workspace = true } +reqwest = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } sqlx = { workspace = true } tokio = { workspace = true } tracing = { workspace = true } +twilight-http = { workspace = true } hyper = { version = "1.3.1", features = ["http1"] } hyper-util = { version = "0.1.5", features = ["client", "client-legacy", "http1"] } reverse-proxy-service = { version = "0.2.1", features = ["axum"] } +serde_urlencoded = "0.7.1" tower = "0.4.13" tower-http = { version = "0.5.2", features = ["catch-panic"] } diff --git a/services/api/src/endpoints/private.rs b/services/api/src/endpoints/private.rs index edf9fee2..5d5049e9 100644 --- a/services/api/src/endpoints/private.rs +++ b/services/api/src/endpoints/private.rs @@ -55,3 +55,149 @@ pub async fn meta(State(ctx): State) -> Json { "channel_count": channel_count, })) } + +use std::time::Duration; + +use crate::util::json_err; +use axum::{ + extract, + response::{IntoResponse, Response}, +}; +use hyper::StatusCode; +use libpk::config; +use pluralkit_models::{PKSystem, PKSystemConfig}; +use reqwest::ClientBuilder; + +#[derive(serde::Deserialize, Debug)] +pub struct CallbackRequestData { + redirect_domain: String, + code: String, + // state: String, +} + +#[derive(serde::Serialize)] +struct CallbackDiscordData { + client_id: String, + client_secret: String, + grant_type: String, + redirect_uri: String, + code: String, +} + +pub async fn discord_callback( + State(ctx): State, + extract::Json(request_data): extract::Json, +) -> Response { + let client = ClientBuilder::new() + .connect_timeout(Duration::from_secs(3)) + .timeout(Duration::from_secs(3)) + .build() + .expect("error making client"); + + let reqbody = serde_urlencoded::to_string(&CallbackDiscordData { + client_id: config.discord.as_ref().unwrap().client_id.get().to_string(), + client_secret: config.discord.as_ref().unwrap().client_secret.clone(), + grant_type: "authorization_code".to_string(), + redirect_uri: request_data.redirect_domain, // change this! + code: request_data.code, + }) + .expect("could not serialize"); + + let discord_resp = client + .post("https://discord.com/api/v10/oauth2/token") + .header("content-type", "application/x-www-form-urlencoded") + .body(reqbody) + .send() + .await + .expect("failed to request discord"); + + let Value::Object(discord_data) = discord_resp + .json::() + .await + .expect("failed to deserialize discord response as json") + else { + panic!("discord response is not an object") + }; + + if !discord_data.contains_key("access_token") { + return json_err( + StatusCode::BAD_REQUEST, + format!( + "{{\"error\":\"{}\"\"}}", + discord_data + .get("error_description") + .expect("missing error_description from discord") + .to_string() + ), + ); + }; + + let token = format!( + "Bearer {}", + discord_data + .get("access_token") + .expect("missing access_token") + .as_str() + .unwrap() + ); + + let discord_client = twilight_http::Client::new(token); + + let user = discord_client + .current_user() + .await + .expect("failed to get current user from discord") + .model() + .await + .expect("failed to parse user model from discord"); + + let system: Option = sqlx::query_as( + r#" + select systems.* + from accounts + left join systems on accounts.system = systems.id + where accounts.uid = $1 + "#, + ) + .bind(user.id.get() as i64) + .fetch_optional(&ctx.db) + .await + .expect("failed to query"); + + if system.is_none() { + return json_err( + StatusCode::BAD_REQUEST, + "user does not have a system registered".to_string(), + ); + } + + let system = system.unwrap(); + + let system_config: Option = sqlx::query_as( + r#" + select * from system_config where system = $1 + "#, + ) + .bind(system.id) + .fetch_optional(&ctx.db) + .await + .expect("failed to query"); + + let system_config = system_config.unwrap(); + + // create dashboard token for system + + let token = system.clone().token; + + ( + StatusCode::OK, + serde_json::to_string(&serde_json::json!({ + "system": system.to_json(), + "config": system_config.to_json(), + "user": user, + "token": token, + })) + .expect("should not error"), + ) + .into_response() +} diff --git a/services/api/src/main.rs b/services/api/src/main.rs index 1a00f1a0..bf07ff8f 100644 --- a/services/api/src/main.rs +++ b/services/api/src/main.rs @@ -107,6 +107,7 @@ fn router(ctx: ApiContext) -> Router { .route("/private/bulk_privacy/member", post(rproxy)) .route("/private/bulk_privacy/group", post(rproxy)) .route("/private/discord/callback", post(rproxy)) + .route("/private/discord/callback2", post(endpoints::private::discord_callback)) .route("/private/discord/shard_state", get(endpoints::private::discord_state)) .route("/private/stats", get(endpoints::private::meta))