mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-10 07:47:53 +00:00
rustproxy: initial commit
This commit is contained in:
parent
b47694edc1
commit
9d90de45a6
23 changed files with 4686 additions and 0 deletions
35
rustproxy/pk_bot/Cargo.toml
Normal file
35
rustproxy/pk_bot/Cargo.toml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
[package]
|
||||
name = "pk_bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.57"
|
||||
async-trait = "0.1.56"
|
||||
chrono = "0.4.19"
|
||||
config = { version = "0.13.1", default-features = false, features = ["toml"] }
|
||||
dashmap = "5.3.4"
|
||||
futures = "0.3.21"
|
||||
moka = { version = "0.8.5", features = ["future"] }
|
||||
once_cell = "1.12.0"
|
||||
redis = { version = "0.21.5", features = ["connection-manager", "tokio-comp"] }
|
||||
regex = "1.5.6"
|
||||
serde = "1.0.137"
|
||||
serde_json = "1.0.81"
|
||||
smallvec = "1.8.0"
|
||||
smol_str = "0.1.23"
|
||||
sqlx = { version = "0.6.0", features = ["runtime-tokio-native-tls", "postgres", "chrono", "macros"] }
|
||||
thiserror = "1.0.31"
|
||||
tokio = { version = "1.19.1", features = ["full"] }
|
||||
tracing = "0.1.34"
|
||||
tracing-subscriber = "0.3.11"
|
||||
twilight-gateway = "0.11.0"
|
||||
twilight-gateway-queue = "0.11.0"
|
||||
twilight-http = "0.11.0"
|
||||
twilight-model = "0.11.0"
|
||||
twilight-util = { version = "0.11.0", features = ["builder", "permission-calculator"] }
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
tikv-jemallocator = "0.5"
|
||||
265
rustproxy/pk_bot/src/cache.rs
Normal file
265
rustproxy/pk_bot/src/cache.rs
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use dashmap::mapref::one::Ref;
|
||||
use dashmap::DashMap;
|
||||
use smol_str::SmolStr;
|
||||
use twilight_model::channel::permission_overwrite::PermissionOverwrite;
|
||||
use twilight_model::channel::{Channel, ChannelType};
|
||||
use twilight_model::gateway::event::Event;
|
||||
use twilight_model::gateway::payload::incoming::ThreadListSync;
|
||||
use twilight_model::guild::{Guild, PartialMember, Permissions, PremiumTier, Role};
|
||||
use twilight_model::id::marker::{ChannelMarker, GuildMarker, RoleMarker, UserMarker};
|
||||
use twilight_model::id::Id;
|
||||
use twilight_util::permission_calculator::PermissionCalculator;
|
||||
|
||||
const DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
|
||||
.union(Permissions::SEND_MESSAGES)
|
||||
.union(Permissions::READ_MESSAGE_HISTORY)
|
||||
.union(Permissions::ADD_REACTIONS)
|
||||
.union(Permissions::ATTACH_FILES)
|
||||
.union(Permissions::EMBED_LINKS)
|
||||
.union(Permissions::USE_EXTERNAL_EMOJIS);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedGuild {
|
||||
owner_id: u64,
|
||||
_premium_tier: PremiumTier,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedChannel {
|
||||
_name: Option<SmolStr>, // stores strings 22 characters or less inline, which is a large portion of channel names
|
||||
_parent_id: Option<Id<ChannelMarker>>,
|
||||
guild_id: Option<Id<GuildMarker>>,
|
||||
kind: ChannelType,
|
||||
overwrites: Vec<PermissionOverwrite>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedRole {
|
||||
permissions: Permissions,
|
||||
_mentionable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedBotMember {
|
||||
roles: Vec<Id<RoleMarker>>,
|
||||
}
|
||||
|
||||
pub struct DiscordCache {
|
||||
bot_user: Arc<RwLock<Option<Id<UserMarker>>>>,
|
||||
guilds: DashMap<u64, CachedGuild>,
|
||||
channels: DashMap<u64, CachedChannel>,
|
||||
roles: DashMap<u64, CachedRole>,
|
||||
bot_members: DashMap<u64, CachedBotMember>,
|
||||
}
|
||||
|
||||
impl DiscordCache {
|
||||
pub fn new() -> DiscordCache {
|
||||
DiscordCache {
|
||||
bot_user: Arc::new(RwLock::new(None)),
|
||||
guilds: DashMap::new(),
|
||||
channels: DashMap::new(),
|
||||
roles: DashMap::new(),
|
||||
bot_members: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_event(&self, event: &Event) {
|
||||
match event {
|
||||
Event::Ready(ref r) => {
|
||||
let mut bot_user = self.bot_user.write().unwrap();
|
||||
*bot_user = Some(r.user.id);
|
||||
}
|
||||
Event::GuildCreate(ref g) => self.update_guild(g),
|
||||
Event::ChannelCreate(ref ch) => self.update_channel(ch),
|
||||
Event::ChannelUpdate(ref ch) => self.update_channel(ch),
|
||||
Event::ChannelDelete(ref ch) => self.delete_channel(ch.id),
|
||||
Event::ThreadCreate(ref ch) => self.update_channel(ch),
|
||||
Event::ThreadUpdate(ref ch) => self.update_channel(ch),
|
||||
Event::ThreadDelete(ref ch) => self.delete_channel(ch.id),
|
||||
Event::ThreadListSync(ref ts) => self.update_threads(ts),
|
||||
Event::RoleCreate(ref r) => self.update_role(&r.role),
|
||||
Event::RoleUpdate(ref r) => self.update_role(&r.role),
|
||||
Event::RoleDelete(ref r) => self.delete_role(r.role_id),
|
||||
Event::MemberUpdate(ref member) => {
|
||||
let current_user = self.bot_user_id();
|
||||
if Some(member.user.id) == current_user {
|
||||
self.update_bot_member(member.guild_id, &member.roles);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_guild(&self, guild: &Guild) {
|
||||
for channel in &guild.channels {
|
||||
self.update_channel(channel);
|
||||
}
|
||||
|
||||
for role in &guild.roles {
|
||||
self.update_role(role);
|
||||
}
|
||||
|
||||
let current_user = self.bot_user_id();
|
||||
for member in &guild.members {
|
||||
if Some(member.user.id) == current_user {
|
||||
self.update_bot_member(member.guild_id, &member.roles);
|
||||
}
|
||||
}
|
||||
|
||||
self.guilds.insert(
|
||||
guild.id.get(),
|
||||
CachedGuild {
|
||||
owner_id: guild.owner_id.get(),
|
||||
_premium_tier: guild.premium_tier,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn bot_user_id(&self) -> Option<Id<UserMarker>> {
|
||||
*self.bot_user.read().unwrap()
|
||||
}
|
||||
|
||||
fn update_channel(&self, channel: &Channel) {
|
||||
self.channels.insert(
|
||||
channel.id.get(),
|
||||
CachedChannel {
|
||||
_name: channel.name.as_deref().map(SmolStr::new),
|
||||
_parent_id: channel.parent_id,
|
||||
guild_id: channel.guild_id,
|
||||
kind: channel.kind,
|
||||
overwrites: channel.permission_overwrites.clone().unwrap_or_default(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn delete_channel(&self, id: Id<ChannelMarker>) {
|
||||
self.channels.remove(&id.get());
|
||||
}
|
||||
|
||||
fn update_role(&self, role: &Role) {
|
||||
self.roles.insert(
|
||||
role.id.get(),
|
||||
CachedRole {
|
||||
permissions: role.permissions,
|
||||
_mentionable: role.mentionable,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn delete_role(&self, id: Id<RoleMarker>) {
|
||||
self.roles.remove(&id.get());
|
||||
}
|
||||
|
||||
fn update_threads(&self, evt: &ThreadListSync) {
|
||||
for thread in &evt.threads {
|
||||
self.update_channel(thread);
|
||||
}
|
||||
}
|
||||
|
||||
fn update_bot_member(&self, guild_id: Id<GuildMarker>, roles: &[Id<RoleMarker>]) {
|
||||
self.bot_members.insert(
|
||||
guild_id.get(),
|
||||
CachedBotMember {
|
||||
roles: roles.to_vec(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_guild(&self, guild_id: Id<GuildMarker>) -> anyhow::Result<Ref<u64, CachedGuild>> {
|
||||
self.guilds
|
||||
.get(&guild_id.get())
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find guild in cache: {}", guild_id))
|
||||
}
|
||||
|
||||
pub fn get_channel(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
) -> anyhow::Result<Ref<u64, CachedChannel>> {
|
||||
self.channels
|
||||
.get(&channel_id.get())
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find channel in cache: {}", channel_id))
|
||||
}
|
||||
|
||||
pub fn get_role(&self, role_id: Id<RoleMarker>) -> anyhow::Result<Ref<u64, CachedRole>> {
|
||||
self.roles
|
||||
.get(&role_id.get())
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find role in cache: {}", role_id))
|
||||
}
|
||||
|
||||
pub fn get_bot_member(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
) -> anyhow::Result<Ref<u64, CachedBotMember>> {
|
||||
self.bot_members.get(&guild_id.get()).ok_or_else(|| {
|
||||
anyhow::anyhow!("could not find bot member in cache for guild: {}", guild_id)
|
||||
})
|
||||
}
|
||||
|
||||
fn calculate_permissions_in(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
roles: &[Id<RoleMarker>],
|
||||
) -> anyhow::Result<Permissions> {
|
||||
let channel = self.get_channel(channel_id)?;
|
||||
|
||||
if let Some(guild_id) = channel.guild_id {
|
||||
let guild = self.get_guild(guild_id)?;
|
||||
let everyone_role = self.get_role(guild_id.cast())?;
|
||||
|
||||
let mut member_roles = Vec::with_capacity(roles.len());
|
||||
for role_id in roles {
|
||||
let role = self.get_role(*role_id)?;
|
||||
member_roles.push((role_id.cast(), role.permissions));
|
||||
}
|
||||
|
||||
let calc = PermissionCalculator::new(
|
||||
guild_id,
|
||||
user_id,
|
||||
everyone_role.permissions,
|
||||
&member_roles,
|
||||
)
|
||||
.owner_id(Id::new(guild.owner_id));
|
||||
|
||||
Ok(calc.in_channel(channel.kind, &channel.overwrites))
|
||||
} else {
|
||||
Ok(DM_PERMISSIONS)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn member_permissions(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
member: Option<&PartialMember>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
if let Some(member) = member {
|
||||
self.calculate_permissions_in(channel_id, user_id, &member.roles)
|
||||
} else {
|
||||
// this should just be dm perms, probably?
|
||||
self.calculate_permissions_in(channel_id, user_id, &[])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bot_permissions(&self, channel_id: Id<ChannelMarker>) -> anyhow::Result<Permissions> {
|
||||
let channel = self.get_channel(channel_id)?;
|
||||
if let Some(guild_id) = channel.guild_id {
|
||||
let member = self.get_bot_member(guild_id)?;
|
||||
|
||||
let user_id = self
|
||||
.bot_user_id()
|
||||
.ok_or_else(|| anyhow::anyhow!("haven't received bot user id yet"))?;
|
||||
|
||||
self.calculate_permissions_in(channel_id, user_id, &member.roles)
|
||||
} else {
|
||||
Ok(DM_PERMISSIONS)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn channel_type(&self, channel_id: Id<ChannelMarker>) -> anyhow::Result<ChannelType> {
|
||||
let channel = self.get_channel(channel_id)?;
|
||||
Ok(channel.kind)
|
||||
}
|
||||
}
|
||||
22
rustproxy/pk_bot/src/config.rs
Normal file
22
rustproxy/pk_bot/src/config.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
use config::{Config, Environment, File, FileFormat};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct BotConfig {
|
||||
pub token: String,
|
||||
|
||||
pub max_concurrency: Option<u64>,
|
||||
pub database: String,
|
||||
pub redis_addr: Option<String>,
|
||||
pub redis_gateway_queue_addr: Option<String>,
|
||||
pub shard_count: Option<u64>,
|
||||
}
|
||||
|
||||
// todo: should this be a once_cell::Lazy global const or something
|
||||
pub fn load_config() -> anyhow::Result<BotConfig> {
|
||||
let builder = Config::builder()
|
||||
.add_source(Environment::default())
|
||||
.add_source(File::new("config", FileFormat::Toml));
|
||||
|
||||
Ok(builder.build()?.try_deserialize()?)
|
||||
}
|
||||
174
rustproxy/pk_bot/src/db.rs
Normal file
174
rustproxy/pk_bot/src/db.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
config::BotConfig,
|
||||
model::{PKMember, PKMemberGuild, PKMessage, PKSystem, PKSystemGuild},
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::{
|
||||
postgres::{PgConnectOptions, PgPoolOptions},
|
||||
ConnectOptions, FromRow, PgPool,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(FromRow, Debug, Default)]
|
||||
pub struct MessageContext {
|
||||
// being defensive with these values - we need to be explicit with Option<T>
|
||||
// when the database might return null, and some of these don't have proper default values set
|
||||
// most of the Option<T>s can probably get removed with a few changes to the db function
|
||||
pub system_id: Option<i32>,
|
||||
pub is_deleting: Option<bool>,
|
||||
pub in_blacklist: Option<bool>,
|
||||
pub in_log_blacklist: Option<bool>,
|
||||
pub proxy_enabled: Option<bool>,
|
||||
pub last_switch: Option<i32>,
|
||||
pub last_switch_members: Option<Vec<i32>>,
|
||||
pub last_switch_timestamp: Option<DateTime<Utc>>,
|
||||
pub system_tag: Option<String>,
|
||||
pub system_guild_tag: Option<String>,
|
||||
pub tag_enabled: Option<bool>,
|
||||
pub system_avatar: Option<String>,
|
||||
pub allow_autoproxy: Option<bool>,
|
||||
pub latch_timeout: Option<i32>,
|
||||
}
|
||||
|
||||
pub async fn get_message_context(
|
||||
pool: &PgPool,
|
||||
account_id: i64,
|
||||
guild_id: i64,
|
||||
channel_id: i64,
|
||||
) -> anyhow::Result<MessageContext> {
|
||||
Ok(sqlx::query_as("select * from message_context($1, $2, $3)")
|
||||
.bind(account_id)
|
||||
.bind(guild_id)
|
||||
.bind(channel_id)
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct ProxyTagEntry {
|
||||
pub prefix: String,
|
||||
pub suffix: String,
|
||||
pub member_id: i32,
|
||||
}
|
||||
|
||||
impl From<(&str, &str, i32)> for ProxyTagEntry {
|
||||
fn from((prefix, suffix, member_id): (&str, &str, i32)) -> Self {
|
||||
ProxyTagEntry {
|
||||
prefix: prefix.to_string(),
|
||||
suffix: suffix.to_string(),
|
||||
member_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_proxy_tags(pool: &PgPool, system_id: i32) -> anyhow::Result<Vec<ProxyTagEntry>> {
|
||||
Ok(sqlx::query_as("select coalesce((i.tags).prefix, '') as prefix, coalesce((i.tags).suffix, '') as suffix, member_id from (select unnest(proxy_tags) as tags, id as member_id from members where system = $1) as i;")
|
||||
.bind(system_id)
|
||||
.fetch_all(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(sqlx::Type, Debug, Copy, Clone)]
|
||||
pub enum AutoproxyMode {
|
||||
Off = 1,
|
||||
Front = 2,
|
||||
Latch = 3,
|
||||
Member = 4,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct AutoproxyState {
|
||||
pub autoproxy_mode: AutoproxyMode,
|
||||
pub autoproxy_member: Option<i32>,
|
||||
pub last_latch_timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
pub async fn get_autoproxy_state(
|
||||
pool: &PgPool,
|
||||
system_id: i32,
|
||||
guild_id: i64,
|
||||
channel_id: i64,
|
||||
) -> anyhow::Result<Option<AutoproxyState>> {
|
||||
Ok(sqlx::query_as(
|
||||
"select * from autoproxy where system = $1 and guild_id = $2 and channel_id = $3;",
|
||||
)
|
||||
.bind(system_id)
|
||||
.bind(guild_id)
|
||||
.bind(channel_id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_system_by_id(pool: &PgPool, system_id: i32) -> anyhow::Result<Option<PKSystem>> {
|
||||
Ok(sqlx::query_as("select * from systems where id = $1")
|
||||
.bind(system_id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_member_by_id(pool: &PgPool, member_id: i32) -> anyhow::Result<Option<PKMember>> {
|
||||
Ok(sqlx::query_as("select * from members where id = $1")
|
||||
.bind(member_id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_system_guild(
|
||||
pool: &PgPool,
|
||||
system_id: i32,
|
||||
guild_id: i64,
|
||||
) -> anyhow::Result<Option<PKSystemGuild>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from system_guild where system = $1 and guild = $2")
|
||||
.bind(system_id)
|
||||
.bind(guild_id)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn get_member_guild(
|
||||
pool: &PgPool,
|
||||
member_id: i32,
|
||||
guild_id: i64,
|
||||
) -> anyhow::Result<Option<PKMemberGuild>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from member_guild where member = $1 and guild = $2")
|
||||
.bind(member_id)
|
||||
.bind(guild_id)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn insert_message(pool: &PgPool, message: PKMessage) -> anyhow::Result<()> {
|
||||
sqlx::query("insert into messages (mid, guild, channel, member, sender, original_mid) values ($1, $2, $3, $4, $5, $6)")
|
||||
.bind(message.mid)
|
||||
.bind(message.guild)
|
||||
.bind(message.channel)
|
||||
.bind(message.member_id)
|
||||
.bind(message.sender)
|
||||
.bind(message.original_mid)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn init_db(config: &BotConfig) -> anyhow::Result<PgPool> {
|
||||
info!("connecting to database");
|
||||
let options = PgConnectOptions::from_str(&config.database)
|
||||
.unwrap()
|
||||
.disable_statement_logging()
|
||||
.clone();
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(32)
|
||||
.connect_with(options)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
63
rustproxy/pk_bot/src/gateway.rs
Normal file
63
rustproxy/pk_bot/src/gateway.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
use crate::config::BotConfig;
|
||||
use crate::redis;
|
||||
use std::{env, sync::Arc};
|
||||
use tracing::info;
|
||||
use twilight_gateway::{
|
||||
cluster::{Events, ShardScheme},
|
||||
Cluster, EventTypeFlags, Intents,
|
||||
};
|
||||
use twilight_http::Client;
|
||||
|
||||
pub async fn init_gateway(
|
||||
http: Arc<Client>,
|
||||
config: &BotConfig,
|
||||
) -> anyhow::Result<(Arc<Cluster>, Events)> {
|
||||
let mut builder = Cluster::builder(
|
||||
config.token.clone(),
|
||||
Intents::GUILDS
|
||||
| Intents::DIRECT_MESSAGES
|
||||
| Intents::GUILD_MESSAGES
|
||||
| Intents::MESSAGE_CONTENT,
|
||||
);
|
||||
builder = builder.http_client(http);
|
||||
builder = builder.event_types(EventTypeFlags::all());
|
||||
|
||||
if let Some(scheme) = get_shard_scheme(config)? {
|
||||
info!("using shard scheme: {:?}", scheme);
|
||||
builder = builder.shard_scheme(scheme);
|
||||
}
|
||||
|
||||
if let Some(queue) = redis::init_gateway_queue(config).await? {
|
||||
info!("using redis gateway queue");
|
||||
builder = builder.queue(Arc::new(queue));
|
||||
}
|
||||
|
||||
let (cluster, events) = builder.build().await?;
|
||||
let cluster = Arc::new(cluster);
|
||||
let cluster_spawn = Arc::clone(&cluster);
|
||||
tokio::spawn(async move {
|
||||
info!("starting shards...");
|
||||
cluster_spawn.up().await;
|
||||
});
|
||||
|
||||
Ok((cluster, events))
|
||||
}
|
||||
|
||||
fn get_cluster_id() -> anyhow::Result<u64> {
|
||||
Ok(env::var("NOMAD_ALLOC_INDEX")
|
||||
.unwrap_or_else(|_| "0".to_string())
|
||||
.parse::<u64>()?)
|
||||
}
|
||||
|
||||
fn get_shard_scheme(config: &BotConfig) -> anyhow::Result<Option<ShardScheme>> {
|
||||
let shard_count = config.shard_count.unwrap_or(1);
|
||||
let scheme = if shard_count >= 16 {
|
||||
let cluster_id = get_cluster_id()?;
|
||||
let first_shard_id = 16 * cluster_id;
|
||||
let shard_range = first_shard_id..first_shard_id + 16;
|
||||
Some(ShardScheme::try_from((shard_range, shard_count))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(scheme)
|
||||
}
|
||||
109
rustproxy/pk_bot/src/main.rs
Normal file
109
rustproxy/pk_bot/src/main.rs
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
#[cfg(not(target_env = "msvc"))]
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
#[global_allocator]
|
||||
static GLOBAL: Jemalloc = Jemalloc;
|
||||
|
||||
use crate::cache::DiscordCache;
|
||||
use crate::redis::RedisEventProxy;
|
||||
use futures::StreamExt;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::Event;
|
||||
use twilight_http::Client as HttpClient;
|
||||
|
||||
mod cache;
|
||||
mod config;
|
||||
mod db;
|
||||
mod gateway;
|
||||
mod model;
|
||||
mod proxy;
|
||||
mod redis;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt().init();
|
||||
|
||||
let config = config::load_config()?;
|
||||
info!("loaded config: {:?}", config);
|
||||
|
||||
let pool = db::init_db(&config).await?;
|
||||
let http = Arc::new(HttpClient::new(config.token.clone()));
|
||||
let (_cluster, mut events) = gateway::init_gateway(Arc::clone(&http), &config).await?;
|
||||
let cache = Arc::new(DiscordCache::new());
|
||||
let redis = redis::init_event_proxy(&config).await?;
|
||||
|
||||
while let Some((shard_id, event)) = events.next().await {
|
||||
let http = Arc::clone(&http);
|
||||
let pool = pool.clone();
|
||||
let cache = Arc::clone(&cache);
|
||||
let redis = redis.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
cache.handle_event(&event);
|
||||
|
||||
let res = handle_event(shard_id, event, http, pool, cache, redis).await;
|
||||
if let Err(e) = res {
|
||||
error!("error handling event: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_event(
|
||||
shard_id: u64,
|
||||
event: Event,
|
||||
http: Arc<HttpClient>,
|
||||
pool: PgPool,
|
||||
cache: Arc<DiscordCache>,
|
||||
mut redis: RedisEventProxy,
|
||||
) -> anyhow::Result<()> {
|
||||
match event {
|
||||
Event::MessageCreate(msg) => {
|
||||
if msg.content.starts_with("pk;") || msg.content.starts_with("pk!") {
|
||||
redis
|
||||
.send_event_parsed(shard_id, Event::MessageCreate(msg))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let channel_type = cache.channel_type(msg.channel_id)?;
|
||||
|
||||
let ctx = db::get_message_context(
|
||||
&pool,
|
||||
msg.author.id.get() as i64,
|
||||
msg.guild_id.map(|x| x.get()).unwrap_or_default() as i64,
|
||||
msg.channel_id.get() as i64,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _member_permissions =
|
||||
cache.member_permissions(msg.channel_id, msg.author.id, msg.member.as_ref())?;
|
||||
let bot_permissions = cache.bot_permissions(msg.channel_id)?;
|
||||
|
||||
match proxy::check_preconditions(&msg, channel_type, bot_permissions, &ctx) {
|
||||
Ok(_) => {
|
||||
info!("attempting to proxy");
|
||||
proxy::do_proxy(&http, &pool, &msg, &ctx).await?;
|
||||
}
|
||||
Err(reason) => {
|
||||
info!("skipping proxy because: {}", reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::ShardConnected(_) => {
|
||||
info!("connected on shard {}", shard_id);
|
||||
}
|
||||
Event::ShardPayload(payload) => {
|
||||
redis.send_event_raw(shard_id, &payload.bytes).await?;
|
||||
}
|
||||
// Other events here...
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
49
rustproxy/pk_bot/src/model.rs
Normal file
49
rustproxy/pk_bot/src/model.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use sqlx::FromRow;
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKSystem {
|
||||
pub id: i32,
|
||||
pub name: Option<String>,
|
||||
pub tag: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKMember {
|
||||
pub id: i32,
|
||||
pub system: i32,
|
||||
pub name: String,
|
||||
|
||||
pub color: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub pronouns: Option<String>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKMessage {
|
||||
pub mid: i64,
|
||||
pub guild: Option<i64>,
|
||||
pub channel: i64,
|
||||
pub member_id: i32,
|
||||
pub sender: i64,
|
||||
pub original_mid: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKSystemGuild {
|
||||
pub system: i32,
|
||||
pub guild: i64,
|
||||
pub proxy_enabled: bool,
|
||||
pub tag: Option<String>,
|
||||
pub tag_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKMemberGuild {
|
||||
pub member: i32,
|
||||
pub guild: i64,
|
||||
pub display_name: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
31
rustproxy/pk_bot/src/proxy/autoproxy.rs
Normal file
31
rustproxy/pk_bot/src/proxy/autoproxy.rs
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
use crate::db::{AutoproxyMode, AutoproxyState, MessageContext};
|
||||
|
||||
pub fn resolve_autoproxy_member(
|
||||
ctx: &MessageContext,
|
||||
state: &AutoproxyState,
|
||||
content: &str,
|
||||
) -> Option<i32> {
|
||||
if !ctx.allow_autoproxy.unwrap_or(true) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if is_escape(content) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let first_fronter = ctx.last_switch_members.iter().flatten().cloned().next();
|
||||
match (state.autoproxy_mode, state.autoproxy_member, first_fronter) {
|
||||
(AutoproxyMode::Latch, Some(m), _) => Some(m),
|
||||
(AutoproxyMode::Member, Some(m), _) => Some(m),
|
||||
(AutoproxyMode::Front, _, Some(f)) => Some(f),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn _is_unlatch(content: &str) -> bool {
|
||||
content.starts_with("\\\\") || content.starts_with("\\\u{200b}\\")
|
||||
}
|
||||
|
||||
fn is_escape(content: &str) -> bool {
|
||||
content.starts_with('\\')
|
||||
}
|
||||
227
rustproxy/pk_bot/src/proxy/mod.rs
Normal file
227
rustproxy/pk_bot/src/proxy/mod.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
use self::{post_proxy::ProxyResult, webhook::WebhookExecuteRequest};
|
||||
use crate::db::{self, MessageContext};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use twilight_http::Client;
|
||||
use twilight_model::{
|
||||
channel::{message::MessageType, ChannelType, Message},
|
||||
guild::Permissions,
|
||||
user::User,
|
||||
};
|
||||
|
||||
mod autoproxy;
|
||||
mod post_proxy;
|
||||
mod profile;
|
||||
mod reply;
|
||||
mod tags;
|
||||
mod webhook;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
// use this for pk;proxycheck or something
|
||||
pub enum PreconditionFailure {
|
||||
#[error("bot is missing permissions (has: {has:?}, needs: {needs:?})")]
|
||||
BotMissingPermission {
|
||||
has: Permissions,
|
||||
needs: Permissions,
|
||||
},
|
||||
|
||||
#[error("invalid channel type {0:?}")]
|
||||
InvalidChannelType(ChannelType),
|
||||
|
||||
#[error("invalid message type {0:?}")]
|
||||
InvalidMessageType(MessageType),
|
||||
|
||||
#[error("user is bot")]
|
||||
UserIsBot,
|
||||
|
||||
#[error("user is webhook")]
|
||||
UserIsWebhook,
|
||||
|
||||
#[error("user is discord system")]
|
||||
UserIsDiscordSystem,
|
||||
|
||||
#[error("user has no system")]
|
||||
UserHasNoSystem,
|
||||
|
||||
#[error("proxy disabled for system")]
|
||||
ProxyDisabledForSystem,
|
||||
|
||||
#[error("proxy disabled in channel")]
|
||||
ProxyDisabledInChannel,
|
||||
|
||||
#[error("message contains activity")]
|
||||
MessageContainsActivity,
|
||||
|
||||
#[error("message contains sticker")]
|
||||
MessageContainsSticker,
|
||||
|
||||
#[error("message is empty and has no attachments")]
|
||||
MessageIsEmpty,
|
||||
}
|
||||
|
||||
// todo: the parameters here are nasty, refactor+put this code somewhere else maybe
|
||||
pub fn check_preconditions(
|
||||
msg: &Message,
|
||||
channel_type: ChannelType,
|
||||
bot_permissions: Permissions,
|
||||
ctx: &MessageContext,
|
||||
) -> Result<(), PreconditionFailure> {
|
||||
let required_permissions =
|
||||
Permissions::SEND_MESSAGES | Permissions::MANAGE_WEBHOOKS | Permissions::MANAGE_MESSAGES;
|
||||
if !bot_permissions.contains(required_permissions) {
|
||||
return Err(PreconditionFailure::BotMissingPermission {
|
||||
has: bot_permissions,
|
||||
needs: required_permissions,
|
||||
});
|
||||
}
|
||||
|
||||
match channel_type {
|
||||
ChannelType::GuildText
|
||||
| ChannelType::GuildNews
|
||||
| ChannelType::GuildPrivateThread
|
||||
| ChannelType::GuildPublicThread
|
||||
| ChannelType::GuildNewsThread => Ok(()),
|
||||
wrong_type => Err(PreconditionFailure::InvalidChannelType(wrong_type)),
|
||||
}?;
|
||||
|
||||
match msg.kind {
|
||||
MessageType::Regular | MessageType::Reply => Ok(()),
|
||||
wrong_type => Err(PreconditionFailure::InvalidMessageType(wrong_type)),
|
||||
}?;
|
||||
|
||||
match msg {
|
||||
Message {
|
||||
author: User {
|
||||
system: Some(true), ..
|
||||
},
|
||||
..
|
||||
} => Err(PreconditionFailure::UserIsDiscordSystem),
|
||||
Message {
|
||||
author: User { bot: true, .. },
|
||||
..
|
||||
} => Err(PreconditionFailure::UserIsBot),
|
||||
Message {
|
||||
webhook_id: Some(_),
|
||||
..
|
||||
} => Err(PreconditionFailure::UserIsWebhook),
|
||||
Message {
|
||||
activity: Some(_), ..
|
||||
} => Err(PreconditionFailure::MessageContainsActivity),
|
||||
Message {
|
||||
sticker_items: s, ..
|
||||
} if !s.is_empty() => Err(PreconditionFailure::MessageContainsSticker),
|
||||
Message {
|
||||
content: c,
|
||||
attachments: a,
|
||||
..
|
||||
} if c.trim().is_empty() && a.is_empty() => Err(PreconditionFailure::MessageIsEmpty),
|
||||
_ => Ok(()),
|
||||
}?;
|
||||
|
||||
match ctx {
|
||||
MessageContext {
|
||||
system_id: None, ..
|
||||
} => Err(PreconditionFailure::UserHasNoSystem),
|
||||
MessageContext {
|
||||
in_blacklist: Some(true),
|
||||
..
|
||||
} => Err(PreconditionFailure::ProxyDisabledInChannel),
|
||||
MessageContext {
|
||||
proxy_enabled: Some(false),
|
||||
..
|
||||
} => Err(PreconditionFailure::ProxyDisabledForSystem),
|
||||
_ => Ok(()),
|
||||
}?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct ProxyMatchResult {
|
||||
member_id: i32,
|
||||
inner_content: String,
|
||||
_tags: Option<(String, String)>, // todo: need this for keepproxy
|
||||
}
|
||||
|
||||
async fn match_tags_or_autoproxy(
|
||||
pool: &PgPool,
|
||||
msg: &Message,
|
||||
ctx: &MessageContext,
|
||||
) -> anyhow::Result<Option<ProxyMatchResult>> {
|
||||
let guild_id = msg.guild_id.ok_or_else(|| anyhow::anyhow!("no guild id"))?;
|
||||
let system_id = ctx.system_id.ok_or_else(|| anyhow::anyhow!("no system"))?;
|
||||
|
||||
let tags = db::get_proxy_tags(pool, system_id).await?;
|
||||
let ap_state = db::get_autoproxy_state(
|
||||
pool,
|
||||
system_id,
|
||||
guild_id.get() as i64,
|
||||
0, // all autoproxy has channel id 0? o.o
|
||||
)
|
||||
.await?;
|
||||
|
||||
let tag_match = tags::match_proxy_tags(&tags, &msg.content);
|
||||
if let Some(tag_match) = tag_match {
|
||||
return Ok(Some(ProxyMatchResult {
|
||||
member_id: tag_match.member_id,
|
||||
inner_content: tag_match.inner_content,
|
||||
_tags: Some(tag_match.tags),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(ap_state) = ap_state {
|
||||
let res = autoproxy::resolve_autoproxy_member(ctx, &ap_state, &msg.content);
|
||||
if let Some(member_id) = res {
|
||||
return Ok(Some(ProxyMatchResult {
|
||||
inner_content: msg.content.clone(),
|
||||
member_id,
|
||||
_tags: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// todo: this shouldn't depend on a Message object (for reproxy/proxy command/etc)
|
||||
pub async fn do_proxy(
|
||||
http: &Client,
|
||||
pool: &PgPool,
|
||||
msg: &Message,
|
||||
ctx: &MessageContext,
|
||||
) -> anyhow::Result<()> {
|
||||
let guild_id = msg.guild_id.ok_or_else(|| anyhow::anyhow!("no guild id"))?;
|
||||
let system_id = ctx.system_id.ok_or_else(|| anyhow::anyhow!("no system"))?;
|
||||
|
||||
// todo: unlatch check/exec should probably go in here somewhere
|
||||
|
||||
let proxy_match = match_tags_or_autoproxy(pool, msg, ctx).await?;
|
||||
if let Some(result) = proxy_match {
|
||||
let profile =
|
||||
profile::fetch_proxy_profile(pool, guild_id.get(), system_id, result.member_id).await?;
|
||||
|
||||
let webhook_req = WebhookExecuteRequest {
|
||||
channel_id: msg.channel_id.get(),
|
||||
avatar_url: profile.avatar_url().map(|s| s.to_string()),
|
||||
content: Some(result.inner_content.clone()),
|
||||
username: profile.formatted_name(),
|
||||
embed: msg
|
||||
.referenced_message
|
||||
.as_deref()
|
||||
.map(|msg| reply::create_reply_embed(guild_id, msg))
|
||||
.transpose()?,
|
||||
};
|
||||
|
||||
let webhook_res = webhook::execute_webhook(http, &webhook_req).await?;
|
||||
|
||||
let proxy_res = ProxyResult {
|
||||
channel_id: msg.channel_id.get(),
|
||||
guild_id: guild_id.get(),
|
||||
member_id: result.member_id,
|
||||
original_message_id: msg.id.get(),
|
||||
proxy_message_id: webhook_res.message_id,
|
||||
sender: msg.author.id.get(),
|
||||
};
|
||||
post_proxy::handle_post_proxy(http, pool, &proxy_res).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
66
rustproxy/pk_bot/src/proxy/post_proxy.rs
Normal file
66
rustproxy/pk_bot/src/proxy/post_proxy.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
use crate::db;
|
||||
use crate::model::PKMessage;
|
||||
use futures::TryFutureExt;
|
||||
use sqlx::PgPool;
|
||||
use tracing::error;
|
||||
use twilight_http::Client;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyResult {
|
||||
pub guild_id: u64,
|
||||
pub channel_id: u64,
|
||||
pub proxy_message_id: u64,
|
||||
pub original_message_id: u64,
|
||||
pub sender: u64,
|
||||
pub member_id: i32,
|
||||
}
|
||||
|
||||
pub async fn handle_post_proxy(
|
||||
http: &Client,
|
||||
pool: &PgPool,
|
||||
res: &ProxyResult,
|
||||
) -> anyhow::Result<()> {
|
||||
// todo: log channel
|
||||
|
||||
let _ = futures::join!(
|
||||
delete_original_message(http, res)
|
||||
.inspect_err(|e| error!("error deleting original message: {}", e)),
|
||||
insert_message_in_db(pool, res)
|
||||
.inspect_err(|e| error!("error deleting original message: {}", e))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_original_message(http: &Client, res: &ProxyResult) -> anyhow::Result<()> {
|
||||
// todo: sleep some amount
|
||||
// (do we still need to do that or did discord fix that client bug?)
|
||||
http.delete_message(
|
||||
res.channel_id.try_into().unwrap(),
|
||||
res.original_message_id.try_into().unwrap(),
|
||||
)
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn insert_message_in_db(pool: &PgPool, res: &ProxyResult) -> anyhow::Result<()> {
|
||||
db::insert_message(
|
||||
pool,
|
||||
PKMessage {
|
||||
mid: res.proxy_message_id as i64,
|
||||
original_mid: Some(res.original_message_id as i64),
|
||||
sender: res.sender as i64,
|
||||
guild: Some(res.guild_id as i64),
|
||||
channel: res.channel_id as i64,
|
||||
member_id: res.member_id,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
90
rustproxy/pk_bot/src/proxy/profile.rs
Normal file
90
rustproxy/pk_bot/src/proxy/profile.rs
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
use futures::try_join;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::{
|
||||
db,
|
||||
model::{PKMember, PKMemberGuild, PKSystem, PKSystemGuild},
|
||||
};
|
||||
|
||||
// inventing the term "proxy profile" here to describe the info needed to work out the webhook name+avatar
|
||||
// arbitrary choice to put the source models in the struct and logic in methods, could just as well have had a function to do the math and put the results in a struct
|
||||
pub struct ProxyProfile {
|
||||
system: PKSystem,
|
||||
member: PKMember,
|
||||
system_guild: Option<PKSystemGuild>,
|
||||
member_guild: Option<PKMemberGuild>,
|
||||
}
|
||||
|
||||
impl ProxyProfile {
|
||||
pub fn name(&self) -> &str {
|
||||
let member_name = &self.member.name;
|
||||
let display_name = self.member.display_name.as_deref();
|
||||
let server_name = self
|
||||
.member_guild
|
||||
.as_ref()
|
||||
.and_then(|x| x.display_name.as_deref());
|
||||
server_name.or(display_name).unwrap_or(member_name)
|
||||
}
|
||||
|
||||
pub fn avatar_url(&self) -> Option<&str> {
|
||||
let system_avatar = self.system.avatar_url.as_deref();
|
||||
let member_avatar = self.member.avatar_url.as_deref();
|
||||
let server_avatar = self
|
||||
.member_guild
|
||||
.as_ref()
|
||||
.and_then(|x| x.avatar_url.as_deref());
|
||||
server_avatar.or(member_avatar).or(system_avatar)
|
||||
}
|
||||
|
||||
pub fn tag(&self) -> Option<&str> {
|
||||
let server_tag = self.system_guild.as_ref().and_then(|x| x.tag.as_deref());
|
||||
let system_tag = self.system.tag.as_deref();
|
||||
server_tag.or(system_tag)
|
||||
}
|
||||
|
||||
pub fn formatted_name(&self) -> String {
|
||||
let mut name = if let Some(tag) = self.tag() {
|
||||
format!("{} {}", self.name(), tag)
|
||||
} else {
|
||||
self.name().to_string()
|
||||
};
|
||||
|
||||
if name.len() == 1 {
|
||||
name.push('\u{17b5}');
|
||||
}
|
||||
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_proxy_profile(
|
||||
pool: &PgPool,
|
||||
guild_id: u64,
|
||||
system_id: i32,
|
||||
member_id: i32,
|
||||
) -> anyhow::Result<ProxyProfile> {
|
||||
// todo: this should be a db view with joins
|
||||
// this is all the info that proxy_members returned, so a single-member version of that could work nicely
|
||||
let system = db::get_system_by_id(pool, system_id);
|
||||
let member = db::get_member_by_id(pool, member_id);
|
||||
let system_guild = db::get_system_guild(pool, system_id, guild_id as i64);
|
||||
let member_guild = db::get_member_guild(pool, member_id, guild_id as i64);
|
||||
|
||||
let (system, member, system_guild, member_guild) =
|
||||
try_join!(system, member, system_guild, member_guild)?;
|
||||
|
||||
let system = system.ok_or_else(|| anyhow::anyhow!("could not find system"))?;
|
||||
let member = member.ok_or_else(|| anyhow::anyhow!("could not find member"))?;
|
||||
|
||||
Ok(ProxyProfile {
|
||||
system,
|
||||
member,
|
||||
system_guild,
|
||||
member_guild,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// todo: this code is gonna be easy to unit test so we should do that
|
||||
}
|
||||
57
rustproxy/pk_bot/src/proxy/reply.rs
Normal file
57
rustproxy/pk_bot/src/proxy/reply.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
use twilight_http::Client;
|
||||
use twilight_model::{
|
||||
channel::{embed::Embed, Message},
|
||||
id::{
|
||||
marker::{GuildMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
util::ImageHash,
|
||||
};
|
||||
use twilight_util::builder::embed::{EmbedAuthorBuilder, EmbedBuilder, ImageSource};
|
||||
|
||||
async fn _fetch_additional_reply_info(_http: &Client) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_reply_embed(
|
||||
guild_id: Id<GuildMarker>,
|
||||
replied_to: &Message,
|
||||
) -> anyhow::Result<Embed> {
|
||||
// todo: guild avatars, guild nicknames
|
||||
// probably put this in fetch_additional_reply_info
|
||||
|
||||
let author = {
|
||||
let icon = replied_to
|
||||
.author
|
||||
.avatar
|
||||
.map(|hash| get_avatar_url(replied_to.author.id, hash))
|
||||
.and_then(|url| ImageSource::url(url).ok());
|
||||
|
||||
let mut builder = EmbedAuthorBuilder::new(replied_to.author.name.clone());
|
||||
if let Some(icon) = icon {
|
||||
builder = builder.icon_url(icon);
|
||||
};
|
||||
builder.build()
|
||||
};
|
||||
|
||||
let content = {
|
||||
let jump_link = format!(
|
||||
"https://discord.com/channels/{}/{}/{}",
|
||||
guild_id, replied_to.channel_id, replied_to.id
|
||||
);
|
||||
|
||||
let content = format!("**[Reply to:]({})** ", jump_link);
|
||||
// todo: properly add truncated content (including handling links/spoilers/etc)
|
||||
content
|
||||
};
|
||||
|
||||
let builder = EmbedBuilder::new().description(content).author(author);
|
||||
Ok(builder.build())
|
||||
}
|
||||
|
||||
fn get_avatar_url(user_id: Id<UserMarker>, hash: ImageHash) -> String {
|
||||
format!(
|
||||
"https://cdn.discordapp.com/avatars/{}/{}.png",
|
||||
user_id, hash
|
||||
)
|
||||
}
|
||||
93
rustproxy/pk_bot/src/proxy/tags.rs
Normal file
93
rustproxy/pk_bot/src/proxy/tags.rs
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
use crate::db::ProxyTagEntry;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyTagMatch {
|
||||
pub inner_content: String,
|
||||
pub tags: (String, String),
|
||||
pub member_id: i32,
|
||||
}
|
||||
|
||||
pub fn match_proxy_tags(tags: &[ProxyTagEntry], content: &str) -> Option<ProxyTagMatch> {
|
||||
let content = content.trim();
|
||||
|
||||
let mut sorted_entries = tags.to_vec();
|
||||
sorted_entries.sort_by_key(|x| -((x.prefix.len() + x.suffix.len()) as i32));
|
||||
|
||||
for entry in sorted_entries {
|
||||
let is_tag_match = content.starts_with(&entry.prefix) && content.ends_with(&entry.suffix);
|
||||
info!(
|
||||
"prefix: {}, suffix: {}, content: {}, is_match: {}",
|
||||
entry.prefix, entry.suffix, content, is_tag_match
|
||||
);
|
||||
|
||||
// todo: extract leading mentions
|
||||
// todo: allow empty matches only if we're proxying an attachment
|
||||
// todo: properly handle <>s etc, there's some regex stuff in there i don't entirely understand
|
||||
// todo: there's some weird edge cases with various unicode control characters and emoji joiners and whatever, should figure that out + unit test it
|
||||
if is_tag_match {
|
||||
let inner_content =
|
||||
&content[entry.prefix.len()..(content.len() - entry.suffix.len())].trim();
|
||||
return Some(ProxyTagMatch {
|
||||
inner_content: inner_content.to_string(),
|
||||
tags: (entry.prefix, entry.suffix),
|
||||
member_id: entry.member_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn basic_match() {
|
||||
let tags = vec![
|
||||
("[", "]", 0).into(),
|
||||
("[[", "]]", 1).into(),
|
||||
("P:", "", 2).into(),
|
||||
("", "-P", 3).into(),
|
||||
("+ ", "", 4).into(),
|
||||
];
|
||||
assert_no_match(&tags, "hello world");
|
||||
assert_match(&tags, "[hello world]", "hello world", 0);
|
||||
assert_match(&tags, "[ hello world ]", "hello world", 0);
|
||||
assert_match(&tags, " [ hello world ] ", "hello world", 0);
|
||||
assert_match(&tags, "[\nhello\n]", "hello", 0);
|
||||
assert_match(&tags, "[\nhello\nworld\n]", "hello\nworld", 0);
|
||||
|
||||
assert_match(&tags, "[[text]]", "text", 1);
|
||||
assert_match(&tags, "[text]]", "text]", 0);
|
||||
assert_match(&tags, "[[[text]]]", "[text]", 1);
|
||||
|
||||
assert_match(&tags, "P:text", "text", 2);
|
||||
assert_match(&tags, "text -P", "text", 3);
|
||||
|
||||
assert_match(&tags, "+ hello", "hello", 4);
|
||||
assert_no_match(&tags, "+hello"); // (prefix contains trailing space)
|
||||
|
||||
assert_match(&tags, "[]", "", 0);
|
||||
|
||||
// edge case: the c# implementation currently does what the commented out test does
|
||||
// *if* the message doesn't have an attachment. not sure if we should mirror this here.
|
||||
// assert_match(&tags, "[[]]", "[]", 0);
|
||||
assert_match(&tags, "[[]]", "", 1);
|
||||
}
|
||||
|
||||
fn assert_match(tags: &[ProxyTagEntry], message: &str, inner: &str, member: i32) {
|
||||
let res = match_proxy_tags(tags, message);
|
||||
assert_eq!(
|
||||
res.as_ref()
|
||||
.map(|x| (x.inner_content.as_str(), x.member_id)),
|
||||
Some((inner, member))
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_no_match(tags: &[ProxyTagEntry], message: &str) {
|
||||
let res = match_proxy_tags(tags, message);
|
||||
assert!(res.is_none());
|
||||
}
|
||||
}
|
||||
143
rustproxy/pk_bot/src/proxy/webhook.rs
Normal file
143
rustproxy/pk_bot/src/proxy/webhook.rs
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
use moka::future::Cache;
|
||||
use once_cell::sync::Lazy;
|
||||
use tracing::info;
|
||||
use twilight_http::Client;
|
||||
use twilight_model::channel::{embed::Embed, Webhook, WebhookType};
|
||||
|
||||
// space for 1 million is probably way overkill, this is a LRU cache so it's okay to evict occasionally
|
||||
static WEBHOOK_CACHE: Lazy<Cache<u64, CachedWebhook>> = Lazy::new(|| Cache::new(1024 * 1024));
|
||||
const WEBHOOK_NAME: &str = "PluralKit Proxy Webhook";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedWebhook {
|
||||
pub id: u64,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
pub async fn get_webhook_cached(http: &Client, channel_id: u64) -> anyhow::Result<CachedWebhook> {
|
||||
let res = WEBHOOK_CACHE
|
||||
.try_get_with(channel_id, fetch_or_create_pk_webhook(http, channel_id))
|
||||
.await;
|
||||
|
||||
// todo: what happens if fetch_or_create_pk_webhook errors? i think moka handles it properly and just retries
|
||||
// but i'm not entiiiirely sure
|
||||
// https://docs.rs/moka/0.8.5/moka/future/struct.Cache.html#method.try_get_with
|
||||
|
||||
// error is Arc<Error> here and it's hard to convert that into an owned ref so we just make a new error lmao
|
||||
res.map_err(|_e| anyhow::anyhow!(
|
||||
"could not fetch webhook: {}", _e
|
||||
))
|
||||
}
|
||||
|
||||
async fn fetch_or_create_pk_webhook(
|
||||
http: &Client,
|
||||
channel_id: u64,
|
||||
) -> anyhow::Result<CachedWebhook> {
|
||||
match fetch_pk_webhook(http, channel_id).await? {
|
||||
Some(hook) => Ok(hook),
|
||||
None => create_pk_webhook(http, channel_id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_pk_webhook(http: &Client, channel_id: u64) -> anyhow::Result<Option<CachedWebhook>> {
|
||||
info!("cache miss, fetching webhook for channel {}", channel_id);
|
||||
|
||||
let webhooks = http
|
||||
.channel_webhooks(channel_id.try_into().unwrap())
|
||||
.exec()
|
||||
.await?
|
||||
.models()
|
||||
.await?;
|
||||
|
||||
webhooks
|
||||
.iter()
|
||||
.find(|wh| is_proxy_webhook(wh))
|
||||
.map(|x| {
|
||||
let token = x
|
||||
.token
|
||||
.as_ref()
|
||||
.map(|x| x.to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("webhook should contain token"));
|
||||
|
||||
token.map(|token| CachedWebhook {
|
||||
id: x.id.get(),
|
||||
token,
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
async fn create_pk_webhook(http: &Client, channel_id: u64) -> anyhow::Result<CachedWebhook> {
|
||||
let response = http
|
||||
.create_webhook(channel_id.try_into().unwrap(), WEBHOOK_NAME)?
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
// todo: error handling here
|
||||
let val = response.model().await?;
|
||||
Ok(CachedWebhook {
|
||||
id: val.id.get(),
|
||||
token: val
|
||||
.token
|
||||
.ok_or_else(|| anyhow::anyhow!("webhook should contain token"))?,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_proxy_webhook(wh: &Webhook) -> bool {
|
||||
wh.kind == WebhookType::Incoming
|
||||
&& wh.token.is_some()
|
||||
&& wh.name.as_deref() == Some(WEBHOOK_NAME)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WebhookExecuteRequest {
|
||||
pub channel_id: u64,
|
||||
pub username: String,
|
||||
pub avatar_url: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub embed: Option<Embed>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WebhookExecuteResult {
|
||||
pub message_id: u64,
|
||||
}
|
||||
|
||||
pub async fn execute_webhook(
|
||||
http: &Client,
|
||||
req: &WebhookExecuteRequest,
|
||||
) -> anyhow::Result<WebhookExecuteResult> {
|
||||
let webhook = get_webhook_cached(http, req.channel_id).await?;
|
||||
let mut request = http
|
||||
.execute_webhook(webhook.id.try_into().unwrap(), &webhook.token)
|
||||
.username(&req.username)?;
|
||||
|
||||
if let Some(ref content) = req.content {
|
||||
request = request.content(content)?;
|
||||
}
|
||||
|
||||
if let Some(ref avatar_url) = req.avatar_url {
|
||||
request = request.avatar_url(avatar_url);
|
||||
}
|
||||
|
||||
let mut embeds = Vec::new();
|
||||
if let Some(ref embed) = req.embed {
|
||||
embeds.push(embed.clone());
|
||||
request = request.embeds(&embeds)?;
|
||||
}
|
||||
|
||||
// todo: handle error if webhook was deleted, should invalidate and retry
|
||||
let result = request.wait().exec().await?;
|
||||
|
||||
let model = result.model().await?;
|
||||
if model.channel_id != req.channel_id {
|
||||
// it's possible for someone to "redirect" a webhook to another channel
|
||||
// and the only way we find out is when we send a message.
|
||||
// if this has happened remove it from cache and refetch later
|
||||
WEBHOOK_CACHE.invalidate(&req.channel_id).await;
|
||||
}
|
||||
|
||||
Ok(WebhookExecuteResult {
|
||||
message_id: model.id.get(),
|
||||
})
|
||||
}
|
||||
168
rustproxy/pk_bot/src/redis.rs
Normal file
168
rustproxy/pk_bot/src/redis.rs
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
use crate::config::BotConfig;
|
||||
use once_cell::sync::Lazy;
|
||||
use redis::aio::ConnectionManager;
|
||||
use std::fmt::Debug;
|
||||
use std::time::Duration;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::Event;
|
||||
use twilight_gateway_queue::Queue;
|
||||
use twilight_model::gateway::event::{DispatchEvent, GatewayEvent, GatewayEventDeserializer};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisEventProxy {
|
||||
// todo: don't know if i want this struct to have responsibility for ignoring calls if redis is disabled
|
||||
inner: Option<ConnectionManager>,
|
||||
}
|
||||
|
||||
// events that should be sent in the raw handler
|
||||
// does not include message create/update since we want first dibs on those and pass them on later
|
||||
static ALLOWED_EVENTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
|
||||
vec![
|
||||
"INTERACTION_CREATE",
|
||||
"MESSAGE_DELETE",
|
||||
"MESSAGE_DELETE_BULK",
|
||||
"MESSAGE_REACTION_ADD",
|
||||
"READY",
|
||||
"GUILD_CREATE",
|
||||
"GUILD_UPDATE",
|
||||
"GUILD_DELETE",
|
||||
"GUILD_ROLE_CREATE",
|
||||
"GUILD_ROLE_UPDATE",
|
||||
"GUILD_ROLE_DELETE",
|
||||
"CHANNEL_CREATE",
|
||||
"CHANNEL_UPDATE",
|
||||
"CHANNEL_DELETE",
|
||||
"THREAD_CREATE",
|
||||
"THREAD_UPDATE",
|
||||
"THREAD_DELETE",
|
||||
"THREAD_LIST_SYNC",
|
||||
]
|
||||
});
|
||||
|
||||
impl RedisEventProxy {
|
||||
async fn send_event_raw_inner(&mut self, shard_id: u64, payload: &[u8]) -> anyhow::Result<()> {
|
||||
if let Some(ref mut redis) = self.inner {
|
||||
info!(shard_id = shard_id, "publishing event");
|
||||
let key = format!("evt-{}", shard_id);
|
||||
|
||||
redis::cmd("PUBLISH")
|
||||
.arg(&key[..])
|
||||
.arg(payload)
|
||||
.query_async(redis)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_event_raw(&mut self, shard_id: u64, payload: &[u8]) -> anyhow::Result<()> {
|
||||
let payload_str = std::str::from_utf8(payload)?;
|
||||
|
||||
if let Some(deser) = GatewayEventDeserializer::from_json(payload_str) {
|
||||
if let Some(event_type) = deser.event_type_ref() {
|
||||
if ALLOWED_EVENTS.contains(&event_type) {
|
||||
self.send_event_raw_inner(shard_id, payload).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_event_parsed(&mut self, shard_id: u64, evt: Event) -> anyhow::Result<()> {
|
||||
info!(shard_id, "sending parsed: {:?}", evt.kind());
|
||||
|
||||
let dispatch_event = DispatchEvent::try_from(evt)?;
|
||||
let gateway_event = GatewayEvent::Dispatch(0, Box::new(dispatch_event));
|
||||
let buf = serde_json::to_vec(&gateway_event)?;
|
||||
self.send_event_raw_inner(shard_id, &buf).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_to_redis(addr: &str) -> anyhow::Result<ConnectionManager> {
|
||||
let client = redis::Client::open(addr)?;
|
||||
info!("connecting to redis at {}...", addr);
|
||||
Ok(ConnectionManager::new(client).await?)
|
||||
}
|
||||
|
||||
pub async fn init_event_proxy(config: &BotConfig) -> anyhow::Result<RedisEventProxy> {
|
||||
let mgr = if let Some(redis_addr) = &config.redis_addr {
|
||||
Some(connect_to_redis(redis_addr).await?)
|
||||
} else {
|
||||
info!("no redis address specified, skipping");
|
||||
None
|
||||
};
|
||||
|
||||
Ok(RedisEventProxy { inner: mgr })
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisQueue {
|
||||
pub redis: ConnectionManager,
|
||||
pub concurrency: u64,
|
||||
}
|
||||
|
||||
impl Debug for RedisQueue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RedisQueue")
|
||||
.field("concurrency", &self.concurrency)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Queue for RedisQueue {
|
||||
fn request<'a>(
|
||||
&'a self,
|
||||
shard_id: [u64; 2],
|
||||
) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + 'a>> {
|
||||
Box::pin(request_inner(
|
||||
self.redis.clone(),
|
||||
self.concurrency,
|
||||
*shard_id.first().unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_inner(mut client: ConnectionManager, concurrency: u64, shard_id: u64) {
|
||||
let bucket = shard_id % concurrency;
|
||||
let key = format!("pluralkit:identify:{}", bucket);
|
||||
|
||||
// SET bucket 1 EX 6 NX = write a key expiring after 6 seconds if there's not already one
|
||||
let mut cmd = redis::cmd("SET");
|
||||
cmd.arg(key).arg("1").arg("EX").arg(6i8).arg("NX");
|
||||
|
||||
info!(shard_id, bucket, "waiting for allowance...");
|
||||
loop {
|
||||
let done = cmd
|
||||
.clone()
|
||||
.query_async::<_, Option<String>>(&mut client)
|
||||
.await;
|
||||
match done {
|
||||
Ok(Some(_)) => {
|
||||
info!(shard_id, bucket, "got allowance!");
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init_gateway_queue(config: &BotConfig) -> anyhow::Result<Option<RedisQueue>> {
|
||||
let queue = if let Some(ref addr) = config.redis_gateway_queue_addr {
|
||||
let redis = connect_to_redis(addr).await?;
|
||||
let concurrency = config.max_concurrency.unwrap_or(1);
|
||||
Some(RedisQueue { redis, concurrency })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(queue)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue