feat: gateway service

This commit is contained in:
alyssa 2024-09-14 12:19:47 +09:00
parent 1118d8bdf8
commit e4ed354536
50 changed files with 1737 additions and 545 deletions

3
.cargo/config.toml Normal file
View file

@ -0,0 +1,3 @@
[build]
rustflags = ["-C", "target-cpu=native"]

View file

@ -3,16 +3,18 @@
# todo: don't use docker/build-push-action
# todo: run builds on pull request
name: Build and push API Docker image
name: Build and push Rust service Docker images
on:
push:
branches:
- main
paths:
- 'lib/libpk/**'
- 'services/api/**'
- 'services/gateway/**'
- '.github/workflows/rust.yml'
- 'Dockerfile.rust'
- 'Dockerfile.bin'
- 'Cargo.toml'
- 'Cargo.lock'
jobs:
deploy:
@ -45,7 +47,7 @@ jobs:
# add more binaries here
- run: |
for binary in "api"; do
for binary in "api" "gateway"; do
for tag in latest ${{ env.BRANCH_NAME }} ${{ github.sha }}; do
cat Dockerfile.bin | sed "s/__BINARY__/$binary/g" | docker build -t ghcr.io/pluralkit/$binary:$tag -f - .
done

448
Cargo.lock generated
View file

@ -329,9 +329,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.4.0"
version = "1.6.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be"
checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9"
[[package]]
name = "bytes-utils"
@ -363,7 +363,9 @@ checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets 0.52.6",
]
@ -428,6 +430,16 @@ version = "0.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "396de984970346b0d9e93d1415082923c679e5ae5c3ee3dcbd104f5610af126b"
[[package]]
name = "core-foundation"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "core-foundation-sys"
version = "0.8.6"
@ -464,6 +476,15 @@ version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "338089f42c427b86394a5ee60ff321da23a5c89c9d89514c829687b26359fcff"
[[package]]
name = "crc32fast"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3"
dependencies = [
"cfg-if",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.14"
@ -508,6 +529,19 @@ dependencies = [
"typenum",
]
[[package]]
name = "dashmap"
version = "5.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc"
dependencies = [
"cfg-if",
"hashbrown 0.12.3",
"lock_api",
"once_cell",
"parking_lot_core 0.9.7",
]
[[package]]
name = "data-encoding"
version = "2.6.0"
@ -525,6 +559,15 @@ dependencies = [
"zeroize",
]
[[package]]
name = "deranged"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4"
dependencies = [
"powerfmt",
]
[[package]]
name = "digest"
version = "0.9.0"
@ -661,6 +704,17 @@ version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]]
name = "flate2"
version = "1.0.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae"
dependencies = [
"crc32fast",
"libz-sys",
"miniz_oxide",
]
[[package]]
name = "float-cmp"
version = "0.8.0"
@ -825,6 +879,30 @@ dependencies = [
"slab",
]
[[package]]
name = "gateway"
version = "0.1.0"
dependencies = [
"anyhow",
"axum 0.7.5",
"bytes",
"chrono",
"fred",
"futures",
"lazy_static",
"libpk",
"prost",
"serde_json",
"signal-hook",
"tokio",
"tracing",
"twilight-cache-inmemory",
"twilight-gateway",
"twilight-http",
"twilight-model",
"twilight-util",
]
[[package]]
name = "generic-array"
version = "0.14.6"
@ -881,6 +959,12 @@ dependencies = [
"tracing",
]
[[package]]
name = "hashbrown"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888"
[[package]]
name = "hashbrown"
version = "0.13.2"
@ -1161,18 +1245,36 @@ dependencies = [
[[package]]
name = "hyper-rustls"
version = "0.27.2"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155"
checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c"
dependencies = [
"futures-util",
"http 1.1.0",
"hyper 1.3.1",
"hyper-util",
"rustls",
"rustls 0.22.4",
"rustls-native-certs",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tokio-rustls 0.25.0",
"tower-service",
]
[[package]]
name = "hyper-rustls"
version = "0.27.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333"
dependencies = [
"futures-util",
"http 1.1.0",
"hyper 1.3.1",
"hyper-util",
"rustls 0.23.10",
"rustls-pki-types",
"tokio",
"tokio-rustls 0.26.0",
"tower-service",
"webpki-roots",
]
@ -1341,6 +1443,7 @@ dependencies = [
"tracing",
"tracing-gelf",
"tracing-subscriber",
"twilight-model",
]
[[package]]
@ -1354,6 +1457,17 @@ dependencies = [
"vcpkg",
]
[[package]]
name = "libz-sys"
version = "1.1.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c15da26e5af7e25c90b37a2d75cdbf940cf4a55316de9d84c679c9b8bfabf82e"
dependencies = [
"cc",
"pkg-config",
"vcpkg",
]
[[package]]
name = "linked-hash-map"
version = "0.5.6"
@ -1561,6 +1675,12 @@ dependencies = [
"zeroize",
]
[[package]]
name = "num-conv"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9"
[[package]]
name = "num-integer"
version = "0.1.46"
@ -1622,6 +1742,21 @@ version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5"
[[package]]
name = "openssl-probe"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf"
[[package]]
name = "ordered-float"
version = "2.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c"
dependencies = [
"num-traits",
]
[[package]]
name = "ordered-multimap"
version = "0.6.0"
@ -1832,6 +1967,12 @@ version = "1.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce"
[[package]]
name = "powerfmt"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391"
[[package]]
name = "ppv-lite86"
version = "0.2.17"
@ -1952,7 +2093,7 @@ dependencies = [
"quinn-proto",
"quinn-udp",
"rustc-hash",
"rustls",
"rustls 0.23.10",
"socket2 0.5.7",
"thiserror",
"tokio",
@ -1969,7 +2110,7 @@ dependencies = [
"rand",
"ring",
"rustc-hash",
"rustls",
"rustls 0.23.10",
"slab",
"thiserror",
"tinyvec",
@ -2137,7 +2278,7 @@ dependencies = [
"http-body 1.0.0",
"http-body-util",
"hyper 1.3.1",
"hyper-rustls",
"hyper-rustls 0.27.3",
"hyper-util",
"ipnet",
"js-sys",
@ -2147,7 +2288,7 @@ dependencies = [
"percent-encoding",
"pin-project-lite",
"quinn",
"rustls",
"rustls 0.23.10",
"rustls-pemfile",
"rustls-pki-types",
"serde",
@ -2155,7 +2296,7 @@ dependencies = [
"serde_urlencoded",
"sync_wrapper 1.0.1",
"tokio",
"tokio-rustls",
"tokio-rustls 0.26.0",
"tower-service",
"url",
"wasm-bindgen",
@ -2263,9 +2404,22 @@ dependencies = [
[[package]]
name = "rustls"
version = "0.23.12"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044"
checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432"
dependencies = [
"ring",
"rustls-pki-types",
"rustls-webpki",
"subtle",
"zeroize",
]
[[package]]
name = "rustls"
version = "0.23.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402"
dependencies = [
"once_cell",
"ring",
@ -2276,10 +2430,23 @@ dependencies = [
]
[[package]]
name = "rustls-pemfile"
version = "2.1.3"
name = "rustls-native-certs"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425"
checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792"
dependencies = [
"openssl-probe",
"rustls-pemfile",
"rustls-pki-types",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d"
dependencies = [
"base64 0.22.1",
"rustls-pki-types",
@ -2287,15 +2454,15 @@ dependencies = [
[[package]]
name = "rustls-pki-types"
version = "1.8.0"
version = "1.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0"
checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d"
[[package]]
name = "rustls-webpki"
version = "0.102.6"
version = "0.102.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e"
checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e"
dependencies = [
"ring",
"rustls-pki-types",
@ -2314,12 +2481,44 @@ version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde"
[[package]]
name = "schannel"
version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534"
dependencies = [
"windows-sys 0.52.0",
]
[[package]]
name = "scopeguard"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd"
[[package]]
name = "security-framework"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0"
dependencies = [
"bitflags 2.5.0",
"core-foundation",
"core-foundation-sys",
"libc",
"security-framework-sys",
]
[[package]]
name = "security-framework-sys"
version = "2.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7"
dependencies = [
"core-foundation-sys",
"libc",
]
[[package]]
name = "semver"
version = "1.0.16"
@ -2335,6 +2534,16 @@ dependencies = [
"serde_derive",
]
[[package]]
name = "serde-value"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c"
dependencies = [
"ordered-float",
"serde",
]
[[package]]
name = "serde_derive"
version = "1.0.203"
@ -2366,6 +2575,17 @@ dependencies = [
"serde",
]
[[package]]
name = "serde_repr"
version = "0.1.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.66",
]
[[package]]
name = "serde_spanned"
version = "0.6.8"
@ -2411,6 +2631,12 @@ dependencies = [
"digest 0.10.7",
]
[[package]]
name = "sha1_smol"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012"
[[package]]
name = "sha2"
version = "0.10.6"
@ -2431,6 +2657,16 @@ dependencies = [
"lazy_static",
]
[[package]]
name = "signal-hook"
version = "0.3.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801"
dependencies = [
"libc",
"signal-hook-registry",
]
[[package]]
name = "signal-hook-registry"
version = "1.4.1"
@ -2450,6 +2686,12 @@ dependencies = [
"rand_core",
]
[[package]]
name = "simdutf8"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a"
[[package]]
name = "sketches-ddsketch"
version = "0.2.0"
@ -2828,6 +3070,36 @@ dependencies = [
"once_cell",
]
[[package]]
name = "time"
version = "0.3.36"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885"
dependencies = [
"deranged",
"num-conv",
"powerfmt",
"serde",
"time-core",
"time-macros",
]
[[package]]
name = "time-core"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
[[package]]
name = "time-macros"
version = "0.2.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf"
dependencies = [
"num-conv",
"time-core",
]
[[package]]
name = "tiny-keccak"
version = "2.0.2"
@ -2881,13 +3153,24 @@ dependencies = [
"syn 2.0.66",
]
[[package]]
name = "tokio-rustls"
version = "0.25.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f"
dependencies = [
"rustls 0.22.4",
"rustls-pki-types",
"tokio",
]
[[package]]
name = "tokio-rustls"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4"
dependencies = [
"rustls",
"rustls 0.23.10",
"rustls-pki-types",
"tokio",
]
@ -2930,6 +3213,30 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-websockets"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "988c6e20955aa5043e0822cb27093ebaabb430a126cda0223824b6d65ea900c1"
dependencies = [
"base64 0.21.7",
"bytes",
"fastrand",
"futures-core",
"futures-sink",
"http 1.1.0",
"httparse",
"ring",
"rustls-native-certs",
"rustls-pki-types",
"sha1_smol",
"simdutf8",
"tokio",
"tokio-rustls 0.25.0",
"tokio-util 0.7.12",
"tracing",
]
[[package]]
name = "toml"
version = "0.8.19"
@ -3140,6 +3447,105 @@ version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed"
[[package]]
name = "twilight-cache-inmemory"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"bitflags 2.5.0",
"dashmap",
"serde",
"twilight-model",
"twilight-util",
]
[[package]]
name = "twilight-gateway"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"bitflags 2.5.0",
"fastrand",
"flate2",
"futures-core",
"futures-sink",
"serde",
"serde_json",
"tokio",
"tokio-websockets",
"tracing",
"twilight-gateway-queue",
"twilight-http",
"twilight-model",
]
[[package]]
name = "twilight-gateway-queue"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"tokio",
"tracing",
]
[[package]]
name = "twilight-http"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"fastrand",
"http 1.1.0",
"http-body-util",
"hyper 1.3.1",
"hyper-rustls 0.26.0",
"hyper-util",
"percent-encoding",
"serde",
"serde_json",
"tokio",
"tracing",
"twilight-http-ratelimiting",
"twilight-model",
"twilight-validate",
]
[[package]]
name = "twilight-http-ratelimiting"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"tokio",
"tracing",
]
[[package]]
name = "twilight-model"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"bitflags 2.5.0",
"serde",
"serde-value",
"serde_repr",
"time",
]
[[package]]
name = "twilight-util"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"twilight-model",
]
[[package]]
name = "twilight-validate"
version = "0.16.0-rc.1"
source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1"
dependencies = [
"twilight-model",
]
[[package]]
name = "typenum"
version = "1.16.0"

View file

@ -2,22 +2,40 @@
members = [
"./lib/libpk",
"./services/api",
"./services/dispatch"
"./services/dispatch",
"./services/gateway"
]
[workspace.dependencies]
anyhow = "1"
axum = "0.7.5"
axum-macros = "0.4.1"
bytes = "1.6.0"
chrono = "0.4"
fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] }
futures = "0.3.30"
lazy_static = "1.4.0"
metrics = "0.23.0"
serde = "1.0.152"
serde_json = "1.0.117"
signal-hook = "0.3.17"
sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "chrono", "macros"] }
tokio = { version = "1.25.0", features = ["full"] }
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] }
twilight-gateway = { git = "https://github.com/pluralkit/twilight" }
twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] }
twilight-util = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] }
twilight-model = { git = "https://github.com/pluralkit/twilight" }
twilight-http = { git = "https://github.com/pluralkit/twilight", default-features = false, features = ["rustls-native-roots"] }
#twilight-gateway = { path = "../twilight/twilight-gateway" }
#twilight-cache-inmemory = { path = "../twilight/twilight-cache-inmemory", features = ["permission-calculator"] }
#twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] }
#twilight-model = { path = "../twilight/twilight-model" }
#twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-native-roots"] }
prost = "0.12"
prost-types = "0.12"
prost-build = "0.12"

