mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
feat: improve dispatch security
This commit is contained in:
parent
aa04124639
commit
45640f08ee
18 changed files with 893 additions and 269 deletions
33
.github/workflows/dispatch.yml
vendored
Normal file
33
.github/workflows/dispatch.yml
vendored
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
name: Build and push dispatch Docker image
|
||||
on:
|
||||
push:
|
||||
paths:
|
||||
- '.github/workflows/dispatch.yml'
|
||||
- 'Cargo.lock'
|
||||
- 'services/dispatch/'
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
packages: write
|
||||
if: github.repository == 'PluralKit/PluralKit'
|
||||
steps:
|
||||
- uses: docker/login-action@v1
|
||||
with:
|
||||
registry: ghcr.io
|
||||
username: ${{ github.actor }}
|
||||
password: ${{ secrets.CR_PAT }}
|
||||
- uses: actions/checkout@v2
|
||||
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" | sed 's|/|-|g' >> $GITHUB_ENV
|
||||
- uses: docker/build-push-action@v2
|
||||
with:
|
||||
# https://github.com/docker/build-push-action/issues/378
|
||||
context: .
|
||||
push: true
|
||||
file: services/dispatch/Dockerfile
|
||||
tags: |
|
||||
ghcr.io/pluralkit/dispatch:${{ github.sha }}
|
||||
ghcr.io/pluralkit/dispatch:${{ env.BRANCH_NAME }}
|
||||
cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}
|
||||
cache-to: type=inline
|
||||
706
Cargo.lock
generated
706
Cargo.lock
generated
File diff suppressed because it is too large
Load diff
|
|
@ -1,7 +1,8 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"./lib/libpk",
|
||||
"./services/api"
|
||||
"./services/api",
|
||||
"./services/dispatch"
|
||||
]
|
||||
|
||||
[workspace.dependencies]
|
||||
|
|
@ -15,6 +16,7 @@ serde_json = "1.0.117"
|
|||
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "chrono", "macros"] }
|
||||
tokio = { version = "1.25.0", features = ["full"] }
|
||||
tracing = "0.1.37"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }
|
||||
|
||||
prost = "0.12"
|
||||
prost-types = "0.12"
|
||||
|
|
|
|||
|
|
@ -516,8 +516,8 @@
|
|||
},
|
||||
"Npgsql": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.1.5",
|
||||
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
|
||||
"resolved": "4.1.13",
|
||||
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
|
||||
"dependencies": {
|
||||
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
|
||||
}
|
||||
|
|
@ -1637,7 +1637,7 @@
|
|||
"Newtonsoft.Json": "[13.0.1, )",
|
||||
"NodaTime": "[3.0.3, )",
|
||||
"NodaTime.Serialization.JsonNet": "[3.0.0, )",
|
||||
"Npgsql": "[4.1.5, )",
|
||||
"Npgsql": "[4.1.13, )",
|
||||
"Npgsql.NodaTime": "[4.1.5, )",
|
||||
"Serilog": "[2.12.0, )",
|
||||
"Serilog.Extensions.Logging": "[3.0.1, )",
|
||||
|
|
|
|||
|
|
@ -143,25 +143,32 @@ public class Api
|
|||
if (_webhookRegex.IsMatch(newUrl))
|
||||
throw new PKError("PluralKit does not currently support setting a Discord webhook URL as your system's webhook URL.");
|
||||
|
||||
try
|
||||
{
|
||||
await _dispatch.DoPostRequest(ctx.System.Id, newUrl, null, true);
|
||||
}
|
||||
catch (Exception e)
|
||||
{
|
||||
throw new PKError($"Could not verify that the new URL is working: {e.Message}");
|
||||
}
|
||||
|
||||
var newToken = StringUtils.GenerateToken();
|
||||
|
||||
await ctx.Reply($"{Emojis.Warn} The following token is used to authenticate requests from PluralKit to you."
|
||||
+ " If it is exposed publicly, you **must** clear and re-set the webhook URL to get a new token."
|
||||
+ "\n\n**Please review the security requirements at <https://pluralkit.me/api/dispatch#security> before continuing.**"
|
||||
+ "\n\nWhen the server is correctly validating the token, click or reply 'yes' to continue."
|
||||
);
|
||||
await ctx.PromptYesNo(newToken, "Continue", matchFlag: false);
|
||||
|
||||
var status = await _dispatch.TestUrl(ctx.System.Uuid, newUrl, newToken);
|
||||
if (status != "OK")
|
||||
{
|
||||
var message = status switch
|
||||
{
|
||||
"BadData" => "the webhook url is invalid",
|
||||
"NoIPs" => "could not find any valid IP addresses for the provided domain",
|
||||
"InvalidIP" => "could not find any valid IP addresses for the provided domain",
|
||||
"FetchFailed" => "unable to reach server",
|
||||
"TestFailed" => "server failed to validate the signing token",
|
||||
_ => $"an unknown error occurred ({status})"
|
||||
};
|
||||
throw new PKError($"Failed to validate the webhook url: {message}");
|
||||
}
|
||||
|
||||
await ctx.Repository.UpdateSystem(ctx.System.Id, new SystemPatch { WebhookUrl = newUrl, WebhookToken = newToken });
|
||||
|
||||
await ctx.Reply($"{Emojis.Success} Successfully the new webhook URL for your system."
|
||||
+ $"\n\n{Emojis.Warn} The following token is used to authenticate requests from PluralKit to you."
|
||||
+ " If it leaks, you should clear and re-set the webhook URL to get a new token."
|
||||
+ "\ntodo: add link to docs or something"
|
||||
);
|
||||
|
||||
await ctx.Reply(newToken);
|
||||
await ctx.Reply($"{Emojis.Success} Successfully the new webhook URL for your system.");
|
||||
}
|
||||
}
|
||||
|
|
@ -42,9 +42,9 @@
|
|||
},
|
||||
"SixLabors.ImageSharp": {
|
||||
"type": "Direct",
|
||||
"requested": "[3.0.1, )",
|
||||
"resolved": "3.0.1",
|
||||
"contentHash": "o0v/J6SJwp3RFrzR29beGx0cK7xcMRgOyIuw8ZNLQyNnBhiyL/vIQKn7cfycthcWUPG3XezUjFwBWzkcUUDFbg=="
|
||||
"requested": "[3.1.5, )",
|
||||
"resolved": "3.1.5",
|
||||
"contentHash": "lNtlq7dSI/QEbYey+A0xn48z5w4XHSffF8222cC4F4YwTXfEImuiBavQcWjr49LThT/pRmtWJRcqA/PlL+eJ6g=="
|
||||
},
|
||||
"App.Metrics": {
|
||||
"type": "Transitive",
|
||||
|
|
@ -466,8 +466,8 @@
|
|||
},
|
||||
"Npgsql": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.1.5",
|
||||
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
|
||||
"resolved": "4.1.13",
|
||||
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
|
||||
"dependencies": {
|
||||
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
|
||||
}
|
||||
|
|
@ -1556,7 +1556,7 @@
|
|||
"Newtonsoft.Json": "[13.0.1, )",
|
||||
"NodaTime": "[3.0.3, )",
|
||||
"NodaTime.Serialization.JsonNet": "[3.0.0, )",
|
||||
"Npgsql": "[4.1.5, )",
|
||||
"Npgsql": "[4.1.13, )",
|
||||
"Npgsql.NodaTime": "[4.1.5, )",
|
||||
"Serilog": "[2.12.0, )",
|
||||
"Serilog.Extensions.Logging": "[3.0.1, )",
|
||||
|
|
|
|||
|
|
@ -15,6 +15,8 @@ public class CoreConfig
|
|||
public string LogDir { get; set; }
|
||||
public string? ElasticUrl { get; set; }
|
||||
public string? SeqLogUrl { get; set; }
|
||||
public string? DispatchProxyUrl { get; set; }
|
||||
public string? DispatchProxyToken { get; set; }
|
||||
|
||||
public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Debug;
|
||||
public LogEventLevel ElasticLogLevel { get; set; } = LogEventLevel.Information;
|
||||
|
|
|
|||
|
|
@ -43,7 +43,7 @@ public struct UpdateDispatchData
|
|||
|
||||
public static class DispatchExt
|
||||
{
|
||||
public static StringContent GetPayloadBody(this UpdateDispatchData data)
|
||||
public static string GetPayloadBody(this UpdateDispatchData data)
|
||||
{
|
||||
var o = new JObject();
|
||||
|
||||
|
|
@ -53,7 +53,18 @@ public static class DispatchExt
|
|||
o.Add("id", data.EntityId);
|
||||
o.Add("data", data.EventData);
|
||||
|
||||
return new StringContent(JsonConvert.SerializeObject(o), Encoding.UTF8, "application/json");
|
||||
return JsonConvert.SerializeObject(o);
|
||||
}
|
||||
|
||||
public static string GetPingBody(string systemId, string token)
|
||||
{
|
||||
var o = new JObject();
|
||||
|
||||
o.Add("type", "PING");
|
||||
o.Add("signing_token", token);
|
||||
o.Add("system_id", systemId);
|
||||
|
||||
return JsonConvert.SerializeObject(o);
|
||||
}
|
||||
|
||||
private static List<IPNetwork> _privateNetworks = new()
|
||||
|
|
@ -71,6 +82,7 @@ public static class DispatchExt
|
|||
try
|
||||
{
|
||||
var uri = new Uri(url);
|
||||
if (uri.Scheme != "https") return false;
|
||||
host = await Dns.GetHostEntryAsync(uri.DnsSafeHost);
|
||||
}
|
||||
catch (Exception)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
using Autofac;
|
||||
|
||||
using System.Text;
|
||||
|
||||
using Newtonsoft.Json;
|
||||
using Newtonsoft.Json.Linq;
|
||||
|
||||
using Serilog;
|
||||
|
||||
namespace PluralKit.Core;
|
||||
|
|
@ -8,32 +13,55 @@ public class DispatchService
|
|||
{
|
||||
private readonly HttpClient _client = new();
|
||||
private readonly ILogger _logger;
|
||||
private readonly CoreConfig _cfg;
|
||||
private readonly ILifetimeScope _provider;
|
||||
|
||||
public DispatchService(ILogger logger, ILifetimeScope provider, CoreConfig cfg)
|
||||
{
|
||||
_logger = logger;
|
||||
_cfg = cfg;
|
||||
_provider = provider;
|
||||
}
|
||||
|
||||
public async Task DoPostRequest(SystemId system, string webhookUrl, HttpContent content, bool isVerify = false)
|
||||
public async Task<string> TestUrl(Guid systemUuid, string newUrl, string newToken)
|
||||
{
|
||||
if (!await DispatchExt.ValidateUri(webhookUrl))
|
||||
if (_cfg.DispatchProxyUrl == null || _cfg.DispatchProxyToken == null)
|
||||
throw new Exception("tried to dispatch without a proxy set!");
|
||||
|
||||
var o = new JObject();
|
||||
o.Add("auth", _cfg.DispatchProxyToken);
|
||||
o.Add("url", newUrl);
|
||||
o.Add("payload", DispatchExt.GetPingBody(systemUuid.ToString(), newToken));
|
||||
o.Add("test", DispatchExt.GetPingBody(systemUuid.ToString(), StringUtils.GenerateToken()));
|
||||
|
||||
var body = new StringContent(JsonConvert.SerializeObject(o), Encoding.UTF8, "application/json");
|
||||
|
||||
var res = await _client.PostAsync(_cfg.DispatchProxyUrl, body);
|
||||
return await res.Content.ReadAsStringAsync();
|
||||
}
|
||||
|
||||
public async Task DoPostRequest(SystemId system, string webhookUrl, string content)
|
||||
{
|
||||
if (_cfg.DispatchProxyUrl == null || _cfg.DispatchProxyToken == null)
|
||||
{
|
||||
_logger.Warning(
|
||||
"Failed to dispatch webhook for system {SystemId}: URL is invalid or points to a private address",
|
||||
system);
|
||||
_logger.Warning("tried to dispatch without a proxy set!");
|
||||
return;
|
||||
}
|
||||
|
||||
var o = new JObject();
|
||||
o.Add("auth", _cfg.DispatchProxyToken);
|
||||
o.Add("url", webhookUrl);
|
||||
o.Add("payload", content);
|
||||
|
||||
var body = new StringContent(JsonConvert.SerializeObject(o), Encoding.UTF8, "application/json");
|
||||
|
||||
try
|
||||
{
|
||||
await _client.PostAsync(webhookUrl, content);
|
||||
await _client.PostAsync(_cfg.DispatchProxyUrl, body);
|
||||
// todo: do something with proxy errors
|
||||
}
|
||||
catch (HttpRequestException e)
|
||||
{
|
||||
if (isVerify)
|
||||
throw;
|
||||
_logger.Error(e, "Could not dispatch webhook request!");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ namespace PluralKit.Core;
|
|||
|
||||
public static class Emojis
|
||||
{
|
||||
public static readonly string Warn = "\u26A0";
|
||||
public static readonly string Warn = "\u26A0\uFE0F";
|
||||
public static readonly string Success = "\u2705";
|
||||
public static readonly string Error = "\u274C";
|
||||
public static readonly string Note = "\U0001f4dd";
|
||||
|
|
|
|||
|
|
@ -183,9 +183,9 @@
|
|||
},
|
||||
"Npgsql": {
|
||||
"type": "Direct",
|
||||
"requested": "[4.1.5, )",
|
||||
"resolved": "4.1.5",
|
||||
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
|
||||
"requested": "[4.1.13, )",
|
||||
"resolved": "4.1.13",
|
||||
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
|
||||
"dependencies": {
|
||||
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
|
||||
}
|
||||
|
|
|
|||
|
|
@ -592,8 +592,8 @@
|
|||
},
|
||||
"Npgsql": {
|
||||
"type": "Transitive",
|
||||
"resolved": "4.1.5",
|
||||
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
|
||||
"resolved": "4.1.13",
|
||||
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
|
||||
"dependencies": {
|
||||
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
|
||||
}
|
||||
|
|
@ -891,8 +891,8 @@
|
|||
},
|
||||
"SixLabors.ImageSharp": {
|
||||
"type": "Transitive",
|
||||
"resolved": "3.0.1",
|
||||
"contentHash": "o0v/J6SJwp3RFrzR29beGx0cK7xcMRgOyIuw8ZNLQyNnBhiyL/vIQKn7cfycthcWUPG3XezUjFwBWzkcUUDFbg=="
|
||||
"resolved": "3.1.5",
|
||||
"contentHash": "lNtlq7dSI/QEbYey+A0xn48z5w4XHSffF8222cC4F4YwTXfEImuiBavQcWjr49LThT/pRmtWJRcqA/PlL+eJ6g=="
|
||||
},
|
||||
"SqlKata": {
|
||||
"type": "Transitive",
|
||||
|
|
@ -1810,7 +1810,7 @@
|
|||
"Myriad": "[1.0.0, )",
|
||||
"PluralKit.Core": "[1.0.0, )",
|
||||
"Sentry": "[3.11.1, )",
|
||||
"SixLabors.ImageSharp": "[3.0.1, )"
|
||||
"SixLabors.ImageSharp": "[3.1.5, )"
|
||||
}
|
||||
},
|
||||
"pluralkit.core": {
|
||||
|
|
@ -1834,7 +1834,7 @@
|
|||
"Newtonsoft.Json": "[13.0.1, )",
|
||||
"NodaTime": "[3.0.3, )",
|
||||
"NodaTime.Serialization.JsonNet": "[3.0.0, )",
|
||||
"Npgsql": "[4.1.5, )",
|
||||
"Npgsql": "[4.1.13, )",
|
||||
"Npgsql.NodaTime": "[4.1.5, )",
|
||||
"Serilog": "[2.12.0, )",
|
||||
"Serilog.Extensions.Logging": "[3.0.1, )",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ To get dispatch events from PluralKit, you must set up a *public* HTTP endpoint.
|
|||
|
||||
For this reason, when you register a webhook URL, PluralKit generates a secret token, and then includes it with every event sent to you in the `signing_token` key. If you receive an event with an invalid `signing_token`, you **must** stop processing the request and **respond with a 401 status code**.
|
||||
|
||||
PluralKit will send invalid requests to your endpoint, with `PING` event type, once in a while to confirm that you are correctly validating requests.
|
||||
PluralKit will send invalid requests to your endpoint, with `PING` event type, once in a while to confirm that you are correctly validating requests. If validation fails, or if requests to your endpoint are repeatedly unsuccessful, the endpoint will be removed.
|
||||
|
||||
## Dispatch Event Model
|
||||
|
||||
|
|
|
|||
|
|
@ -15,8 +15,7 @@ serde = { workspace = true }
|
|||
sqlx = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-gelf = "0.7.1"
|
||||
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
|
||||
tracing-subscriber = { workspace = true}
|
||||
|
||||
prost = { workspace = true }
|
||||
prost-types = { workspace = true }
|
||||
|
|
|
|||
16
services/dispatch/Cargo.toml
Normal file
16
services/dispatch/Cargo.toml
Normal file
|
|
@ -0,0 +1,16 @@
|
|||
[package]
|
||||
name = "dispatch"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
axum = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
|
||||
hickory-client = "0.24.1"
|
||||
reqwest = { version = "0.12.7", default-features = false, features = ["rustls-tls"] }
|
||||
17
services/dispatch/Dockerfile
Normal file
17
services/dispatch/Dockerfile
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
FROM alpine:latest AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
RUN apk add rustup build-base
|
||||
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain nightly-2024-08-20 --profile default -y
|
||||
|
||||
ENV PATH=/root/.cargo/bin:$PATH
|
||||
ENV RUSTFLAGS='-C link-arg=-s'
|
||||
|
||||
COPY . .
|
||||
|
||||
RUN cargo build --bin dispatch --release --target x86_64-unknown-linux-musl
|
||||
|
||||
FROM alpine:latest
|
||||
COPY --from=builder /build/target/x86_64-unknown-linux-musl/release/dispatch /usr/local/bin/dispatch
|
||||
ENTRYPOINT ["/usr/local/bin/dispatch"]
|
||||
52
services/dispatch/src/logger.rs
Normal file
52
services/dispatch/src/logger.rs
Normal file
|
|
@ -0,0 +1,52 @@
|
|||
use std::time::Instant;
|
||||
|
||||
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
|
||||
use tracing::{info, span, warn, Instrument, Level};
|
||||
|
||||
// log any requests that take longer than 2 seconds
|
||||
// todo: change as necessary
|
||||
const MIN_LOG_TIME: u128 = 2_000;
|
||||
|
||||
pub async fn logger(request: Request, next: Next) -> Response {
|
||||
let method = request.method().clone();
|
||||
|
||||
let endpoint = request
|
||||
.extensions()
|
||||
.get::<MatchedPath>()
|
||||
.cloned()
|
||||
.map(|v| v.as_str().to_string())
|
||||
.unwrap_or("unknown".to_string());
|
||||
|
||||
let uri = request.uri().clone();
|
||||
|
||||
let request_id_span = span!(
|
||||
Level::INFO,
|
||||
"request",
|
||||
method = method.as_str(),
|
||||
endpoint = endpoint.clone(),
|
||||
);
|
||||
|
||||
let start = Instant::now();
|
||||
let response = next.run(request).instrument(request_id_span).await;
|
||||
let elapsed = start.elapsed().as_millis();
|
||||
|
||||
info!(
|
||||
"{} handled request for {} {} in {}ms",
|
||||
response.status(),
|
||||
method,
|
||||
uri.path(),
|
||||
elapsed
|
||||
);
|
||||
|
||||
if elapsed > MIN_LOG_TIME {
|
||||
warn!(
|
||||
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
|
||||
method,
|
||||
uri.path(),
|
||||
endpoint,
|
||||
elapsed
|
||||
)
|
||||
}
|
||||
|
||||
response
|
||||
}
|
||||
190
services/dispatch/src/main.rs
Normal file
190
services/dispatch/src/main.rs
Normal file
|
|
@ -0,0 +1,190 @@
|
|||
#![feature(ip)]
|
||||
|
||||
use hickory_client::{
|
||||
client::{AsyncClient, ClientHandle},
|
||||
rr::{DNSClass, Name, RData, RecordType},
|
||||
udp::UdpClientStream,
|
||||
};
|
||||
use reqwest::{redirect::Policy, StatusCode};
|
||||
use std::{
|
||||
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
|
||||
sync::Arc,
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{net::UdpSocket, sync::RwLock};
|
||||
use tracing::{debug, error, info};
|
||||
use tracing_subscriber::EnvFilter;
|
||||
|
||||
use axum::{extract::State, http::Uri, routing::post, Json, Router};
|
||||
|
||||
mod logger;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt()
|
||||
.json()
|
||||
.with_env_filter(EnvFilter::from_default_env())
|
||||
.init();
|
||||
|
||||
info!("hello world");
|
||||
|
||||
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
|
||||
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(1));
|
||||
let (client, bg) = AsyncClient::connect(stream).await?;
|
||||
tokio::spawn(bg);
|
||||
|
||||
let app = Router::new()
|
||||
.route("/", post(dispatch))
|
||||
.with_state(Arc::new(RwLock::new(DNSClient(client))))
|
||||
.layer(axum::middleware::from_fn(logger::logger));
|
||||
|
||||
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await?;
|
||||
axum::serve(listener, app).await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct DispatchRequest {
|
||||
auth: String,
|
||||
url: String,
|
||||
payload: String,
|
||||
test: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum DispatchResponse {
|
||||
OK,
|
||||
BadData,
|
||||
ResolveFailed,
|
||||
NoIPs,
|
||||
InvalidIP,
|
||||
FetchFailed,
|
||||
InvalidResponseCode(StatusCode),
|
||||
TestFailed,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for DispatchResponse {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "{:?}", self)
|
||||
}
|
||||
}
|
||||
|
||||
async fn dispatch(
|
||||
// not entirely sure if this RwLock is the right way to do it
|
||||
State(dns): State<Arc<RwLock<DNSClient>>>,
|
||||
Json(req): Json<DispatchRequest>,
|
||||
) -> String {
|
||||
// todo: fix
|
||||
if req.auth != std::env::var("HTTP_AUTH_TOKEN").unwrap() {
|
||||
return "".to_string();
|
||||
}
|
||||
|
||||
let uri = match req.url.parse::<Uri>() {
|
||||
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
|
||||
Err(error) => {
|
||||
error!(?error, "failed to parse uri {}", req.url);
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
_ => {
|
||||
error!("uri {} is invalid", req.url);
|
||||
return DispatchResponse::BadData.to_string();
|
||||
}
|
||||
};
|
||||
let ips = {
|
||||
let mut dns = dns.write().await;
|
||||
match dns.resolve(uri.host().unwrap().to_string()).await {
|
||||
Ok(v) => v,
|
||||
Err(error) => {
|
||||
error!(?error, "failed to resolve");
|
||||
return DispatchResponse::ResolveFailed.to_string();
|
||||
}
|
||||
}
|
||||
};
|
||||
if ips.iter().any(|ip| !ip.is_global()) {
|
||||
return DispatchResponse::InvalidIP.to_string();
|
||||
}
|
||||
|
||||
if ips.len() == 0 {
|
||||
return DispatchResponse::NoIPs.to_string();
|
||||
}
|
||||
|
||||
let ips: Vec<SocketAddr> = ips
|
||||
.iter()
|
||||
.map(|ip| SocketAddr::V4(SocketAddrV4::new(*ip, 443)))
|
||||
.collect();
|
||||
|
||||
let client = reqwest::ClientBuilder::new()
|
||||
.user_agent("PluralKit Dispatch (https://pluralkit.me/api/dispatch/)")
|
||||
.redirect(Policy::none())
|
||||
.timeout(Duration::from_secs(10))
|
||||
.http1_only()
|
||||
.use_rustls_tls()
|
||||
.https_only(true)
|
||||
.resolve_to_addrs(uri.host().unwrap(), &ips)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
let res = client
|
||||
.post(req.url.clone())
|
||||
.header("content-type", "application/json")
|
||||
.body(req.payload)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match res {
|
||||
Ok(res) if res.status() != 200 => {
|
||||
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
|
||||
}
|
||||
Err(error) => {
|
||||
error!(?error, url = req.url.clone(), "failed to fetch");
|
||||
return DispatchResponse::FetchFailed.to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
|
||||
if let Some(test) = req.test {
|
||||
let test_res = client
|
||||
.post(req.url.clone())
|
||||
.header("content-type", "application/json")
|
||||
.body(test)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
match test_res {
|
||||
Ok(res) if res.status() != 401 => return DispatchResponse::TestFailed.to_string(),
|
||||
Err(error) => {
|
||||
error!(?error, url = req.url.clone(), "failed to fetch");
|
||||
return DispatchResponse::FetchFailed.to_string();
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
DispatchResponse::OK.to_string()
|
||||
}
|
||||
|
||||
struct DNSClient(AsyncClient);
|
||||
|
||||
impl DNSClient {
|
||||
async fn resolve(&mut self, host: String) -> anyhow::Result<Vec<Ipv4Addr>> {
|
||||
let resp = self
|
||||
.0
|
||||
.query(Name::from_ascii(host)?, DNSClass::IN, RecordType::A)
|
||||
.await?;
|
||||
|
||||
debug!("got dns response: {resp:?}");
|
||||
|
||||
Ok(resp
|
||||
.answers()
|
||||
.iter()
|
||||
.filter_map(|ans| {
|
||||
if let Some(RData::A(val)) = ans.data() {
|
||||
Some(val.0)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect())
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue