feat: improve dispatch security

This commit is contained in:
alyssa 2024-08-22 07:10:35 +09:00
parent aa04124639
commit 45640f08ee
18 changed files with 893 additions and 269 deletions

33
.github/workflows/dispatch.yml vendored Normal file
View file

@ -0,0 +1,33 @@
name: Build and push dispatch Docker image
on:
push:
paths:
- '.github/workflows/dispatch.yml'
- 'Cargo.lock'
- 'services/dispatch/'
jobs:
deploy:
runs-on: ubuntu-latest
permissions:
packages: write
if: github.repository == 'PluralKit/PluralKit'
steps:
- uses: docker/login-action@v1
with:
registry: ghcr.io
username: ${{ github.actor }}
password: ${{ secrets.CR_PAT }}
- uses: actions/checkout@v2
- run: echo "BRANCH_NAME=${GITHUB_REF#refs/heads/}" | sed 's|/|-|g' >> $GITHUB_ENV
- uses: docker/build-push-action@v2
with:
# https://github.com/docker/build-push-action/issues/378
context: .
push: true
file: services/dispatch/Dockerfile
tags: |
ghcr.io/pluralkit/dispatch:${{ github.sha }}
ghcr.io/pluralkit/dispatch:${{ env.BRANCH_NAME }}
cache-from: type=registry,ref=ghcr.io/pluralkit/pluralkit:${{ env.BRANCH_NAME }}
cache-to: type=inline

706
Cargo.lock generated

File diff suppressed because it is too large Load diff

View file

@ -1,7 +1,8 @@
[workspace]
members = [
"./lib/libpk",
"./services/api"
"./services/api",
"./services/dispatch"
]
[workspace.dependencies]
@ -15,6 +16,7 @@ serde_json = "1.0.117"
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "chrono", "macros"] }
tokio = { version = "1.25.0", features = ["full"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }
prost = "0.12"
prost-types = "0.12"

View file

@ -516,8 +516,8 @@
},
"Npgsql": {
"type": "Transitive",
"resolved": "4.1.5",
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
"resolved": "4.1.13",
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
"dependencies": {
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
}
@ -1637,7 +1637,7 @@
"Newtonsoft.Json": "[13.0.1, )",
"NodaTime": "[3.0.3, )",
"NodaTime.Serialization.JsonNet": "[3.0.0, )",
"Npgsql": "[4.1.5, )",
"Npgsql": "[4.1.13, )",
"Npgsql.NodaTime": "[4.1.5, )",
"Serilog": "[2.12.0, )",
"Serilog.Extensions.Logging": "[3.0.1, )",

View file

@ -143,25 +143,32 @@ public class Api
if (_webhookRegex.IsMatch(newUrl))
throw new PKError("PluralKit does not currently support setting a Discord webhook URL as your system's webhook URL.");
try
{
await _dispatch.DoPostRequest(ctx.System.Id, newUrl, null, true);
}
catch (Exception e)
{
throw new PKError($"Could not verify that the new URL is working: {e.Message}");
}
var newToken = StringUtils.GenerateToken();
await ctx.Reply($"{Emojis.Warn} The following token is used to authenticate requests from PluralKit to you."
+ " If it is exposed publicly, you **must** clear and re-set the webhook URL to get a new token."
+ "\n\n**Please review the security requirements at <https://pluralkit.me/api/dispatch#security> before continuing.**"
+ "\n\nWhen the server is correctly validating the token, click or reply 'yes' to continue."
);
await ctx.PromptYesNo(newToken, "Continue", matchFlag: false);
var status = await _dispatch.TestUrl(ctx.System.Uuid, newUrl, newToken);
if (status != "OK")
{
var message = status switch
{
"BadData" => "the webhook url is invalid",
"NoIPs" => "could not find any valid IP addresses for the provided domain",
"InvalidIP" => "could not find any valid IP addresses for the provided domain",
"FetchFailed" => "unable to reach server",
"TestFailed" => "server failed to validate the signing token",
_ => $"an unknown error occurred ({status})"
};
throw new PKError($"Failed to validate the webhook url: {message}");
}
await ctx.Repository.UpdateSystem(ctx.System.Id, new SystemPatch { WebhookUrl = newUrl, WebhookToken = newToken });
await ctx.Reply($"{Emojis.Success} Successfully the new webhook URL for your system."
+ $"\n\n{Emojis.Warn} The following token is used to authenticate requests from PluralKit to you."
+ " If it leaks, you should clear and re-set the webhook URL to get a new token."
+ "\ntodo: add link to docs or something"
);
await ctx.Reply(newToken);
await ctx.Reply($"{Emojis.Success} Successfully the new webhook URL for your system.");
}
}