View file

@ -27,9 +27,12 @@ COPY proto/ /build/proto
# this needs to match workspaces in Cargo.toml
COPY lib/libpk /build/lib/libpk
COPY services/api/ /build/services/api
COPY services/gateway/ /build/services/gateway
RUN cargo build --bin api --release --target x86_64-unknown-linux-musl
RUN cargo build --bin gateway --release --target x86_64-unknown-linux-musl
FROM scratch
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/api /api
COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/gateway /gateway

View file

@ -100,15 +100,18 @@ public static class DiscordCacheExtensions
await cache.SaveChannel(thread);
}
public static async Task<PermissionSet> BotPermissionsIn(this IDiscordCache cache, ulong channelId)
public static async Task<PermissionSet> BotPermissionsIn(this IDiscordCache cache, ulong guildId, ulong channelId)
{
var channel = await cache.GetRootChannel(channelId);
if (cache is HttpDiscordCache)
return await ((HttpDiscordCache)cache).BotPermissions(guildId, channelId);
var channel = await cache.GetRootChannel(guildId, channelId);
if (channel.GuildId != null)
{
var userId = cache.GetOwnUser();
var member = await cache.TryGetSelfMember(channel.GuildId.Value);
return await cache.PermissionsFor2(channelId, userId, member);
return await cache.PermissionsFor2(guildId, channelId, userId, member);
}
return PermissionSet.Dm;

View file

@ -0,0 +1,75 @@
using Serilog;
using System.Net;
using System.Text.Json;
using Myriad.Serialization;
using Myriad.Types;
namespace Myriad.Cache;
public class HttpDiscordCache: IDiscordCache
{
private readonly ILogger _logger;
private readonly HttpClient _client;
private readonly string _cacheEndpoint;
private readonly ulong _ownUserId;
private readonly JsonSerializerOptions _jsonSerializerOptions;
public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, ulong ownUserId)
{
_logger = logger;
_client = client;
_cacheEndpoint = cacheEndpoint;
_ownUserId = ownUserId;
_jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad();
}
public ValueTask SaveGuild(Guild guild) => default;
public ValueTask SaveChannel(Channel channel) => default;
public ValueTask SaveUser(User user) => default;
public ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member) => default;
public ValueTask SaveRole(ulong guildId, Myriad.Types.Role role) => default;
public ValueTask SaveDmChannelStub(ulong channelId) => default;
public ValueTask RemoveGuild(ulong guildId) => default;
public ValueTask RemoveChannel(ulong channelId) => default;
public ValueTask RemoveUser(ulong userId) => default;
public ValueTask RemoveRole(ulong guildId, ulong roleId) => default;
public ulong GetOwnUser() => _ownUserId;
// todo: cluster
private async Task<T?> QueryCache<T>(string endpoint)
{
var response = await _client.GetAsync($"{_cacheEndpoint}{endpoint}");
if (response.StatusCode == HttpStatusCode.NotFound)
return default;
if (response.StatusCode != HttpStatusCode.Found)
throw new Exception($"failed to query http cache: {response.StatusCode}");
var plaintext = await response.Content.ReadAsStringAsync();
return JsonSerializer.Deserialize<T>(plaintext, _jsonSerializerOptions);
}
public Task<Guild?> TryGetGuild(ulong guildId)
=> QueryCache<Guild?>($"/guilds/{guildId}");
public Task<Channel?> TryGetChannel(ulong guildId, ulong channelId)
=> QueryCache<Channel?>($"/guilds/{guildId}/channels/{channelId}");
// this should be a GetUserCached method on nirn-proxy (it's always called as GetOrFetchUser)
// so just return nothing
public Task<User?> TryGetUser(ulong userId)
=> Task.FromResult<User?>(null);
public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId)
=> QueryCache<GuildMemberPartial?>($"/guilds/{guildId}/members/@me");
public Task<PermissionSet> BotPermissions(ulong guildId, ulong channelId)
=> QueryCache<PermissionSet>($"/guilds/{guildId}/channels/{channelId}/permissions/@me");
public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId)
=> QueryCache<IEnumerable<Channel>>($"/guilds/{guildId}/channels");
}

View file

@ -18,11 +18,9 @@ public interface IDiscordCache
internal ulong GetOwnUser();
public Task<Guild?> TryGetGuild(ulong guildId);
public Task<Channel?> TryGetChannel(ulong channelId);
public Task<Channel?> TryGetChannel(ulong guildId, ulong channelId);
public Task<User?> TryGetUser(ulong userId);
public Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId);
public Task<Role?> TryGetRole(ulong roleId);
public IAsyncEnumerable<Guild> GetAllGuilds();
public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId);
}

View file

