chore: reorganize rust crates

This commit is contained in:
alyssa 2025-01-02 00:50:36 +00:00
parent 357122a892
commit 16ce67e02c
58 changed files with 6 additions and 13 deletions

View file

@ -0,0 +1,16 @@
[package]
name = "dispatch"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
reqwest = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
hickory-client = "0.24.1"

View file

@ -0,0 +1,52 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

192
crates/dispatch/src/main.rs Normal file
View file

@ -0,0 +1,192 @@
#![feature(ip)]
use hickory_client::{
client::{AsyncClient, ClientHandle},
rr::{DNSClass, Name, RData, RecordType},
udp::UdpClientStream,
};
use reqwest::{redirect::Policy, StatusCode};
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use tokio::{net::UdpSocket, sync::RwLock};
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;
use axum::{extract::State, http::Uri, routing::post, Json, Router};
mod logger;
// this package does not currently use libpk
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.json()
.with_env_filter(EnvFilter::from_default_env())
.init();
info!("hello world");
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(3));
let (client, bg) = AsyncClient::connect(stream).await?;
tokio::spawn(bg);
let app = Router::new()
.route("/", post(dispatch))
.with_state(Arc::new(RwLock::new(DNSClient(client))))
.layer(axum::middleware::from_fn(logger::logger));
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await?;
axum::serve(listener, app).await?;
Ok(())
}
#[derive(Debug, serde::Deserialize)]
struct DispatchRequest {
auth: String,
url: String,
payload: String,
test: Option<String>,
}
#[derive(Debug)]
enum DispatchResponse {
OK,
BadData,
ResolveFailed,
NoIPs,
InvalidIP,
FetchFailed,
InvalidResponseCode(StatusCode),
TestFailed,
}
impl std::fmt::Display for DispatchResponse {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
async fn dispatch(
// not entirely sure if this RwLock is the right way to do it
State(dns): State<Arc<RwLock<DNSClient>>>,
Json(req): Json<DispatchRequest>,
) -> String {
// todo: fix
if req.auth != std::env::var("HTTP_AUTH_TOKEN").unwrap() {
return "".to_string();
}
let uri = match req.url.parse::<Uri>() {
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
Err(error) => {
error!(?error, "failed to parse uri {}", req.url);
return DispatchResponse::BadData.to_string();
}
_ => {
error!("uri {} is invalid", req.url);
return DispatchResponse::BadData.to_string();
}
};
let ips = {
let mut dns = dns.write().await;
match dns.resolve(uri.host().unwrap().to_string()).await {
Ok(v) => v,
Err(error) => {
error!(?error, "failed to resolve");
return DispatchResponse::ResolveFailed.to_string();
}
}
};
if ips.iter().any(|ip| !ip.is_global()) {
return DispatchResponse::InvalidIP.to_string();
}
if ips.len() == 0 {
return DispatchResponse::NoIPs.to_string();
}
let ips: Vec<SocketAddr> = ips
.iter()
.map(|ip| SocketAddr::V4(SocketAddrV4::new(*ip, 443)))
.collect();
let client = reqwest::ClientBuilder::new()
.user_agent("PluralKit Dispatch (https://pluralkit.me/api/dispatch/)")
.redirect(Policy::none())
.timeout(Duration::from_secs(10))
.http1_only()
.use_rustls_tls()
.https_only(true)
.resolve_to_addrs(uri.host().unwrap(), &ips)
.build()
.unwrap();
let res = client
.post(req.url.clone())
.header("content-type", "application/json")
.body(req.payload)
.send()
.await;
match res {
Ok(res) if res.status() != 200 => {
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
}
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");
return DispatchResponse::FetchFailed.to_string();
}
_ => {}
}
if let Some(test) = req.test {
let test_res = client
.post(req.url.clone())
.header("content-type", "application/json")
.body(test)
.send()
.await;
match test_res {
Ok(res) if res.status() != 401 => return DispatchResponse::TestFailed.to_string(),
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");
return DispatchResponse::FetchFailed.to_string();
}
_ => {}
}
}
DispatchResponse::OK.to_string()
}
struct DNSClient(AsyncClient);
impl DNSClient {
async fn resolve(&mut self, host: String) -> anyhow::Result<Vec<Ipv4Addr>> {
let resp = self
.0
.query(Name::from_ascii(host)?, DNSClass::IN, RecordType::A)
.await?;
debug!("got dns response: {resp:?}");
Ok(resp
.answers()
.iter()
.filter_map(|ans| {
if let Some(RData::A(val)) = ans.data() {
Some(val.0)
} else {
None
}
})
.collect())
}
}