View file

@ -42,9 +42,9 @@
},
"SixLabors.ImageSharp": {
"type": "Direct",
"requested": "[3.0.1, )",
"resolved": "3.0.1",
"contentHash": "o0v/J6SJwp3RFrzR29beGx0cK7xcMRgOyIuw8ZNLQyNnBhiyL/vIQKn7cfycthcWUPG3XezUjFwBWzkcUUDFbg=="
"requested": "[3.1.5, )",
"resolved": "3.1.5",
"contentHash": "lNtlq7dSI/QEbYey+A0xn48z5w4XHSffF8222cC4F4YwTXfEImuiBavQcWjr49LThT/pRmtWJRcqA/PlL+eJ6g=="
},
"App.Metrics": {
"type": "Transitive",
@ -466,8 +466,8 @@
},
"Npgsql": {
"type": "Transitive",
"resolved": "4.1.5",
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
"resolved": "4.1.13",
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
"dependencies": {
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
}
@ -1556,7 +1556,7 @@
"Newtonsoft.Json": "[13.0.1, )",
"NodaTime": "[3.0.3, )",
"NodaTime.Serialization.JsonNet": "[3.0.0, )",
"Npgsql": "[4.1.5, )",
"Npgsql": "[4.1.13, )",
"Npgsql.NodaTime": "[4.1.5, )",
"Serilog": "[2.12.0, )",
"Serilog.Extensions.Logging": "[3.0.1, )",

View file

@ -15,6 +15,8 @@ public class CoreConfig
public string LogDir { get; set; }
public string? ElasticUrl { get; set; }
public string? SeqLogUrl { get; set; }
public string? DispatchProxyUrl { get; set; }
public string? DispatchProxyToken { get; set; }
public LogEventLevel ConsoleLogLevel { get; set; } = LogEventLevel.Debug;
public LogEventLevel ElasticLogLevel { get; set; } = LogEventLevel.Information;

View file

@ -43,7 +43,7 @@ public struct UpdateDispatchData
public static class DispatchExt
{
public static StringContent GetPayloadBody(this UpdateDispatchData data)
public static string GetPayloadBody(this UpdateDispatchData data)
{
var o = new JObject();
@ -53,7 +53,18 @@ public static class DispatchExt
o.Add("id", data.EntityId);
o.Add("data", data.EventData);
return new StringContent(JsonConvert.SerializeObject(o), Encoding.UTF8, "application/json");
return JsonConvert.SerializeObject(o);
}
public static string GetPingBody(string systemId, string token)
{
var o = new JObject();
o.Add("type", "PING");
o.Add("signing_token", token);
o.Add("system_id", systemId);
return JsonConvert.SerializeObject(o);
}
private static List<IPNetwork> _privateNetworks = new()
@ -71,6 +82,7 @@ public static class DispatchExt
try
{
var uri = new Uri(url);
if (uri.Scheme != "https") return false;
host = await Dns.GetHostEntryAsync(uri.DnsSafeHost);
}
catch (Exception)

View file

@ -1,5 +1,10 @@
using Autofac;
using System.Text;
using Newtonsoft.Json;
using Newtonsoft.Json.Linq;
using Serilog;
namespace PluralKit.Core;
@ -8,32 +13,55 @@ public class DispatchService
{
private readonly HttpClient _client = new();
private readonly ILogger _logger;
private readonly CoreConfig _cfg;
private readonly ILifetimeScope _provider;
public DispatchService(ILogger logger, ILifetimeScope provider, CoreConfig cfg)
{
_logger = logger;
_cfg = cfg;
_provider = provider;
}
public async Task DoPostRequest(SystemId system, string webhookUrl, HttpContent content, bool isVerify = false)
public async Task<string> TestUrl(Guid systemUuid, string newUrl, string newToken)
{
if (!await DispatchExt.ValidateUri(webhookUrl))
if (_cfg.DispatchProxyUrl == null || _cfg.DispatchProxyToken == null)
throw new Exception("tried to dispatch without a proxy set!");
var o = new JObject();
o.Add("auth", _cfg.DispatchProxyToken);
o.Add("url", newUrl);
o.Add("payload", DispatchExt.GetPingBody(systemUuid.ToString(), newToken));
o.Add("test", DispatchExt.GetPingBody(systemUuid.ToString(), StringUtils.GenerateToken()));
var body = new StringContent(JsonConvert.SerializeObject(o), Encoding.UTF8, "application/json");
var res = await _client.PostAsync(_cfg.DispatchProxyUrl, body);
return await res.Content.ReadAsStringAsync();
}
public async Task DoPostRequest(SystemId system, string webhookUrl, string content)
{
if (_cfg.DispatchProxyUrl == null || _cfg.DispatchProxyToken == null)
{
_logger.Warning(
"Failed to dispatch webhook for system {SystemId}: URL is invalid or points to a private address",
system);
_logger.Warning("tried to dispatch without a proxy set!");
return;
}
var o = new JObject();
o.Add("auth", _cfg.DispatchProxyToken);
o.Add("url", webhookUrl);
o.Add("payload", content);
var body = new StringContent(JsonConvert.SerializeObject(o), Encoding.UTF8, "application/json");
try
{
await _client.PostAsync(webhookUrl, content);
await _client.PostAsync(_cfg.DispatchProxyUrl, body);
// todo: do something with proxy errors
}
catch (HttpRequestException e)
{
if (isVerify)
throw;
_logger.Error(e, "Could not dispatch webhook request!");
}
}

View file

@ -2,7 +2,7 @@ namespace PluralKit.Core;
public static class Emojis
{
public static readonly string Warn = "\u26A0";
public static readonly string Warn = "\u26A0\uFE0F";
public static readonly string Success = "\u2705";
public static readonly string Error = "\u274C";
public static readonly string Note = "\U0001f4dd";

View file

@ -183,9 +183,9 @@
},
"Npgsql": {
"type": "Direct",
"requested": "[4.1.5, )",
"resolved": "4.1.5",
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
"requested": "[4.1.13, )",
"resolved": "4.1.13",
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
"dependencies": {
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
}

View file

@ -592,8 +592,8 @@
},
"Npgsql": {
"type": "Transitive",
"resolved": "4.1.5",
"contentHash": "juDlNse+SKfXRP0VSgpJkpdCcaVLZt8m37EHdRX+8hw+GG69Eat1Y0MdEfl+oetdOnf9E133GjIDEjg9AF6HSQ==",
"resolved": "4.1.13",
"contentHash": "p79cObfuRgS8KD5sFmQUqVlINEkJm39bCrzRclicZE1942mKcbLlc0NdoVKhBeZPv//prK/sVTUmRVxdnoPCoA==",
"dependencies": {
"System.Runtime.CompilerServices.Unsafe": "4.6.0"
}
@ -891,8 +891,8 @@
},
"SixLabors.ImageSharp": {
"type": "Transitive",
"resolved": "3.0.1",
"contentHash": "o0v/J6SJwp3RFrzR29beGx0cK7xcMRgOyIuw8ZNLQyNnBhiyL/vIQKn7cfycthcWUPG3XezUjFwBWzkcUUDFbg=="
"resolved": "3.1.5",
"contentHash": "lNtlq7dSI/QEbYey+A0xn48z5w4XHSffF8222cC4F4YwTXfEImuiBavQcWjr49LThT/pRmtWJRcqA/PlL+eJ6g=="
},
"SqlKata": {
"type": "Transitive",
@ -1810,7 +1810,7 @@
"Myriad": "[1.0.0, )",
"PluralKit.Core": "[1.0.0, )",
"Sentry": "[3.11.1, )",
"SixLabors.ImageSharp": "[3.0.1, )"
"SixLabors.ImageSharp": "[3.1.5, )"
}
},
"pluralkit.core": {
@ -1834,7 +1834,7 @@
"Newtonsoft.Json": "[13.0.1, )",
"NodaTime": "[3.0.3, )",
"NodaTime.Serialization.JsonNet": "[3.0.0, )",
"Npgsql": "[4.1.5, )",
"Npgsql": "[4.1.13, )",
"Npgsql.NodaTime": "[4.1.5, )",
"Serilog": "[2.12.0, )",
"Serilog.Extensions.Logging": "[3.0.1, )",

View file

@ -19,7 +19,7 @@ To get dispatch events from PluralKit, you must set up a *public* HTTP endpoint.
For this reason, when you register a webhook URL, PluralKit generates a secret token, and then includes it with every event sent to you in the `signing_token` key. If you receive an event with an invalid `signing_token`, you **must** stop processing the request and **respond with a 401 status code**.
PluralKit will send invalid requests to your endpoint, with `PING` event type, once in a while to confirm that you are correctly validating requests.
PluralKit will send invalid requests to your endpoint, with `PING` event type, once in a while to confirm that you are correctly validating requests. If validation fails, or if requests to your endpoint are repeatedly unsuccessful, the endpoint will be removed.
## Dispatch Event Model

View file

@ -15,8 +15,7 @@ serde = { workspace = true }
sqlx = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-gelf = "0.7.1"
tracing-subscriber = { version = "0.3.16", features = ["env-filter"] }
tracing-subscriber = { workspace = true}
prost = { workspace = true }
prost-types = { workspace = true }

View file

@ -0,0 +1,16 @@
[package]
name = "dispatch"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
serde = { workspace = true }
serde_json = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
tracing-subscriber = { workspace = true }
hickory-client = "0.24.1"
reqwest = { version = "0.12.7", default-features = false, features = ["rustls-tls"] }

View file

@ -0,0 +1,17 @@
FROM alpine:latest AS builder
WORKDIR /build
RUN apk add rustup build-base
RUN rustup-init --default-host x86_64-unknown-linux-musl --default-toolchain nightly-2024-08-20 --profile default -y
ENV PATH=/root/.cargo/bin:$PATH
ENV RUSTFLAGS='-C link-arg=-s'
COPY . .
RUN cargo build --bin dispatch --release --target x86_64-unknown-linux-musl
FROM alpine:latest
COPY --from=builder /build/target/x86_64-unknown-linux-musl/release/dispatch /usr/local/bin/dispatch
ENTRYPOINT ["/usr/local/bin/dispatch"]

View file

@ -0,0 +1,52 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

View file

@ -0,0 +1,190 @@
#![feature(ip)]
use hickory_client::{
client::{AsyncClient, ClientHandle},
rr::{DNSClass, Name, RData, RecordType},
udp::UdpClientStream,
};
use reqwest::{redirect::Policy, StatusCode};
use std::{
net::{Ipv4Addr, SocketAddr, SocketAddrV4},
sync::Arc,
time::Duration,
};
use tokio::{net::UdpSocket, sync::RwLock};
use tracing::{debug, error, info};
use tracing_subscriber::EnvFilter;
use axum::{extract::State, http::Uri, routing::post, Json, Router};
mod logger;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.json()
.with_env_filter(EnvFilter::from_default_env())
.init();
info!("hello world");
let address = std::env::var("DNS_UPSTREAM").unwrap().parse().unwrap();
let stream = UdpClientStream::<UdpSocket>::with_timeout(address, Duration::from_secs(1));
let (client, bg) = AsyncClient::connect(stream).await?;
tokio::spawn(bg);
let app = Router::new()
.route("/", post(dispatch))
.with_state(Arc::new(RwLock::new(DNSClient(client))))
.layer(axum::middleware::from_fn(logger::logger));
let listener = tokio::net::TcpListener::bind("0.0.0.0:5000").await?;
axum::serve(listener, app).await?;
Ok(())
}
#[derive(Debug, serde::Deserialize)]
struct DispatchRequest {
auth: String,
url: String,
payload: String,
test: Option<String>,
}
#[derive(Debug)]
enum DispatchResponse {
OK,
BadData,
ResolveFailed,
NoIPs,
InvalidIP,
FetchFailed,
InvalidResponseCode(StatusCode),
TestFailed,
}
impl std::fmt::Display for DispatchResponse {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
async fn dispatch(
// not entirely sure if this RwLock is the right way to do it
State(dns): State<Arc<RwLock<DNSClient>>>,
Json(req): Json<DispatchRequest>,
) -> String {
// todo: fix
if req.auth != std::env::var("HTTP_AUTH_TOKEN").unwrap() {
return "".to_string();
}
let uri = match req.url.parse::<Uri>() {
Ok(v) if v.scheme_str() == Some("https") && v.host().is_some() => v,
Err(error) => {
error!(?error, "failed to parse uri {}", req.url);
return DispatchResponse::BadData.to_string();
}
_ => {
error!("uri {} is invalid", req.url);
return DispatchResponse::BadData.to_string();
}
};
let ips = {
let mut dns = dns.write().await;
match dns.resolve(uri.host().unwrap().to_string()).await {
Ok(v) => v,
Err(error) => {
error!(?error, "failed to resolve");
return DispatchResponse::ResolveFailed.to_string();
}
}
};
if ips.iter().any(|ip| !ip.is_global()) {
return DispatchResponse::InvalidIP.to_string();
}
if ips.len() == 0 {
return DispatchResponse::NoIPs.to_string();
}
let ips: Vec<SocketAddr> = ips
.iter()
.map(|ip| SocketAddr::V4(SocketAddrV4::new(*ip, 443)))
.collect();
let client = reqwest::ClientBuilder::new()
.user_agent("PluralKit Dispatch (https://pluralkit.me/api/dispatch/)")
.redirect(Policy::none())
.timeout(Duration::from_secs(10))
.http1_only()
.use_rustls_tls()
.https_only(true)
.resolve_to_addrs(uri.host().unwrap(), &ips)
.build()
.unwrap();
let res = client
.post(req.url.clone())
.header("content-type", "application/json")
.body(req.payload)
.send()
.await;
match res {
Ok(res) if res.status() != 200 => {
return DispatchResponse::InvalidResponseCode(res.status()).to_string()
}
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");
return DispatchResponse::FetchFailed.to_string();
}
_ => {}
}
if let Some(test) = req.test {
let test_res = client
.post(req.url.clone())
.header("content-type", "application/json")
.body(test)
.send()
.await;
match test_res {
Ok(res) if res.status() != 401 => return DispatchResponse::TestFailed.to_string(),
Err(error) => {
error!(?error, url = req.url.clone(), "failed to fetch");
return DispatchResponse::FetchFailed.to_string();
}
_ => {}
}
}
DispatchResponse::OK.to_string()
}
struct DNSClient(AsyncClient);
impl DNSClient {
async fn resolve(&mut self, host: String) -> anyhow::Result<Vec<Ipv4Addr>> {
let resp = self
.0
.query(Name::from_ascii(host)?, DNSClass::IN, RecordType::A)
.await?;
debug!("got dns response: {resp:?}");
Ok(resp
.answers()
.iter()
.filter_map(|ans| {
if let Some(RData::A(val)) = ans.data() {
Some(val.0)
} else {
None
}
})
.collect())
}
}