@ -137,7 +137,7 @@ public class MemoryDiscordCache: IDiscordCache
return Task.FromResult(cg?.Guild);
}
public Task<Channel?> TryGetChannel(ulong channelId)
public Task<Channel?> TryGetChannel(ulong _, ulong channelId)
{
_channels.TryGetValue(channelId, out var channel);
return Task.FromResult(channel);
@ -155,19 +155,6 @@ public class MemoryDiscordCache: IDiscordCache
return Task.FromResult(guildMember);
}
public Task<Role?> TryGetRole(ulong roleId)
{
_roles.TryGetValue(roleId, out var role);
return Task.FromResult(role);
}
public IAsyncEnumerable<Guild> GetAllGuilds()
{
return _guilds.Values
.Select(g => g.Guild)
.ToAsyncEnumerable();
}
public Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId)
{
if (!_guilds.TryGetValue(guildId, out var guild))

View file

@ -1,340 +0,0 @@
using Google.Protobuf;
using StackExchange.Redis;
using StackExchange.Redis.KeyspaceIsolation;
using Serilog;
using Myriad.Types;
namespace Myriad.Cache;
#pragma warning disable 4014
public class RedisDiscordCache: IDiscordCache
{
private readonly ILogger _logger;
private readonly ulong _ownUserId;
public RedisDiscordCache(ILogger logger, ulong ownUserId)
{
_logger = logger;
_ownUserId = ownUserId;
}
private ConnectionMultiplexer _redis { get; set; }
public async Task InitAsync(string addr)
{
_redis = await ConnectionMultiplexer.ConnectAsync(addr);
}
private IDatabase db => _redis.GetDatabase().WithKeyPrefix("discord:");
public async ValueTask SaveGuild(Guild guild)
{
_logger.Verbose("Saving guild {GuildId} to redis", guild.Id);
var g = new CachedGuild();
g.Id = guild.Id;
g.Name = guild.Name;
g.OwnerId = guild.OwnerId;
g.PremiumTier = (int)guild.PremiumTier;
var tr = db.CreateTransaction();
tr.HashSetAsync("guilds", guild.Id.HashWrapper(g));
foreach (var role in guild.Roles)
{
// Don't call SaveRole because that updates guild state
// and we just got a brand new one :)
// actually with redis it doesn't update guild state, but we're still doing it here because transaction
tr.HashSetAsync("roles", role.Id.HashWrapper(new CachedRole()
{
Id = role.Id,
Name = role.Name,
Position = role.Position,
Permissions = (ulong)role.Permissions,
Mentionable = role.Mentionable,
}));
tr.HashSetAsync($"guild_roles:{guild.Id}", role.Id, true, When.NotExists);
}
await tr.ExecuteAsync();
}
public async ValueTask SaveChannel(Channel channel)
{
_logger.Verbose("Saving channel {ChannelId} to redis", channel.Id);
await db.HashSetAsync("channels", channel.Id.HashWrapper(channel.ToProtobuf()));
if (channel.GuildId != null)
await db.HashSetAsync($"guild_channels:{channel.GuildId.Value}", channel.Id, true, When.NotExists);
// todo: use a transaction for this?
if (channel.Recipients != null)
foreach (var recipient in channel.Recipients)
await SaveUser(recipient);
}
public async ValueTask SaveUser(User user)
{
_logger.Verbose("Saving user {UserId} to redis", user.Id);
var u = new CachedUser()
{
Id = user.Id,
Username = user.Username,
Discriminator = user.Discriminator,
Bot = user.Bot,
};
if (user.Avatar != null)
u.Avatar = user.Avatar;
await db.HashSetAsync("users", user.Id.HashWrapper(u));
}
public async ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member)
{
_logger.Verbose("Saving self member for guild {GuildId} to redis", guildId);
var gm = new CachedGuildMember();
foreach (var role in member.Roles)
gm.Roles.Add(role);
await db.HashSetAsync("members", guildId.HashWrapper(gm));
}
public async ValueTask SaveRole(ulong guildId, Myriad.Types.Role role)
{
_logger.Verbose("Saving role {RoleId} in {GuildId} to redis", role.Id, guildId);
await db.HashSetAsync("roles", role.Id.HashWrapper(new CachedRole()
{
Id = role.Id,
Mentionable = role.Mentionable,
Name = role.Name,
Permissions = (ulong)role.Permissions,
Position = role.Position,
}));
await db.HashSetAsync($"guild_roles:{guildId}", role.Id, true, When.NotExists);
}
public async ValueTask SaveDmChannelStub(ulong channelId)
{
// Use existing channel object if present, otherwise add a stub
// We may get a message create before channel create and we want to have it saved
if (await TryGetChannel(channelId) == null)
await db.HashSetAsync("channels", channelId.HashWrapper(new CachedChannel()
{
Id = channelId,
Type = (int)Channel.ChannelType.Dm,
}));
}
public async ValueTask RemoveGuild(ulong guildId)
=> await db.HashDeleteAsync("guilds", guildId);
public async ValueTask RemoveChannel(ulong channelId)
{
var oldChannel = await TryGetChannel(channelId);
if (oldChannel == null)
return;
await db.HashDeleteAsync("channels", channelId);
if (oldChannel.GuildId != null)
await db.HashDeleteAsync($"guild_channels:{oldChannel.GuildId.Value}", oldChannel.Id);
}
public async ValueTask RemoveUser(ulong userId)
=> await db.HashDeleteAsync("users", userId);
public ulong GetOwnUser() => _ownUserId;
public async ValueTask RemoveRole(ulong guildId, ulong roleId)
{
await db.HashDeleteAsync("roles", roleId);
await db.HashDeleteAsync($"guild_roles:{guildId}", roleId);
}
public async Task<Guild?> TryGetGuild(ulong guildId)
{
var redisGuild = await db.HashGetAsync("guilds", guildId);
if (redisGuild.IsNullOrEmpty)
return null;
var guild = ((byte[])redisGuild).Unmarshal<CachedGuild>();
var redisRoles = await db.HashGetAllAsync($"guild_roles:{guildId}");
// todo: put this in a transaction or something
var roles = await Task.WhenAll(redisRoles.Select(r => TryGetRole((ulong)r.Name)));
#pragma warning disable 8619
return guild.FromProtobuf() with { Roles = roles };
#pragma warning restore 8619
}
public async Task<Channel?> TryGetChannel(ulong channelId)
{
var redisChannel = await db.HashGetAsync("channels", channelId);
if (redisChannel.IsNullOrEmpty)
return null;
return ((byte[])redisChannel).Unmarshal<CachedChannel>().FromProtobuf();
}
public async Task<User?> TryGetUser(ulong userId)
{
var redisUser = await db.HashGetAsync("users", userId);
if (redisUser.IsNullOrEmpty)
return null;
return ((byte[])redisUser).Unmarshal<CachedUser>().FromProtobuf();
}
public async Task<GuildMemberPartial?> TryGetSelfMember(ulong guildId)
{
var redisMember = await db.HashGetAsync("members", guildId);
if (redisMember.IsNullOrEmpty)
return null;
return new GuildMemberPartial()
{
Roles = ((byte[])redisMember).Unmarshal<CachedGuildMember>().Roles.ToArray()
};
}
public async Task<Myriad.Types.Role?> TryGetRole(ulong roleId)
{
var redisRole = await db.HashGetAsync("roles", roleId);
if (redisRole.IsNullOrEmpty)
return null;
var role = ((byte[])redisRole).Unmarshal<CachedRole>();
return new Myriad.Types.Role()
{
Id = role.Id,
Name = role.Name,
Position = role.Position,
Permissions = (PermissionSet)role.Permissions,
Mentionable = role.Mentionable,
};
}
public IAsyncEnumerable<Guild> GetAllGuilds()
{
// return _guilds.Values
// .Select(g => g.Guild)
// .ToAsyncEnumerable();
return new Guild[] { }.ToAsyncEnumerable();
}
public async Task<IEnumerable<Channel>> GetGuildChannels(ulong guildId)
{
var redisChannels = await db.HashGetAllAsync($"guild_channels:{guildId}");
if (redisChannels.Length == 0)
throw new ArgumentException("Guild not found", nameof(guildId));
#pragma warning disable 8619
return await Task.WhenAll(redisChannels.Select(c => TryGetChannel((ulong)c.Name)));
#pragma warning restore 8619
}
}
internal static class CacheProtoExt
{
public static Guild FromProtobuf(this CachedGuild guild)
=> new Guild()
{
Id = guild.Id,
Name = guild.Name,
OwnerId = guild.OwnerId,
PremiumTier = (PremiumTier)guild.PremiumTier,
};
public static CachedChannel ToProtobuf(this Channel channel)
{
var c = new CachedChannel();
c.Id = channel.Id;
c.Type = (int)channel.Type;
if (channel.Position != null)
c.Position = channel.Position.Value;
c.Name = channel.Name;
if (channel.PermissionOverwrites != null)
foreach (var overwrite in channel.PermissionOverwrites)
c.PermissionOverwrites.Add(new Overwrite()
{
Id = overwrite.Id,
Type = (int)overwrite.Type,
Allow = (ulong)overwrite.Allow,
Deny = (ulong)overwrite.Deny,
});
if (channel.GuildId != null)
c.GuildId = channel.GuildId.Value;
return c;
}
public static Channel FromProtobuf(this CachedChannel channel)
=> new Channel()
{
Id = channel.Id,
Type = (Channel.ChannelType)channel.Type,
Position = channel.Position,
Name = channel.Name,
PermissionOverwrites = channel.PermissionOverwrites
.Select(x => new Channel.Overwrite()
{
Id = x.Id,
Type = (Channel.OverwriteType)x.Type,
Allow = (PermissionSet)x.Allow,
Deny = (PermissionSet)x.Deny,
}).ToArray(),
GuildId = channel.HasGuildId ? channel.GuildId : null,
ParentId = channel.HasParentId ? channel.ParentId : null,
};
public static User FromProtobuf(this CachedUser user)
=> new User()
{
Id = user.Id,
Username = user.Username,
Discriminator = user.Discriminator,
Avatar = user.HasAvatar ? user.Avatar : null,
Bot = user.Bot,
};
}
internal static class RedisExt
{
// convenience method
public static HashEntry[] HashWrapper<T>(this ulong key, T value) where T : IMessage
=> new[] { new HashEntry(key, value.ToByteArray()) };
}
public static class ProtobufExt
{
private static Dictionary<string, MessageParser> _parser = new();
public static byte[] Marshal(this IMessage message) => message.ToByteArray();
public static T Unmarshal<T>(this byte[] message) where T : IMessage<T>, new()
{
var type = typeof(T).ToString();
if (_parser.ContainsKey(type))
return (T)_parser[type].ParseFrom(message);
else
{
_parser.Add(type, new MessageParser<T>(() => new T()));
return Unmarshal<T>(message);
}
}
}

View file

@ -13,27 +13,13 @@ public static class CacheExtensions
return guild;
}
public static async Task<Channel> GetChannel(this IDiscordCache cache, ulong channelId)
public static async Task<Channel> GetChannel(this IDiscordCache cache, ulong guildId, ulong channelId)
{
if (!(await cache.TryGetChannel(channelId) is Channel channel))
if (!(await cache.TryGetChannel(guildId, channelId) is Channel channel))
throw new KeyNotFoundException($"Channel {channelId} not found in cache");
return channel;
}
public static async Task<User> GetUser(this IDiscordCache cache, ulong userId)
{
if (!(await cache.TryGetUser(userId) is User user))
throw new KeyNotFoundException($"User {userId} not found in cache");
return user;
}
public static async Task<Role> GetRole(this IDiscordCache cache, ulong roleId)
{
if (!(await cache.TryGetRole(roleId) is Role role))
throw new KeyNotFoundException($"Role {roleId} not found in cache");
return role;
}
public static async ValueTask<User?> GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest,
ulong userId)
{
@ -47,9 +33,9 @@ public static class CacheExtensions
}
public static async ValueTask<Channel?> GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest,
ulong channelId)
ulong guildId, ulong channelId)
{
if (await cache.TryGetChannel(channelId) is { } cacheChannel)
if (await cache.TryGetChannel(guildId, channelId) is { } cacheChannel)
return cacheChannel;
var restChannel = await rest.GetChannel(channelId);
@ -58,13 +44,13 @@ public static class CacheExtensions
return restChannel;
}
public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong channelOrThread)
public static async Task<Channel> GetRootChannel(this IDiscordCache cache, ulong guildId, ulong channelOrThread)
{
var channel = await cache.GetChannel(channelOrThread);
var channel = await cache.GetChannel(guildId, channelOrThread);
if (!channel.IsThread())
return channel;
var parent = await cache.GetChannel(channel.ParentId!.Value);
var parent = await cache.GetChannel(guildId, channel.ParentId!.Value);
return parent;
}
}

View file

