diff --git a/.cargo/config.toml b/.cargo/config.toml new file mode 100644 index 00000000..958d7fca --- /dev/null +++ b/.cargo/config.toml @@ -0,0 +1,3 @@ +[build] +rustflags = ["-C", "target-cpu=native"] + diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index 39bd4a5a..584ab179 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -3,16 +3,18 @@ # todo: don't use docker/build-push-action # todo: run builds on pull request -name: Build and push API Docker image +name: Build and push Rust service Docker images on: push: - branches: - - main paths: - 'lib/libpk/**' - 'services/api/**' + - 'services/gateway/**' - '.github/workflows/rust.yml' - 'Dockerfile.rust' + - 'Dockerfile.bin' + - 'Cargo.toml' + - 'Cargo.lock' jobs: deploy: @@ -45,7 +47,7 @@ jobs: # add more binaries here - run: | - for binary in "api"; do + for binary in "api" "gateway"; do for tag in latest ${{ env.BRANCH_NAME }} ${{ github.sha }}; do cat Dockerfile.bin | sed "s/__BINARY__/$binary/g" | docker build -t ghcr.io/pluralkit/$binary:$tag -f - . done diff --git a/Cargo.lock b/Cargo.lock index 2b79edd1..2765ac8d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -329,9 +329,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.4.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "89b2fd2a0dcf38d7971e2194b6b6eebab45ae01067456a7fd93d5547a61b70be" +checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" [[package]] name = "bytes-utils" @@ -363,7 +363,9 @@ checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ "android-tzdata", "iana-time-zone", + "js-sys", "num-traits", + "wasm-bindgen", "windows-targets 0.52.6", ] @@ -428,6 +430,16 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "396de984970346b0d9e93d1415082923c679e5ae5c3ee3dcbd104f5610af126b" +[[package]] +name = "core-foundation" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -464,6 +476,15 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "338089f42c427b86394a5ee60ff321da23a5c89c9d89514c829687b26359fcff" +[[package]] +name = "crc32fast" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" +dependencies = [ + "cfg-if", +] + [[package]] name = "crossbeam-epoch" version = "0.9.14" @@ -508,6 +529,19 @@ dependencies = [ "typenum", ] +[[package]] +name = "dashmap" +version = "5.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "907076dfda823b0b36d2a1bb5f90c96660a5bbcd7729e10727f07858f22c4edc" +dependencies = [ + "cfg-if", + "hashbrown 0.12.3", + "lock_api", + "once_cell", + "parking_lot_core 0.9.7", +] + [[package]] name = "data-encoding" version = "2.6.0" @@ -525,6 +559,15 @@ dependencies = [ "zeroize", ] +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.9.0" @@ -661,6 +704,17 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80" +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "libz-sys", + "miniz_oxide", +] + [[package]] name = "float-cmp" version = "0.8.0" @@ -825,6 +879,30 @@ dependencies = [ "slab", ] +[[package]] +name = "gateway" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum 0.7.5", + "bytes", + "chrono", + "fred", + "futures", + "lazy_static", + "libpk", + "prost", + "serde_json", + "signal-hook", + "tokio", + "tracing", + "twilight-cache-inmemory", + "twilight-gateway", + "twilight-http", + "twilight-model", + "twilight-util", +] + [[package]] name = "generic-array" version = "0.14.6" @@ -881,6 +959,12 @@ dependencies = [ "tracing", ] +[[package]] +name = "hashbrown" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" + [[package]] name = "hashbrown" version = "0.13.2" @@ -1161,18 +1245,36 @@ dependencies = [ [[package]] name = "hyper-rustls" -version = "0.27.2" +version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ee4be2c948921a1a5320b629c4193916ed787a7f7f293fd3f7f5a6c9de74155" +checksum = "a0bea761b46ae2b24eb4aef630d8d1c398157b6fc29e6350ecf090a0b70c952c" dependencies = [ "futures-util", "http 1.1.0", "hyper 1.3.1", "hyper-util", - "rustls", + "rustls 0.22.4", + "rustls-native-certs", "rustls-pki-types", "tokio", - "tokio-rustls", + "tokio-rustls 0.25.0", + "tower-service", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.3.1", + "hyper-util", + "rustls 0.23.10", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", "tower-service", "webpki-roots", ] @@ -1341,6 +1443,7 @@ dependencies = [ "tracing", "tracing-gelf", "tracing-subscriber", + "twilight-model", ] [[package]] @@ -1354,6 +1457,17 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "libz-sys" +version = "1.1.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c15da26e5af7e25c90b37a2d75cdbf940cf4a55316de9d84c679c9b8bfabf82e" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -1561,6 +1675,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + [[package]] name = "num-integer" version = "0.1.46" @@ -1622,6 +1742,21 @@ version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "624a8340c38c1b80fd549087862da4ba43e08858af025b236e509b6649fc13d5" +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + +[[package]] +name = "ordered-float" +version = "2.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68f19d67e5a2795c94e73e0bb1cc1a7edeb2e28efd39e2e1c9b7a40c1108b11c" +dependencies = [ + "num-traits", +] + [[package]] name = "ordered-multimap" version = "0.6.0" @@ -1832,6 +1967,12 @@ version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + [[package]] name = "ppv-lite86" version = "0.2.17" @@ -1952,7 +2093,7 @@ dependencies = [ "quinn-proto", "quinn-udp", "rustc-hash", - "rustls", + "rustls 0.23.10", "socket2 0.5.7", "thiserror", "tokio", @@ -1969,7 +2110,7 @@ dependencies = [ "rand", "ring", "rustc-hash", - "rustls", + "rustls 0.23.10", "slab", "thiserror", "tinyvec", @@ -2137,7 +2278,7 @@ dependencies = [ "http-body 1.0.0", "http-body-util", "hyper 1.3.1", - "hyper-rustls", + "hyper-rustls 0.27.3", "hyper-util", "ipnet", "js-sys", @@ -2147,7 +2288,7 @@ dependencies = [ "percent-encoding", "pin-project-lite", "quinn", - "rustls", + "rustls 0.23.10", "rustls-pemfile", "rustls-pki-types", "serde", @@ -2155,7 +2296,7 @@ dependencies = [ "serde_urlencoded", "sync_wrapper 1.0.1", "tokio", - "tokio-rustls", + "tokio-rustls 0.26.0", "tower-service", "url", "wasm-bindgen", @@ -2263,9 +2404,22 @@ dependencies = [ [[package]] name = "rustls" -version = "0.23.12" +version = "0.22.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c58f8c84392efc0a126acce10fa59ff7b3d2ac06ab451a33f2741989b806b044" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls" +version = "0.23.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05cff451f60db80f490f3c182b77c35260baace73209e9cdbbe526bfe3a4d402" dependencies = [ "once_cell", "ring", @@ -2276,10 +2430,23 @@ dependencies = [ ] [[package]] -name = "rustls-pemfile" -version = "2.1.3" +name = "rustls-native-certs" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "196fe16b00e106300d3e45ecfcb764fa292a535d7326a29a5875c579c7417425" +checksum = "8f1fb85efa936c42c6d5fc28d2629bb51e4b2f4b8a5211e297d599cc5a093792" +dependencies = [ + "openssl-probe", + "rustls-pemfile", + "rustls-pki-types", + "schannel", + "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "2.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29993a25686778eb88d4189742cd713c9bce943bc54251a33509dc63cbacf73d" dependencies = [ "base64 0.22.1", "rustls-pki-types", @@ -2287,15 +2454,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.8.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.6" +version = "0.102.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" +checksum = "ff448f7e92e913c4b7d4c6d8e4540a1724b319b4152b8aef6d4cf8339712b33e" dependencies = [ "ring", "rustls-pki-types", @@ -2314,12 +2481,44 @@ version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7b4b9743ed687d4b4bcedf9ff5eaa7398495ae14e61cba0a295704edbc7decde" +[[package]] +name = "schannel" +version = "0.1.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbc91545643bcf3a0bbb6569265615222618bdf33ce4ffbbd13c4bbd4c093534" +dependencies = [ + "windows-sys 0.52.0", +] + [[package]] name = "scopeguard" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d29ab0c6d3fc0ee92fe66e2d99f700eab17a8d57d1c1d3b748380fb20baa78cd" +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.5.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "317936bbbd05227752583946b9e66d7ce3b489f84e11a94a510b4437fef407d7" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "semver" version = "1.0.16" @@ -2335,6 +2534,16 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-value" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3a1a3341211875ef120e117ea7fd5228530ae7e7036a779fdc9117be6b3282c" +dependencies = [ + "ordered-float", + "serde", +] + [[package]] name = "serde_derive" version = "1.0.203" @@ -2366,6 +2575,17 @@ dependencies = [ "serde", ] +[[package]] +name = "serde_repr" +version = "0.1.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.66", +] + [[package]] name = "serde_spanned" version = "0.6.8" @@ -2411,6 +2631,12 @@ dependencies = [ "digest 0.10.7", ] +[[package]] +name = "sha1_smol" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae1a47186c03a32177042e55dbc5fd5aee900b8e0069a8d70fba96a9375cd012" + [[package]] name = "sha2" version = "0.10.6" @@ -2431,6 +2657,16 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "signal-hook" +version = "0.3.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" +dependencies = [ + "libc", + "signal-hook-registry", +] + [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -2450,6 +2686,12 @@ dependencies = [ "rand_core", ] +[[package]] +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + [[package]] name = "sketches-ddsketch" version = "0.2.0" @@ -2828,6 +3070,36 @@ dependencies = [ "once_cell", ] +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +dependencies = [ + "num-conv", + "time-core", +] + [[package]] name = "tiny-keccak" version = "2.0.2" @@ -2881,13 +3153,24 @@ dependencies = [ "syn 2.0.66", ] +[[package]] +name = "tokio-rustls" +version = "0.25.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "775e0c0f0adb3a2f22a00c4745d728b479985fc15ee7ca6a2608388c5569860f" +dependencies = [ + "rustls 0.22.4", + "rustls-pki-types", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.26.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" dependencies = [ - "rustls", + "rustls 0.23.10", "rustls-pki-types", "tokio", ] @@ -2930,6 +3213,30 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-websockets" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "988c6e20955aa5043e0822cb27093ebaabb430a126cda0223824b6d65ea900c1" +dependencies = [ + "base64 0.21.7", + "bytes", + "fastrand", + "futures-core", + "futures-sink", + "http 1.1.0", + "httparse", + "ring", + "rustls-native-certs", + "rustls-pki-types", + "sha1_smol", + "simdutf8", + "tokio", + "tokio-rustls 0.25.0", + "tokio-util 0.7.12", + "tracing", +] + [[package]] name = "toml" version = "0.8.19" @@ -3140,6 +3447,105 @@ version = "0.2.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3528ecfd12c466c6f163363caf2d02a71161dd5e1cc6ae7b34207ea2d42d81ed" +[[package]] +name = "twilight-cache-inmemory" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "bitflags 2.5.0", + "dashmap", + "serde", + "twilight-model", + "twilight-util", +] + +[[package]] +name = "twilight-gateway" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "bitflags 2.5.0", + "fastrand", + "flate2", + "futures-core", + "futures-sink", + "serde", + "serde_json", + "tokio", + "tokio-websockets", + "tracing", + "twilight-gateway-queue", + "twilight-http", + "twilight-model", +] + +[[package]] +name = "twilight-gateway-queue" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "tokio", + "tracing", +] + +[[package]] +name = "twilight-http" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "fastrand", + "http 1.1.0", + "http-body-util", + "hyper 1.3.1", + "hyper-rustls 0.26.0", + "hyper-util", + "percent-encoding", + "serde", + "serde_json", + "tokio", + "tracing", + "twilight-http-ratelimiting", + "twilight-model", + "twilight-validate", +] + +[[package]] +name = "twilight-http-ratelimiting" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "tokio", + "tracing", +] + +[[package]] +name = "twilight-model" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "bitflags 2.5.0", + "serde", + "serde-value", + "serde_repr", + "time", +] + +[[package]] +name = "twilight-util" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "twilight-model", +] + +[[package]] +name = "twilight-validate" +version = "0.16.0-rc.1" +source = "git+https://github.com/pluralkit/twilight#5027119d689c9c5aff8dac73b676995bb7e0e3b1" +dependencies = [ + "twilight-model", +] + [[package]] name = "typenum" version = "1.16.0" diff --git a/Cargo.toml b/Cargo.toml index d882f150..9428e560 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -2,22 +2,40 @@ members = [ "./lib/libpk", "./services/api", - "./services/dispatch" + "./services/dispatch", + "./services/gateway" ] [workspace.dependencies] anyhow = "1" axum = "0.7.5" +axum-macros = "0.4.1" +bytes = "1.6.0" +chrono = "0.4" fred = { version = "5.2.0", default-features = false, features = ["tracing", "pool-prefer-active"] } +futures = "0.3.30" lazy_static = "1.4.0" metrics = "0.23.0" serde = "1.0.152" serde_json = "1.0.117" +signal-hook = "0.3.17" sqlx = { version = "0.7.4", features = ["runtime-tokio", "postgres", "chrono", "macros"] } tokio = { version = "1.25.0", features = ["full"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["env-filter", "json"] } +twilight-gateway = { git = "https://github.com/pluralkit/twilight" } +twilight-cache-inmemory = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] } +twilight-util = { git = "https://github.com/pluralkit/twilight", features = ["permission-calculator"] } +twilight-model = { git = "https://github.com/pluralkit/twilight" } +twilight-http = { git = "https://github.com/pluralkit/twilight", default-features = false, features = ["rustls-native-roots"] } + +#twilight-gateway = { path = "../twilight/twilight-gateway" } +#twilight-cache-inmemory = { path = "../twilight/twilight-cache-inmemory", features = ["permission-calculator"] } +#twilight-util = { path = "../twilight/twilight-util", features = ["permission-calculator"] } +#twilight-model = { path = "../twilight/twilight-model" } +#twilight-http = { path = "../twilight/twilight-http", default-features = false, features = ["rustls-native-roots"] } + prost = "0.12" prost-types = "0.12" prost-build = "0.12" diff --git a/Dockerfile.rust b/Dockerfile.rust index 8f6341ed..84a33453 100644 --- a/Dockerfile.rust +++ b/Dockerfile.rust @@ -27,9 +27,12 @@ COPY proto/ /build/proto # this needs to match workspaces in Cargo.toml COPY lib/libpk /build/lib/libpk COPY services/api/ /build/services/api +COPY services/gateway/ /build/services/gateway RUN cargo build --bin api --release --target x86_64-unknown-linux-musl +RUN cargo build --bin gateway --release --target x86_64-unknown-linux-musl FROM scratch COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/api /api +COPY --from=binary-builder /build/target/x86_64-unknown-linux-musl/release/gateway /gateway diff --git a/Myriad/Cache/DiscordCacheExtensions.cs b/Myriad/Cache/DiscordCacheExtensions.cs index 6016bc84..08e2e971 100644 --- a/Myriad/Cache/DiscordCacheExtensions.cs +++ b/Myriad/Cache/DiscordCacheExtensions.cs @@ -100,15 +100,18 @@ public static class DiscordCacheExtensions await cache.SaveChannel(thread); } - public static async Task BotPermissionsIn(this IDiscordCache cache, ulong channelId) + public static async Task BotPermissionsIn(this IDiscordCache cache, ulong guildId, ulong channelId) { - var channel = await cache.GetRootChannel(channelId); + if (cache is HttpDiscordCache) + return await ((HttpDiscordCache)cache).BotPermissions(guildId, channelId); + + var channel = await cache.GetRootChannel(guildId, channelId); if (channel.GuildId != null) { var userId = cache.GetOwnUser(); var member = await cache.TryGetSelfMember(channel.GuildId.Value); - return await cache.PermissionsFor2(channelId, userId, member); + return await cache.PermissionsFor2(guildId, channelId, userId, member); } return PermissionSet.Dm; diff --git a/Myriad/Cache/HTTPDiscordCache.cs b/Myriad/Cache/HTTPDiscordCache.cs new file mode 100644 index 00000000..57d4e9b3 --- /dev/null +++ b/Myriad/Cache/HTTPDiscordCache.cs @@ -0,0 +1,75 @@ +using Serilog; +using System.Net; +using System.Text.Json; + +using Myriad.Serialization; +using Myriad.Types; + +namespace Myriad.Cache; + +public class HttpDiscordCache: IDiscordCache +{ + private readonly ILogger _logger; + private readonly HttpClient _client; + private readonly string _cacheEndpoint; + private readonly ulong _ownUserId; + + private readonly JsonSerializerOptions _jsonSerializerOptions; + + public HttpDiscordCache(ILogger logger, HttpClient client, string cacheEndpoint, ulong ownUserId) + { + _logger = logger; + _client = client; + _cacheEndpoint = cacheEndpoint; + _ownUserId = ownUserId; + _jsonSerializerOptions = new JsonSerializerOptions().ConfigureForMyriad(); + } + + public ValueTask SaveGuild(Guild guild) => default; + public ValueTask SaveChannel(Channel channel) => default; + public ValueTask SaveUser(User user) => default; + public ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member) => default; + public ValueTask SaveRole(ulong guildId, Myriad.Types.Role role) => default; + public ValueTask SaveDmChannelStub(ulong channelId) => default; + public ValueTask RemoveGuild(ulong guildId) => default; + public ValueTask RemoveChannel(ulong channelId) => default; + public ValueTask RemoveUser(ulong userId) => default; + public ValueTask RemoveRole(ulong guildId, ulong roleId) => default; + + public ulong GetOwnUser() => _ownUserId; + + // todo: cluster + private async Task QueryCache(string endpoint) + { + var response = await _client.GetAsync($"{_cacheEndpoint}{endpoint}"); + + if (response.StatusCode == HttpStatusCode.NotFound) + return default; + + if (response.StatusCode != HttpStatusCode.Found) + throw new Exception($"failed to query http cache: {response.StatusCode}"); + + var plaintext = await response.Content.ReadAsStringAsync(); + return JsonSerializer.Deserialize(plaintext, _jsonSerializerOptions); + } + + public Task TryGetGuild(ulong guildId) + => QueryCache($"/guilds/{guildId}"); + + public Task TryGetChannel(ulong guildId, ulong channelId) + => QueryCache($"/guilds/{guildId}/channels/{channelId}"); + + // this should be a GetUserCached method on nirn-proxy (it's always called as GetOrFetchUser) + // so just return nothing + public Task TryGetUser(ulong userId) + => Task.FromResult(null); + + public Task TryGetSelfMember(ulong guildId) + => QueryCache($"/guilds/{guildId}/members/@me"); + + public Task BotPermissions(ulong guildId, ulong channelId) + => QueryCache($"/guilds/{guildId}/channels/{channelId}/permissions/@me"); + + public Task> GetGuildChannels(ulong guildId) + => QueryCache>($"/guilds/{guildId}/channels"); +} \ No newline at end of file diff --git a/Myriad/Cache/IDiscordCache.cs b/Myriad/Cache/IDiscordCache.cs index e1b4fad4..f7a49bf6 100644 --- a/Myriad/Cache/IDiscordCache.cs +++ b/Myriad/Cache/IDiscordCache.cs @@ -18,11 +18,9 @@ public interface IDiscordCache internal ulong GetOwnUser(); public Task TryGetGuild(ulong guildId); - public Task TryGetChannel(ulong channelId); + public Task TryGetChannel(ulong guildId, ulong channelId); public Task TryGetUser(ulong userId); public Task TryGetSelfMember(ulong guildId); - public Task TryGetRole(ulong roleId); - public IAsyncEnumerable GetAllGuilds(); public Task> GetGuildChannels(ulong guildId); } \ No newline at end of file diff --git a/Myriad/Cache/MemoryDiscordCache.cs b/Myriad/Cache/MemoryDiscordCache.cs index 193d4606..e4ab34b9 100644 --- a/Myriad/Cache/MemoryDiscordCache.cs +++ b/Myriad/Cache/MemoryDiscordCache.cs @@ -137,7 +137,7 @@ public class MemoryDiscordCache: IDiscordCache return Task.FromResult(cg?.Guild); } - public Task TryGetChannel(ulong channelId) + public Task TryGetChannel(ulong _, ulong channelId) { _channels.TryGetValue(channelId, out var channel); return Task.FromResult(channel); @@ -155,19 +155,6 @@ public class MemoryDiscordCache: IDiscordCache return Task.FromResult(guildMember); } - public Task TryGetRole(ulong roleId) - { - _roles.TryGetValue(roleId, out var role); - return Task.FromResult(role); - } - - public IAsyncEnumerable GetAllGuilds() - { - return _guilds.Values - .Select(g => g.Guild) - .ToAsyncEnumerable(); - } - public Task> GetGuildChannels(ulong guildId) { if (!_guilds.TryGetValue(guildId, out var guild)) diff --git a/Myriad/Cache/RedisDiscordCache.cs b/Myriad/Cache/RedisDiscordCache.cs deleted file mode 100644 index fff7920e..00000000 --- a/Myriad/Cache/RedisDiscordCache.cs +++ /dev/null @@ -1,340 +0,0 @@ -using Google.Protobuf; - -using StackExchange.Redis; -using StackExchange.Redis.KeyspaceIsolation; - -using Serilog; - -using Myriad.Types; - -namespace Myriad.Cache; - -#pragma warning disable 4014 -public class RedisDiscordCache: IDiscordCache -{ - private readonly ILogger _logger; - private readonly ulong _ownUserId; - public RedisDiscordCache(ILogger logger, ulong ownUserId) - { - _logger = logger; - _ownUserId = ownUserId; - } - - private ConnectionMultiplexer _redis { get; set; } - - public async Task InitAsync(string addr) - { - _redis = await ConnectionMultiplexer.ConnectAsync(addr); - } - - private IDatabase db => _redis.GetDatabase().WithKeyPrefix("discord:"); - - public async ValueTask SaveGuild(Guild guild) - { - _logger.Verbose("Saving guild {GuildId} to redis", guild.Id); - - var g = new CachedGuild(); - g.Id = guild.Id; - g.Name = guild.Name; - g.OwnerId = guild.OwnerId; - g.PremiumTier = (int)guild.PremiumTier; - - var tr = db.CreateTransaction(); - - tr.HashSetAsync("guilds", guild.Id.HashWrapper(g)); - - foreach (var role in guild.Roles) - { - // Don't call SaveRole because that updates guild state - // and we just got a brand new one :) - // actually with redis it doesn't update guild state, but we're still doing it here because transaction - tr.HashSetAsync("roles", role.Id.HashWrapper(new CachedRole() - { - Id = role.Id, - Name = role.Name, - Position = role.Position, - Permissions = (ulong)role.Permissions, - Mentionable = role.Mentionable, - })); - - tr.HashSetAsync($"guild_roles:{guild.Id}", role.Id, true, When.NotExists); - } - - await tr.ExecuteAsync(); - } - - public async ValueTask SaveChannel(Channel channel) - { - _logger.Verbose("Saving channel {ChannelId} to redis", channel.Id); - - await db.HashSetAsync("channels", channel.Id.HashWrapper(channel.ToProtobuf())); - - if (channel.GuildId != null) - await db.HashSetAsync($"guild_channels:{channel.GuildId.Value}", channel.Id, true, When.NotExists); - - // todo: use a transaction for this? - if (channel.Recipients != null) - foreach (var recipient in channel.Recipients) - await SaveUser(recipient); - } - - public async ValueTask SaveUser(User user) - { - _logger.Verbose("Saving user {UserId} to redis", user.Id); - - var u = new CachedUser() - { - Id = user.Id, - Username = user.Username, - Discriminator = user.Discriminator, - Bot = user.Bot, - }; - - if (user.Avatar != null) - u.Avatar = user.Avatar; - - await db.HashSetAsync("users", user.Id.HashWrapper(u)); - } - - public async ValueTask SaveSelfMember(ulong guildId, GuildMemberPartial member) - { - _logger.Verbose("Saving self member for guild {GuildId} to redis", guildId); - - var gm = new CachedGuildMember(); - foreach (var role in member.Roles) - gm.Roles.Add(role); - - await db.HashSetAsync("members", guildId.HashWrapper(gm)); - } - - public async ValueTask SaveRole(ulong guildId, Myriad.Types.Role role) - { - _logger.Verbose("Saving role {RoleId} in {GuildId} to redis", role.Id, guildId); - - await db.HashSetAsync("roles", role.Id.HashWrapper(new CachedRole() - { - Id = role.Id, - Mentionable = role.Mentionable, - Name = role.Name, - Permissions = (ulong)role.Permissions, - Position = role.Position, - })); - - await db.HashSetAsync($"guild_roles:{guildId}", role.Id, true, When.NotExists); - } - - public async ValueTask SaveDmChannelStub(ulong channelId) - { - // Use existing channel object if present, otherwise add a stub - // We may get a message create before channel create and we want to have it saved - - if (await TryGetChannel(channelId) == null) - await db.HashSetAsync("channels", channelId.HashWrapper(new CachedChannel() - { - Id = channelId, - Type = (int)Channel.ChannelType.Dm, - })); - } - - public async ValueTask RemoveGuild(ulong guildId) - => await db.HashDeleteAsync("guilds", guildId); - - public async ValueTask RemoveChannel(ulong channelId) - { - var oldChannel = await TryGetChannel(channelId); - - if (oldChannel == null) - return; - - await db.HashDeleteAsync("channels", channelId); - - if (oldChannel.GuildId != null) - await db.HashDeleteAsync($"guild_channels:{oldChannel.GuildId.Value}", oldChannel.Id); - } - - public async ValueTask RemoveUser(ulong userId) - => await db.HashDeleteAsync("users", userId); - - public ulong GetOwnUser() => _ownUserId; - - public async ValueTask RemoveRole(ulong guildId, ulong roleId) - { - await db.HashDeleteAsync("roles", roleId); - await db.HashDeleteAsync($"guild_roles:{guildId}", roleId); - } - - public async Task TryGetGuild(ulong guildId) - { - var redisGuild = await db.HashGetAsync("guilds", guildId); - if (redisGuild.IsNullOrEmpty) - return null; - - var guild = ((byte[])redisGuild).Unmarshal(); - - var redisRoles = await db.HashGetAllAsync($"guild_roles:{guildId}"); - - // todo: put this in a transaction or something - var roles = await Task.WhenAll(redisRoles.Select(r => TryGetRole((ulong)r.Name))); - -#pragma warning disable 8619 - return guild.FromProtobuf() with { Roles = roles }; -#pragma warning restore 8619 - } - - public async Task TryGetChannel(ulong channelId) - { - var redisChannel = await db.HashGetAsync("channels", channelId); - if (redisChannel.IsNullOrEmpty) - return null; - - return ((byte[])redisChannel).Unmarshal().FromProtobuf(); - } - - public async Task TryGetUser(ulong userId) - { - var redisUser = await db.HashGetAsync("users", userId); - if (redisUser.IsNullOrEmpty) - return null; - - return ((byte[])redisUser).Unmarshal().FromProtobuf(); - } - - public async Task TryGetSelfMember(ulong guildId) - { - var redisMember = await db.HashGetAsync("members", guildId); - if (redisMember.IsNullOrEmpty) - return null; - - return new GuildMemberPartial() - { - Roles = ((byte[])redisMember).Unmarshal().Roles.ToArray() - }; - } - - public async Task TryGetRole(ulong roleId) - { - var redisRole = await db.HashGetAsync("roles", roleId); - if (redisRole.IsNullOrEmpty) - return null; - - var role = ((byte[])redisRole).Unmarshal(); - - return new Myriad.Types.Role() - { - Id = role.Id, - Name = role.Name, - Position = role.Position, - Permissions = (PermissionSet)role.Permissions, - Mentionable = role.Mentionable, - }; - } - - public IAsyncEnumerable GetAllGuilds() - { - // return _guilds.Values - // .Select(g => g.Guild) - // .ToAsyncEnumerable(); - return new Guild[] { }.ToAsyncEnumerable(); - } - - public async Task> GetGuildChannels(ulong guildId) - { - var redisChannels = await db.HashGetAllAsync($"guild_channels:{guildId}"); - if (redisChannels.Length == 0) - throw new ArgumentException("Guild not found", nameof(guildId)); - -#pragma warning disable 8619 - return await Task.WhenAll(redisChannels.Select(c => TryGetChannel((ulong)c.Name))); -#pragma warning restore 8619 - } -} - -internal static class CacheProtoExt -{ - public static Guild FromProtobuf(this CachedGuild guild) - => new Guild() - { - Id = guild.Id, - Name = guild.Name, - OwnerId = guild.OwnerId, - PremiumTier = (PremiumTier)guild.PremiumTier, - }; - - public static CachedChannel ToProtobuf(this Channel channel) - { - var c = new CachedChannel(); - c.Id = channel.Id; - c.Type = (int)channel.Type; - if (channel.Position != null) - c.Position = channel.Position.Value; - c.Name = channel.Name; - if (channel.PermissionOverwrites != null) - foreach (var overwrite in channel.PermissionOverwrites) - c.PermissionOverwrites.Add(new Overwrite() - { - Id = overwrite.Id, - Type = (int)overwrite.Type, - Allow = (ulong)overwrite.Allow, - Deny = (ulong)overwrite.Deny, - }); - if (channel.GuildId != null) - c.GuildId = channel.GuildId.Value; - - return c; - } - - public static Channel FromProtobuf(this CachedChannel channel) - => new Channel() - { - Id = channel.Id, - Type = (Channel.ChannelType)channel.Type, - Position = channel.Position, - Name = channel.Name, - PermissionOverwrites = channel.PermissionOverwrites - .Select(x => new Channel.Overwrite() - { - Id = x.Id, - Type = (Channel.OverwriteType)x.Type, - Allow = (PermissionSet)x.Allow, - Deny = (PermissionSet)x.Deny, - }).ToArray(), - GuildId = channel.HasGuildId ? channel.GuildId : null, - ParentId = channel.HasParentId ? channel.ParentId : null, - }; - - public static User FromProtobuf(this CachedUser user) - => new User() - { - Id = user.Id, - Username = user.Username, - Discriminator = user.Discriminator, - Avatar = user.HasAvatar ? user.Avatar : null, - Bot = user.Bot, - }; -} - -internal static class RedisExt -{ - // convenience method - public static HashEntry[] HashWrapper(this ulong key, T value) where T : IMessage - => new[] { new HashEntry(key, value.ToByteArray()) }; -} - -public static class ProtobufExt -{ - private static Dictionary _parser = new(); - - public static byte[] Marshal(this IMessage message) => message.ToByteArray(); - - public static T Unmarshal(this byte[] message) where T : IMessage, new() - { - var type = typeof(T).ToString(); - if (_parser.ContainsKey(type)) - return (T)_parser[type].ParseFrom(message); - else - { - _parser.Add(type, new MessageParser(() => new T())); - return Unmarshal(message); - } - } -} \ No newline at end of file diff --git a/Myriad/Extensions/CacheExtensions.cs b/Myriad/Extensions/CacheExtensions.cs index 17660002..0f6fb931 100644 --- a/Myriad/Extensions/CacheExtensions.cs +++ b/Myriad/Extensions/CacheExtensions.cs @@ -13,27 +13,13 @@ public static class CacheExtensions return guild; } - public static async Task GetChannel(this IDiscordCache cache, ulong channelId) + public static async Task GetChannel(this IDiscordCache cache, ulong guildId, ulong channelId) { - if (!(await cache.TryGetChannel(channelId) is Channel channel)) + if (!(await cache.TryGetChannel(guildId, channelId) is Channel channel)) throw new KeyNotFoundException($"Channel {channelId} not found in cache"); return channel; } - public static async Task GetUser(this IDiscordCache cache, ulong userId) - { - if (!(await cache.TryGetUser(userId) is User user)) - throw new KeyNotFoundException($"User {userId} not found in cache"); - return user; - } - - public static async Task GetRole(this IDiscordCache cache, ulong roleId) - { - if (!(await cache.TryGetRole(roleId) is Role role)) - throw new KeyNotFoundException($"Role {roleId} not found in cache"); - return role; - } - public static async ValueTask GetOrFetchUser(this IDiscordCache cache, DiscordApiClient rest, ulong userId) { @@ -47,9 +33,9 @@ public static class CacheExtensions } public static async ValueTask GetOrFetchChannel(this IDiscordCache cache, DiscordApiClient rest, - ulong channelId) + ulong guildId, ulong channelId) { - if (await cache.TryGetChannel(channelId) is { } cacheChannel) + if (await cache.TryGetChannel(guildId, channelId) is { } cacheChannel) return cacheChannel; var restChannel = await rest.GetChannel(channelId); @@ -58,13 +44,13 @@ public static class CacheExtensions return restChannel; } - public static async Task GetRootChannel(this IDiscordCache cache, ulong channelOrThread) + public static async Task GetRootChannel(this IDiscordCache cache, ulong guildId, ulong channelOrThread) { - var channel = await cache.GetChannel(channelOrThread); + var channel = await cache.GetChannel(guildId, channelOrThread); if (!channel.IsThread()) return channel; - var parent = await cache.GetChannel(channel.ParentId!.Value); + var parent = await cache.GetChannel(guildId, channel.ParentId!.Value); return parent; } } \ No newline at end of file diff --git a/Myriad/Extensions/PermissionExtensions.cs b/Myriad/Extensions/PermissionExtensions.cs index 503c5bf2..c55a73d3 100644 --- a/Myriad/Extensions/PermissionExtensions.cs +++ b/Myriad/Extensions/PermissionExtensions.cs @@ -32,23 +32,23 @@ public static class PermissionExtensions PermissionSet.EmbedLinks; public static Task PermissionsForMCE(this IDiscordCache cache, MessageCreateEvent message) => - PermissionsFor2(cache, message.ChannelId, message.Author.Id, message.Member, message.WebhookId != null); + PermissionsFor2(cache, message.GuildId ?? 0, message.ChannelId, message.Author.Id, message.Member, message.WebhookId != null); public static Task - PermissionsForMemberInChannel(this IDiscordCache cache, ulong channelId, GuildMember member) => - PermissionsFor2(cache, channelId, member.User.Id, member); + PermissionsForMemberInChannel(this IDiscordCache cache, ulong guildId, ulong channelId, GuildMember member) => + PermissionsFor2(cache, guildId, channelId, member.User.Id, member); - public static async Task PermissionsFor2(this IDiscordCache cache, ulong channelId, ulong userId, + public static async Task PermissionsFor2(this IDiscordCache cache, ulong guildId, ulong channelId, ulong userId, GuildMemberPartial? member, bool isThread = false) { - if (!(await cache.TryGetChannel(channelId) is Channel channel)) + if (!(await cache.TryGetChannel(guildId, channelId) is Channel channel)) // todo: handle channel not found better return PermissionSet.Dm; if (channel.GuildId == null) return PermissionSet.Dm; - var rootChannel = await cache.GetRootChannel(channelId); + var rootChannel = await cache.GetRootChannel(guildId, channelId); var guild = await cache.GetGuild(channel.GuildId.Value); diff --git a/PluralKit.Bot/ApplicationCommands/Message.cs b/PluralKit.Bot/ApplicationCommands/Message.cs index e40a24d2..d610fb13 100644 --- a/PluralKit.Bot/ApplicationCommands/Message.cs +++ b/PluralKit.Bot/ApplicationCommands/Message.cs @@ -63,14 +63,14 @@ public class ApplicationCommandProxiedMessage var messageId = ctx.Event.Data!.TargetId!.Value; // check for command messages - var (authorId, channelId) = await ctx.Services.Resolve().GetCommandMessage(messageId); - if (authorId != null) + var cmessage = await ctx.Services.Resolve().GetCommandMessage(messageId); + if (cmessage != null) { - if (authorId != ctx.User.Id) + if (cmessage.AuthorId != ctx.User.Id) throw new PKError("You can only delete command messages queried by this account."); - var isDM = (await _repo.GetDmChannel(ctx.User!.Id)) == channelId; - await DeleteMessageInner(ctx, channelId!.Value, messageId, isDM); + var isDM = (await _repo.GetDmChannel(ctx.User!.Id)) == cmessage.ChannelId; + await DeleteMessageInner(ctx, cmessage.GuildId, cmessage.ChannelId, messageId, isDM); return; } @@ -81,7 +81,7 @@ public class ApplicationCommandProxiedMessage if (message.System?.Id != ctx.System.Id && message.Message.Sender != ctx.User.Id) throw new PKError("You can only delete your own messages."); - await DeleteMessageInner(ctx, message.Message.Channel, message.Message.Mid, false); + await DeleteMessageInner(ctx, message.Message.Guild ?? 0, message.Message.Channel, message.Message.Mid, false); return; } @@ -89,9 +89,9 @@ public class ApplicationCommandProxiedMessage throw Errors.MessageNotFound(messageId); } - internal async Task DeleteMessageInner(InteractionContext ctx, ulong channelId, ulong messageId, bool isDM = false) + internal async Task DeleteMessageInner(InteractionContext ctx, ulong guildId, ulong channelId, ulong messageId, bool isDM = false) { - if (!((await _cache.BotPermissionsIn(channelId)).HasFlag(PermissionSet.ManageMessages) || isDM)) + if (!((await _cache.BotPermissionsIn(guildId, channelId)).HasFlag(PermissionSet.ManageMessages) || isDM)) throw new PKError("PluralKit does not have the *Manage Messages* permission in this channel, and thus cannot delete the message." + " Please contact a server administrator to remedy this."); @@ -110,7 +110,7 @@ public class ApplicationCommandProxiedMessage // (if not, PK shouldn't send messages on their behalf) var member = await _rest.GetGuildMember(ctx.GuildId, ctx.User.Id); var requiredPerms = PermissionSet.ViewChannel | PermissionSet.SendMessages; - if (member == null || !(await _cache.PermissionsForMemberInChannel(ctx.ChannelId, member)).HasFlag(requiredPerms)) + if (member == null || !(await _cache.PermissionsForMemberInChannel(ctx.GuildId, ctx.ChannelId, member)).HasFlag(requiredPerms)) { throw new PKError("You do not have permission to send messages in this channel."); }; diff --git a/PluralKit.Bot/Bot.cs b/PluralKit.Bot/Bot.cs index 506382ef..ec33705a 100644 --- a/PluralKit.Bot/Bot.cs +++ b/PluralKit.Bot/Bot.cs @@ -99,11 +99,13 @@ public class Bot private async Task OnEventReceived(int shardId, IGatewayEvent evt) { - // we HandleGatewayEvent **before** getting the own user, because the own user is set in HandleGatewayEvent for ReadyEvent - await _cache.HandleGatewayEvent(evt); - - await _cache.TryUpdateSelfMember(_config.ClientId, evt); + if (_cache is MemoryDiscordCache) + { + // we HandleGatewayEvent **before** getting the own user, because the own user is set in HandleGatewayEvent for ReadyEvent + await _cache.HandleGatewayEvent(evt); + await _cache.TryUpdateSelfMember(_config.ClientId, evt); + } await OnEventReceivedInner(shardId, evt); } @@ -175,7 +177,16 @@ public class Bot } using var _ = LogContext.PushProperty("EventId", Guid.NewGuid()); - using var __ = LogContext.Push(await serviceScope.Resolve().GetEnricher(shardId, evt)); + // this fails when cache lookup fails, so put it in a try-catch + try + { + using var __ = LogContext.Push(await serviceScope.Resolve().GetEnricher(shardId, evt)); + } + catch (Exception exc) + { + + await HandleError(handler, evt, serviceScope, exc); + } _logger.Verbose("Received gateway event: {@Event}", evt); try @@ -243,7 +254,7 @@ public class Bot if (!exc.ShowToUser()) return; // Once we've sent it to Sentry, report it to the user (if we have permission to) - var reportChannel = handler.ErrorChannelFor(evt, _config.ClientId); + var (guildId, reportChannel) = handler.ErrorChannelFor(evt, _config.ClientId); if (reportChannel == null) { if (evt is InteractionCreateEvent ice && ice.Type == Interaction.InteractionType.ApplicationCommand) @@ -251,7 +262,7 @@ public class Bot return; } - var botPerms = await _cache.BotPermissionsIn(reportChannel.Value); + var botPerms = await _cache.BotPermissionsIn(guildId ?? 0, reportChannel.Value); if (botPerms.HasFlag(PermissionSet.SendMessages | PermissionSet.EmbedLinks)) await _errorMessageService.SendErrorMessage(reportChannel.Value, sentryEvent.EventId.ToString()); } diff --git a/PluralKit.Bot/BotConfig.cs b/PluralKit.Bot/BotConfig.cs index 4621717e..05fc3943 100644 --- a/PluralKit.Bot/BotConfig.cs +++ b/PluralKit.Bot/BotConfig.cs @@ -20,7 +20,8 @@ public class BotConfig public string? GatewayQueueUrl { get; set; } public bool UseRedisRatelimiter { get; set; } = false; - public bool UseRedisCache { get; set; } = false; + + public string? HttpCacheUrl { get; set; } public string? RedisGatewayUrl { get; set; } diff --git a/PluralKit.Bot/CommandSystem/Context/Context.cs b/PluralKit.Bot/CommandSystem/Context/Context.cs index 99d4a39b..4719e57a 100644 --- a/PluralKit.Bot/CommandSystem/Context/Context.cs +++ b/PluralKit.Bot/CommandSystem/Context/Context.cs @@ -62,7 +62,7 @@ public class Context public readonly int ShardId; public readonly Cluster Cluster; - public Task BotPermissions => Cache.BotPermissionsIn(Channel.Id); + public Task BotPermissions => Cache.BotPermissionsIn(Guild?.Id ?? 0, Channel.Id); public Task UserPermissions => Cache.PermissionsForMCE((MessageCreateEvent)Message); @@ -100,7 +100,7 @@ public class Context // { // Sensitive information that might want to be deleted by :x: reaction is typically in an embed format (member cards, for example) // but since we can, we just store all sent messages for possible deletion - await _commandMessageService.RegisterMessage(msg.Id, msg.ChannelId, Author.Id); + await _commandMessageService.RegisterMessage(msg.Id, Guild?.Id ?? 0, msg.ChannelId, Author.Id); // } return msg; diff --git a/PluralKit.Bot/CommandSystem/Context/ContextEntityArgumentsExt.cs b/PluralKit.Bot/CommandSystem/Context/ContextEntityArgumentsExt.cs index 8bf273f1..c3bb1cb2 100644 --- a/PluralKit.Bot/CommandSystem/Context/ContextEntityArgumentsExt.cs +++ b/PluralKit.Bot/CommandSystem/Context/ContextEntityArgumentsExt.cs @@ -188,7 +188,8 @@ public static class ContextEntityArgumentsExt if (!MentionUtils.TryParseChannel(ctx.PeekArgument(), out var id)) return null; - var channel = await ctx.Cache.TryGetChannel(id); + // todo: match channels in other guilds + var channel = await ctx.Cache.TryGetChannel(ctx.Guild!.Id, id); if (channel == null) channel = await ctx.Rest.GetChannelOrNull(id); if (channel == null) diff --git a/PluralKit.Bot/Commands/Checks.cs b/PluralKit.Bot/Commands/Checks.cs index 8d88177a..6b68c7b3 100644 --- a/PluralKit.Bot/Commands/Checks.cs +++ b/PluralKit.Bot/Commands/Checks.cs @@ -143,6 +143,7 @@ public class Checks var error = "Channel not found or you do not have permissions to access it."; // todo: this breaks if channel is not in cache and bot does not have View Channel permissions + // with new cache it breaks if channel is not in current guild var channel = await ctx.MatchChannel(); if (channel == null || channel.GuildId == null) throw new PKError(error); @@ -156,7 +157,8 @@ public class Checks if (!await ctx.CheckPermissionsInGuildChannel(channel, PermissionSet.ViewChannel)) throw new PKError(error); - var botPermissions = await _cache.BotPermissionsIn(channel.Id); + // todo: permcheck channel outside of guild? + var botPermissions = await _cache.BotPermissionsIn(ctx.Guild.Id, channel.Id); // We use a bitfield so we can set individual permission bits ulong missingPermissions = 0; @@ -231,11 +233,11 @@ public class Checks var channel = await _rest.GetChannelOrNull(channelId.Value); if (channel == null) throw new PKError("Unable to get the channel associated with this message."); - - var rootChannel = await _cache.GetRootChannel(channel.Id); if (channel.GuildId == null) throw new PKError("PluralKit is not able to proxy messages in DMs."); + var rootChannel = await _cache.GetRootChannel(channel.GuildId!.Value, channel.Id); + // using channel.GuildId here since _rest.GetMessage() doesn't return the GuildId var context = await ctx.Repository.GetMessageContext(msg.Author.Id, channel.GuildId.Value, rootChannel.Id, msg.ChannelId); var members = (await ctx.Repository.GetProxyMembers(msg.Author.Id, channel.GuildId.Value)).ToList(); diff --git a/PluralKit.Bot/Commands/Message.cs b/PluralKit.Bot/Commands/Message.cs index ca423f94..3e02fede 100644 --- a/PluralKit.Bot/Commands/Message.cs +++ b/PluralKit.Bot/Commands/Message.cs @@ -218,7 +218,7 @@ public class ProxiedMessage try { var editedMsg = - await _webhookExecutor.EditWebhookMessage(msg.Channel, msg.Mid, newContent, clearEmbeds); + await _webhookExecutor.EditWebhookMessage(msg.Guild ?? 0, msg.Channel, msg.Mid, newContent, clearEmbeds); if (ctx.Guild == null) await _rest.CreateReaction(ctx.Channel.Id, ctx.Message.Id, new Emoji { Name = Emojis.Success }); @@ -436,14 +436,14 @@ public class ProxiedMessage private async Task DeleteCommandMessage(Context ctx, ulong messageId) { - var (authorId, channelId) = await ctx.Services.Resolve().GetCommandMessage(messageId); - if (authorId == null) + var cmessage = await ctx.Services.Resolve().GetCommandMessage(messageId); + if (cmessage == null) throw Errors.MessageNotFound(messageId); - if (authorId != ctx.Author.Id) + if (cmessage!.AuthorId != ctx.Author.Id) throw new PKError("You can only delete command messages queried by this account."); - await ctx.Rest.DeleteMessage(channelId!.Value, messageId); + await ctx.Rest.DeleteMessage(cmessage.ChannelId, messageId); if (ctx.Guild != null) await ctx.Rest.DeleteMessage(ctx.Message); diff --git a/PluralKit.Bot/Commands/ServerConfig.cs b/PluralKit.Bot/Commands/ServerConfig.cs index a33d39d3..f411be77 100644 --- a/PluralKit.Bot/Commands/ServerConfig.cs +++ b/PluralKit.Bot/Commands/ServerConfig.cs @@ -49,7 +49,7 @@ public class ServerConfig if (channel.Type != Channel.ChannelType.GuildText && channel.Type != Channel.ChannelType.GuildPublicThread && channel.Type != Channel.ChannelType.GuildPrivateThread) throw new PKError("PluralKit cannot log messages to this type of channel."); - var perms = await _cache.BotPermissionsIn(channel.Id); + var perms = await _cache.BotPermissionsIn(ctx.Guild.Id, channel.Id); if (!perms.HasFlag(PermissionSet.SendMessages)) throw new PKError("PluralKit is missing **Send Messages** permissions in the new log channel."); if (!perms.HasFlag(PermissionSet.EmbedLinks)) @@ -104,7 +104,7 @@ public class ServerConfig // Resolve all channels from the cache and order by position var channels = (await Task.WhenAll(blacklist.Blacklist - .Select(id => _cache.TryGetChannel(id)))) + .Select(id => _cache.TryGetChannel(ctx.Guild.Id, id)))) .Where(c => c != null) .OrderBy(c => c.Position) .ToList(); @@ -121,7 +121,7 @@ public class ServerConfig async (eb, l) => { async Task CategoryName(ulong? id) => - id != null ? (await _cache.GetChannel(id.Value)).Name : "(no category)"; + id != null ? (await _cache.GetChannel(ctx.Guild.Id, id.Value)).Name : "(no category)"; ulong? lastCategory = null; @@ -153,8 +153,9 @@ public class ServerConfig var config = await ctx.Repository.GetGuild(ctx.Guild.Id); // Resolve all channels from the cache and order by position + // todo: GetAllChannels? var channels = (await Task.WhenAll(config.LogBlacklist - .Select(id => _cache.TryGetChannel(id)))) + .Select(id => _cache.TryGetChannel(ctx.Guild.Id, id)))) .Where(c => c != null) .OrderBy(c => c.Position) .ToList(); @@ -171,7 +172,7 @@ public class ServerConfig async (eb, l) => { async Task CategoryName(ulong? id) => - id != null ? (await _cache.GetChannel(id.Value)).Name : "(no category)"; + id != null ? (await _cache.GetChannel(ctx.Guild.Id, id.Value)).Name : "(no category)"; ulong? lastCategory = null; diff --git a/PluralKit.Bot/Handlers/IEventHandler.cs b/PluralKit.Bot/Handlers/IEventHandler.cs index fd18849e..f6b6ba01 100644 --- a/PluralKit.Bot/Handlers/IEventHandler.cs +++ b/PluralKit.Bot/Handlers/IEventHandler.cs @@ -6,5 +6,5 @@ public interface IEventHandler where T : IGatewayEvent { Task Handle(int shardId, T evt); - ulong? ErrorChannelFor(T evt, ulong userId) => null; + (ulong?, ulong?) ErrorChannelFor(T evt, ulong userId) => (null, null); } \ No newline at end of file diff --git a/PluralKit.Bot/Handlers/MessageCreated.cs b/PluralKit.Bot/Handlers/MessageCreated.cs index aad80bd2..e0387505 100644 --- a/PluralKit.Bot/Handlers/MessageCreated.cs +++ b/PluralKit.Bot/Handlers/MessageCreated.cs @@ -52,7 +52,7 @@ public class MessageCreated: IEventHandler _dmCache = dmCache; } - public ulong? ErrorChannelFor(MessageCreateEvent evt, ulong userId) => evt.ChannelId; + public (ulong?, ulong?) ErrorChannelFor(MessageCreateEvent evt, ulong userId) => (evt.GuildId, evt.ChannelId); private bool IsDuplicateMessage(Message msg) => // We consider a message duplicate if it has the same ID as the previous message that hit the gateway _lastMessageCache.GetLastMessage(msg.ChannelId)?.Current.Id == msg.Id; @@ -63,7 +63,7 @@ public class MessageCreated: IEventHandler if (evt.Type != Message.MessageType.Default && evt.Type != Message.MessageType.Reply) return; if (IsDuplicateMessage(evt)) return; - var botPermissions = await _cache.BotPermissionsIn(evt.ChannelId); + var botPermissions = await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId); if (!botPermissions.HasFlag(PermissionSet.SendMessages)) return; // spawn off saving the private channel into another thread @@ -71,8 +71,8 @@ public class MessageCreated: IEventHandler _ = _dmCache.TrySavePrivateChannel(evt); var guild = evt.GuildId != null ? await _cache.GetGuild(evt.GuildId.Value) : null; - var channel = await _cache.GetChannel(evt.ChannelId); - var rootChannel = await _cache.GetRootChannel(evt.ChannelId); + var channel = await _cache.GetChannel(evt.GuildId ?? 0, evt.ChannelId); + var rootChannel = await _cache.GetRootChannel(evt.GuildId ?? 0, evt.ChannelId); // Log metrics and message info _metrics.Measure.Meter.Mark(BotMetrics.MessagesReceived); @@ -90,7 +90,8 @@ public class MessageCreated: IEventHandler if (await TryHandleCommand(shardId, evt, guild, channel)) return; - await TryHandleProxy(evt, guild, channel, rootChannel.Id, botPermissions); + if (evt.GuildId != null) + await TryHandleProxy(evt, guild, channel, rootChannel.Id, botPermissions); } private async Task TryHandleLogClean(Channel channel, MessageCreateEvent evt) diff --git a/PluralKit.Bot/Handlers/MessageEdited.cs b/PluralKit.Bot/Handlers/MessageEdited.cs index bb33bc15..2ba523f2 100644 --- a/PluralKit.Bot/Handlers/MessageEdited.cs +++ b/PluralKit.Bot/Handlers/MessageEdited.cs @@ -52,10 +52,12 @@ public class MessageEdited: IEventHandler if (!evt.Content.HasValue || !evt.Author.HasValue || !evt.Member.HasValue) return; - var channel = await _cache.GetChannel(evt.ChannelId); + var guildIdMaybe = evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0; + + var channel = await _cache.GetChannel(guildIdMaybe, evt.ChannelId); // todo: is this correct for message update? if (!DiscordUtils.IsValidGuildChannel(channel)) return; - var rootChannel = await _cache.GetRootChannel(channel.Id); + var rootChannel = await _cache.GetRootChannel(guildIdMaybe, channel.Id); var guild = await _cache.GetGuild(channel.GuildId!.Value); var lastMessage = _lastMessageCache.GetLastMessage(evt.ChannelId)?.Current; @@ -69,7 +71,7 @@ public class MessageEdited: IEventHandler ctx = await _repo.GetMessageContext(evt.Author.Value!.Id, channel.GuildId!.Value, rootChannel.Id, evt.ChannelId); var equivalentEvt = await GetMessageCreateEvent(evt, lastMessage, channel); - var botPermissions = await _cache.BotPermissionsIn(channel.Id); + var botPermissions = await _cache.BotPermissionsIn(guildIdMaybe, channel.Id); try { @@ -91,7 +93,7 @@ public class MessageEdited: IEventHandler private async Task GetMessageCreateEvent(MessageUpdateEvent evt, CachedMessage lastMessage, Channel channel) { - var referencedMessage = await GetReferencedMessage(evt.ChannelId, lastMessage.ReferencedMessage); + var referencedMessage = await GetReferencedMessage(evt.GuildId.HasValue ? evt.GuildId.Value ?? 0 : 0, evt.ChannelId, lastMessage.ReferencedMessage); var messageReference = lastMessage.ReferencedMessage != null ? new Message.Reference(channel.GuildId, evt.ChannelId, lastMessage.ReferencedMessage.Value) @@ -118,12 +120,12 @@ public class MessageEdited: IEventHandler return equivalentEvt; } - private async Task GetReferencedMessage(ulong channelId, ulong? referencedMessageId) + private async Task GetReferencedMessage(ulong guildId, ulong channelId, ulong? referencedMessageId) { if (referencedMessageId == null) return null; - var botPermissions = await _cache.BotPermissionsIn(channelId); + var botPermissions = await _cache.BotPermissionsIn(guildId, channelId); if (!botPermissions.HasFlag(PermissionSet.ReadMessageHistory)) { _logger.Warning( diff --git a/PluralKit.Bot/Handlers/ReactionAdded.cs b/PluralKit.Bot/Handlers/ReactionAdded.cs index de6a2c69..633d8f01 100644 --- a/PluralKit.Bot/Handlers/ReactionAdded.cs +++ b/PluralKit.Bot/Handlers/ReactionAdded.cs @@ -62,7 +62,7 @@ public class ReactionAdded: IEventHandler // but we aren't able to get DMs from bots anyway, so it's not really needed if (evt.GuildId != null && (evt.Member?.User?.Bot ?? false)) return; - var channel = await _cache.GetChannel(evt.ChannelId); + var channel = await _cache.GetChannel(evt.GuildId ?? 0, evt.ChannelId); // check if it's a command message first // since this can happen in DMs as well @@ -75,10 +75,10 @@ public class ReactionAdded: IEventHandler return; } - var (authorId, _) = await _commandMessageService.GetCommandMessage(evt.MessageId); - if (authorId != null) + var cmessage = await _commandMessageService.GetCommandMessage(evt.MessageId); + if (cmessage != null) { - await HandleCommandDeleteReaction(evt, authorId.Value, false); + await HandleCommandDeleteReaction(evt, cmessage.AuthorId, false); return; } } @@ -123,7 +123,7 @@ public class ReactionAdded: IEventHandler private async ValueTask HandleProxyDeleteReaction(MessageReactionAddEvent evt, PKMessage msg) { - if (!(await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) + if (!(await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) return; var isSameSystem = msg.Member != null && await _repo.IsMemberOwnedByAccount(msg.Member.Value, evt.UserId); @@ -150,7 +150,7 @@ public class ReactionAdded: IEventHandler if (authorId != null && authorId != evt.UserId) return; - if (!((await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages) || isDM)) + if (!((await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages) || isDM)) return; // todo: don't try to delete the user's own messages in DMs @@ -206,14 +206,14 @@ public class ReactionAdded: IEventHandler private async ValueTask HandlePingReaction(MessageReactionAddEvent evt, FullMessage msg) { - if (!(await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) + if (!(await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) return; // Check if the "pinger" has permission to send messages in this channel // (if not, PK shouldn't send messages on their behalf) var member = await _rest.GetGuildMember(evt.GuildId!.Value, evt.UserId); var requiredPerms = PermissionSet.ViewChannel | PermissionSet.SendMessages; - if (member == null || !(await _cache.PermissionsForMemberInChannel(evt.ChannelId, member)).HasFlag(requiredPerms)) return; + if (member == null || !(await _cache.PermissionsForMemberInChannel(evt.GuildId ?? 0, evt.ChannelId, member)).HasFlag(requiredPerms)) return; if (msg.Member == null) return; @@ -266,7 +266,7 @@ public class ReactionAdded: IEventHandler private async Task TryRemoveOriginalReaction(MessageReactionAddEvent evt) { - if ((await _cache.BotPermissionsIn(evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) + if ((await _cache.BotPermissionsIn(evt.GuildId ?? 0, evt.ChannelId)).HasFlag(PermissionSet.ManageMessages)) await _rest.DeleteUserReaction(evt.ChannelId, evt.MessageId, evt.Emoji, evt.UserId); } } \ No newline at end of file diff --git a/PluralKit.Bot/Init.cs b/PluralKit.Bot/Init.cs index 9e882370..28bc6ed9 100644 --- a/PluralKit.Bot/Init.cs +++ b/PluralKit.Bot/Init.cs @@ -56,8 +56,6 @@ public class Init await redis.InitAsync(coreConfig); var cache = services.Resolve(); - if (cache is RedisDiscordCache) - await (cache as RedisDiscordCache).InitAsync(coreConfig.RedisAddr); if (config.Cluster == null) { diff --git a/PluralKit.Bot/Modules.cs b/PluralKit.Bot/Modules.cs index 730908a9..f3cca508 100644 --- a/PluralKit.Bot/Modules.cs +++ b/PluralKit.Bot/Modules.cs @@ -48,8 +48,10 @@ public class BotModule: Module { var botConfig = c.Resolve(); - if (botConfig.UseRedisCache) - return new RedisDiscordCache(c.Resolve(), botConfig.ClientId); + if (botConfig.HttpCacheUrl != null) + return new HttpDiscordCache(c.Resolve(), + c.Resolve(), botConfig.HttpCacheUrl, botConfig.ClientId); + return new MemoryDiscordCache(botConfig.ClientId); }).AsSelf().SingleInstance(); builder.RegisterType().AsSelf().SingleInstance(); diff --git a/PluralKit.Bot/Proxy/ProxyService.cs b/PluralKit.Bot/Proxy/ProxyService.cs index 4c9b701a..83299d09 100644 --- a/PluralKit.Bot/Proxy/ProxyService.cs +++ b/PluralKit.Bot/Proxy/ProxyService.cs @@ -59,7 +59,7 @@ public class ProxyService public async Task HandleIncomingMessage(MessageCreateEvent message, MessageContext ctx, Guild guild, Channel channel, bool allowAutoproxy, PermissionSet botPermissions) { - var rootChannel = await _cache.GetRootChannel(message.ChannelId); + var rootChannel = await _cache.GetRootChannel(message.GuildId!.Value, message.ChannelId); if (!ShouldProxy(channel, rootChannel, message, ctx)) return false; @@ -207,8 +207,8 @@ public class ProxyService var content = match.ProxyContent; if (!allowEmbeds) content = content.BreakLinkEmbeds(); - var messageChannel = await _cache.GetChannel(trigger.ChannelId); - var rootChannel = await _cache.GetRootChannel(trigger.ChannelId); + var messageChannel = await _cache.GetChannel(trigger.GuildId!.Value, trigger.ChannelId); + var rootChannel = await _cache.GetRootChannel(trigger.GuildId!.Value, trigger.ChannelId); var threadId = messageChannel.IsThread() ? messageChannel.Id : (ulong?)null; var guild = await _cache.GetGuild(trigger.GuildId.Value); var guildMember = await _rest.GetGuildMember(trigger.GuildId!.Value, trigger.Author.Id); diff --git a/PluralKit.Bot/Services/CommandMessageService.cs b/PluralKit.Bot/Services/CommandMessageService.cs index 5ef84ad1..796f7f0a 100644 --- a/PluralKit.Bot/Services/CommandMessageService.cs +++ b/PluralKit.Bot/Services/CommandMessageService.cs @@ -18,7 +18,7 @@ public class CommandMessageService _logger = logger.ForContext(); } - public async Task RegisterMessage(ulong messageId, ulong channelId, ulong authorId) + public async Task RegisterMessage(ulong messageId, ulong guildId, ulong channelId, ulong authorId) { if (_redis.Connection == null) return; @@ -27,17 +27,19 @@ public class CommandMessageService messageId, authorId, channelId ); - await _redis.Connection.GetDatabase().StringSetAsync(messageId.ToString(), $"{authorId}-{channelId}", expiry: CommandMessageRetention); + await _redis.Connection.GetDatabase().StringSetAsync(messageId.ToString(), $"{authorId}-{channelId}-{guildId}", expiry: CommandMessageRetention); } - public async Task<(ulong?, ulong?)> GetCommandMessage(ulong messageId) + public async Task GetCommandMessage(ulong messageId) { var str = await _redis.Connection.GetDatabase().StringGetAsync(messageId.ToString()); if (str.HasValue) { var split = ((string)str).Split("-"); - return (ulong.Parse(split[0]), ulong.Parse(split[1])); + return new CommandMessage(ulong.Parse(split[0]), ulong.Parse(split[1]), ulong.Parse(split[2])); } - return (null, null); + return null; } -} \ No newline at end of file +} + +public record CommandMessage(ulong AuthorId, ulong ChannelId, ulong GuildId); \ No newline at end of file diff --git a/PluralKit.Bot/Services/EmbedService.cs b/PluralKit.Bot/Services/EmbedService.cs index fa9d9a90..73874c7e 100644 --- a/PluralKit.Bot/Services/EmbedService.cs +++ b/PluralKit.Bot/Services/EmbedService.cs @@ -336,7 +336,7 @@ public class EmbedService public async Task CreateMessageInfoEmbed(FullMessage msg, bool showContent, SystemConfig? ccfg = null) { - var channel = await _cache.GetOrFetchChannel(_rest, msg.Message.Channel); + var channel = await _cache.GetOrFetchChannel(_rest, msg.Message.Guild ?? 0, msg.Message.Channel); var ctx = LookupContext.ByNonOwner; var serverMsg = await _rest.GetMessageOrNull(msg.Message.Channel, msg.Message.Mid); @@ -403,14 +403,15 @@ public class EmbedService var roles = memberInfo?.Roles?.ToList(); if (roles != null && roles.Count > 0 && showContent) { - var rolesString = string.Join(", ", (await Task.WhenAll(roles - .Select(async id => + var guild = await _cache.GetGuild(channel.GuildId!.Value); + var rolesString = string.Join(", ", (roles + .Select(id => { - var role = await _cache.TryGetRole(id); + var role = Array.Find(guild.Roles, r => r.Id == id); if (role != null) return role; return new Role { Name = "*(unknown role)*", Position = 0 }; - }))) + })) .OrderByDescending(role => role.Position) .Select(role => role.Name)); eb.Field(new Embed.Field($"Account roles ({roles.Count})", rolesString.Truncate(1024))); diff --git a/PluralKit.Bot/Services/LogChannelService.cs b/PluralKit.Bot/Services/LogChannelService.cs index c8b21282..f7149e75 100644 --- a/PluralKit.Bot/Services/LogChannelService.cs +++ b/PluralKit.Bot/Services/LogChannelService.cs @@ -42,7 +42,7 @@ public class LogChannelService if (logChannelId == null) return; - var triggerChannel = await _cache.GetChannel(proxiedMessage.Channel); + var triggerChannel = await _cache.GetChannel(proxiedMessage.Guild!.Value, proxiedMessage.Channel); var member = await _repo.GetMember(proxiedMessage.Member!.Value); var system = await _repo.GetSystem(member.System); @@ -63,7 +63,7 @@ public class LogChannelService return null; var guildId = proxiedMessage.Guild ?? trigger.GuildId.Value; - var rootChannel = await _cache.GetRootChannel(trigger.ChannelId); + var rootChannel = await _cache.GetRootChannel(guildId, trigger.ChannelId); // get log channel info from the database var guild = await _repo.GetGuild(guildId); @@ -109,7 +109,7 @@ public class LogChannelService private async Task FindLogChannel(ulong guildId, ulong channelId) { // TODO: fetch it directly on cache miss? - if (await _cache.TryGetChannel(channelId) is Channel channel) + if (await _cache.TryGetChannel(guildId, channelId) is Channel channel) return channel; if (await _rest.GetChannelOrNull(channelId) is Channel restChannel) diff --git a/PluralKit.Bot/Services/LoggerCleanService.cs b/PluralKit.Bot/Services/LoggerCleanService.cs index 6346b8b2..120bcda8 100644 --- a/PluralKit.Bot/Services/LoggerCleanService.cs +++ b/PluralKit.Bot/Services/LoggerCleanService.cs @@ -100,10 +100,10 @@ public class LoggerCleanService public async ValueTask HandleLoggerBotCleanup(Message msg) { - var channel = await _cache.GetChannel(msg.ChannelId); + var channel = await _cache.GetChannel(msg.GuildId!.Value, msg.ChannelId!); if (channel.Type != Channel.ChannelType.GuildText) return; - if (!(await _cache.BotPermissionsIn(channel.Id)).HasFlag(PermissionSet.ManageMessages)) return; + if (!(await _cache.BotPermissionsIn(msg.GuildId!.Value, channel.Id)).HasFlag(PermissionSet.ManageMessages)) return; // If this message is from a *webhook*, check if the application ID matches one of the bots we know // If it's from a *bot*, check the bot ID to see if we know it. diff --git a/PluralKit.Bot/Services/PeriodicStatCollector.cs b/PluralKit.Bot/Services/PeriodicStatCollector.cs index 92c5ea05..c73294a6 100644 --- a/PluralKit.Bot/Services/PeriodicStatCollector.cs +++ b/PluralKit.Bot/Services/PeriodicStatCollector.cs @@ -54,33 +54,6 @@ public class PeriodicStatCollector var stopwatch = new Stopwatch(); stopwatch.Start(); - // Aggregate guild/channel stats - var guildCount = 0; - var channelCount = 0; - - // No LINQ today, sorry - await foreach (var guild in _cache.GetAllGuilds()) - { - guildCount++; - foreach (var channel in await _cache.GetGuildChannels(guild.Id)) - if (DiscordUtils.IsValidGuildChannel(channel)) - channelCount++; - } - - if (_config.UseRedisMetrics) - { - var db = _redis.Connection.GetDatabase(); - await db.HashSetAsync("pluralkit:cluster_stats", new StackExchange.Redis.HashEntry[] { - new(_botConfig.Cluster.NodeIndex, JsonConvert.SerializeObject(new ClusterMetricInfo - { - GuildCount = guildCount, - ChannelCount = channelCount, - DatabaseConnectionCount = _countHolder.ConnectionCount, - WebhookCacheSize = _webhookCache.CacheSize, - })), - }); - } - // Process info var process = Process.GetCurrentProcess(); _metrics.Measure.Gauge.SetValue(CoreMetrics.ProcessPhysicalMemory, process.WorkingSet64); diff --git a/PluralKit.Bot/Services/WebhookExecutorService.cs b/PluralKit.Bot/Services/WebhookExecutorService.cs index e45ca67c..fc23388c 100644 --- a/PluralKit.Bot/Services/WebhookExecutorService.cs +++ b/PluralKit.Bot/Services/WebhookExecutorService.cs @@ -87,7 +87,7 @@ public class WebhookExecutorService return webhookMessage; } - public async Task EditWebhookMessage(ulong channelId, ulong messageId, string newContent, bool clearEmbeds = false) + public async Task EditWebhookMessage(ulong guildId, ulong channelId, ulong messageId, string newContent, bool clearEmbeds = false) { var allowedMentions = newContent.ParseMentions() with { @@ -96,7 +96,7 @@ public class WebhookExecutorService }; ulong? threadId = null; - var channel = await _cache.GetOrFetchChannel(_rest, channelId); + var channel = await _cache.GetOrFetchChannel(_rest, guildId, channelId); if (channel.IsThread()) { threadId = channelId; diff --git a/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs b/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs index 4fba3f24..76e00a80 100644 --- a/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs +++ b/PluralKit.Bot/Utils/SerilogGatewayEnricherFactory.cs @@ -38,9 +38,11 @@ public class SerilogGatewayEnricherFactory { props.Add(new LogEventProperty("ChannelId", new ScalarValue(channel.Value))); - if (await _cache.TryGetChannel(channel.Value) != null) + var guildIdForCache = guild != null ? guild.Value : 0; + + if (await _cache.TryGetChannel(guildIdForCache, channel.Value) != null) { - var botPermissions = await _cache.BotPermissionsIn(channel.Value); + var botPermissions = await _cache.BotPermissionsIn(guildIdForCache, channel.Value); props.Add(new LogEventProperty("BotPermissions", new ScalarValue(botPermissions))); } } diff --git a/PluralKit.Core/CoreConfig.cs b/PluralKit.Core/CoreConfig.cs index cac76e3d..4adf815b 100644 --- a/PluralKit.Core/CoreConfig.cs +++ b/PluralKit.Core/CoreConfig.cs @@ -8,7 +8,6 @@ public class CoreConfig public string? MessagesDatabase { get; set; } public string? DatabasePassword { get; set; } public string RedisAddr { get; set; } - public bool UseRedisMetrics { get; set; } = false; public string SentryUrl { get; set; } public string InfluxUrl { get; set; } public string InfluxDb { get; set; } diff --git a/lib/libpk/Cargo.toml b/lib/libpk/Cargo.toml index a5ec39c5..4372a68c 100644 --- a/lib/libpk/Cargo.toml +++ b/lib/libpk/Cargo.toml @@ -17,6 +17,7 @@ tokio = { workspace = true } tracing = { workspace = true } tracing-gelf = "0.7.1" tracing-subscriber = { workspace = true} +twilight-model = { workspace = true } prost = { workspace = true } prost-types = { workspace = true } diff --git a/lib/libpk/src/_config.rs b/lib/libpk/src/_config.rs index 5392371f..7c1364c1 100644 --- a/lib/libpk/src/_config.rs +++ b/lib/libpk/src/_config.rs @@ -3,11 +3,23 @@ use lazy_static::lazy_static; use serde::Deserialize; use std::sync::Arc; +use twilight_model::id::{marker::UserMarker, Id}; + +#[derive(Clone, Deserialize, Debug)] +pub struct ClusterSettings { + pub node_id: u32, + pub total_shards: u32, + pub total_nodes: u32, +} + #[derive(Deserialize, Debug)] pub struct DiscordConfig { - pub client_id: u32, + pub client_id: Id, pub bot_token: String, pub client_secret: String, + pub max_concurrency: u32, + pub cluster: Option, + pub api_base_url: Option, } #[derive(Deserialize, Debug)] @@ -41,6 +53,9 @@ pub struct ApiConfig { fn _metrics_default() -> bool { false } +fn _json_log_default() -> bool { + false +} #[derive(Deserialize, Debug)] pub struct PKConfig { @@ -52,13 +67,20 @@ pub struct PKConfig { #[serde(default = "_metrics_default")] pub run_metrics_server: bool, - pub(crate) gelf_log_url: Option, + #[serde(default = "_json_log_default")] + pub(crate) json_log: bool, } lazy_static! { #[derive(Debug)] - pub static ref CONFIG: Arc = Arc::new(Config::builder() - .add_source(config::Environment::with_prefix("pluralkit").separator("__")) - .build().unwrap() - .try_deserialize::().unwrap()); + pub static ref CONFIG: Arc = { + if let Ok(var) = std::env::var("NOMAD_ALLOC_INDEX") { + std::env::set_var("pluralkit__discord__cluster__node_id", var); + } + + Arc::new(Config::builder() + .add_source(config::Environment::with_prefix("pluralkit").separator("__")) + .build().unwrap() + .try_deserialize::().unwrap()) + }; } diff --git a/lib/libpk/src/lib.rs b/lib/libpk/src/lib.rs index daf01959..9b945db4 100644 --- a/lib/libpk/src/lib.rs +++ b/lib/libpk/src/lib.rs @@ -1,30 +1,24 @@ -use gethostname::gethostname; use metrics_exporter_prometheus::PrometheusBuilder; -use tracing_subscriber::{prelude::__tracing_subscriber_SubscriberExt, EnvFilter, Registry}; +use tracing_subscriber::{EnvFilter, Registry}; pub mod db; pub mod proto; +pub mod util; pub mod _config; pub use crate::_config::CONFIG as config; pub fn init_logging(component: &str) -> anyhow::Result<()> { - let subscriber = Registry::default() - .with(EnvFilter::from_default_env()) - .with(tracing_subscriber::fmt::layer()); - - if let Some(gelf_url) = &config.gelf_log_url { - let gelf_logger = tracing_gelf::Logger::builder() - .additional_field("component", component) - .additional_field("hostname", gethostname().to_str()); - let mut conn_handle = gelf_logger - .init_udp_with_subscriber(gelf_url, subscriber) - .unwrap(); - tokio::spawn(async move { conn_handle.connect().await }); + // todo: fix component + if config.json_log { + tracing_subscriber::fmt() + .json() + .with_env_filter(EnvFilter::from_default_env()) + .init(); } else { - // gelf_logger internally sets the global subscriber - tracing::subscriber::set_global_default(subscriber) - .expect("unable to set global subscriber"); + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); } Ok(()) diff --git a/lib/libpk/src/util/mod.rs b/lib/libpk/src/util/mod.rs new file mode 100644 index 00000000..027fbef5 --- /dev/null +++ b/lib/libpk/src/util/mod.rs @@ -0,0 +1 @@ +pub mod redis; diff --git a/lib/libpk/src/util/redis.rs b/lib/libpk/src/util/redis.rs new file mode 100644 index 00000000..25a5acdb --- /dev/null +++ b/lib/libpk/src/util/redis.rs @@ -0,0 +1,15 @@ +use fred::error::RedisError; + +pub trait RedisErrorExt { + fn to_option_or_error(self) -> Result, RedisError>; +} + +impl RedisErrorExt for Result { + fn to_option_or_error(self) -> Result, RedisError> { + match self { + Ok(v) => Ok(Some(v)), + Err(error) if error.is_not_found() => Ok(None), + Err(error) => Err(error), + } + } +} diff --git a/services/api/src/main.rs b/services/api/src/main.rs index f0452e54..a8b5a9ff 100644 --- a/services/api/src/main.rs +++ b/services/api/src/main.rs @@ -147,6 +147,7 @@ async fn main() -> anyhow::Result<()> { let addr: &str = libpk::config.api.addr.as_ref(); let listener = tokio::net::TcpListener::bind(addr).await?; + info!("listening on {}", addr); axum::serve(listener, app).await?; Ok(()) diff --git a/services/gateway/Cargo.toml b/services/gateway/Cargo.toml new file mode 100644 index 00000000..ac6e457b --- /dev/null +++ b/services/gateway/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "gateway" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = { workspace = true } +axum = { workspace = true } +bytes = { workspace = true } +chrono = { workspace = true } +fred = { workspace = true } +futures = { workspace = true } +lazy_static = { workspace = true } +libpk = { path = "../../lib/libpk" } +prost = { workspace = true } +serde_json = { workspace = true } +signal-hook = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } + +twilight-gateway = { workspace = true } +twilight-cache-inmemory = { workspace = true } +twilight-util = { workspace = true } +twilight-model = { workspace = true } +twilight-http = { workspace = true } diff --git a/services/gateway/src/cache_api.rs b/services/gateway/src/cache_api.rs new file mode 100644 index 00000000..15066053 --- /dev/null +++ b/services/gateway/src/cache_api.rs @@ -0,0 +1,168 @@ +use axum::{ + extract::{Path, State}, + http::StatusCode, + response::{IntoResponse, Response}, + routing::get, + Router, +}; +use serde_json::to_string; +use tracing::{error, info}; +use twilight_model::guild::Permissions; +use twilight_model::id::Id; + +use crate::discord::cache::{dm_channel, DiscordCache, DM_PERMISSIONS}; +use std::sync::Arc; + +fn status_code(code: StatusCode, body: String) -> Response { + (code, body).into_response() +} + +// this function is manually formatted for easier legibility of route_services +#[rustfmt::skip] +pub async fn run_server(cache: Arc) -> anyhow::Result<()> { + let app = Router::new() + .route( + "/guilds/:guild_id", + get(|State(cache): State>, Path(guild_id): Path| async move { + match cache.guild(Id::new(guild_id)) { + Some(guild) => status_code(StatusCode::FOUND, to_string(&guild).unwrap()), + None => status_code(StatusCode::NOT_FOUND, "".to_string()), + } + }), + ) + .route( + "/guilds/:guild_id/members/@me", + get(|State(cache): State>, Path(guild_id): Path| async move { + match cache.0.member(Id::new(guild_id), libpk::config.discord.client_id) { + Some(member) => status_code(StatusCode::FOUND, to_string(member.value()).unwrap()), + None => status_code(StatusCode::NOT_FOUND, "".to_string()), + } + }), + ) + .route( + "/guilds/:guild_id/permissions/@me", + get(|State(cache): State>, Path(guild_id): Path| async move { + match cache.guild_permissions(Id::new(guild_id), libpk::config.discord.client_id).await { + Ok(val) => { + println!("hh {}", Permissions::all().bits()); + status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()) + }, + Err(err) => { + error!(?err, ?guild_id, "failed to get own guild member permissions"); + status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string()) + }, + } + }), + ) + .route( + "/guilds/:guild_id/permissions/:user_id", + get(|State(cache): State>, Path((guild_id, user_id)): Path<(u64, u64)>| async move { + match cache.guild_permissions(Id::new(guild_id), Id::new(user_id)).await { + Ok(val) => status_code(StatusCode::FOUND, to_string(&val.bits()).unwrap()), + Err(err) => { + error!(?err, ?guild_id, ?user_id, "failed to get guild member permissions"); + status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string()) + }, + } + }), + ) + + .route( + "/guilds/:guild_id/channels", + get(|State(cache): State>, Path(guild_id): Path| async move { + let channel_ids = match cache.0.guild_channels(Id::new(guild_id)) { + Some(channels) => channels.to_owned(), + None => return status_code(StatusCode::NOT_FOUND, "".to_string()), + }; + + let mut channels = Vec::new(); + for id in channel_ids { + match cache.0.channel(id) { + Some(channel) => channels.push(channel.to_owned()), + None => { + tracing::error!( + channel_id = id.get(), + "referenced channel {} from guild {} not found in cache", + id.get(), guild_id, + ); + return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string()); + } + } + } + + status_code(StatusCode::FOUND, to_string(&channels).unwrap()) + }) + ) + .route( + "/guilds/:guild_id/channels/:channel_id", + get(|State(cache): State>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move { + if guild_id == 0 { + return status_code(StatusCode::FOUND, to_string(&dm_channel(Id::new(channel_id))).unwrap()); + } + match cache.0.channel(Id::new(channel_id)) { + Some(channel) => status_code(StatusCode::FOUND, to_string(channel.value()).unwrap()), + None => status_code(StatusCode::NOT_FOUND, "".to_string()) + } + }) + ) + .route( + "/guilds/:guild_id/channels/:channel_id/permissions/@me", + get(|State(cache): State>, Path((guild_id, channel_id)): Path<(u64, u64)>| async move { + if guild_id == 0 { + return status_code(StatusCode::FOUND, to_string(&*DM_PERMISSIONS).unwrap()); + } + match cache.channel_permissions(Id::new(channel_id), libpk::config.discord.client_id).await { + Ok(val) => status_code(StatusCode::FOUND, to_string(&val).unwrap()), + Err(err) => { + error!(?err, ?channel_id, ?guild_id, "failed to get own channelpermissions"); + status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string()) + }, + } + }), + ) + .route( + "/guilds/:guild_id/channels/:channel_id/permissions/:user_id", + get(|| async { "todo" }), + ) + .route( + "/guilds/:guild_id/channels/:channel_id/last_message", + get(|| async { status_code(StatusCode::NOT_IMPLEMENTED, "".to_string()) }), + ) + + .route( + "/guilds/:guild_id/roles", + get(|State(cache): State>, Path(guild_id): Path| async move { + let role_ids = match cache.0.guild_roles(Id::new(guild_id)) { + Some(roles) => roles.to_owned(), + None => return status_code(StatusCode::NOT_FOUND, "".to_string()), + }; + + let mut roles = Vec::new(); + for id in role_ids { + match cache.0.role(id) { + Some(role) => roles.push(role.value().resource().to_owned()), + None => { + tracing::error!( + role_id = id.get(), + "referenced role {} from guild {} not found in cache", + id.get(), guild_id, + ); + return status_code(StatusCode::INTERNAL_SERVER_ERROR, "".to_string()); + } + } + } + + status_code(StatusCode::FOUND, to_string(&roles).unwrap()) + }) + ) + + .layer(axum::middleware::from_fn(crate::logger::logger)) + .with_state(cache); + + let addr: &str = libpk::config.api.addr.as_ref(); + let listener = tokio::net::TcpListener::bind(addr).await?; + info!("listening on {}", addr); + axum::serve(listener, app).await?; + + Ok(()) +} diff --git a/services/gateway/src/discord/cache.rs b/services/gateway/src/discord/cache.rs new file mode 100644 index 00000000..5ada74b8 --- /dev/null +++ b/services/gateway/src/discord/cache.rs @@ -0,0 +1,339 @@ +use anyhow::format_err; +use lazy_static::lazy_static; +use std::sync::Arc; +use twilight_cache_inmemory::{ + model::CachedMember, + permission::{MemberRoles, RootError}, + traits::CacheableChannel, + InMemoryCache, ResourceType, +}; +use twilight_model::{ + channel::{Channel, ChannelType}, + guild::{Guild, Member, Permissions}, + id::{ + marker::{ChannelMarker, GuildMarker, UserMarker}, + Id, + }, +}; +use twilight_util::permission_calculator::PermissionCalculator; + +lazy_static! { + pub static ref DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL + | Permissions::SEND_MESSAGES + | Permissions::READ_MESSAGE_HISTORY + | Permissions::ADD_REACTIONS + | Permissions::ATTACH_FILES + | Permissions::EMBED_LINKS + | Permissions::USE_EXTERNAL_EMOJIS + | Permissions::CONNECT + | Permissions::SPEAK + | Permissions::USE_VAD; +} + +pub fn dm_channel(id: Id) -> Channel { + Channel { + id, + kind: ChannelType::Private, + + application_id: None, + applied_tags: None, + available_tags: None, + bitrate: None, + default_auto_archive_duration: None, + default_forum_layout: None, + default_reaction_emoji: None, + default_sort_order: None, + default_thread_rate_limit_per_user: None, + flags: None, + guild_id: None, + icon: None, + invitable: None, + last_message_id: None, + last_pin_timestamp: None, + managed: None, + member: None, + member_count: None, + message_count: None, + name: None, + newly_created: None, + nsfw: None, + owner_id: None, + parent_id: None, + permission_overwrites: None, + position: None, + rate_limit_per_user: None, + recipients: None, + rtc_region: None, + thread_metadata: None, + topic: None, + user_limit: None, + video_quality_mode: None, + } +} + +fn member_to_cached_member(item: Member, id: Id) -> CachedMember { + CachedMember { + avatar: item.avatar, + communication_disabled_until: item.communication_disabled_until, + deaf: Some(item.deaf), + flags: item.flags, + joined_at: item.joined_at, + mute: Some(item.mute), + nick: item.nick, + premium_since: item.premium_since, + roles: item.roles, + pending: false, + user_id: id, + } +} + +pub fn new() -> DiscordCache { + let mut client_builder = + twilight_http::Client::builder().token(libpk::config.discord.bot_token.clone()); + + if let Some(base_url) = libpk::config.discord.api_base_url.clone() { + client_builder = client_builder.proxy(base_url, true); + } + + let client = Arc::new(client_builder.build()); + + let cache = Arc::new( + InMemoryCache::builder() + .resource_types( + ResourceType::GUILD + | ResourceType::CHANNEL + | ResourceType::ROLE + | ResourceType::USER_CURRENT + | ResourceType::MEMBER_CURRENT, + ) + .message_cache_size(0) + .build(), + ); + + DiscordCache(cache, client) +} + +pub struct DiscordCache(pub Arc, pub Arc); + +impl DiscordCache { + pub async fn guild_permissions( + &self, + guild_id: Id, + user_id: Id, + ) -> anyhow::Result { + if self + .0 + .guild(guild_id) + .ok_or_else(|| format_err!("guild not found"))? + .owner_id() + == user_id + { + return Ok(Permissions::all()); + } + + let member = if user_id == libpk::config.discord.client_id { + self.0 + .member(guild_id, user_id) + .ok_or(format_err!("self member not found"))? + .value() + .to_owned() + } else { + member_to_cached_member( + self.1 + .guild_member(guild_id, user_id) + .await? + .model() + .await?, + user_id, + ) + }; + + let MemberRoles { assigned, everyone } = self + .0 + .permissions() + .member_roles(guild_id, &member) + .map_err(RootError::from_member_roles)?; + let calculator = + PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice()); + + let permissions = calculator.root(); + + Ok(self + .0 + .permissions() + .disable_member_communication(&member, permissions)) + } + + pub async fn channel_permissions( + &self, + channel_id: Id, + user_id: Id, + ) -> anyhow::Result { + let channel = self + .0 + .channel(channel_id) + .ok_or(format_err!("channel not found"))?; + + if channel.value().guild_id.is_none() { + return Ok(*DM_PERMISSIONS); + } + + let guild_id = channel.value().guild_id.unwrap(); + + if self + .0 + .guild(guild_id) + .ok_or_else(|| { + tracing::error!( + channel_id = channel_id.get(), + guild_id = guild_id.get(), + "referenced guild from cached channel {channel_id} not found in cache" + ); + format_err!("internal cache error") + })? + .owner_id() + == user_id + { + return Ok(Permissions::all()); + } + + let member = if user_id == libpk::config.discord.client_id { + self.0 + .member(guild_id, user_id) + .ok_or_else(|| { + tracing::error!( + guild_id = guild_id.get(), + "self member for cached guild {guild_id} not found in cache" + ); + format_err!("internal cache error") + })? + .value() + .to_owned() + } else { + member_to_cached_member( + self.1 + .guild_member(guild_id, user_id) + .await? + .model() + .await?, + user_id, + ) + }; + + let MemberRoles { assigned, everyone } = self + .0 + .permissions() + .member_roles(guild_id, &member) + .map_err(RootError::from_member_roles)?; + + let overwrites = match channel.kind { + ChannelType::AnnouncementThread + | ChannelType::PrivateThread + | ChannelType::PublicThread => self.0.permissions().parent_overwrites(&channel)?, + _ => channel + .value() + .permission_overwrites() + .unwrap_or_default() + .to_vec(), + }; + + let calculator = + PermissionCalculator::new(guild_id, user_id, everyone, assigned.as_slice()); + + let permissions = calculator.in_channel(channel.kind(), overwrites.as_slice()); + + Ok(self + .0 + .permissions() + .disable_member_communication(&member, permissions)) + } + + // from https://github.com/Gelbpunkt/gateway-proxy/blob/5bcb080a1fcb09f6fafecad7736819663a625d84/src/cache.rs + pub fn guild(&self, id: Id) -> Option { + self.0.guild(id).map(|guild| { + let channels = self + .0 + .guild_channels(id) + .map(|reference| { + reference + .iter() + .filter_map(|channel_id| { + let channel = self.0.channel(*channel_id)?; + + if channel.kind.is_thread() { + None + } else { + Some(channel.value().clone()) + } + }) + .collect() + }) + .unwrap_or_default(); + + let roles = self + .0 + .guild_roles(id) + .map(|reference| { + reference + .iter() + .filter_map(|role_id| { + Some(self.0.role(*role_id)?.value().resource().clone()) + }) + .collect() + }) + .unwrap_or_default(); + + Guild { + afk_channel_id: guild.afk_channel_id(), + afk_timeout: guild.afk_timeout(), + application_id: guild.application_id(), + approximate_member_count: None, // Only present in with_counts HTTP endpoint + banner: guild.banner().map(ToOwned::to_owned), + approximate_presence_count: None, // Only present in with_counts HTTP endpoint + channels, + default_message_notifications: guild.default_message_notifications(), + description: guild.description().map(ToString::to_string), + discovery_splash: guild.discovery_splash().map(ToOwned::to_owned), + emojis: vec![], + explicit_content_filter: guild.explicit_content_filter(), + features: guild.features().cloned().collect(), + icon: guild.icon().map(ToOwned::to_owned), + id: guild.id(), + joined_at: guild.joined_at(), + large: guild.large(), + max_members: guild.max_members(), + max_presences: guild.max_presences(), + max_video_channel_users: guild.max_video_channel_users(), + member_count: guild.member_count(), + members: vec![], + mfa_level: guild.mfa_level(), + name: guild.name().to_string(), + nsfw_level: guild.nsfw_level(), + owner_id: guild.owner_id(), + owner: guild.owner(), + permissions: guild.permissions(), + public_updates_channel_id: guild.public_updates_channel_id(), + preferred_locale: guild.preferred_locale().to_string(), + premium_progress_bar_enabled: guild.premium_progress_bar_enabled(), + premium_subscription_count: guild.premium_subscription_count(), + premium_tier: guild.premium_tier(), + presences: vec![], + roles, + rules_channel_id: guild.rules_channel_id(), + safety_alerts_channel_id: guild.safety_alerts_channel_id(), + splash: guild.splash().map(ToOwned::to_owned), + stage_instances: vec![], + stickers: vec![], + system_channel_flags: guild.system_channel_flags(), + system_channel_id: guild.system_channel_id(), + threads: vec![], + unavailable: false, + vanity_url_code: guild.vanity_url_code().map(ToString::to_string), + verification_level: guild.verification_level(), + voice_states: vec![], + widget_channel_id: guild.widget_channel_id(), + widget_enabled: guild.widget_enabled(), + } + }) + } +} diff --git a/services/gateway/src/discord/gateway.rs b/services/gateway/src/discord/gateway.rs new file mode 100644 index 00000000..e8bfb26f --- /dev/null +++ b/services/gateway/src/discord/gateway.rs @@ -0,0 +1,121 @@ +use std::sync::{mpsc::Sender, Arc}; +use tracing::{info, warn}; +use twilight_gateway::{ + create_iterator, ConfigBuilder, Event, EventTypeFlags, Shard, ShardId, StreamExt, +}; +use twilight_model::gateway::{ + payload::outgoing::update_presence::UpdatePresencePayload, + presence::{Activity, ActivityType, Status}, + Intents, +}; + +use crate::discord::identify_queue::{self, RedisQueue}; + +use super::{cache::DiscordCache, shard_state::ShardStateManager}; + +pub fn create_shards(redis: fred::pool::RedisPool) -> anyhow::Result>> { + let intents = Intents::GUILDS + | Intents::DIRECT_MESSAGES + | Intents::DIRECT_MESSAGE_REACTIONS + | Intents::GUILD_MESSAGES + | Intents::GUILD_MESSAGE_REACTIONS + | Intents::MESSAGE_CONTENT; + + let queue = identify_queue::new(redis); + + let cluster_settings = + libpk::config + .discord + .cluster + .clone() + .unwrap_or(libpk::_config::ClusterSettings { + node_id: 0, + total_shards: 1, + total_nodes: 1, + }); + + let (start_shard, end_shard): (u32, u32) = if cluster_settings.total_shards < 16 { + warn!("we have less than 16 shards, assuming single gateway process"); + (0, (cluster_settings.total_shards - 1).into()) + } else { + ( + (cluster_settings.node_id * 16).into(), + (((cluster_settings.node_id + 1) * 16) - 1).into(), + ) + }; + + let shards = create_iterator( + start_shard..end_shard + 1, + cluster_settings.total_shards, + ConfigBuilder::new(libpk::config.discord.bot_token.to_owned(), intents) + .presence(presence("pk;help", false)) + .queue(queue.clone()) + .build(), + |_, builder| builder.build(), + ); + + let mut shards_vec = Vec::new(); + shards_vec.extend(shards); + + Ok(shards_vec) +} + +pub async fn runner( + mut shard: Shard, + tx: Sender<(ShardId, Event)>, + shard_state: ShardStateManager, + cache: Arc, +) { + //let _span = info_span!("shard_runner", shard_id = shard.id().number()).entered(); + info!("waiting for events"); + while let Some(item) = shard.next_event(EventTypeFlags::all()).await { + match item { + Ok(event) => { + if let Err(error) = shard_state + .handle_event(shard.id().number(), event.clone()) + .await + { + tracing::warn!(?error, "error updating redis state") + } + cache.0.update(&event); + //if let Err(error) = tx.send((shard.id(), event)) { + // tracing::warn!(?error, "error sending event to global handler: {error}",); + //} + } + Err(error) => { + tracing::warn!(?error, "error receiving event from shard {}", shard.id()); + continue; + } + } + } +} + +pub fn presence(status: &str, going_away: bool) -> UpdatePresencePayload { + UpdatePresencePayload { + activities: vec![Activity { + application_id: None, + assets: None, + buttons: vec![], + created_at: None, + details: None, + id: None, + state: None, + url: None, + emoji: None, + flags: None, + instance: None, + kind: ActivityType::Playing, + name: status.to_string(), + party: None, + secrets: None, + timestamps: None, + }], + afk: false, + since: None, + status: if going_away { + Status::Idle + } else { + Status::Online + }, + } +} diff --git a/services/gateway/src/discord/identify_queue.rs b/services/gateway/src/discord/identify_queue.rs new file mode 100644 index 00000000..a98ad967 --- /dev/null +++ b/services/gateway/src/discord/identify_queue.rs @@ -0,0 +1,87 @@ +use fred::{ + error::RedisError, + interfaces::KeysInterface, + pool::RedisPool, + types::{Expiration, SetOptions}, +}; +use std::fmt::Debug; +use std::time::Duration; +use tokio::sync::oneshot; +use tracing::{error, info}; +use twilight_gateway::queue::Queue; + +use libpk::util::redis::RedisErrorExt; + +pub fn new(redis: RedisPool) -> RedisQueue { + RedisQueue { + redis, + concurrency: libpk::config.discord.max_concurrency, + } +} + +#[derive(Clone)] +pub struct RedisQueue { + pub redis: RedisPool, + pub concurrency: u32, +} + +impl Debug for RedisQueue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("RedisQueue") + .field("concurrency", &self.concurrency) + .finish() + } +} + +impl Queue for RedisQueue { + fn enqueue<'a>(&'a self, shard_id: u32) -> oneshot::Receiver<()> { + let (tx, rx) = oneshot::channel(); + + tokio::spawn(request_inner( + self.redis.clone(), + self.concurrency, + shard_id, + tx, + )); + + rx + } +} + +const EXPIRY: i64 = 6; +const RETRY_INTERVAL: u64 = 500; + +async fn request_inner(redis: RedisPool, concurrency: u32, shard_id: u32, tx: oneshot::Sender<()>) { + let bucket = shard_id % concurrency; + let key = format!("pluralkit:identify:{}", bucket); + + info!(shard_id, bucket, "waiting for allowance..."); + loop { + let done: Result, RedisError> = redis + .set( + key.to_string(), + "1", + Some(Expiration::EX(EXPIRY)), + Some(SetOptions::NX), + false, + ) + .await + .to_option_or_error(); + match done { + Ok(Some(_)) => { + info!(shard_id, bucket, "got allowance!"); + // if this fails, it's probably already doing something else + let _ = tx.send(()); + return; + } + Ok(None) => { + // not allowed yet, waiting + } + Err(e) => { + error!(shard_id, bucket, "error getting shard allowance: {}", e) + } + } + + tokio::time::sleep(Duration::from_millis(RETRY_INTERVAL)).await; + } +} diff --git a/services/gateway/src/discord/mod.rs b/services/gateway/src/discord/mod.rs new file mode 100644 index 00000000..2c4da3be --- /dev/null +++ b/services/gateway/src/discord/mod.rs @@ -0,0 +1,4 @@ +pub mod cache; +pub mod gateway; +pub mod identify_queue; +pub mod shard_state; diff --git a/services/gateway/src/discord/shard_state.rs b/services/gateway/src/discord/shard_state.rs new file mode 100644 index 00000000..8c2369ca --- /dev/null +++ b/services/gateway/src/discord/shard_state.rs @@ -0,0 +1,84 @@ +use bytes::Bytes; +use fred::{interfaces::HashesInterface, pool::RedisPool}; +use prost::Message; +use tracing::info; +use twilight_gateway::Event; + +use libpk::{proto::*, util::redis::*}; + +#[derive(Clone)] +pub struct ShardStateManager { + redis: RedisPool, +} + +pub fn new(redis: RedisPool) -> ShardStateManager { + ShardStateManager { redis } +} + +impl ShardStateManager { + pub async fn handle_event(&self, shard_id: u32, event: Event) -> anyhow::Result<()> { + match event { + Event::Ready(_) => self.ready_or_resumed(shard_id).await, + Event::Resumed => self.ready_or_resumed(shard_id).await, + Event::GatewayClose(_) => self.socket_closed(shard_id).await, + Event::GatewayHeartbeat(_) => self.heartbeated(shard_id).await, + _ => Ok(()), + } + } + + async fn get_shard(&self, shard_id: u32) -> anyhow::Result { + let data: Option> = self + .redis + .hget("pluralkit:shardstatus", shard_id) + .await + .to_option_or_error()?; + match data { + Some(buf) => { + Ok(ShardState::decode(buf.as_slice()).expect("could not decode shard data!")) + } + None => Ok(ShardState::default()), + } + } + + async fn save_shard(&self, shard_id: u32, info: ShardState) -> anyhow::Result<()> { + self.redis + .hset( + "pluralkit:shardstatus", + ( + shard_id.to_string(), + Bytes::copy_from_slice(&info.encode_to_vec()), + ), + ) + .await?; + Ok(()) + } + + async fn ready_or_resumed(&self, shard_id: u32) -> anyhow::Result<()> { + info!("shard {} ready", shard_id); + let mut info = self.get_shard(shard_id).await?; + info.last_connection = chrono::offset::Utc::now().timestamp() as i32; + info.up = true; + self.save_shard(shard_id, info).await?; + Ok(()) + } + + async fn socket_closed(&self, shard_id: u32) -> anyhow::Result<()> { + info!("shard {} closed", shard_id); + let mut info = self.get_shard(shard_id).await?; + info.up = false; + info.disconnection_count += 1; + self.save_shard(shard_id, info).await?; + Ok(()) + } + + async fn heartbeated(&self, shard_id: u32) -> anyhow::Result<()> { + let mut info = self.get_shard(shard_id).await?; + info.up = true; + info.last_heartbeat = chrono::offset::Utc::now().timestamp() as i32; + // todo + // info.latency = latency.recent().front().map_or_else(|| 0, |d| d.as_millis()) as i32; + info.latency = 1; + self.save_shard(shard_id, info).await?; + Ok(()) + } +} diff --git a/services/gateway/src/logger.rs b/services/gateway/src/logger.rs new file mode 100644 index 00000000..aa65bc67 --- /dev/null +++ b/services/gateway/src/logger.rs @@ -0,0 +1,52 @@ +use std::time::Instant; + +use axum::{extract::MatchedPath, extract::Request, middleware::Next, response::Response}; +use tracing::{info, span, warn, Instrument, Level}; + +// log any requests that take longer than 2 seconds +// todo: change as necessary +const MIN_LOG_TIME: u128 = 2_000; + +pub async fn logger(request: Request, next: Next) -> Response { + let method = request.method().clone(); + + let endpoint = request + .extensions() + .get::() + .cloned() + .map(|v| v.as_str().to_string()) + .unwrap_or("unknown".to_string()); + + let uri = request.uri().clone(); + + let request_id_span = span!( + Level::INFO, + "request", + method = method.as_str(), + endpoint = endpoint.clone(), + ); + + let start = Instant::now(); + let response = next.run(request).instrument(request_id_span).await; + let elapsed = start.elapsed().as_millis(); + + info!( + "{} handled request for {} {} in {}ms", + response.status(), + method, + uri.path(), + elapsed + ); + + if elapsed > MIN_LOG_TIME { + warn!( + "request to {} full path {} (endpoint {}) took a long time ({}ms)!", + method, + uri.path(), + endpoint, + elapsed + ) + } + + response +} diff --git a/services/gateway/src/main.rs b/services/gateway/src/main.rs new file mode 100644 index 00000000..7ac18fa3 --- /dev/null +++ b/services/gateway/src/main.rs @@ -0,0 +1,141 @@ +use chrono::Timelike; +use fred::{interfaces::*, pool::RedisPool}; +use signal_hook::{ + consts::{SIGINT, SIGTERM}, + iterator::Signals, +}; +use std::{ + sync::{mpsc::channel, Arc}, + time::Duration, + vec::Vec, +}; +use tokio::task::JoinSet; +use tracing::{info, warn}; +use twilight_gateway::{MessageSender, ShardId}; +use twilight_model::gateway::payload::outgoing::UpdatePresence; + +mod cache_api; +mod discord; +mod logger; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + libpk::init_logging("gateway")?; + libpk::init_metrics()?; + info!("hello world"); + + let (shutdown_tx, shutdown_rx) = channel::<()>(); + let shutdown_tx = Arc::new(shutdown_tx); + + let redis = libpk::db::init_redis().await?; + + let shard_state = discord::shard_state::new(redis.clone()); + let cache = Arc::new(discord::cache::new()); + + let shards = discord::gateway::create_shards(redis.clone())?; + + let (event_tx, _event_rx) = channel(); + + let mut senders = Vec::new(); + let mut signal_senders = Vec::new(); + + let mut set = JoinSet::new(); + for shard in shards { + senders.push((shard.id(), shard.sender())); + signal_senders.push(shard.sender()); + set.spawn(tokio::spawn(discord::gateway::runner( + shard, + event_tx.clone(), + shard_state.clone(), + cache.clone(), + ))); + } + + set.spawn(tokio::spawn( + async move { scheduled_task(redis, senders).await }, + )); + + // todo: probably don't do it this way + let api_shutdown_tx = shutdown_tx.clone(); + set.spawn(tokio::spawn(async move { + match cache_api::run_server(cache).await { + Err(error) => { + tracing::error!(?error, "failed to serve cache api"); + let _ = api_shutdown_tx.send(()); + } + _ => unreachable!(), + } + })); + + let mut signals = Signals::new(&[SIGINT, SIGTERM])?; + + tokio::spawn(async move { + for sig in signals.forever() { + info!("received signal {:?}", sig); + + let presence = UpdatePresence { + op: twilight_model::gateway::OpCode::PresenceUpdate, + d: discord::gateway::presence("Restarting... (please wait)", true), + }; + + for sender in signal_senders.iter() { + let presence = presence.clone(); + let _ = sender.command(&presence); + } + + let _ = shutdown_tx.send(()); + break; + } + }); + + let _ = shutdown_rx.recv(); + + // sleep 500ms to allow everything to clean up properly + tokio::time::sleep(Duration::from_millis(500)).await; + + set.abort_all(); + + info!("gateway exiting, have a nice day!"); + + Ok(()) +} + +async fn scheduled_task(redis: RedisPool, senders: Vec<(ShardId, MessageSender)>) { + loop { + tokio::time::sleep(Duration::from_secs( + (60 - chrono::offset::Utc::now().second()).into(), + )) + .await; + info!("running per-minute scheduled tasks"); + + let status: Option = match redis.get("pluralkit:botstatus").await { + Ok(val) => Some(val), + Err(error) => { + tracing::warn!(?error, "failed to fetch bot status from redis"); + None + } + }; + + let presence = UpdatePresence { + op: twilight_model::gateway::OpCode::PresenceUpdate, + d: discord::gateway::presence( + if let Some(status) = status { + format!("pk;help | {}", status) + } else { + "pk;help".to_string() + } + .as_str(), + false, + ), + }; + + for sender in senders.iter() { + match sender.1.command(&presence) { + Err(error) => { + warn!(?error, "could not update presence on shard {}", sender.0) + } + _ => {} + }; + } + } +}