@ -32,23 +32,23 @@ public static class PermissionExtensions
PermissionSet.EmbedLinks;
public static Task<PermissionSet> PermissionsForMCE(this IDiscordCache cache, MessageCreateEvent message) =>
PermissionsFor2(cache, message.ChannelId, message.Author.Id, message.Member, message.WebhookId != null);
PermissionsFor2(cache, message.GuildId ?? 0, message.ChannelId, message.Author.Id, message.Member, message.WebhookId != null);
public static Task<PermissionSet>
PermissionsForMemberInChannel(this IDiscordCache cache, ulong channelId, GuildMember member) =>
PermissionsFor2(cache, channelId, member.User.Id, member);
PermissionsForMemberInChannel(this IDiscordCache cache, ulong guildId, ulong channelId, GuildMember member) =>
PermissionsFor2(cache, guildId, channelId, member.User.Id, member);
public static async Task<PermissionSet> PermissionsFor2(this IDiscordCache cache, ulong channelId, ulong userId,
public static async Task<PermissionSet> PermissionsFor2(this IDiscordCache cache, ulong guildId, ulong channelId, ulong userId,
GuildMemberPartial? member, bool isThread = false)
{
if (!(await cache.TryGetChannel(channelId) is Channel channel))
if (!(await cache.TryGetChannel(guildId, channelId) is Channel channel))
// todo: handle channel not found better
return PermissionSet.Dm;
if (channel.GuildId == null)
return PermissionSet.Dm;
var rootChannel = await cache.GetRootChannel(channelId);
var rootChannel = await cache.GetRootChannel(guildId, channelId);
var guild = await cache.GetGuild(channel.GuildId.Value);

View file

@ -63,14 +63,14 @@ public class ApplicationCommandProxiedMessage
var messageId = ctx.Event.Data!.TargetId!.Value;
// check for command messages
var (authorId, channelId) = await ctx.Services.Resolve<CommandMessageService>().GetCommandMessage(messageId);
if (authorId != null)
var cmessage = await ctx.Services.Resolve<CommandMessageService>().GetCommandMessage(messageId);
if (cmessage != null)
{
if (authorId != ctx.User.Id)
if (cmessage.AuthorId != ctx.User.Id)
throw new PKError("You can only delete command messages queried by this account.");
var isDM = (await _repo.GetDmChannel(ctx.User!.Id)) == channelId;
await DeleteMessageInner(ctx, channelId!.Value, messageId, isDM);
var isDM = (await _repo.GetDmChannel(ctx.User!.Id)) == cmessage.ChannelId;
await DeleteMessageInner(ctx, cmessage.GuildId, cmessage.ChannelId, messageId, isDM);
return;
}
@ -81,7 +81,7 @@ public class ApplicationCommandProxiedMessage
if (message.System?.Id != ctx.System.Id && message.Message.Sender != ctx.User.Id)
throw new PKError("You can only delete your own messages.");
await DeleteMessageInner(ctx, message.Message.Channel, message.Message.Mid, false);
await DeleteMessageInner(ctx, message.Message.Guild ?? 0, message.Message.Channel, message.Message.Mid, false);
return;
}
@ -89,9 +89,9 @@ public class ApplicationCommandProxiedMessage
throw Errors.MessageNotFound(messageId);
}
internal async Task DeleteMessageInner(InteractionContext ctx, ulong channelId, ulong messageId, bool isDM = false)
internal async Task DeleteMessageInner(InteractionContext ctx, ulong guildId, ulong channelId, ulong messageId, bool isDM = false)
{
if (!((await _cache.BotPermissionsIn(channelId)).HasFlag(PermissionSet.ManageMessages) || isDM))
if (!((await _cache.BotPermissionsIn(guildId, channelId)).HasFlag(PermissionSet.ManageMessages) || isDM))
throw new PKError("PluralKit does not have the *Manage Messages* permission in this channel, and thus cannot delete the message."
+ " Please contact a server administrator to remedy this.");
@ -110,7 +110,7 @@ public class ApplicationCommandProxiedMessage
// (if not, PK shouldn't send messages on their behalf)
var member = await _rest.GetGuildMember(ctx.GuildId, ctx.User.Id);
var requiredPerms = PermissionSet.ViewChannel | PermissionSet.SendMessages;
if (member == null || !(await _cache.PermissionsForMemberInChannel(ctx.ChannelId, member)).HasFlag(requiredPerms))
if (member == null || !(await _cache.PermissionsForMemberInChannel(ctx.GuildId, ctx.ChannelId, member)).HasFlag(requiredPerms))
{
throw new PKError("You do not have permission to send messages in this channel.");
};

View file

@ -99,11 +99,13 @@ public class Bot
private async Task OnEventReceived(int shardId, IGatewayEvent evt)
{
// we HandleGatewayEvent **before** getting the own user, because the own user is set in HandleGatewayEvent for ReadyEvent
await _cache.HandleGatewayEvent(evt);
await _cache.TryUpdateSelfMember(_config.ClientId, evt);
if (_cache is MemoryDiscordCache)
{
// we HandleGatewayEvent **before** getting the own user, because the own user is set in HandleGatewayEvent for ReadyEvent
await _cache.HandleGatewayEvent(evt);
await _cache.TryUpdateSelfMember(_config.ClientId, evt);
}
await OnEventReceivedInner(shardId, evt);
}
@ -175,7 +177,16 @@ public class Bot
}
using var _ = LogContext.PushProperty("EventId", Guid.NewGuid());
using var __ = LogContext.Push(await serviceScope.Resolve<SerilogGatewayEnricherFactory>().GetEnricher(shardId, evt));
// this fails when cache lookup fails, so put it in a try-catch
try
{
using var __ = LogContext.Push(await serviceScope.Resolve<SerilogGatewayEnricherFactory>().GetEnricher(shardId, evt));
}
catch (Exception exc)
{
await HandleError(handler, evt, serviceScope, exc);
}
_logger.Verbose("Received gateway event: {@Event}", evt);
try
@ -243,7 +254,7 @@ public class Bot
if (!exc.ShowToUser()) return;
// Once we've sent it to Sentry, report it to the user (if we have permission to)
var reportChannel = handler.ErrorChannelFor(evt, _config.ClientId);
var (guildId, reportChannel) = handler.ErrorChannelFor(evt, _config.ClientId);
if (reportChannel == null)
{
if (evt is InteractionCreateEvent ice && ice.Type == Interaction.InteractionType.ApplicationCommand)
@ -251,7 +262,7 @@ public class Bot
return;
}
var botPerms = await _cache.BotPermissionsIn(reportChannel.Value);
var botPerms = await _cache.BotPermissionsIn(guildId ?? 0, reportChannel.Value);
if (botPerms.HasFlag(PermissionSet.SendMessages | PermissionSet.EmbedLinks))
await _errorMessageService.SendErrorMessage(reportChannel.Value, sentryEvent.EventId.ToString());
}

View file

@ -20,7 +20,8 @@ public class BotConfig
public string? GatewayQueueUrl { get; set; }
public bool UseRedisRatelimiter { get; set; } = false;
public bool UseRedisCache { get; set; } = false;
public string? HttpCacheUrl { get; set; }
public string? RedisGatewayUrl { get; set; }

View file

@ -62,7 +62,7 @@ public class Context
public readonly int ShardId;
public readonly Cluster Cluster;
public Task<PermissionSet> BotPermissions => Cache.BotPermissionsIn(Channel.Id);
public Task<PermissionSet> BotPermissions => Cache.BotPermissionsIn(Guild?.Id ?? 0, Channel.Id);
public Task<PermissionSet> UserPermissions => Cache.PermissionsForMCE((MessageCreateEvent)Message);
@ -100,7 +100,7 @@ public class Context
// {
// Sensitive information that might want to be deleted by :x: reaction is typically in an embed format (member cards, for example)
// but since we can, we just store all sent messages for possible deletion
await _commandMessageService.RegisterMessage(msg.Id, msg.ChannelId, Author.Id);
await _commandMessageService.RegisterMessage(msg.Id, Guild?.Id ?? 0, msg.ChannelId, Author.Id);
// }
return msg;

View file

@ -188,7 +188,8 @@ public static class ContextEntityArgumentsExt
if (!MentionUtils.TryParseChannel(ctx.PeekArgument(), out var id))
return null;
var channel = await ctx.Cache.TryGetChannel(id);
// todo: match channels in other guilds
var channel = await ctx.Cache.TryGetChannel(ctx.Guild!.Id, id);
if (channel == null)
channel = await ctx.Rest.GetChannelOrNull(id);
if (channel == null)

View file

@ -143,6 +143,7 @@ public class Checks
var error = "Channel not found or you do not have permissions to access it.";
// todo: this breaks if channel is not in cache and bot does not have View Channel permissions
// with new cache it breaks if channel is not in current guild
var channel = await ctx.MatchChannel();
if (channel == null || channel.GuildId == null)
throw new PKError(error);
@ -156,7 +157,8 @@ public class Checks
if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel))
throw new PKError(error);
var botPermissions = await _cache.BotPermissionsIn(channel.Id);
// todo: permcheck channel outside of guild?
var botPermissions = await _cache.BotPermissionsIn(ctx.Guild.Id, channel.Id);
// We use a bitfield so we can set individual permission bits
ulong missingPermissions = 0;
@ -231,11 +233,11 @@ public class Checks
var channel = await _rest.GetChannelOrNull(channelId.Value);
if (channel == null)
throw new PKError("Unable to get the channel associated with this message.");
var rootChannel = await _cache.GetRootChannel(channel.Id);
if (channel.GuildId == null)
throw new PKError("PluralKit is not able to proxy messages in DMs.");
var rootChannel = await _cache.GetRootChannel(channel.GuildId!.Value, channel.Id);
// using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId
var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId);
var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList();

View file

@ -218,7 +218,7 @@ public class ProxiedMessage
try
{
var editedMsg =
await _webhookExecutor.EditWebhookMessage(msg.Channel, msg.Mid, newContent, clearEmbeds);
await _webhookExecutor.EditWebhookMessage(msg.Guild ?? 0, msg.Channel, msg.Mid, newContent, clearEmbeds);
if (ctx.Guild == null)
await _rest.CreateReaction(ctx.Channel.Id, ctx.Message.Id, new Emoji { Name = Emojis.Success });
@ -436,14 +436,14 @@ public class ProxiedMessage
private async Task DeleteCommandMessage(Context ctx, ulong messageId)
{
var (authorId, channelId) = await ctx.Services.Resolve<CommandMessageService>().GetCommandMessage(messageId);
if (authorId == null)
var cmessage = await ctx.Services.Resolve<CommandMessageService>().GetCommandMessage(messageId);
if (cmessage == null)
throw Errors.MessageNotFound(messageId);
if (authorId != ctx.Author.Id)
if (cmessage!.AuthorId != ctx.Author.Id)
throw new PKError("You can only delete command messages queried by this account.");
await ctx.Rest.DeleteMessage(channelId!.Value, messageId);
await ctx.Rest.DeleteMessage(cmessage.ChannelId, messageId);
if (ctx.Guild != null)
await ctx.Rest.DeleteMessage(ctx.Message);

View file

@ -49,7 +49,7 @@ public class ServerConfig
if (channel.Type != Channel.ChannelType.GuildText && channel.Type != Channel.ChannelType.GuildPublicThread && channel.Type != Channel.ChannelType.GuildPrivateThread)
throw new PKError("PluralKit cannot log messages to this type of channel.");
var perms = await _cache.BotPermissionsIn(channel.Id);
var perms = await _cache.BotPermissionsIn(ctx.Guild.Id, channel.Id);
if (!perms.HasFlag(PermissionSet.SendMessages))
throw new PKError("PluralKit is missing **Send Messages** permissions in the new log channel.");
if (!perms.HasFlag(PermissionSet.EmbedLinks))
@ -104,7 +104,7 @@ public class ServerConfig
// Resolve all channels from the cache and order by position
var channels = (await Task.WhenAll(blacklist.Blacklist
.Select(id => _cache.TryGetChannel(id))))
.Select(id => _cache.TryGetChannel(ctx.Guild.Id, id))))
.Where(c => c != null)
.OrderBy(c => c.Position)
.ToList();
@ -121,7 +121,7 @@ public class ServerConfig
async (eb, l) =>
{
async Task<string> CategoryName(ulong? id) =>
id != null ? (await _cache.GetChannel(id.Value)).Name : "(no category)";
id != null ? (await _cache.GetChannel(ctx.Guild.Id, id.Value)).Name : "(no category)";
ulong? lastCategory = null;
@ -153,8 +153,9 @@ public class ServerConfig
var config = await ctx.Repository.GetGuild(ctx.Guild.Id);
// Resolve all channels from the cache and order by position
// todo: GetAllChannels?
var channels = (await Task.WhenAll(config.LogBlacklist
.Select(id => _cache.TryGetChannel(id))))
.Select(id => _cache.TryGetChannel(ctx.Guild.Id, id))))
.Where(c => c != null)
.OrderBy(c => c.Position)
.ToList();
@ -171,7 +172,7 @@ public class ServerConfig
async (eb, l) =>
{
async Task<string> CategoryName(ulong? id) =>
id != null ? (await _cache.GetChannel(id.Value)).Name : "(no category)";
id != null ? (await _cache.GetChannel(ctx.Guild.Id, id.Value)).Name : "(no category)";
ulong? lastCategory = null;

View file

@ -6,5 +6,5 @@ public interface IEventHandler<in T> where T : IGatewayEvent
{
Task Handle(int shardId, T evt);
ulong? ErrorChannelFor(T evt, ulong userId) => null;
(ulong?, ulong?) ErrorChannelFor(T evt, ulong userId) => (null, null);
}

View file

@ -52,7 +52,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
_dmCache = dmCache;
}
public ulong? ErrorChannelFor(MessageCreateEvent evt, ulong userId) => evt.ChannelId;
public (ulong?, ulong?) ErrorChannelFor(MessageCreateEvent evt, ulong userId) => (evt.GuildId, evt.ChannelId);
private bool IsDuplicateMessage(Message msg) =>
// We consider a message duplicate if it has the same ID as the previous message that hit the gateway
_lastMessageCache.GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id;
@ -63,7 +63,7 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return;
if (IsDuplicateMessage(evt)) return;
var botPermissions = await _cache.BotPermissionsIn(evt.ChannelId);
var botPermissions = await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId);
if (!botPermissions.HasFlag(PermissionSet.SendMessages)) return;
// spawn off saving the private channel into another thread
@ -71,8 +71,8 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
_ = _dmCache.TrySavePrivateChannel(evt);
var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null;
var channel = await _cache.GetChannel(evt.ChannelId);
var rootChannel = await _cache.GetRootChannel(evt.ChannelId);
var channel = await _cache.GetChannel(evt.GuildId ?? 0, evt.ChannelId);
var rootChannel = await _cache.GetRootChannel(evt.GuildId ?? 0, evt.ChannelId);
// Log metrics and message info
_metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived);
@ -90,7 +90,8 @@ public class MessageCreated: IEventHandler<MessageCreateEvent>
if (await TryHandleCommand(shardId, evt, guild, channel))
return;
await TryHandleProxy(evt, guild, channel, rootChannel.Id, botPermissions);
if (evt.GuildId != null)
await TryHandleProxy(evt, guild, channel, rootChannel.Id, botPermissions);
}
private async Task TryHandleLogClean(Channel channel, MessageCreateEvent evt)

View file

@ -52,10 +52,12 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
if (!evt.Content.HasValue || !evt.Author.HasValue || !evt.Member.HasValue)
return;
var channel = await _cache.GetChannel(evt.ChannelId);
var guildIdMaybe = evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0;
var channel = await _cache.GetChannel(guildIdMaybe, evt.ChannelId); // todo: is this correct for message update?
if (!DiscordUtils.IsValidGuildChannel(channel))
return;
var rootChannel = await _cache.GetRootChannel(channel.Id);
var rootChannel = await _cache.GetRootChannel(guildIdMaybe, channel.Id);
var guild = await _cache.GetGuild(channel.GuildId!.Value);
var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current;
@ -69,7 +71,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId);
var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel);
var botPermissions = await _cache.BotPermissionsIn(channel.Id);
var botPermissions = await _cache.BotPermissionsIn(guildIdMaybe, channel.Id);
try
{
@ -91,7 +93,7 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
private async Task<MessageCreateEvent> GetMessageCreateEvent(MessageUpdateEvent evt, CachedMessage lastMessage,
Channel channel)
{
var referencedMessage = await GetReferencedMessage(evt.ChannelId, lastMessage.ReferencedMessage);
var referencedMessage = await GetReferencedMessage(evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0, evt.ChannelId, lastMessage.ReferencedMessage);
var messageReference = lastMessage.ReferencedMessage != null
? new Message.Reference(channel.GuildId, evt.ChannelId, lastMessage.ReferencedMessage.Value)
@ -118,12 +120,12 @@ public class MessageEdited: IEventHandler<MessageUpdateEvent>
return equivalentEvt;
}
private async Task<Message?> GetReferencedMessage(ulong channelId, ulong? referencedMessageId)
private async Task<Message?> GetReferencedMessage(ulong guildId, ulong channelId, ulong? referencedMessageId)
{
if (referencedMessageId == null)
return null;
var botPermissions = await _cache.BotPermissionsIn(channelId);
var botPermissions = await _cache.BotPermissionsIn(guildId, channelId);
if (!botPermissions.HasFlag(PermissionSet.ReadMessageHistory))
{
_logger.Warning(

View file

@ -62,7 +62,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
// but we aren't able to get DMs from bots anyway, so it's not really needed
if (evt.GuildId != null && (evt.Member?.User?.Bot ?? false)) return;
var channel = await _cache.GetChannel(evt.ChannelId);
var channel = await _cache.GetChannel(evt.GuildId ?? 0, evt.ChannelId);
// check if it's a command message first
// since this can happen in DMs as well
@ -75,10 +75,10 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
return;
}
var (authorId, _) = await _commandMessageService.GetCommandMessage(evt.MessageId);
if (authorId != null)
var cmessage = await _commandMessageService.GetCommandMessage(evt.MessageId);
if (cmessage != null)
{
await HandleCommandDeleteReaction(evt, authorId.Value, false);
await HandleCommandDeleteReaction(evt, cmessage.AuthorId, false);
return;
}
}
@ -123,7 +123,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEvent evt, PKMessage msg)
{
if (!(await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
if (!(await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
return;
var isSameSystem = msg.Member != null && await _repo.IsMemberOwnedByAccount(msg.Member.Value, evt.UserId);
@ -150,7 +150,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
if (authorId != null && authorId != evt.UserId)
return;
if (!((await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages) || isDM))
if (!((await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages) || isDM))
return;
// todo: don't try to delete the user's own messages in DMs
@ -206,14 +206,14 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
private async ValueTask HandlePingReaction(MessageReactionAddEvent evt, FullMessage msg)
{
if (!(await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
if (!(await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
return;
// Check if the "pinger" has permission to send messages in this channel
// (if not, PK shouldn't send messages on their behalf)
var member = await _rest.GetGuildMember(evt.GuildId!.Value, evt.UserId);
var requiredPerms = PermissionSet.ViewChannel | PermissionSet.SendMessages;
if (member == null || !(await _cache.PermissionsForMemberInChannel(evt.ChannelId, member)).HasFlag(requiredPerms)) return;
if (member == null || !(await _cache.PermissionsForMemberInChannel(evt.GuildId ?? 0, evt.ChannelId, member)).HasFlag(requiredPerms)) return;
if (msg.Member == null) return;
@ -266,7 +266,7 @@ public class ReactionAdded: IEventHandler<MessageReactionAddEvent>
private async Task TryRemoveOriginalReaction(MessageReactionAddEvent evt)
{
if ((await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
if ((await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages))
await _rest.DeleteUserReaction(evt.ChannelId, evt.MessageId, evt.Emoji, evt.UserId);
}
}

View file

@ -56,8 +56,6 @@ public class Init
await redis.InitAsync(coreConfig);
var cache = services.Resolve<IDiscordCache>();
if (cache is RedisDiscordCache)
await (cache as RedisDiscordCache).InitAsync(coreConfig.RedisAddr);
if (config.Cluster == null)
{

View file

@ -48,8 +48,10 @@ public class BotModule: Module
{
var botConfig = c.Resolve<BotConfig>();
if (botConfig.UseRedisCache)
return new RedisDiscordCache(c.Resolve<ILogger>(), botConfig.ClientId);
if (botConfig.HttpCacheUrl != null)
return new HttpDiscordCache(c.Resolve<ILogger>(),
c.Resolve<HttpClient>(), botConfig.HttpCacheUrl, botConfig.ClientId);
return new MemoryDiscordCache(botConfig.ClientId);
}).AsSelf().SingleInstance();
builder.RegisterType<PrivateChannelService>().AsSelf().SingleInstance();

View file

@ -59,7 +59,7 @@ public class ProxyService
public async Task<bool> HandleIncomingMessage(MessageCreateEvent message, MessageContext ctx,
Guild guild, Channel channel, bool allowAutoproxy, PermissionSet botPermissions)
{
var rootChannel = await _cache.GetRootChannel(message.ChannelId);
var rootChannel = await _cache.GetRootChannel(message.GuildId!.Value, message.ChannelId);
if (!ShouldProxy(channel, rootChannel, message, ctx))
return false;
@ -207,8 +207,8 @@ public class ProxyService
var content = match.ProxyContent;
if (!allowEmbeds) content = content.BreakLinkEmbeds();
var messageChannel = await _cache.GetChannel(trigger.ChannelId);
var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
var messageChannel = await _cache.GetChannel(trigger.GuildId!.Value, trigger.ChannelId);
var rootChannel = await _cache.GetRootChannel(trigger.GuildId!.Value, trigger.ChannelId);
var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null;
var guild = await _cache.GetGuild(trigger.GuildId.Value);
var guildMember = await _rest.GetGuildMember(trigger.GuildId!.Value, trigger.Author.Id);

View file

@ -18,7 +18,7 @@ public class CommandMessageService
_logger = logger.ForContext<CommandMessageService>();
}
public async Task RegisterMessage(ulong messageId, ulong channelId, ulong authorId)
public async Task RegisterMessage(ulong messageId, ulong guildId, ulong channelId, ulong authorId)
{
if (_redis.Connection == null) return;
@ -27,17 +27,19 @@ public class CommandMessageService
messageId, authorId, channelId
);
await _redis.Connection.GetDatabase().StringSetAsync(messageId.ToString(), $"{authorId}-{channelId}", expiry: CommandMessageRetention);
await _redis.Connection.GetDatabase().StringSetAsync(messageId.ToString(), $"{authorId}-{channelId}-{guildId}", expiry: CommandMessageRetention);
}
public async Task<(ulong?, ulong?)> GetCommandMessage(ulong messageId)
public async Task<CommandMessage?> GetCommandMessage(ulong messageId)
{
var str = await _redis.Connection.GetDatabase().StringGetAsync(messageId.ToString());
if (str.HasValue)
{
var split = ((string)str).Split("-");
return (ulong.Parse(split[0]), ulong.Parse(split[1]));
return new CommandMessage(ulong.Parse(split[0]), ulong.Parse(split[1]), ulong.Parse(split[2]));
}
return (null, null);
return null;
}
}
}
public record CommandMessage(ulong AuthorId, ulong ChannelId, ulong GuildId);

View file

@ -336,7 +336,7 @@ public class EmbedService
public async Task<Embed> CreateMessageInfoEmbed(FullMessage msg, bool showContent, SystemConfig? ccfg = null)
{
var channel = await _cache.GetOrFetchChannel(_rest, msg.Message.Channel);
var channel = await _cache.GetOrFetchChannel(_rest, msg.Message.Guild ?? 0, msg.Message.Channel);
var ctx = LookupContext.ByNonOwner;
var serverMsg = await _rest.GetMessageOrNull(msg.Message.Channel, msg.Message.Mid);
@ -403,14 +403,15 @@ public class EmbedService
var roles = memberInfo?.Roles?.ToList();
if (roles != null && roles.Count > 0 && showContent)
{
var rolesString = string.Join(", ", (await Task.WhenAll(roles
.Select(async id =>
var guild = await _cache.GetGuild(channel.GuildId!.Value);
var rolesString = string.Join(", ", (roles
.Select(id =>
{
var role = await _cache.TryGetRole(id);
var role = Array.Find(guild.Roles, r => r.Id == id);
if (role != null)
return role;
return new Role { Name = "*(unknown role)*", Position = 0 };
})))
}))
.OrderByDescending(role => role.Position)
.Select(role => role.Name));
eb.Field(new Embed.Field($"Account roles ({roles.Count})", rolesString.Truncate(1024)));

View file

@ -42,7 +42,7 @@ public class LogChannelService
if (logChannelId == null)
return;
var triggerChannel = await _cache.GetChannel(proxiedMessage.Channel);
var triggerChannel = await _cache.GetChannel(proxiedMessage.Guild!.Value, proxiedMessage.Channel);
var member = await _repo.GetMember(proxiedMessage.Member!.Value);
var system = await _repo.GetSystem(member.System);
@ -63,7 +63,7 @@ public class LogChannelService
return null;
var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value;
var rootChannel = await _cache.GetRootChannel(trigger.ChannelId);
var rootChannel = await _cache.GetRootChannel(guildId, trigger.ChannelId);
// get log channel info from the database
var guild = await _repo.GetGuild(guildId);
@ -109,7 +109,7 @@ public class LogChannelService
private async Task<Channel?> FindLogChannel(ulong guildId, ulong channelId)
{
// TODO: fetch it directly on cache miss?
if (await _cache.TryGetChannel(channelId) is Channel channel)
if (await _cache.TryGetChannel(guildId, channelId) is Channel channel)
return channel;
if (await _rest.GetChannelOrNull(channelId) is Channel restChannel)

View file

@ -100,10 +100,10 @@ public class LoggerCleanService
public async ValueTask HandleLoggerBotCleanup(Message msg)
{
var channel = await _cache.GetChannel(msg.ChannelId);
var channel = await _cache.GetChannel(msg.GuildId!.Value, msg.ChannelId!);
if (channel.Type != Channel.ChannelType.GuildText) return;
if (!(await _cache.BotPermissionsIn(channel.Id)).HasFlag(PermissionSet.ManageMessages)) return;
if (!(await _cache.BotPermissionsIn(msg.GuildId!.Value, channel.Id)).HasFlag(PermissionSet.ManageMessages)) return;
// If this message is from a *webhook*, check if the application ID matches one of the bots we know
// If it's from a *bot*, check the bot ID to see if we know it.

View file

@ -54,33 +54,6 @@ public class PeriodicStatCollector
var stopwatch = new Stopwatch();
stopwatch.Start();
// Aggregate guild/channel stats
var guildCount = 0;
var channelCount = 0;
// No LINQ today, sorry
await foreach (var guild in _cache.GetAllGuilds())
{
guildCount++;
foreach (var channel in await _cache.GetGuildChannels(guild.Id))
if (DiscordUtils.IsValidGuildChannel(channel))
channelCount++;
}
if (_config.UseRedisMetrics)
{
var db = _redis.Connection.GetDatabase();
await db.HashSetAsync("pluralkit:cluster_stats", new StackExchange.Redis.HashEntry[] {
new(_botConfig.Cluster.NodeIndex, JsonConvert.SerializeObject(new ClusterMetricInfo
{
GuildCount = guildCount,
ChannelCount = channelCount,
DatabaseConnectionCount = _countHolder.ConnectionCount,
WebhookCacheSize = _webhookCache.CacheSize,
})),
});
}
// Process info
var process = Process.GetCurrentProcess();
_metrics.Measure.Gauge.SetValue(CoreMetrics.ProcessPhysicalMemory, process.WorkingSet64);

View file

@ -87,7 +87,7 @@ public class WebhookExecutorService
return webhookMessage;
}
public async Task<Message> EditWebhookMessage(ulong channelId, ulong messageId, string newContent, bool clearEmbeds = false)
public async Task<Message> EditWebhookMessage(ulong guildId, ulong channelId, ulong messageId, string newContent, bool clearEmbeds = false)
{
var allowedMentions = newContent.ParseMentions() with
{
@ -96,7 +96,7 @@ public class WebhookExecutorService
};
ulong? threadId = null;
var channel = await _cache.GetOrFetchChannel(_rest, channelId);
var channel = await _cache.GetOrFetchChannel(_rest, guildId, channelId);
if (channel.IsThread())
{
threadId = channelId;

View file

@ -38,9 +38,11 @@ public class SerilogGatewayEnricherFactory
{
props.Add(new LogEventProperty("ChannelId", new ScalarValue(channel.Value)));
if (await _cache.TryGetChannel(channel.Value) != null)
var guildIdForCache = guild != null ? guild.Value : 0;
if (await _cache.TryGetChannel(guildIdForCache, channel.Value) != null)
{
var botPermissions = await _cache.BotPermissionsIn(channel.Value);
var botPermissions = await _cache.BotPermissionsIn(guildIdForCache, channel.Value);
props.Add(new LogEventProperty("BotPermissions", new ScalarValue(botPermissions)));
}
}

View file

@ -8,7 +8,6 @@ public class CoreConfig
public string? MessagesDatabase { get; set; }
public string? DatabasePassword { get; set; }
public string RedisAddr { get; set; }
public bool UseRedisMetrics { get; set; } = false;
public string SentryUrl { get; set; }
public string InfluxUrl { get; set; }
public string InfluxDb { get; set; }

View file

@ -17,6 +17,7 @@ tokio = { workspace = true }
tracing = { workspace = true }
tracing-gelf = "0.7.1"
tracing-subscriber = { workspace = true}
twilight-model = { workspace = true }
prost = { workspace = true }
prost-types = { workspace = true }

View file

@ -3,11 +3,23 @@ use lazy_static::lazy_static;
use serde::Deserialize;
use std::sync::Arc;
use twilight_model::id::{marker::UserMarker, Id};
#[derive(Clone, Deserialize, Debug)]
pub struct ClusterSettings {
pub node_id: u32,
pub total_shards: u32,
pub total_nodes: u32,
}
#[derive(Deserialize, Debug)]
pub struct DiscordConfig {
pub client_id: u32,
pub client_id: Id<UserMarker>,
pub bot_token: String,
pub client_secret: String,
pub max_concurrency: u32,
pub cluster: Option<ClusterSettings>,
pub api_base_url: Option<String>,
}
#[derive(Deserialize, Debug)]
@ -41,6 +53,9 @@ pub struct ApiConfig {
fn _metrics_default() -> bool {
false
}
fn _json_log_default() -> bool {
false
}
#[derive(Deserialize, Debug)]
pub struct PKConfig {
@ -52,13 +67,20 @@ pub struct PKConfig {
#[serde(default = "_metrics_default")]
pub run_metrics_server: bool,
pub(crate) gelf_log_url: Option<String>,
#[serde(default = "_json_log_default")]
pub(crate) json_log: bool,
}
lazy_static! {
#[derive(Debug)]
pub static ref CONFIG: Arc<PKConfig> = Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
.build().unwrap()
.try_deserialize::<PKConfig>().unwrap());
pub static ref CONFIG: Arc<PKConfig> = {
if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX") {
std::env::set_var("pluralkit__discord__cluster__node_id", var);
}
Arc::new(Config::builder()
.add_source(config::Environment::with_prefix("pluralkit").separator("__"))
.build().unwrap()
.try_deserialize::<PKConfig>().unwrap())
};
}

View file

@ -1,30 +1,24 @@
use gethostname::gethostname;
use metrics_exporter_prometheus::PrometheusBuilder;
use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Registry};
use tracing_subscriber::{EnvFilter, Registry};
pub mod db;
pub mod proto;
pub mod util;
pub mod _config;
pub use crate::_config::CONFIG as config;
pub fn init_logging(component: &str) -> anyhow::Result<()> {
let subscriber = Registry::default()
.with(EnvFilter::from_default_env())
.with(tracing_subscriber::fmt::layer());
if let Some(gelf_url) = &config.gelf_log_url {
let gelf_logger = tracing_gelf::Logger::builder()
.additional_field("component", component)
.additional_field("hostname", gethostname().to_str());
let mut conn_handle = gelf_logger
.init_udp_with_subscriber(gelf_url, subscriber)
.unwrap();
tokio::spawn(async move { conn_handle.connect().await });
// todo: fix component
if config.json_log {
tracing_subscriber::fmt()
.json()
.with_env_filter(EnvFilter::from_default_env())
.init();
} else {
// gelf_logger internally sets the global subscriber
tracing::subscriber::set_global_default(subscriber)
.expect("unable to set global subscriber");
tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.init();
}
Ok(())

View file

@ -0,0 +1 @@
pub mod redis;

View file

@ -0,0 +1,15 @@
use fred::error::RedisError;
pub trait RedisErrorExt<T> {
fn to_option_or_error(self) -> Result<Option<T>, RedisError>;
}
impl<T> RedisErrorExt<T> for Result<T, RedisError> {
fn to_option_or_error(self) -> Result<Option<T>, RedisError> {
match self {
Ok(v) => Ok(Some(v)),
Err(error) if error.is_not_found() => Ok(None),
Err(error) => Err(error),
}
}
}

View file

@ -147,6 +147,7 @@ async fn main() -> anyhow::Result<()> {
let addr: &str = libpk::config.api.addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())

View file

@ -0,0 +1,25 @@
[package]
name = "gateway"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = { workspace = true }
axum = { workspace = true }
bytes = { workspace = true }
chrono = { workspace = true }
fred = { workspace = true }
futures = { workspace = true }
lazy_static = { workspace = true }
libpk = { path = "../../lib/libpk" }
prost = { workspace = true }
serde_json = { workspace = true }
signal-hook = { workspace = true }
tokio = { workspace = true }
tracing = { workspace = true }
twilight-gateway = { workspace = true }
twilight-cache-inmemory = { workspace = true }
twilight-util = { workspace = true }
twilight-model = { workspace = true }
twilight-http = { workspace = true }

View file

@ -0,0 +1,168 @@
use axum::{
extract::{Path, State},
http::StatusCode,
response::{IntoResponse, Response},
routing::get,
Router,
};
use serde_json::to_string;
use tracing::{error, info};
use twilight_model::guild::Permissions;
use twilight_model::id::Id;
use crate::discord::cache::{dm_channel, DiscordCache, DM_PERMISSIONS};
use std::sync::Arc;
fn status_code(code: StatusCode, body: String) -> Response {
(code, body).into_response()
}
// this function is manually formatted for easier legibility of route_services
#[rustfmt::skip]
pub async fn run_server(cache: Arc<DiscordCache>) -> anyhow::Result<()> {
let app = Router::new()
.route(
"/guilds/:guild_id",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild(Id::new(guild_id)) {
Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/members/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.0.member(Id::new(guild_id), libpk::config.discord.client_id) {
Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string()),
}
}),
)
.route(
"/guilds/:guild_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.client_id).await {
Ok(val) => {
println!("hh {}", Permissions::all().bits());
status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap())
},
Err(err) => {
error!(?err, ?guild_id, "failed to get own guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/permissions/:user_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, user_id)): Path<(u64, u64)>| async move {
match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()),
Err(err) => {
error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) {
Some(channels) => channels.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut channels = Vec::new();
for id in channel_ids {
match cache.0.channel(id) {
Some(channel) => channels.push(channel.to_owned()),
None => {
tracing::error!(
channel_id = id.get(),
"referenced channel {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&channels).unwrap())
})
)
.route(
"/guilds/:guild_id/channels/:channel_id",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap());
}
match cache.0.channel(Id::new(channel_id)) {
Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()),
None => status_code(StatusCode::NOT_FOUND, "".to_string())
}
})
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/@me",
get(|State(cache): State<Arc<DiscordCache>>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move {
if guild_id == 0 {
return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap());
}
match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.client_id).await {
Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()),
Err(err) => {
error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions");
status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string())
},
}
}),
)
.route(
"/guilds/:guild_id/channels/:channel_id/permissions/:user_id",
get(|| async { "todo" }),
)
.route(
"/guilds/:guild_id/channels/:channel_id/last_message",
get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }),
)
.route(
"/guilds/:guild_id/roles",
get(|State(cache): State<Arc<DiscordCache>>, Path(guild_id): Path<u64>| async move {
let role_ids = match cache.0.guild_roles(Id::new(guild_id)) {
Some(roles) => roles.to_owned(),
None => return status_code(StatusCode::NOT_FOUND, "".to_string()),
};
let mut roles = Vec::new();
for id in role_ids {
match cache.0.role(id) {
Some(role) => roles.push(role.value().resource().to_owned()),
None => {
tracing::error!(
role_id = id.get(),
"referenced role {} from guild {} not found in cache",
id.get(), guild_id,
);
return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string());
}
}
}
status_code(StatusCode::FOUND, to_string(&roles).unwrap())
})
)
.layer(axum::middleware::from_fn(crate::logger::logger))
.with_state(cache);
let addr: &str = libpk::config.api.addr.as_ref();
let listener = tokio::net::TcpListener::bind(addr).await?;
info!("listening on {}", addr);
axum::serve(listener, app).await?;
Ok(())
}

View file

@ -0,0 +1,339 @@
use anyhow::format_err;
use lazy_static::lazy_static;
use std::sync::Arc;
use twilight_cache_inmemory::{
model::CachedMember,
permission::{MemberRoles, RootError},
traits::CacheableChannel,
InMemoryCache, ResourceType,
};
use twilight_model::{
channel::{Channel, ChannelType},
guild::{Guild, Member, Permissions},
id::{
marker::{ChannelMarker, GuildMarker, UserMarker},
Id,
},
};
use twilight_util::permission_calculator::PermissionCalculator;
lazy_static! {
pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
| Permissions::SEND_MESSAGES
| Permissions::READ_MESSAGE_HISTORY
| Permissions::ADD_REACTIONS
| Permissions::ATTACH_FILES
| Permissions::EMBED_LINKS
| Permissions::USE_EXTERNAL_EMOJIS
| Permissions::CONNECT
| Permissions::SPEAK
| Permissions::USE_VAD;
}
pub fn dm_channel(id: Id<ChannelMarker>) -> Channel {
Channel {
id,
kind: ChannelType::Private,
application_id: None,
applied_tags: None,
available_tags: None,
bitrate: None,
default_auto_archive_duration: None,
default_forum_layout: None,
default_reaction_emoji: None,
default_sort_order: None,
default_thread_rate_limit_per_user: None,
flags: None,
guild_id: None,
icon: None,
invitable: None,
last_message_id: None,
last_pin_timestamp: None,
managed: None,
member: None,
member_count: None,
message_count: None,
name: None,
newly_created: None,
nsfw: None,
owner_id: None,
parent_id: None,
permission_overwrites: None,
position: None,
rate_limit_per_user: None,
recipients: None,
rtc_region: None,
thread_metadata: None,
topic: None,
user_limit: None,
video_quality_mode: None,
}
}
fn member_to_cached_member(item: Member, id: Id<UserMarker>) -> CachedMember {
CachedMember {
avatar: item.avatar,
communication_disabled_until: item.communication_disabled_until,
deaf: Some(item.deaf),
flags: item.flags,
joined_at: item.joined_at,
mute: Some(item.mute),
nick: item.nick,
premium_since: item.premium_since,
roles: item.roles,
pending: false,
user_id: id,
}
}
pub fn new() -> DiscordCache {
let mut client_builder =
twilight_http::Client::builder().token(libpk::config.discord.bot_token.clone());
if let Some(base_url) = libpk::config.discord.api_base_url.clone() {
client_builder = client_builder.proxy(base_url, true);
}
let client = Arc::new(client_builder.build());
let cache = Arc::new(
InMemoryCache::builder()
.resource_types(
ResourceType::GUILD
| ResourceType::CHANNEL
| ResourceType::ROLE
| ResourceType::USER_CURRENT
| ResourceType::MEMBER_CURRENT,
)
.message_cache_size(0)
.build(),
);
DiscordCache(cache, client)
}
pub struct DiscordCache(pub Arc<InMemoryCache>, pub Arc<twilight_http::Client>);
impl DiscordCache {
pub async fn guild_permissions(
&self,
guild_id: Id<GuildMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
if self
.0
.guild(guild_id)
.ok_or_else(|| format_err!("guild not found"))?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id == libpk::config.discord.client_id {
self.0
.member(guild_id, user_id)
.ok_or(format_err!("self member not found"))?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.root();
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
pub async fn channel_permissions(
&self,
channel_id: Id<ChannelMarker>,
user_id: Id<UserMarker>,
) -> anyhow::Result<Permissions> {
let channel = self
.0
.channel(channel_id)
.ok_or(format_err!("channel not found"))?;
if channel.value().guild_id.is_none() {
return Ok(*DM_PERMISSIONS);
}
let guild_id = channel.value().guild_id.unwrap();
if self
.0
.guild(guild_id)
.ok_or_else(|| {
tracing::error!(
channel_id = channel_id.get(),
guild_id = guild_id.get(),
"referenced guild from cached channel {channel_id} not found in cache"
);
format_err!("internal cache error")
})?
.owner_id()
== user_id
{
return Ok(Permissions::all());
}
let member = if user_id == libpk::config.discord.client_id {
self.0
.member(guild_id, user_id)
.ok_or_else(|| {
tracing::error!(
guild_id = guild_id.get(),
"self member for cached guild {guild_id} not found in cache"
);
format_err!("internal cache error")
})?
.value()
.to_owned()
} else {
member_to_cached_member(
self.1
.guild_member(guild_id, user_id)
.await?
.model()
.await?,
user_id,
)
};
let MemberRoles { assigned, everyone } = self
.0
.permissions()
.member_roles(guild_id, &member)
.map_err(RootError::from_member_roles)?;
let overwrites = match channel.kind {
ChannelType::AnnouncementThread
| ChannelType::PrivateThread
| ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?,
_ => channel
.value()
.permission_overwrites()
.unwrap_or_default()
.to_vec(),
};
let calculator =
PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice());
let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice());
Ok(self
.0
.permissions()
.disable_member_communication(&member, permissions))
}
// from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs
pub fn guild(&self, id: Id<GuildMarker>) -> Option<Guild> {
self.0.guild(id).map(|guild| {
let channels = self
.0
.guild_channels(id)
.map(|reference| {
reference
.iter()
.filter_map(|channel_id| {
let channel = self.0.channel(*channel_id)?;
if channel.kind.is_thread() {
None
} else {
Some(channel.value().clone())
}
})
.collect()
})
.unwrap_or_default();
let roles = self
.0
.guild_roles(id)
.map(|reference| {
reference
.iter()
.filter_map(|role_id| {
Some(self.0.role(*role_id)?.value().resource().clone())
})
.collect()
})
.unwrap_or_default();
Guild {
afk_channel_id: guild.afk_channel_id(),
afk_timeout: guild.afk_timeout(),
application_id: guild.application_id(),
approximate_member_count: None, // Only present in with_counts HTTP endpoint
banner: guild.banner().map(ToOwned::to_owned),
approximate_presence_count: None, // Only present in with_counts HTTP endpoint
channels,
default_message_notifications: guild.default_message_notifications(),
description: guild.description().map(ToString::to_string),
discovery_splash: guild.discovery_splash().map(ToOwned::to_owned),
emojis: vec![],
explicit_content_filter: guild.explicit_content_filter(),
features: guild.features().cloned().collect(),
icon: guild.icon().map(ToOwned::to_owned),
id: guild.id(),
joined_at: guild.joined_at(),
large: guild.large(),
max_members: guild.max_members(),
max_presences: guild.max_presences(),
max_video_channel_users: guild.max_video_channel_users(),
member_count: guild.member_count(),
members: vec![],
mfa_level: guild.mfa_level(),
name: guild.name().to_string(),
nsfw_level: guild.nsfw_level(),
owner_id: guild.owner_id(),
owner: guild.owner(),
permissions: guild.permissions(),
public_updates_channel_id: guild.public_updates_channel_id(),
preferred_locale: guild.preferred_locale().to_string(),
premium_progress_bar_enabled: guild.premium_progress_bar_enabled(),
premium_subscription_count: guild.premium_subscription_count(),
premium_tier: guild.premium_tier(),
presences: vec![],
roles,
rules_channel_id: guild.rules_channel_id(),
safety_alerts_channel_id: guild.safety_alerts_channel_id(),
splash: guild.splash().map(ToOwned::to_owned),
stage_instances: vec![],
stickers: vec![],
system_channel_flags: guild.system_channel_flags(),
system_channel_id: guild.system_channel_id(),
threads: vec![],
unavailable: false,
vanity_url_code: guild.vanity_url_code().map(ToString::to_string),
verification_level: guild.verification_level(),
voice_states: vec![],
widget_channel_id: guild.widget_channel_id(),
widget_enabled: guild.widget_enabled(),
}
})
}
}

View file

@ -0,0 +1,121 @@
use std::sync::{mpsc::Sender, Arc};
use tracing::{info, warn};
use twilight_gateway::{
create_iterator, ConfigBuilder, Event, EventTypeFlags, Shard, ShardId, StreamExt,
};
use twilight_model::gateway::{
payload::outgoing::update_presence::UpdatePresencePayload,
presence::{Activity, ActivityType, Status},
Intents,
};
use crate::discord::identify_queue::{self, RedisQueue};
use super::{cache::DiscordCache, shard_state::ShardStateManager};
pub fn create_shards(redis: fred::pool::RedisPool) -> anyhow::Result<Vec<Shard<RedisQueue>>> {
let intents = Intents::GUILDS
| Intents::DIRECT_MESSAGES
| Intents::DIRECT_MESSAGE_REACTIONS
| Intents::GUILD_MESSAGES
| Intents::GUILD_MESSAGE_REACTIONS
| Intents::MESSAGE_CONTENT;
let queue = identify_queue::new(redis);
let cluster_settings =
libpk::config
.discord
.cluster
.clone()
.unwrap_or(libpk::_config::ClusterSettings {
node_id: 0,
total_shards: 1,
total_nodes: 1,
});
let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 {
warn!("we have less than 16 shards, assuming single gateway process");
(0, (cluster_settings.total_shards - 1).into())
} else {
(
(cluster_settings.node_id * 16).into(),
(((cluster_settings.node_id + 1) * 16) - 1).into(),
)
};
let shards = create_iterator(
start_shard..end_shard + 1,
cluster_settings.total_shards,
ConfigBuilder::new(libpk::config.discord.bot_token.to_owned(), intents)
.presence(presence("pk;help", false))
.queue(queue.clone())
.build(),
|_, builder| builder.build(),
);
let mut shards_vec = Vec::new();
shards_vec.extend(shards);
Ok(shards_vec)
}
pub async fn runner(
mut shard: Shard<RedisQueue>,
tx: Sender<(ShardId, Event)>,
shard_state: ShardStateManager,
cache: Arc<DiscordCache>,
) {
//let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered();
info!("waiting for events");
while let Some(item) = shard.next_event(EventTypeFlags::all()).await {
match item {
Ok(event) => {
if let Err(error) = shard_state
.handle_event(shard.id().number(), event.clone())
.await
{
tracing::warn!(?error, "error updating redis state")
}
cache.0.update(&event);
//if let Err(error) = tx.send((shard.id(), event)) {
// tracing::warn!(?error, "error sending event to global handler: {error}",);
//}
}
Err(error) => {
tracing::warn!(?error, "error receiving event from shard {}", shard.id());
continue;
}
}
}
}
pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload {
UpdatePresencePayload {
activities: vec![Activity {
application_id: None,
assets: None,
buttons: vec![],
created_at: None,
details: None,
id: None,
state: None,
url: None,
emoji: None,
flags: None,
instance: None,
kind: ActivityType::Playing,
name: status.to_string(),
party: None,
secrets: None,
timestamps: None,
}],
afk: false,
since: None,
status: if going_away {
Status::Idle
} else {
Status::Online
},
}
}

View file

@ -0,0 +1,87 @@
use fred::{
error::RedisError,
interfaces::KeysInterface,
pool::RedisPool,
types::{Expiration, SetOptions},
};
use std::fmt::Debug;
use std::time::Duration;
use tokio::sync::oneshot;
use tracing::{error, info};
use twilight_gateway::queue::Queue;
use libpk::util::redis::RedisErrorExt;
pub fn new(redis: RedisPool) -> RedisQueue {
RedisQueue {
redis,
concurrency: libpk::config.discord.max_concurrency,
}
}
#[derive(Clone)]
pub struct RedisQueue {
pub redis: RedisPool,
pub concurrency: u32,
}
impl Debug for RedisQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisQueue")
.field("concurrency", &self.concurrency)
.finish()
}
}
impl Queue for RedisQueue {
fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> {
let (tx, rx) = oneshot::channel();
tokio::spawn(request_inner(
self.redis.clone(),
self.concurrency,
shard_id,
tx,
));
rx
}
}
const EXPIRY: i64 = 6;
const RETRY_INTERVAL: u64 = 500;
async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) {
let bucket = shard_id % concurrency;
let key = format!("pluralkit:identify:{}", bucket);
info!(shard_id, bucket, "waiting for allowance...");
loop {
let done: Result<Option<String>, RedisError> = redis
.set(
key.to_string(),
"1",
Some(Expiration::EX(EXPIRY)),
Some(SetOptions::NX),
false,
)
.await
.to_option_or_error();
match done {
Ok(Some(_)) => {
info!(shard_id, bucket, "got allowance!");
// if this fails, it's probably already doing something else
let _ = tx.send(());
return;
}
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
}
}
tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await;
}
}

View file

@ -0,0 +1,4 @@
pub mod cache;
pub mod gateway;
pub mod identify_queue;
pub mod shard_state;

View file

@ -0,0 +1,84 @@
use bytes::Bytes;
use fred::{interfaces::HashesInterface, pool::RedisPool};
use prost::Message;
use tracing::info;
use twilight_gateway::Event;
use libpk::{proto::*, util::redis::*};
#[derive(Clone)]
pub struct ShardStateManager {
redis: RedisPool,
}
pub fn new(redis: RedisPool) -> ShardStateManager {
ShardStateManager { redis }
}
impl ShardStateManager {
pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> {
match event {
Event::Ready(_) => self.ready_or_resumed(shard_id).await,
Event::Resumed => self.ready_or_resumed(shard_id).await,
Event::GatewayClose(_) => self.socket_closed(shard_id).await,
Event::GatewayHeartbeat(_) => self.heartbeated(shard_id).await,
_ => Ok(()),
}
}
async fn get_shard(&self, shard_id: u32) -> anyhow::Result<ShardState> {
let data: Option<Vec<u8>> = self
.redis
.hget("pluralkit:shardstatus", shard_id)
.await
.to_option_or_error()?;
match data {
Some(buf) => {
Ok(ShardState::decode(buf.as_slice()).expect("could not decode shard data!"))
}
None => Ok(ShardState::default()),
}
}
async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> {
self.redis
.hset(
"pluralkit:shardstatus",
(
shard_id.to_string(),
Bytes::copy_from_slice(&info.encode_to_vec()),
),
)
.await?;
Ok(())
}
async fn ready_or_resumed(&self, shard_id: u32) -> anyhow::Result<()> {
info!("shard {} ready", shard_id);
let mut info = self.get_shard(shard_id).await?;
info.last_connection = chrono::offset::Utc::now().timestamp() as i32;
info.up = true;
self.save_shard(shard_id, info).await?;
Ok(())
}
async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> {
info!("shard {} closed", shard_id);
let mut info = self.get_shard(shard_id).await?;
info.up = false;
info.disconnection_count += 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
async fn heartbeated(&self, shard_id: u32) -> anyhow::Result<()> {
let mut info = self.get_shard(shard_id).await?;
info.up = true;
info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32;
// todo
// info.latency = latency.recent().front().map_or_else(|| 0, |d| d.as_millis()) as i32;
info.latency = 1;
self.save_shard(shard_id, info).await?;
Ok(())
}
}

View file

@ -0,0 +1,52 @@
use std::time::Instant;
use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response};
use tracing::{info, span, warn, Instrument, Level};
// log any requests that take longer than 2 seconds
// todo: change as necessary
const MIN_LOG_TIME: u128 = 2_000;
pub async fn logger(request: Request, next: Next) -> Response {
let method = request.method().clone();
let endpoint = request
.extensions()
.get::<MatchedPath>()
.cloned()
.map(|v| v.as_str().to_string())
.unwrap_or("unknown".to_string());
let uri = request.uri().clone();
let request_id_span = span!(
Level::INFO,
"request",
method = method.as_str(),
endpoint = endpoint.clone(),
);
let start = Instant::now();
let response = next.run(request).instrument(request_id_span).await;
let elapsed = start.elapsed().as_millis();
info!(
"{} handled request for {} {} in {}ms",
response.status(),
method,
uri.path(),
elapsed
);
if elapsed > MIN_LOG_TIME {
warn!(
"request to {} full path {} (endpoint {}) took a long time ({}ms)!",
method,
uri.path(),
endpoint,
elapsed
)
}
response
}

View file

@ -0,0 +1,141 @@
use chrono::Timelike;
use fred::{interfaces::*, pool::RedisPool};
use signal_hook::{
consts::{SIGINT, SIGTERM},
iterator::Signals,
};
use std::{
sync::{mpsc::channel, Arc},
time::Duration,
vec::Vec,
};
use tokio::task::JoinSet;
use tracing::{info, warn};
use twilight_gateway::{MessageSender, ShardId};
use twilight_model::gateway::payload::outgoing::UpdatePresence;
mod cache_api;
mod discord;
mod logger;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
libpk::init_logging("gateway")?;
libpk::init_metrics()?;
info!("hello world");
let (shutdown_tx, shutdown_rx) = channel::<()>();
let shutdown_tx = Arc::new(shutdown_tx);
let redis = libpk::db::init_redis().await?;
let shard_state = discord::shard_state::new(redis.clone());
let cache = Arc::new(discord::cache::new());
let shards = discord::gateway::create_shards(redis.clone())?;
let (event_tx, _event_rx) = channel();
let mut senders = Vec::new();
let mut signal_senders = Vec::new();
let mut set = JoinSet::new();
for shard in shards {
senders.push((shard.id(), shard.sender()));
signal_senders.push(shard.sender());
set.spawn(tokio::spawn(discord::gateway::runner(
shard,
event_tx.clone(),
shard_state.clone(),
cache.clone(),
)));
}
set.spawn(tokio::spawn(
async move { scheduled_task(redis, senders).await },
));
// todo: probably don't do it this way
let api_shutdown_tx = shutdown_tx.clone();
set.spawn(tokio::spawn(async move {
match cache_api::run_server(cache).await {
Err(error) => {
tracing::error!(?error, "failed to serve cache api");
let _ = api_shutdown_tx.send(());
}
_ => unreachable!(),
}
}));
let mut signals = Signals::new(&[SIGINT, SIGTERM])?;
tokio::spawn(async move {
for sig in signals.forever() {
info!("received signal {:?}", sig);
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence("Restarting... (please wait)", true),
};
for sender in signal_senders.iter() {
let presence = presence.clone();
let _ = sender.command(&presence);
}
let _ = shutdown_tx.send(());
break;
}
});
let _ = shutdown_rx.recv();
// sleep 500ms to allow everything to clean up properly
tokio::time::sleep(Duration::from_millis(500)).await;
set.abort_all();
info!("gateway exiting, have a nice day!");
Ok(())
}
async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) {
loop {
tokio::time::sleep(Duration::from_secs(
(60 - chrono::offset::Utc::now().second()).into(),
))
.await;
info!("running per-minute scheduled tasks");
let status: Option<String> = match redis.get("pluralkit:botstatus").await {
Ok(val) => Some(val),
Err(error) => {
tracing::warn!(?error, "failed to fetch bot status from redis");
None
}
};
let presence = UpdatePresence {
op: twilight_model::gateway::OpCode::PresenceUpdate,
d: discord::gateway::presence(
if let Some(status) = status {
format!("pk;help | {}", status)
} else {
"pk;help".to_string()
}
.as_str(),
false,
),
};
for sender in senders.iter() {
match sender.1.command(&presence) {
Err(error) => {
warn!(?error, "could not update presence on shard {}", sender.0)
}
_ => {}
};
}
}
}