mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-04 04:56:49 +00:00
rustproxy: initial commit
This commit is contained in:
parent
b47694edc1
commit
9d90de45a6
23 changed files with 4686 additions and 0 deletions
3
rustproxy/.gitignore
vendored
Normal file
3
rustproxy/.gitignore
vendored
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
/target
|
||||
.idea/
|
||||
config.toml
|
||||
2625
rustproxy/Cargo.lock
generated
Normal file
2625
rustproxy/Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
5
rustproxy/Cargo.toml
Normal file
5
rustproxy/Cargo.toml
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
[workspace]
|
||||
members = [
|
||||
"pk_bot",
|
||||
"pk_command_parser"
|
||||
]
|
||||
7
rustproxy/config.example.toml
Normal file
7
rustproxy/config.example.toml
Normal file
|
|
@ -0,0 +1,7 @@
|
|||
token = "put your token here"
|
||||
|
||||
database = "postgres://postgres:password@localhost/postgres"
|
||||
|
||||
redis_addr = "redis://127.0.0.1/"
|
||||
redis_gateway_queue_addr = "redis://127.0.0.1/"
|
||||
shard_count = 1
|
||||
35
rustproxy/pk_bot/Cargo.toml
Normal file
35
rustproxy/pk_bot/Cargo.toml
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
[package]
|
||||
name = "pk_bot"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.57"
|
||||
async-trait = "0.1.56"
|
||||
chrono = "0.4.19"
|
||||
config = { version = "0.13.1", default-features = false, features = ["toml"] }
|
||||
dashmap = "5.3.4"
|
||||
futures = "0.3.21"
|
||||
moka = { version = "0.8.5", features = ["future"] }
|
||||
once_cell = "1.12.0"
|
||||
redis = { version = "0.21.5", features = ["connection-manager", "tokio-comp"] }
|
||||
regex = "1.5.6"
|
||||
serde = "1.0.137"
|
||||
serde_json = "1.0.81"
|
||||
smallvec = "1.8.0"
|
||||
smol_str = "0.1.23"
|
||||
sqlx = { version = "0.6.0", features = ["runtime-tokio-native-tls", "postgres", "chrono", "macros"] }
|
||||
thiserror = "1.0.31"
|
||||
tokio = { version = "1.19.1", features = ["full"] }
|
||||
tracing = "0.1.34"
|
||||
tracing-subscriber = "0.3.11"
|
||||
twilight-gateway = "0.11.0"
|
||||
twilight-gateway-queue = "0.11.0"
|
||||
twilight-http = "0.11.0"
|
||||
twilight-model = "0.11.0"
|
||||
twilight-util = { version = "0.11.0", features = ["builder", "permission-calculator"] }
|
||||
|
||||
[target.'cfg(not(target_env = "msvc"))'.dependencies]
|
||||
tikv-jemallocator = "0.5"
|
||||
265
rustproxy/pk_bot/src/cache.rs
Normal file
265
rustproxy/pk_bot/src/cache.rs
Normal file
|
|
@ -0,0 +1,265 @@
|
|||
use std::sync::{Arc, RwLock};
|
||||
|
||||
use dashmap::mapref::one::Ref;
|
||||
use dashmap::DashMap;
|
||||
use smol_str::SmolStr;
|
||||
use twilight_model::channel::permission_overwrite::PermissionOverwrite;
|
||||
use twilight_model::channel::{Channel, ChannelType};
|
||||
use twilight_model::gateway::event::Event;
|
||||
use twilight_model::gateway::payload::incoming::ThreadListSync;
|
||||
use twilight_model::guild::{Guild, PartialMember, Permissions, PremiumTier, Role};
|
||||
use twilight_model::id::marker::{ChannelMarker, GuildMarker, RoleMarker, UserMarker};
|
||||
use twilight_model::id::Id;
|
||||
use twilight_util::permission_calculator::PermissionCalculator;
|
||||
|
||||
const DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
|
||||
.union(Permissions::SEND_MESSAGES)
|
||||
.union(Permissions::READ_MESSAGE_HISTORY)
|
||||
.union(Permissions::ADD_REACTIONS)
|
||||
.union(Permissions::ATTACH_FILES)
|
||||
.union(Permissions::EMBED_LINKS)
|
||||
.union(Permissions::USE_EXTERNAL_EMOJIS);
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedGuild {
|
||||
owner_id: u64,
|
||||
_premium_tier: PremiumTier,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedChannel {
|
||||
_name: Option<SmolStr>, // stores strings 22 characters or less inline, which is a large portion of channel names
|
||||
_parent_id: Option<Id<ChannelMarker>>,
|
||||
guild_id: Option<Id<GuildMarker>>,
|
||||
kind: ChannelType,
|
||||
overwrites: Vec<PermissionOverwrite>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedRole {
|
||||
permissions: Permissions,
|
||||
_mentionable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedBotMember {
|
||||
roles: Vec<Id<RoleMarker>>,
|
||||
}
|
||||
|
||||
pub struct DiscordCache {
|
||||
bot_user: Arc<RwLock<Option<Id<UserMarker>>>>,
|
||||
guilds: DashMap<u64, CachedGuild>,
|
||||
channels: DashMap<u64, CachedChannel>,
|
||||
roles: DashMap<u64, CachedRole>,
|
||||
bot_members: DashMap<u64, CachedBotMember>,
|
||||
}
|
||||
|
||||
impl DiscordCache {
|
||||
pub fn new() -> DiscordCache {
|
||||
DiscordCache {
|
||||
bot_user: Arc::new(RwLock::new(None)),
|
||||
guilds: DashMap::new(),
|
||||
channels: DashMap::new(),
|
||||
roles: DashMap::new(),
|
||||
bot_members: DashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn handle_event(&self, event: &Event) {
|
||||
match event {
|
||||
Event::Ready(ref r) => {
|
||||
let mut bot_user = self.bot_user.write().unwrap();
|
||||
*bot_user = Some(r.user.id);
|
||||
}
|
||||
Event::GuildCreate(ref g) => self.update_guild(g),
|
||||
Event::ChannelCreate(ref ch) => self.update_channel(ch),
|
||||
Event::ChannelUpdate(ref ch) => self.update_channel(ch),
|
||||
Event::ChannelDelete(ref ch) => self.delete_channel(ch.id),
|
||||
Event::ThreadCreate(ref ch) => self.update_channel(ch),
|
||||
Event::ThreadUpdate(ref ch) => self.update_channel(ch),
|
||||
Event::ThreadDelete(ref ch) => self.delete_channel(ch.id),
|
||||
Event::ThreadListSync(ref ts) => self.update_threads(ts),
|
||||
Event::RoleCreate(ref r) => self.update_role(&r.role),
|
||||
Event::RoleUpdate(ref r) => self.update_role(&r.role),
|
||||
Event::RoleDelete(ref r) => self.delete_role(r.role_id),
|
||||
Event::MemberUpdate(ref member) => {
|
||||
let current_user = self.bot_user_id();
|
||||
if Some(member.user.id) == current_user {
|
||||
self.update_bot_member(member.guild_id, &member.roles);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
fn update_guild(&self, guild: &Guild) {
|
||||
for channel in &guild.channels {
|
||||
self.update_channel(channel);
|
||||
}
|
||||
|
||||
for role in &guild.roles {
|
||||
self.update_role(role);
|
||||
}
|
||||
|
||||
let current_user = self.bot_user_id();
|
||||
for member in &guild.members {
|
||||
if Some(member.user.id) == current_user {
|
||||
self.update_bot_member(member.guild_id, &member.roles);
|
||||
}
|
||||
}
|
||||
|
||||
self.guilds.insert(
|
||||
guild.id.get(),
|
||||
CachedGuild {
|
||||
owner_id: guild.owner_id.get(),
|
||||
_premium_tier: guild.premium_tier,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn bot_user_id(&self) -> Option<Id<UserMarker>> {
|
||||
*self.bot_user.read().unwrap()
|
||||
}
|
||||
|
||||
fn update_channel(&self, channel: &Channel) {
|
||||
self.channels.insert(
|
||||
channel.id.get(),
|
||||
CachedChannel {
|
||||
_name: channel.name.as_deref().map(SmolStr::new),
|
||||
_parent_id: channel.parent_id,
|
||||
guild_id: channel.guild_id,
|
||||
kind: channel.kind,
|
||||
overwrites: channel.permission_overwrites.clone().unwrap_or_default(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn delete_channel(&self, id: Id<ChannelMarker>) {
|
||||
self.channels.remove(&id.get());
|
||||
}
|
||||
|
||||
fn update_role(&self, role: &Role) {
|
||||
self.roles.insert(
|
||||
role.id.get(),
|
||||
CachedRole {
|
||||
permissions: role.permissions,
|
||||
_mentionable: role.mentionable,
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
fn delete_role(&self, id: Id<RoleMarker>) {
|
||||
self.roles.remove(&id.get());
|
||||
}
|
||||
|
||||
fn update_threads(&self, evt: &ThreadListSync) {
|
||||
for thread in &evt.threads {
|
||||
self.update_channel(thread);
|
||||
}
|
||||
}
|
||||
|
||||
fn update_bot_member(&self, guild_id: Id<GuildMarker>, roles: &[Id<RoleMarker>]) {
|
||||
self.bot_members.insert(
|
||||
guild_id.get(),
|
||||
CachedBotMember {
|
||||
roles: roles.to_vec(),
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
pub fn get_guild(&self, guild_id: Id<GuildMarker>) -> anyhow::Result<Ref<u64, CachedGuild>> {
|
||||
self.guilds
|
||||
.get(&guild_id.get())
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find guild in cache: {}", guild_id))
|
||||
}
|
||||
|
||||
pub fn get_channel(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
) -> anyhow::Result<Ref<u64, CachedChannel>> {
|
||||
self.channels
|
||||
.get(&channel_id.get())
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find channel in cache: {}", channel_id))
|
||||
}
|
||||
|
||||
pub fn get_role(&self, role_id: Id<RoleMarker>) -> anyhow::Result<Ref<u64, CachedRole>> {
|
||||
self.roles
|
||||
.get(&role_id.get())
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find role in cache: {}", role_id))
|
||||
}
|
||||
|
||||
pub fn get_bot_member(
|
||||
&self,
|
||||
guild_id: Id<GuildMarker>,
|
||||
) -> anyhow::Result<Ref<u64, CachedBotMember>> {
|
||||
self.bot_members.get(&guild_id.get()).ok_or_else(|| {
|
||||
anyhow::anyhow!("could not find bot member in cache for guild: {}", guild_id)
|
||||
})
|
||||
}
|
||||
|
||||
fn calculate_permissions_in(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
roles: &[Id<RoleMarker>],
|
||||
) -> anyhow::Result<Permissions> {
|
||||
let channel = self.get_channel(channel_id)?;
|
||||
|
||||
if let Some(guild_id) = channel.guild_id {
|
||||
let guild = self.get_guild(guild_id)?;
|
||||
let everyone_role = self.get_role(guild_id.cast())?;
|
||||
|
||||
let mut member_roles = Vec::with_capacity(roles.len());
|
||||
for role_id in roles {
|
||||
let role = self.get_role(*role_id)?;
|
||||
member_roles.push((role_id.cast(), role.permissions));
|
||||
}
|
||||
|
||||
let calc = PermissionCalculator::new(
|
||||
guild_id,
|
||||
user_id,
|
||||
everyone_role.permissions,
|
||||
&member_roles,
|
||||
)
|
||||
.owner_id(Id::new(guild.owner_id));
|
||||
|
||||
Ok(calc.in_channel(channel.kind, &channel.overwrites))
|
||||
} else {
|
||||
Ok(DM_PERMISSIONS)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn member_permissions(
|
||||
&self,
|
||||
channel_id: Id<ChannelMarker>,
|
||||
user_id: Id<UserMarker>,
|
||||
member: Option<&PartialMember>,
|
||||
) -> anyhow::Result<Permissions> {
|
||||
if let Some(member) = member {
|
||||
self.calculate_permissions_in(channel_id, user_id, &member.roles)
|
||||
} else {
|
||||
// this should just be dm perms, probably?
|
||||
self.calculate_permissions_in(channel_id, user_id, &[])
|
||||
}
|
||||
}
|
||||
|
||||
pub fn bot_permissions(&self, channel_id: Id<ChannelMarker>) -> anyhow::Result<Permissions> {
|
||||
let channel = self.get_channel(channel_id)?;
|
||||
if let Some(guild_id) = channel.guild_id {
|
||||
let member = self.get_bot_member(guild_id)?;
|
||||
|
||||
let user_id = self
|
||||
.bot_user_id()
|
||||
.ok_or_else(|| anyhow::anyhow!("haven't received bot user id yet"))?;
|
||||
|
||||
self.calculate_permissions_in(channel_id, user_id, &member.roles)
|
||||
} else {
|
||||
Ok(DM_PERMISSIONS)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn channel_type(&self, channel_id: Id<ChannelMarker>) -> anyhow::Result<ChannelType> {
|
||||
let channel = self.get_channel(channel_id)?;
|
||||
Ok(channel.kind)
|
||||
}
|
||||
}
|
||||
22
rustproxy/pk_bot/src/config.rs
Normal file
22
rustproxy/pk_bot/src/config.rs
Normal file
|
|
@ -0,0 +1,22 @@
|
|||
use config::{Config, Environment, File, FileFormat};
|
||||
use serde::Deserialize;
|
||||
|
||||
#[derive(Deserialize, Debug)]
|
||||
pub struct BotConfig {
|
||||
pub token: String,
|
||||
|
||||
pub max_concurrency: Option<u64>,
|
||||
pub database: String,
|
||||
pub redis_addr: Option<String>,
|
||||
pub redis_gateway_queue_addr: Option<String>,
|
||||
pub shard_count: Option<u64>,
|
||||
}
|
||||
|
||||
// todo: should this be a once_cell::Lazy global const or something
|
||||
pub fn load_config() -> anyhow::Result<BotConfig> {
|
||||
let builder = Config::builder()
|
||||
.add_source(Environment::default())
|
||||
.add_source(File::new("config", FileFormat::Toml));
|
||||
|
||||
Ok(builder.build()?.try_deserialize()?)
|
||||
}
|
||||
174
rustproxy/pk_bot/src/db.rs
Normal file
174
rustproxy/pk_bot/src/db.rs
Normal file
|
|
@ -0,0 +1,174 @@
|
|||
use std::str::FromStr;
|
||||
|
||||
use crate::{
|
||||
config::BotConfig,
|
||||
model::{PKMember, PKMemberGuild, PKMessage, PKSystem, PKSystemGuild},
|
||||
};
|
||||
use chrono::{DateTime, Utc};
|
||||
use sqlx::{
|
||||
postgres::{PgConnectOptions, PgPoolOptions},
|
||||
ConnectOptions, FromRow, PgPool,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
#[derive(FromRow, Debug, Default)]
|
||||
pub struct MessageContext {
|
||||
// being defensive with these values - we need to be explicit with Option<T>
|
||||
// when the database might return null, and some of these don't have proper default values set
|
||||
// most of the Option<T>s can probably get removed with a few changes to the db function
|
||||
pub system_id: Option<i32>,
|
||||
pub is_deleting: Option<bool>,
|
||||
pub in_blacklist: Option<bool>,
|
||||
pub in_log_blacklist: Option<bool>,
|
||||
pub proxy_enabled: Option<bool>,
|
||||
pub last_switch: Option<i32>,
|
||||
pub last_switch_members: Option<Vec<i32>>,
|
||||
pub last_switch_timestamp: Option<DateTime<Utc>>,
|
||||
pub system_tag: Option<String>,
|
||||
pub system_guild_tag: Option<String>,
|
||||
pub tag_enabled: Option<bool>,
|
||||
pub system_avatar: Option<String>,
|
||||
pub allow_autoproxy: Option<bool>,
|
||||
pub latch_timeout: Option<i32>,
|
||||
}
|
||||
|
||||
pub async fn get_message_context(
|
||||
pool: &PgPool,
|
||||
account_id: i64,
|
||||
guild_id: i64,
|
||||
channel_id: i64,
|
||||
) -> anyhow::Result<MessageContext> {
|
||||
Ok(sqlx::query_as("select * from message_context($1, $2, $3)")
|
||||
.bind(account_id)
|
||||
.bind(guild_id)
|
||||
.bind(channel_id)
|
||||
.fetch_one(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct ProxyTagEntry {
|
||||
pub prefix: String,
|
||||
pub suffix: String,
|
||||
pub member_id: i32,
|
||||
}
|
||||
|
||||
impl From<(&str, &str, i32)> for ProxyTagEntry {
|
||||
fn from((prefix, suffix, member_id): (&str, &str, i32)) -> Self {
|
||||
ProxyTagEntry {
|
||||
prefix: prefix.to_string(),
|
||||
suffix: suffix.to_string(),
|
||||
member_id,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_proxy_tags(pool: &PgPool, system_id: i32) -> anyhow::Result<Vec<ProxyTagEntry>> {
|
||||
Ok(sqlx::query_as("select coalesce((i.tags).prefix, '') as prefix, coalesce((i.tags).suffix, '') as suffix, member_id from (select unnest(proxy_tags) as tags, id as member_id from members where system = $1) as i;")
|
||||
.bind(system_id)
|
||||
.fetch_all(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
#[repr(i32)]
|
||||
#[derive(sqlx::Type, Debug, Copy, Clone)]
|
||||
pub enum AutoproxyMode {
|
||||
Off = 1,
|
||||
Front = 2,
|
||||
Latch = 3,
|
||||
Member = 4,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct AutoproxyState {
|
||||
pub autoproxy_mode: AutoproxyMode,
|
||||
pub autoproxy_member: Option<i32>,
|
||||
pub last_latch_timestamp: Option<DateTime<Utc>>,
|
||||
}
|
||||
|
||||
pub async fn get_autoproxy_state(
|
||||
pool: &PgPool,
|
||||
system_id: i32,
|
||||
guild_id: i64,
|
||||
channel_id: i64,
|
||||
) -> anyhow::Result<Option<AutoproxyState>> {
|
||||
Ok(sqlx::query_as(
|
||||
"select * from autoproxy where system = $1 and guild_id = $2 and channel_id = $3;",
|
||||
)
|
||||
.bind(system_id)
|
||||
.bind(guild_id)
|
||||
.bind(channel_id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_system_by_id(pool: &PgPool, system_id: i32) -> anyhow::Result<Option<PKSystem>> {
|
||||
Ok(sqlx::query_as("select * from systems where id = $1")
|
||||
.bind(system_id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_member_by_id(pool: &PgPool, member_id: i32) -> anyhow::Result<Option<PKMember>> {
|
||||
Ok(sqlx::query_as("select * from members where id = $1")
|
||||
.bind(member_id)
|
||||
.fetch_optional(pool)
|
||||
.await?)
|
||||
}
|
||||
|
||||
pub async fn get_system_guild(
|
||||
pool: &PgPool,
|
||||
system_id: i32,
|
||||
guild_id: i64,
|
||||
) -> anyhow::Result<Option<PKSystemGuild>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from system_guild where system = $1 and guild = $2")
|
||||
.bind(system_id)
|
||||
.bind(guild_id)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn get_member_guild(
|
||||
pool: &PgPool,
|
||||
member_id: i32,
|
||||
guild_id: i64,
|
||||
) -> anyhow::Result<Option<PKMemberGuild>> {
|
||||
Ok(
|
||||
sqlx::query_as("select * from member_guild where member = $1 and guild = $2")
|
||||
.bind(member_id)
|
||||
.bind(guild_id)
|
||||
.fetch_optional(pool)
|
||||
.await?,
|
||||
)
|
||||
}
|
||||
|
||||
pub async fn insert_message(pool: &PgPool, message: PKMessage) -> anyhow::Result<()> {
|
||||
sqlx::query("insert into messages (mid, guild, channel, member, sender, original_mid) values ($1, $2, $3, $4, $5, $6)")
|
||||
.bind(message.mid)
|
||||
.bind(message.guild)
|
||||
.bind(message.channel)
|
||||
.bind(message.member_id)
|
||||
.bind(message.sender)
|
||||
.bind(message.original_mid)
|
||||
.execute(pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn init_db(config: &BotConfig) -> anyhow::Result<PgPool> {
|
||||
info!("connecting to database");
|
||||
let options = PgConnectOptions::from_str(&config.database)
|
||||
.unwrap()
|
||||
.disable_statement_logging()
|
||||
.clone();
|
||||
|
||||
let pool = PgPoolOptions::new()
|
||||
.max_connections(32)
|
||||
.connect_with(options)
|
||||
.await?;
|
||||
|
||||
Ok(pool)
|
||||
}
|
||||
63
rustproxy/pk_bot/src/gateway.rs
Normal file
63
rustproxy/pk_bot/src/gateway.rs
Normal file
|
|
@ -0,0 +1,63 @@
|
|||
use crate::config::BotConfig;
|
||||
use crate::redis;
|
||||
use std::{env, sync::Arc};
|
||||
use tracing::info;
|
||||
use twilight_gateway::{
|
||||
cluster::{Events, ShardScheme},
|
||||
Cluster, EventTypeFlags, Intents,
|
||||
};
|
||||
use twilight_http::Client;
|
||||
|
||||
pub async fn init_gateway(
|
||||
http: Arc<Client>,
|
||||
config: &BotConfig,
|
||||
) -> anyhow::Result<(Arc<Cluster>, Events)> {
|
||||
let mut builder = Cluster::builder(
|
||||
config.token.clone(),
|
||||
Intents::GUILDS
|
||||
| Intents::DIRECT_MESSAGES
|
||||
| Intents::GUILD_MESSAGES
|
||||
| Intents::MESSAGE_CONTENT,
|
||||
);
|
||||
builder = builder.http_client(http);
|
||||
builder = builder.event_types(EventTypeFlags::all());
|
||||
|
||||
if let Some(scheme) = get_shard_scheme(config)? {
|
||||
info!("using shard scheme: {:?}", scheme);
|
||||
builder = builder.shard_scheme(scheme);
|
||||
}
|
||||
|
||||
if let Some(queue) = redis::init_gateway_queue(config).await? {
|
||||
info!("using redis gateway queue");
|
||||
builder = builder.queue(Arc::new(queue));
|
||||
}
|
||||
|
||||
let (cluster, events) = builder.build().await?;
|
||||
let cluster = Arc::new(cluster);
|
||||
let cluster_spawn = Arc::clone(&cluster);
|
||||
tokio::spawn(async move {
|
||||
info!("starting shards...");
|
||||
cluster_spawn.up().await;
|
||||
});
|
||||
|
||||
Ok((cluster, events))
|
||||
}
|
||||
|
||||
fn get_cluster_id() -> anyhow::Result<u64> {
|
||||
Ok(env::var("NOMAD_ALLOC_INDEX")
|
||||
.unwrap_or_else(|_| "0".to_string())
|
||||
.parse::<u64>()?)
|
||||
}
|
||||
|
||||
fn get_shard_scheme(config: &BotConfig) -> anyhow::Result<Option<ShardScheme>> {
|
||||
let shard_count = config.shard_count.unwrap_or(1);
|
||||
let scheme = if shard_count >= 16 {
|
||||
let cluster_id = get_cluster_id()?;
|
||||
let first_shard_id = 16 * cluster_id;
|
||||
let shard_range = first_shard_id..first_shard_id + 16;
|
||||
Some(ShardScheme::try_from((shard_range, shard_count))?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(scheme)
|
||||
}
|
||||
109
rustproxy/pk_bot/src/main.rs
Normal file
109
rustproxy/pk_bot/src/main.rs
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
#[cfg(not(target_env = "msvc"))]
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[cfg(not(target_env = "msvc"))]
|
||||
#[global_allocator]
|
||||
static GLOBAL: Jemalloc = Jemalloc;
|
||||
|
||||
use crate::cache::DiscordCache;
|
||||
use crate::redis::RedisEventProxy;
|
||||
use futures::StreamExt;
|
||||
use sqlx::PgPool;
|
||||
use std::sync::Arc;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::Event;
|
||||
use twilight_http::Client as HttpClient;
|
||||
|
||||
mod cache;
|
||||
mod config;
|
||||
mod db;
|
||||
mod gateway;
|
||||
mod model;
|
||||
mod proxy;
|
||||
mod redis;
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
tracing_subscriber::fmt().init();
|
||||
|
||||
let config = config::load_config()?;
|
||||
info!("loaded config: {:?}", config);
|
||||
|
||||
let pool = db::init_db(&config).await?;
|
||||
let http = Arc::new(HttpClient::new(config.token.clone()));
|
||||
let (_cluster, mut events) = gateway::init_gateway(Arc::clone(&http), &config).await?;
|
||||
let cache = Arc::new(DiscordCache::new());
|
||||
let redis = redis::init_event_proxy(&config).await?;
|
||||
|
||||
while let Some((shard_id, event)) = events.next().await {
|
||||
let http = Arc::clone(&http);
|
||||
let pool = pool.clone();
|
||||
let cache = Arc::clone(&cache);
|
||||
let redis = redis.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
cache.handle_event(&event);
|
||||
|
||||
let res = handle_event(shard_id, event, http, pool, cache, redis).await;
|
||||
if let Err(e) = res {
|
||||
error!("error handling event: {:?}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn handle_event(
|
||||
shard_id: u64,
|
||||
event: Event,
|
||||
http: Arc<HttpClient>,
|
||||
pool: PgPool,
|
||||
cache: Arc<DiscordCache>,
|
||||
mut redis: RedisEventProxy,
|
||||
) -> anyhow::Result<()> {
|
||||
match event {
|
||||
Event::MessageCreate(msg) => {
|
||||
if msg.content.starts_with("pk;") || msg.content.starts_with("pk!") {
|
||||
redis
|
||||
.send_event_parsed(shard_id, Event::MessageCreate(msg))
|
||||
.await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let channel_type = cache.channel_type(msg.channel_id)?;
|
||||
|
||||
let ctx = db::get_message_context(
|
||||
&pool,
|
||||
msg.author.id.get() as i64,
|
||||
msg.guild_id.map(|x| x.get()).unwrap_or_default() as i64,
|
||||
msg.channel_id.get() as i64,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let _member_permissions =
|
||||
cache.member_permissions(msg.channel_id, msg.author.id, msg.member.as_ref())?;
|
||||
let bot_permissions = cache.bot_permissions(msg.channel_id)?;
|
||||
|
||||
match proxy::check_preconditions(&msg, channel_type, bot_permissions, &ctx) {
|
||||
Ok(_) => {
|
||||
info!("attempting to proxy");
|
||||
proxy::do_proxy(&http, &pool, &msg, &ctx).await?;
|
||||
}
|
||||
Err(reason) => {
|
||||
info!("skipping proxy because: {}", reason)
|
||||
}
|
||||
}
|
||||
}
|
||||
Event::ShardConnected(_) => {
|
||||
info!("connected on shard {}", shard_id);
|
||||
}
|
||||
Event::ShardPayload(payload) => {
|
||||
redis.send_event_raw(shard_id, &payload.bytes).await?;
|
||||
}
|
||||
// Other events here...
|
||||
_ => {}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
49
rustproxy/pk_bot/src/model.rs
Normal file
49
rustproxy/pk_bot/src/model.rs
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
use sqlx::FromRow;
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKSystem {
|
||||
pub id: i32,
|
||||
pub name: Option<String>,
|
||||
pub tag: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKMember {
|
||||
pub id: i32,
|
||||
pub system: i32,
|
||||
pub name: String,
|
||||
|
||||
pub color: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
pub display_name: Option<String>,
|
||||
pub pronouns: Option<String>,
|
||||
pub description: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKMessage {
|
||||
pub mid: i64,
|
||||
pub guild: Option<i64>,
|
||||
pub channel: i64,
|
||||
pub member_id: i32,
|
||||
pub sender: i64,
|
||||
pub original_mid: Option<i64>,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKSystemGuild {
|
||||
pub system: i32,
|
||||
pub guild: i64,
|
||||
pub proxy_enabled: bool,
|
||||
pub tag: Option<String>,
|
||||
pub tag_enabled: bool,
|
||||
}
|
||||
|
||||
#[derive(FromRow, Debug, Clone)]
|
||||
pub struct PKMemberGuild {
|
||||
pub member: i32,
|
||||
pub guild: i64,
|
||||
pub display_name: Option<String>,
|
||||
pub avatar_url: Option<String>,
|
||||
}
|
||||
31
rustproxy/pk_bot/src/proxy/autoproxy.rs
Normal file
31
rustproxy/pk_bot/src/proxy/autoproxy.rs
Normal file
|
|
@ -0,0 +1,31 @@
|
|||
use crate::db::{AutoproxyMode, AutoproxyState, MessageContext};
|
||||
|
||||
pub fn resolve_autoproxy_member(
|
||||
ctx: &MessageContext,
|
||||
state: &AutoproxyState,
|
||||
content: &str,
|
||||
) -> Option<i32> {
|
||||
if !ctx.allow_autoproxy.unwrap_or(true) {
|
||||
return None;
|
||||
}
|
||||
|
||||
if is_escape(content) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let first_fronter = ctx.last_switch_members.iter().flatten().cloned().next();
|
||||
match (state.autoproxy_mode, state.autoproxy_member, first_fronter) {
|
||||
(AutoproxyMode::Latch, Some(m), _) => Some(m),
|
||||
(AutoproxyMode::Member, Some(m), _) => Some(m),
|
||||
(AutoproxyMode::Front, _, Some(f)) => Some(f),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn _is_unlatch(content: &str) -> bool {
|
||||
content.starts_with("\\\\") || content.starts_with("\\\u{200b}\\")
|
||||
}
|
||||
|
||||
fn is_escape(content: &str) -> bool {
|
||||
content.starts_with('\\')
|
||||
}
|
||||
227
rustproxy/pk_bot/src/proxy/mod.rs
Normal file
227
rustproxy/pk_bot/src/proxy/mod.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
use self::{post_proxy::ProxyResult, webhook::WebhookExecuteRequest};
|
||||
use crate::db::{self, MessageContext};
|
||||
use sqlx::PgPool;
|
||||
use thiserror::Error;
|
||||
use twilight_http::Client;
|
||||
use twilight_model::{
|
||||
channel::{message::MessageType, ChannelType, Message},
|
||||
guild::Permissions,
|
||||
user::User,
|
||||
};
|
||||
|
||||
mod autoproxy;
|
||||
mod post_proxy;
|
||||
mod profile;
|
||||
mod reply;
|
||||
mod tags;
|
||||
mod webhook;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
// use this for pk;proxycheck or something
|
||||
pub enum PreconditionFailure {
|
||||
#[error("bot is missing permissions (has: {has:?}, needs: {needs:?})")]
|
||||
BotMissingPermission {
|
||||
has: Permissions,
|
||||
needs: Permissions,
|
||||
},
|
||||
|
||||
#[error("invalid channel type {0:?}")]
|
||||
InvalidChannelType(ChannelType),
|
||||
|
||||
#[error("invalid message type {0:?}")]
|
||||
InvalidMessageType(MessageType),
|
||||
|
||||
#[error("user is bot")]
|
||||
UserIsBot,
|
||||
|
||||
#[error("user is webhook")]
|
||||
UserIsWebhook,
|
||||
|
||||
#[error("user is discord system")]
|
||||
UserIsDiscordSystem,
|
||||
|
||||
#[error("user has no system")]
|
||||
UserHasNoSystem,
|
||||
|
||||
#[error("proxy disabled for system")]
|
||||
ProxyDisabledForSystem,
|
||||
|
||||
#[error("proxy disabled in channel")]
|
||||
ProxyDisabledInChannel,
|
||||
|
||||
#[error("message contains activity")]
|
||||
MessageContainsActivity,
|
||||
|
||||
#[error("message contains sticker")]
|
||||
MessageContainsSticker,
|
||||
|
||||
#[error("message is empty and has no attachments")]
|
||||
MessageIsEmpty,
|
||||
}
|
||||
|
||||
// todo: the parameters here are nasty, refactor+put this code somewhere else maybe
|
||||
pub fn check_preconditions(
|
||||
msg: &Message,
|
||||
channel_type: ChannelType,
|
||||
bot_permissions: Permissions,
|
||||
ctx: &MessageContext,
|
||||
) -> Result<(), PreconditionFailure> {
|
||||
let required_permissions =
|
||||
Permissions::SEND_MESSAGES | Permissions::MANAGE_WEBHOOKS | Permissions::MANAGE_MESSAGES;
|
||||
if !bot_permissions.contains(required_permissions) {
|
||||
return Err(PreconditionFailure::BotMissingPermission {
|
||||
has: bot_permissions,
|
||||
needs: required_permissions,
|
||||
});
|
||||
}
|
||||
|
||||
match channel_type {
|
||||
ChannelType::GuildText
|
||||
| ChannelType::GuildNews
|
||||
| ChannelType::GuildPrivateThread
|
||||
| ChannelType::GuildPublicThread
|
||||
| ChannelType::GuildNewsThread => Ok(()),
|
||||
wrong_type => Err(PreconditionFailure::InvalidChannelType(wrong_type)),
|
||||
}?;
|
||||
|
||||
match msg.kind {
|
||||
MessageType::Regular | MessageType::Reply => Ok(()),
|
||||
wrong_type => Err(PreconditionFailure::InvalidMessageType(wrong_type)),
|
||||
}?;
|
||||
|
||||
match msg {
|
||||
Message {
|
||||
author: User {
|
||||
system: Some(true), ..
|
||||
},
|
||||
..
|
||||
} => Err(PreconditionFailure::UserIsDiscordSystem),
|
||||
Message {
|
||||
author: User { bot: true, .. },
|
||||
..
|
||||
} => Err(PreconditionFailure::UserIsBot),
|
||||
Message {
|
||||
webhook_id: Some(_),
|
||||
..
|
||||
} => Err(PreconditionFailure::UserIsWebhook),
|
||||
Message {
|
||||
activity: Some(_), ..
|
||||
} => Err(PreconditionFailure::MessageContainsActivity),
|
||||
Message {
|
||||
sticker_items: s, ..
|
||||
} if !s.is_empty() => Err(PreconditionFailure::MessageContainsSticker),
|
||||
Message {
|
||||
content: c,
|
||||
attachments: a,
|
||||
..
|
||||
} if c.trim().is_empty() && a.is_empty() => Err(PreconditionFailure::MessageIsEmpty),
|
||||
_ => Ok(()),
|
||||
}?;
|
||||
|
||||
match ctx {
|
||||
MessageContext {
|
||||
system_id: None, ..
|
||||
} => Err(PreconditionFailure::UserHasNoSystem),
|
||||
MessageContext {
|
||||
in_blacklist: Some(true),
|
||||
..
|
||||
} => Err(PreconditionFailure::ProxyDisabledInChannel),
|
||||
MessageContext {
|
||||
proxy_enabled: Some(false),
|
||||
..
|
||||
} => Err(PreconditionFailure::ProxyDisabledForSystem),
|
||||
_ => Ok(()),
|
||||
}?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
struct ProxyMatchResult {
|
||||
member_id: i32,
|
||||
inner_content: String,
|
||||
_tags: Option<(String, String)>, // todo: need this for keepproxy
|
||||
}
|
||||
|
||||
async fn match_tags_or_autoproxy(
|
||||
pool: &PgPool,
|
||||
msg: &Message,
|
||||
ctx: &MessageContext,
|
||||
) -> anyhow::Result<Option<ProxyMatchResult>> {
|
||||
let guild_id = msg.guild_id.ok_or_else(|| anyhow::anyhow!("no guild id"))?;
|
||||
let system_id = ctx.system_id.ok_or_else(|| anyhow::anyhow!("no system"))?;
|
||||
|
||||
let tags = db::get_proxy_tags(pool, system_id).await?;
|
||||
let ap_state = db::get_autoproxy_state(
|
||||
pool,
|
||||
system_id,
|
||||
guild_id.get() as i64,
|
||||
0, // all autoproxy has channel id 0? o.o
|
||||
)
|
||||
.await?;
|
||||
|
||||
let tag_match = tags::match_proxy_tags(&tags, &msg.content);
|
||||
if let Some(tag_match) = tag_match {
|
||||
return Ok(Some(ProxyMatchResult {
|
||||
member_id: tag_match.member_id,
|
||||
inner_content: tag_match.inner_content,
|
||||
_tags: Some(tag_match.tags),
|
||||
}));
|
||||
}
|
||||
|
||||
if let Some(ap_state) = ap_state {
|
||||
let res = autoproxy::resolve_autoproxy_member(ctx, &ap_state, &msg.content);
|
||||
if let Some(member_id) = res {
|
||||
return Ok(Some(ProxyMatchResult {
|
||||
inner_content: msg.content.clone(),
|
||||
member_id,
|
||||
_tags: None,
|
||||
}));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
// todo: this shouldn't depend on a Message object (for reproxy/proxy command/etc)
|
||||
pub async fn do_proxy(
|
||||
http: &Client,
|
||||
pool: &PgPool,
|
||||
msg: &Message,
|
||||
ctx: &MessageContext,
|
||||
) -> anyhow::Result<()> {
|
||||
let guild_id = msg.guild_id.ok_or_else(|| anyhow::anyhow!("no guild id"))?;
|
||||
let system_id = ctx.system_id.ok_or_else(|| anyhow::anyhow!("no system"))?;
|
||||
|
||||
// todo: unlatch check/exec should probably go in here somewhere
|
||||
|
||||
let proxy_match = match_tags_or_autoproxy(pool, msg, ctx).await?;
|
||||
if let Some(result) = proxy_match {
|
||||
let profile =
|
||||
profile::fetch_proxy_profile(pool, guild_id.get(), system_id, result.member_id).await?;
|
||||
|
||||
let webhook_req = WebhookExecuteRequest {
|
||||
channel_id: msg.channel_id.get(),
|
||||
avatar_url: profile.avatar_url().map(|s| s.to_string()),
|
||||
content: Some(result.inner_content.clone()),
|
||||
username: profile.formatted_name(),
|
||||
embed: msg
|
||||
.referenced_message
|
||||
.as_deref()
|
||||
.map(|msg| reply::create_reply_embed(guild_id, msg))
|
||||
.transpose()?,
|
||||
};
|
||||
|
||||
let webhook_res = webhook::execute_webhook(http, &webhook_req).await?;
|
||||
|
||||
let proxy_res = ProxyResult {
|
||||
channel_id: msg.channel_id.get(),
|
||||
guild_id: guild_id.get(),
|
||||
member_id: result.member_id,
|
||||
original_message_id: msg.id.get(),
|
||||
proxy_message_id: webhook_res.message_id,
|
||||
sender: msg.author.id.get(),
|
||||
};
|
||||
post_proxy::handle_post_proxy(http, pool, &proxy_res).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
66
rustproxy/pk_bot/src/proxy/post_proxy.rs
Normal file
66
rustproxy/pk_bot/src/proxy/post_proxy.rs
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
use crate::db;
|
||||
use crate::model::PKMessage;
|
||||
use futures::TryFutureExt;
|
||||
use sqlx::PgPool;
|
||||
use tracing::error;
|
||||
use twilight_http::Client;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ProxyResult {
|
||||
pub guild_id: u64,
|
||||
pub channel_id: u64,
|
||||
pub proxy_message_id: u64,
|
||||
pub original_message_id: u64,
|
||||
pub sender: u64,
|
||||
pub member_id: i32,
|
||||
}
|
||||
|
||||
pub async fn handle_post_proxy(
|
||||
http: &Client,
|
||||
pool: &PgPool,
|
||||
res: &ProxyResult,
|
||||
) -> anyhow::Result<()> {
|
||||
// todo: log channel
|
||||
|
||||
let _ = futures::join!(
|
||||
delete_original_message(http, res)
|
||||
.inspect_err(|e| error!("error deleting original message: {}", e)),
|
||||
insert_message_in_db(pool, res)
|
||||
.inspect_err(|e| error!("error deleting original message: {}", e))
|
||||
);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn delete_original_message(http: &Client, res: &ProxyResult) -> anyhow::Result<()> {
|
||||
// todo: sleep some amount
|
||||
// (do we still need to do that or did discord fix that client bug?)
|
||||
http.delete_message(
|
||||
res.channel_id.try_into().unwrap(),
|
||||
res.original_message_id.try_into().unwrap(),
|
||||
)
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn insert_message_in_db(pool: &PgPool, res: &ProxyResult) -> anyhow::Result<()> {
|
||||
db::insert_message(
|
||||
pool,
|
||||
PKMessage {
|
||||
mid: res.proxy_message_id as i64,
|
||||
original_mid: Some(res.original_message_id as i64),
|
||||
sender: res.sender as i64,
|
||||
guild: Some(res.guild_id as i64),
|
||||
channel: res.channel_id as i64,
|
||||
member_id: res.member_id,
|
||||
},
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {}
|
||||
90
rustproxy/pk_bot/src/proxy/profile.rs
Normal file
90
rustproxy/pk_bot/src/proxy/profile.rs
Normal file
|
|
@ -0,0 +1,90 @@
|
|||
use futures::try_join;
|
||||
use sqlx::PgPool;
|
||||
|
||||
use crate::{
|
||||
db,
|
||||
model::{PKMember, PKMemberGuild, PKSystem, PKSystemGuild},
|
||||
};
|
||||
|
||||
// inventing the term "proxy profile" here to describe the info needed to work out the webhook name+avatar
|
||||
// arbitrary choice to put the source models in the struct and logic in methods, could just as well have had a function to do the math and put the results in a struct
|
||||
pub struct ProxyProfile {
|
||||
system: PKSystem,
|
||||
member: PKMember,
|
||||
system_guild: Option<PKSystemGuild>,
|
||||
member_guild: Option<PKMemberGuild>,
|
||||
}
|
||||
|
||||
impl ProxyProfile {
|
||||
pub fn name(&self) -> &str {
|
||||
let member_name = &self.member.name;
|
||||
let display_name = self.member.display_name.as_deref();
|
||||
let server_name = self
|
||||
.member_guild
|
||||
.as_ref()
|
||||
.and_then(|x| x.display_name.as_deref());
|
||||
server_name.or(display_name).unwrap_or(member_name)
|
||||
}
|
||||
|
||||
pub fn avatar_url(&self) -> Option<&str> {
|
||||
let system_avatar = self.system.avatar_url.as_deref();
|
||||
let member_avatar = self.member.avatar_url.as_deref();
|
||||
let server_avatar = self
|
||||
.member_guild
|
||||
.as_ref()
|
||||
.and_then(|x| x.avatar_url.as_deref());
|
||||
server_avatar.or(member_avatar).or(system_avatar)
|
||||
}
|
||||
|
||||
pub fn tag(&self) -> Option<&str> {
|
||||
let server_tag = self.system_guild.as_ref().and_then(|x| x.tag.as_deref());
|
||||
let system_tag = self.system.tag.as_deref();
|
||||
server_tag.or(system_tag)
|
||||
}
|
||||
|
||||
pub fn formatted_name(&self) -> String {
|
||||
let mut name = if let Some(tag) = self.tag() {
|
||||
format!("{} {}", self.name(), tag)
|
||||
} else {
|
||||
self.name().to_string()
|
||||
};
|
||||
|
||||
if name.len() == 1 {
|
||||
name.push('\u{17b5}');
|
||||
}
|
||||
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_proxy_profile(
|
||||
pool: &PgPool,
|
||||
guild_id: u64,
|
||||
system_id: i32,
|
||||
member_id: i32,
|
||||
) -> anyhow::Result<ProxyProfile> {
|
||||
// todo: this should be a db view with joins
|
||||
// this is all the info that proxy_members returned, so a single-member version of that could work nicely
|
||||
let system = db::get_system_by_id(pool, system_id);
|
||||
let member = db::get_member_by_id(pool, member_id);
|
||||
let system_guild = db::get_system_guild(pool, system_id, guild_id as i64);
|
||||
let member_guild = db::get_member_guild(pool, member_id, guild_id as i64);
|
||||
|
||||
let (system, member, system_guild, member_guild) =
|
||||
try_join!(system, member, system_guild, member_guild)?;
|
||||
|
||||
let system = system.ok_or_else(|| anyhow::anyhow!("could not find system"))?;
|
||||
let member = member.ok_or_else(|| anyhow::anyhow!("could not find member"))?;
|
||||
|
||||
Ok(ProxyProfile {
|
||||
system,
|
||||
member,
|
||||
system_guild,
|
||||
member_guild,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
// todo: this code is gonna be easy to unit test so we should do that
|
||||
}
|
||||
57
rustproxy/pk_bot/src/proxy/reply.rs
Normal file
57
rustproxy/pk_bot/src/proxy/reply.rs
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
use twilight_http::Client;
|
||||
use twilight_model::{
|
||||
channel::{embed::Embed, Message},
|
||||
id::{
|
||||
marker::{GuildMarker, UserMarker},
|
||||
Id,
|
||||
},
|
||||
util::ImageHash,
|
||||
};
|
||||
use twilight_util::builder::embed::{EmbedAuthorBuilder, EmbedBuilder, ImageSource};
|
||||
|
||||
async fn _fetch_additional_reply_info(_http: &Client) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn create_reply_embed(
|
||||
guild_id: Id<GuildMarker>,
|
||||
replied_to: &Message,
|
||||
) -> anyhow::Result<Embed> {
|
||||
// todo: guild avatars, guild nicknames
|
||||
// probably put this in fetch_additional_reply_info
|
||||
|
||||
let author = {
|
||||
let icon = replied_to
|
||||
.author
|
||||
.avatar
|
||||
.map(|hash| get_avatar_url(replied_to.author.id, hash))
|
||||
.and_then(|url| ImageSource::url(url).ok());
|
||||
|
||||
let mut builder = EmbedAuthorBuilder::new(replied_to.author.name.clone());
|
||||
if let Some(icon) = icon {
|
||||
builder = builder.icon_url(icon);
|
||||
};
|
||||
builder.build()
|
||||
};
|
||||
|
||||
let content = {
|
||||
let jump_link = format!(
|
||||
"https://discord.com/channels/{}/{}/{}",
|
||||
guild_id, replied_to.channel_id, replied_to.id
|
||||
);
|
||||
|
||||
let content = format!("**[Reply to:]({})** ", jump_link);
|
||||
// todo: properly add truncated content (including handling links/spoilers/etc)
|
||||
content
|
||||
};
|
||||
|
||||
let builder = EmbedBuilder::new().description(content).author(author);
|
||||
Ok(builder.build())
|
||||
}
|
||||
|
||||
fn get_avatar_url(user_id: Id<UserMarker>, hash: ImageHash) -> String {
|
||||
format!(
|
||||
"https://cdn.discordapp.com/avatars/{}/{}.png",
|
||||
user_id, hash
|
||||
)
|
||||
}
|
||||
93
rustproxy/pk_bot/src/proxy/tags.rs
Normal file
93
rustproxy/pk_bot/src/proxy/tags.rs
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
use crate::db::ProxyTagEntry;
|
||||
use tracing::info;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ProxyTagMatch {
|
||||
pub inner_content: String,
|
||||
pub tags: (String, String),
|
||||
pub member_id: i32,
|
||||
}
|
||||
|
||||
pub fn match_proxy_tags(tags: &[ProxyTagEntry], content: &str) -> Option<ProxyTagMatch> {
|
||||
let content = content.trim();
|
||||
|
||||
let mut sorted_entries = tags.to_vec();
|
||||
sorted_entries.sort_by_key(|x| -((x.prefix.len() + x.suffix.len()) as i32));
|
||||
|
||||
for entry in sorted_entries {
|
||||
let is_tag_match = content.starts_with(&entry.prefix) && content.ends_with(&entry.suffix);
|
||||
info!(
|
||||
"prefix: {}, suffix: {}, content: {}, is_match: {}",
|
||||
entry.prefix, entry.suffix, content, is_tag_match
|
||||
);
|
||||
|
||||
// todo: extract leading mentions
|
||||
// todo: allow empty matches only if we're proxying an attachment
|
||||
// todo: properly handle <>s etc, there's some regex stuff in there i don't entirely understand
|
||||
// todo: there's some weird edge cases with various unicode control characters and emoji joiners and whatever, should figure that out + unit test it
|
||||
if is_tag_match {
|
||||
let inner_content =
|
||||
&content[entry.prefix.len()..(content.len() - entry.suffix.len())].trim();
|
||||
return Some(ProxyTagMatch {
|
||||
inner_content: inner_content.to_string(),
|
||||
tags: (entry.prefix, entry.suffix),
|
||||
member_id: entry.member_id,
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn basic_match() {
|
||||
let tags = vec![
|
||||
("[", "]", 0).into(),
|
||||
("[[", "]]", 1).into(),
|
||||
("P:", "", 2).into(),
|
||||
("", "-P", 3).into(),
|
||||
("+ ", "", 4).into(),
|
||||
];
|
||||
assert_no_match(&tags, "hello world");
|
||||
assert_match(&tags, "[hello world]", "hello world", 0);
|
||||
assert_match(&tags, "[ hello world ]", "hello world", 0);
|
||||
assert_match(&tags, " [ hello world ] ", "hello world", 0);
|
||||
assert_match(&tags, "[\nhello\n]", "hello", 0);
|
||||
assert_match(&tags, "[\nhello\nworld\n]", "hello\nworld", 0);
|
||||
|
||||
assert_match(&tags, "[[text]]", "text", 1);
|
||||
assert_match(&tags, "[text]]", "text]", 0);
|
||||
assert_match(&tags, "[[[text]]]", "[text]", 1);
|
||||
|
||||
assert_match(&tags, "P:text", "text", 2);
|
||||
assert_match(&tags, "text -P", "text", 3);
|
||||
|
||||
assert_match(&tags, "+ hello", "hello", 4);
|
||||
assert_no_match(&tags, "+hello"); // (prefix contains trailing space)
|
||||
|
||||
assert_match(&tags, "[]", "", 0);
|
||||
|
||||
// edge case: the c# implementation currently does what the commented out test does
|
||||
// *if* the message doesn't have an attachment. not sure if we should mirror this here.
|
||||
// assert_match(&tags, "[[]]", "[]", 0);
|
||||
assert_match(&tags, "[[]]", "", 1);
|
||||
}
|
||||
|
||||
fn assert_match(tags: &[ProxyTagEntry], message: &str, inner: &str, member: i32) {
|
||||
let res = match_proxy_tags(tags, message);
|
||||
assert_eq!(
|
||||
res.as_ref()
|
||||
.map(|x| (x.inner_content.as_str(), x.member_id)),
|
||||
Some((inner, member))
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_no_match(tags: &[ProxyTagEntry], message: &str) {
|
||||
let res = match_proxy_tags(tags, message);
|
||||
assert!(res.is_none());
|
||||
}
|
||||
}
|
||||
143
rustproxy/pk_bot/src/proxy/webhook.rs
Normal file
143
rustproxy/pk_bot/src/proxy/webhook.rs
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
use moka::future::Cache;
|
||||
use once_cell::sync::Lazy;
|
||||
use tracing::info;
|
||||
use twilight_http::Client;
|
||||
use twilight_model::channel::{embed::Embed, Webhook, WebhookType};
|
||||
|
||||
// space for 1 million is probably way overkill, this is a LRU cache so it's okay to evict occasionally
|
||||
static WEBHOOK_CACHE: Lazy<Cache<u64, CachedWebhook>> = Lazy::new(|| Cache::new(1024 * 1024));
|
||||
const WEBHOOK_NAME: &str = "PluralKit Proxy Webhook";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CachedWebhook {
|
||||
pub id: u64,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
pub async fn get_webhook_cached(http: &Client, channel_id: u64) -> anyhow::Result<CachedWebhook> {
|
||||
let res = WEBHOOK_CACHE
|
||||
.try_get_with(channel_id, fetch_or_create_pk_webhook(http, channel_id))
|
||||
.await;
|
||||
|
||||
// todo: what happens if fetch_or_create_pk_webhook errors? i think moka handles it properly and just retries
|
||||
// but i'm not entiiiirely sure
|
||||
// https://docs.rs/moka/0.8.5/moka/future/struct.Cache.html#method.try_get_with
|
||||
|
||||
// error is Arc<Error> here and it's hard to convert that into an owned ref so we just make a new error lmao
|
||||
res.map_err(|_e| anyhow::anyhow!(
|
||||
"could not fetch webhook: {}", _e
|
||||
))
|
||||
}
|
||||
|
||||
async fn fetch_or_create_pk_webhook(
|
||||
http: &Client,
|
||||
channel_id: u64,
|
||||
) -> anyhow::Result<CachedWebhook> {
|
||||
match fetch_pk_webhook(http, channel_id).await? {
|
||||
Some(hook) => Ok(hook),
|
||||
None => create_pk_webhook(http, channel_id).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn fetch_pk_webhook(http: &Client, channel_id: u64) -> anyhow::Result<Option<CachedWebhook>> {
|
||||
info!("cache miss, fetching webhook for channel {}", channel_id);
|
||||
|
||||
let webhooks = http
|
||||
.channel_webhooks(channel_id.try_into().unwrap())
|
||||
.exec()
|
||||
.await?
|
||||
.models()
|
||||
.await?;
|
||||
|
||||
webhooks
|
||||
.iter()
|
||||
.find(|wh| is_proxy_webhook(wh))
|
||||
.map(|x| {
|
||||
let token = x
|
||||
.token
|
||||
.as_ref()
|
||||
.map(|x| x.to_string())
|
||||
.ok_or_else(|| anyhow::anyhow!("webhook should contain token"));
|
||||
|
||||
token.map(|token| CachedWebhook {
|
||||
id: x.id.get(),
|
||||
token,
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
async fn create_pk_webhook(http: &Client, channel_id: u64) -> anyhow::Result<CachedWebhook> {
|
||||
let response = http
|
||||
.create_webhook(channel_id.try_into().unwrap(), WEBHOOK_NAME)?
|
||||
.exec()
|
||||
.await?;
|
||||
|
||||
// todo: error handling here
|
||||
let val = response.model().await?;
|
||||
Ok(CachedWebhook {
|
||||
id: val.id.get(),
|
||||
token: val
|
||||
.token
|
||||
.ok_or_else(|| anyhow::anyhow!("webhook should contain token"))?,
|
||||
})
|
||||
}
|
||||
|
||||
fn is_proxy_webhook(wh: &Webhook) -> bool {
|
||||
wh.kind == WebhookType::Incoming
|
||||
&& wh.token.is_some()
|
||||
&& wh.name.as_deref() == Some(WEBHOOK_NAME)
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WebhookExecuteRequest {
|
||||
pub channel_id: u64,
|
||||
pub username: String,
|
||||
pub avatar_url: Option<String>,
|
||||
pub content: Option<String>,
|
||||
pub embed: Option<Embed>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct WebhookExecuteResult {
|
||||
pub message_id: u64,
|
||||
}
|
||||
|
||||
pub async fn execute_webhook(
|
||||
http: &Client,
|
||||
req: &WebhookExecuteRequest,
|
||||
) -> anyhow::Result<WebhookExecuteResult> {
|
||||
let webhook = get_webhook_cached(http, req.channel_id).await?;
|
||||
let mut request = http
|
||||
.execute_webhook(webhook.id.try_into().unwrap(), &webhook.token)
|
||||
.username(&req.username)?;
|
||||
|
||||
if let Some(ref content) = req.content {
|
||||
request = request.content(content)?;
|
||||
}
|
||||
|
||||
if let Some(ref avatar_url) = req.avatar_url {
|
||||
request = request.avatar_url(avatar_url);
|
||||
}
|
||||
|
||||
let mut embeds = Vec::new();
|
||||
if let Some(ref embed) = req.embed {
|
||||
embeds.push(embed.clone());
|
||||
request = request.embeds(&embeds)?;
|
||||
}
|
||||
|
||||
// todo: handle error if webhook was deleted, should invalidate and retry
|
||||
let result = request.wait().exec().await?;
|
||||
|
||||
let model = result.model().await?;
|
||||
if model.channel_id != req.channel_id {
|
||||
// it's possible for someone to "redirect" a webhook to another channel
|
||||
// and the only way we find out is when we send a message.
|
||||
// if this has happened remove it from cache and refetch later
|
||||
WEBHOOK_CACHE.invalidate(&req.channel_id).await;
|
||||
}
|
||||
|
||||
Ok(WebhookExecuteResult {
|
||||
message_id: model.id.get(),
|
||||
})
|
||||
}
|
||||
168
rustproxy/pk_bot/src/redis.rs
Normal file
168
rustproxy/pk_bot/src/redis.rs
Normal file
|
|
@ -0,0 +1,168 @@
|
|||
use crate::config::BotConfig;
|
||||
use once_cell::sync::Lazy;
|
||||
use redis::aio::ConnectionManager;
|
||||
use std::fmt::Debug;
|
||||
use std::time::Duration;
|
||||
use tracing::{error, info};
|
||||
use twilight_gateway::Event;
|
||||
use twilight_gateway_queue::Queue;
|
||||
use twilight_model::gateway::event::{DispatchEvent, GatewayEvent, GatewayEventDeserializer};
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisEventProxy {
|
||||
// todo: don't know if i want this struct to have responsibility for ignoring calls if redis is disabled
|
||||
inner: Option<ConnectionManager>,
|
||||
}
|
||||
|
||||
// events that should be sent in the raw handler
|
||||
// does not include message create/update since we want first dibs on those and pass them on later
|
||||
static ALLOWED_EVENTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
|
||||
vec![
|
||||
"INTERACTION_CREATE",
|
||||
"MESSAGE_DELETE",
|
||||
"MESSAGE_DELETE_BULK",
|
||||
"MESSAGE_REACTION_ADD",
|
||||
"READY",
|
||||
"GUILD_CREATE",
|
||||
"GUILD_UPDATE",
|
||||
"GUILD_DELETE",
|
||||
"GUILD_ROLE_CREATE",
|
||||
"GUILD_ROLE_UPDATE",
|
||||
"GUILD_ROLE_DELETE",
|
||||
"CHANNEL_CREATE",
|
||||
"CHANNEL_UPDATE",
|
||||
"CHANNEL_DELETE",
|
||||
"THREAD_CREATE",
|
||||
"THREAD_UPDATE",
|
||||
"THREAD_DELETE",
|
||||
"THREAD_LIST_SYNC",
|
||||
]
|
||||
});
|
||||
|
||||
impl RedisEventProxy {
|
||||
async fn send_event_raw_inner(&mut self, shard_id: u64, payload: &[u8]) -> anyhow::Result<()> {
|
||||
if let Some(ref mut redis) = self.inner {
|
||||
info!(shard_id = shard_id, "publishing event");
|
||||
let key = format!("evt-{}", shard_id);
|
||||
|
||||
redis::cmd("PUBLISH")
|
||||
.arg(&key[..])
|
||||
.arg(payload)
|
||||
.query_async(redis)
|
||||
.await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_event_raw(&mut self, shard_id: u64, payload: &[u8]) -> anyhow::Result<()> {
|
||||
let payload_str = std::str::from_utf8(payload)?;
|
||||
|
||||
if let Some(deser) = GatewayEventDeserializer::from_json(payload_str) {
|
||||
if let Some(event_type) = deser.event_type_ref() {
|
||||
if ALLOWED_EVENTS.contains(&event_type) {
|
||||
self.send_event_raw_inner(shard_id, payload).await?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn send_event_parsed(&mut self, shard_id: u64, evt: Event) -> anyhow::Result<()> {
|
||||
info!(shard_id, "sending parsed: {:?}", evt.kind());
|
||||
|
||||
let dispatch_event = DispatchEvent::try_from(evt)?;
|
||||
let gateway_event = GatewayEvent::Dispatch(0, Box::new(dispatch_event));
|
||||
let buf = serde_json::to_vec(&gateway_event)?;
|
||||
self.send_event_raw_inner(shard_id, &buf).await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn connect_to_redis(addr: &str) -> anyhow::Result<ConnectionManager> {
|
||||
let client = redis::Client::open(addr)?;
|
||||
info!("connecting to redis at {}...", addr);
|
||||
Ok(ConnectionManager::new(client).await?)
|
||||
}
|
||||
|
||||
pub async fn init_event_proxy(config: &BotConfig) -> anyhow::Result<RedisEventProxy> {
|
||||
let mgr = if let Some(redis_addr) = &config.redis_addr {
|
||||
Some(connect_to_redis(redis_addr).await?)
|
||||
} else {
|
||||
info!("no redis address specified, skipping");
|
||||
None
|
||||
};
|
||||
|
||||
Ok(RedisEventProxy { inner: mgr })
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RedisQueue {
|
||||
pub redis: ConnectionManager,
|
||||
pub concurrency: u64,
|
||||
}
|
||||
|
||||
impl Debug for RedisQueue {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RedisQueue")
|
||||
.field("concurrency", &self.concurrency)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
impl Queue for RedisQueue {
|
||||
fn request<'a>(
|
||||
&'a self,
|
||||
shard_id: [u64; 2],
|
||||
) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + 'a>> {
|
||||
Box::pin(request_inner(
|
||||
self.redis.clone(),
|
||||
self.concurrency,
|
||||
*shard_id.first().unwrap(),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
async fn request_inner(mut client: ConnectionManager, concurrency: u64, shard_id: u64) {
|
||||
let bucket = shard_id % concurrency;
|
||||
let key = format!("pluralkit:identify:{}", bucket);
|
||||
|
||||
// SET bucket 1 EX 6 NX = write a key expiring after 6 seconds if there's not already one
|
||||
let mut cmd = redis::cmd("SET");
|
||||
cmd.arg(key).arg("1").arg("EX").arg(6i8).arg("NX");
|
||||
|
||||
info!(shard_id, bucket, "waiting for allowance...");
|
||||
loop {
|
||||
let done = cmd
|
||||
.clone()
|
||||
.query_async::<_, Option<String>>(&mut client)
|
||||
.await;
|
||||
match done {
|
||||
Ok(Some(_)) => {
|
||||
info!(shard_id, bucket, "got allowance!");
|
||||
return;
|
||||
}
|
||||
Ok(None) => {
|
||||
// not allowed yet, waiting
|
||||
}
|
||||
Err(e) => {
|
||||
error!(shard_id, bucket, "error getting shard allowance: {}", e)
|
||||
}
|
||||
}
|
||||
|
||||
tokio::time::sleep(Duration::from_millis(500)).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn init_gateway_queue(config: &BotConfig) -> anyhow::Result<Option<RedisQueue>> {
|
||||
let queue = if let Some(ref addr) = config.redis_gateway_queue_addr {
|
||||
let redis = connect_to_redis(addr).await?;
|
||||
let concurrency = config.max_concurrency.unwrap_or(1);
|
||||
Some(RedisQueue { redis, concurrency })
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(queue)
|
||||
}
|
||||
9
rustproxy/pk_command_parser/Cargo.toml
Normal file
9
rustproxy/pk_command_parser/Cargo.toml
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
[package]
|
||||
name = "pk_command_parser"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.57"
|
||||
matches = "0.1.9"
|
||||
slab = "0.4.6"
|
||||
2
rustproxy/pk_command_parser/src/lib.rs
Normal file
2
rustproxy/pk_command_parser/src/lib.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
mod matcher;
|
||||
mod tokenizer;
|
||||
177
rustproxy/pk_command_parser/src/matcher.rs
Normal file
177
rustproxy/pk_command_parser/src/matcher.rs
Normal file
|
|
@ -0,0 +1,177 @@
|
|||
use std::ops::Range;
|
||||
|
||||
use crate::tokenizer::{Token, Tokenizer};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum Segment {
|
||||
Word(Vec<String>),
|
||||
Parameter { name: String, optional: bool },
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct Pattern {
|
||||
segments: Vec<Segment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct ParameterMatch {
|
||||
name: String,
|
||||
value: String,
|
||||
span: Range<usize>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct FlagMatch {
|
||||
name: String,
|
||||
value: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct MatchResult {
|
||||
parameters: Vec<ParameterMatch>,
|
||||
flags: Vec<FlagMatch>,
|
||||
remainder: Option<String>,
|
||||
}
|
||||
|
||||
pub fn does_match(s: &str, pat: &Pattern) -> Option<MatchResult> {
|
||||
let mut flags = Vec::new();
|
||||
let mut parameters = Vec::new();
|
||||
|
||||
let mut remainder_pos = None;
|
||||
|
||||
let mut segments = pat.segments.iter().peekable();
|
||||
let mut tokenizer = Tokenizer::new(s);
|
||||
|
||||
// loop until we find a keyword token
|
||||
while let Some(token) = tokenizer.next() {
|
||||
match token {
|
||||
Token::Flag { name, .. } => {
|
||||
// flags are set aside
|
||||
flags.push(FlagMatch { name, value: None });
|
||||
}
|
||||
Token::Keyword {
|
||||
value,
|
||||
quoted: _,
|
||||
span,
|
||||
} => {
|
||||
let mut next_segment = segments.next();
|
||||
|
||||
match next_segment {
|
||||
Some(Segment::Word(options)) => {
|
||||
// keyword doesn't match? definitely not a match then
|
||||
if !matches_word(&value, &options) {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
Some(Segment::Parameter { name, optional }) => {
|
||||
// for an optional parameter, check the next token instead and consume
|
||||
if let Some(Segment::Word(options)) = segments.peek() {
|
||||
if *optional {
|
||||
if !matches_word(&value, &options) {
|
||||
return None;
|
||||
}
|
||||
|
||||
segments.next();
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// set parameter aside for later
|
||||
parameters.push(ParameterMatch {
|
||||
name: name.clone(),
|
||||
span: span,
|
||||
value: value.clone(),
|
||||
});
|
||||
}
|
||||
None => {
|
||||
// out of segments to match, but we already consumed the next token
|
||||
// so set position aside for remainder and exit
|
||||
remainder_pos = Some(span.start);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Some(MatchResult {
|
||||
parameters,
|
||||
flags,
|
||||
remainder: remainder_pos.map(|x| s[x..].to_string()),
|
||||
})
|
||||
}
|
||||
|
||||
fn matches_word(word: &str, options: &[String]) -> bool {
|
||||
options.iter().any(|o| o.eq_ignore_ascii_case(word))
|
||||
}
|
||||
|
||||
struct PatternBuilder {
|
||||
segments: Vec<Segment>,
|
||||
}
|
||||
|
||||
impl PatternBuilder {
|
||||
fn new() -> PatternBuilder {
|
||||
PatternBuilder {
|
||||
segments: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn word(mut self, options: &[&str]) -> PatternBuilder {
|
||||
self.segments.push(Segment::Word(
|
||||
options.iter().map(|x| x.to_string()).collect(),
|
||||
));
|
||||
self
|
||||
}
|
||||
|
||||
fn param(mut self, name: &str) -> PatternBuilder {
|
||||
self.segments.push(Segment::Parameter {
|
||||
name: name.to_string(),
|
||||
optional: false,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
fn param_opt(mut self, name: &str) -> PatternBuilder {
|
||||
self.segments.push(Segment::Parameter {
|
||||
name: name.to_string(),
|
||||
optional: true,
|
||||
});
|
||||
self
|
||||
}
|
||||
|
||||
fn build(self) -> Pattern {
|
||||
Pattern {
|
||||
segments: self.segments,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn hi() {
|
||||
let pat = PatternBuilder::new()
|
||||
.word(&["m", "member"])
|
||||
.param("member_ref")
|
||||
.word(&["desc", "d"])
|
||||
.build();
|
||||
|
||||
assert_eq!(
|
||||
does_match("member \"Hello World\" -raw desc More text goes here", &pat),
|
||||
Some(MatchResult {
|
||||
parameters: vec![ParameterMatch {
|
||||
name: "member_ref".to_string(),
|
||||
value: "Hello World".to_string(),
|
||||
span: 7..20
|
||||
}],
|
||||
flags: vec![FlagMatch {
|
||||
name: "raw".to_string(),
|
||||
value: None
|
||||
}],
|
||||
remainder: Some("More text goes here".to_string())
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
266
rustproxy/pk_command_parser/src/tokenizer.rs
Normal file
266
rustproxy/pk_command_parser/src/tokenizer.rs
Normal file
|
|
@ -0,0 +1,266 @@
|
|||
use std::ops::Range;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq)]
|
||||
pub enum Token {
|
||||
Keyword {
|
||||
value: String,
|
||||
quoted: bool,
|
||||
span: Range<usize>,
|
||||
},
|
||||
Flag {
|
||||
name: String,
|
||||
value: Option<String>,
|
||||
span: Range<usize>,
|
||||
},
|
||||
}
|
||||
|
||||
pub struct Tokenizer<'a> {
|
||||
inner: &'a str,
|
||||
words: WordSpanIterator<'a>,
|
||||
}
|
||||
|
||||
impl<'a> Tokenizer<'a> {
|
||||
pub fn new(s: &str) -> Tokenizer {
|
||||
Tokenizer {
|
||||
inner: s,
|
||||
words: WordSpanIterator::new(s),
|
||||
}
|
||||
}
|
||||
|
||||
fn next_token(&mut self) -> Option<Token> {
|
||||
self.words.next().map(|span| {
|
||||
let word = &self.inner[span.clone()];
|
||||
|
||||
if let Some((inner, quoted_span)) = self.try_read_quoted_token(span.clone()) {
|
||||
Token::Keyword {
|
||||
value: inner.to_string(),
|
||||
quoted: true,
|
||||
span: quoted_span,
|
||||
}
|
||||
} else if word.starts_with("-") {
|
||||
let flag_name = word.trim_start_matches('-');
|
||||
|
||||
let (flag_name, flag_value) = match flag_name.split_once('=') {
|
||||
Some((flag_name, flag_value)) => {
|
||||
(flag_name.to_string(), Some(flag_value.to_string()))
|
||||
}
|
||||
None => (flag_name.to_string(), None),
|
||||
};
|
||||
|
||||
Token::Flag {
|
||||
name: flag_name,
|
||||
value: flag_value,
|
||||
span,
|
||||
}
|
||||
} else {
|
||||
Token::Keyword {
|
||||
value: word.to_string(),
|
||||
quoted: false,
|
||||
span: span,
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn try_read_quoted_token(&mut self, span: Range<usize>) -> Option<(&str, Range<usize>)> {
|
||||
let start_pos = span.start;
|
||||
find_quote_pair(&self.inner[span.start..]).map(|(left_quote, right_quotes)| {
|
||||
let mut word_span = span;
|
||||
word_span.start += left_quote.len_utf8();
|
||||
|
||||
// effectively do-while-let but rust doesn't have that :/
|
||||
loop {
|
||||
let end_word = &self.inner[word_span.clone()];
|
||||
|
||||
for right_quote in right_quotes {
|
||||
if end_word.ends_with(*right_quote) {
|
||||
let end_pos = word_span.end;
|
||||
let inner_span =
|
||||
(start_pos + left_quote.len_utf8())..(end_pos - right_quote.len_utf8());
|
||||
let inner_str = &self.inner[inner_span];
|
||||
return (inner_str, start_pos..end_pos);
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(next_word_span) = self.words.next() {
|
||||
word_span = next_word_span;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
(&self.inner[start_pos..], start_pos..self.inner.len())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for Tokenizer<'a> {
|
||||
type Item = Token;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.next_token()
|
||||
}
|
||||
}
|
||||
|
||||
struct WordSpanIterator<'a> {
|
||||
iter: std::str::SplitInclusive<'a, fn(char) -> bool>,
|
||||
pos: usize,
|
||||
}
|
||||
|
||||
impl<'a> WordSpanIterator<'a> {
|
||||
fn new(s: &'a str) -> WordSpanIterator<'a> {
|
||||
WordSpanIterator {
|
||||
iter: s.split_inclusive(char::is_whitespace),
|
||||
pos: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Iterator for WordSpanIterator<'a> {
|
||||
type Item = Range<usize>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
while let Some(word) = self.iter.next() {
|
||||
let word_start = self.pos;
|
||||
self.pos += word.len();
|
||||
|
||||
let trimmed = word.trim_end();
|
||||
if word.trim_end().len() > 0 {
|
||||
let trimmed_span = word_start..(word_start + trimmed.len());
|
||||
return Some(trimmed_span);
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn find_quote_pair(s: &str) -> Option<(char, &'static [char])> {
|
||||
s.chars()
|
||||
.next()
|
||||
.and_then(|c| matching_quotes(c).map(|x| (c, x)))
|
||||
}
|
||||
|
||||
fn matching_quotes(c: char) -> Option<&'static [char]> {
|
||||
match c {
|
||||
// Basic
|
||||
'"' => Some(&['"']),
|
||||
'\'' => Some(&['\'']),
|
||||
|
||||
// "Smart quotes"
|
||||
// Specifically ignore the left/right status of the quotes and match any combination of them
|
||||
// Left string also includes "low" quotes to allow for the low-high style used in some locales
|
||||
'\u{201c}' | '\u{201d}' | '\u{201f}' | '\u{201e}' => {
|
||||
Some(&['\u{201c}', '\u{201d}', '\u{201f}'])
|
||||
} // double
|
||||
'\u{2018}' | '\u{2019}' | '\u{201b}' | '\u{201a}' => {
|
||||
Some(&['\u{2018}', '\u{2019}', '\u{201b}'])
|
||||
} // single
|
||||
|
||||
// Chevrons (normal and "fullwidth" variants)
|
||||
'\u{00ab}' | '\u{300a}' => Some(&['\u{00bb}', '\u{300b}']), // double chevrons, pointing away (<<text>>)
|
||||
'\u{00bb}' | '\u{300b}' => Some(&['\u{00aa}', '\u{300a}']), // double chevrons, pointing together (>>text<<)
|
||||
'\u{2039}' | '\u{3008}' => Some(&['\u{203a}', '\u{3009}']), // single chevrons, pointing away (<text>)
|
||||
'\u{203a}' | '\u{3009}' => Some(&['\u{2039}', '\u{3008}']), // single chevrons, pointing together (>text<)
|
||||
|
||||
// Other
|
||||
'\u{300c}' | '\u{300e}' => Some(&['\u{300d}', '\u{300f}']), // corner brackets (Japanese/Chinese)
|
||||
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn basic_words() {
|
||||
let s = "hello world abcdefg";
|
||||
let mut tk = Tokenizer::new(s);
|
||||
|
||||
assert_word(tk.next(), "hello", false, 0..5);
|
||||
assert_word(tk.next(), "world", false, 6..11);
|
||||
assert_word(tk.next(), "abcdefg", false, 12..19);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ignore_whitespace() {
|
||||
// U+2003 EM SPACE is 3 utf-8 bytes
|
||||
let s = " lotsa \u{2003} spaces \t and \t\n\t stuff \n";
|
||||
let mut tk = Tokenizer::new(s);
|
||||
|
||||
assert_word(tk.next(), "lotsa", false, 4..9);
|
||||
assert_word(tk.next(), "spaces", false, 19..25);
|
||||
assert_word(tk.next(), "and", false, 29..32);
|
||||
assert_word(tk.next(), "stuff", false, 37..42);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn quoted_words() {
|
||||
let mut tk = Tokenizer::new("hello \"in double quotes\" 'and single quotes'");
|
||||
assert_word(tk.next(), "hello", false, 0..5);
|
||||
assert_word(tk.next(), "in double quotes", true, 6..24);
|
||||
assert_word(tk.next(), "and single quotes", true, 25..44);
|
||||
|
||||
let mut tk = Tokenizer::new("\"quote at start of\" string");
|
||||
assert_word(tk.next(), "quote at start of", true, 0..19);
|
||||
assert_word(tk.next(), "string", false, 20..26);
|
||||
|
||||
let mut tk = Tokenizer::new("\"\n include whitespace\nin quotes\n\"");
|
||||
assert_word(
|
||||
tk.next(),
|
||||
"\n include whitespace\nin quotes\n",
|
||||
true,
|
||||
0..34,
|
||||
);
|
||||
|
||||
let mut tk = Tokenizer::new("'it's 5 o'clock' said o'brian");
|
||||
assert_word(tk.next(), "it's 5 o'clock", true, 0..16);
|
||||
assert_word(tk.next(), "said", false, 17..21);
|
||||
assert_word(tk.next(), "o'brian", false, 22..29);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flags() {
|
||||
let mut tk = Tokenizer::new("word -flag and-word");
|
||||
assert_word(tk.next(), "word", false, 0..4);
|
||||
assert_flag(tk.next(), "flag", None, 5..10);
|
||||
assert_word(tk.next(), "and-word", false, 11..19);
|
||||
|
||||
let mut tk = Tokenizer::new("-lots --of ---dashes");
|
||||
assert_flag(tk.next(), "lots", None, 0..5);
|
||||
assert_flag(tk.next(), "of", None, 6..10);
|
||||
assert_flag(tk.next(), "dashes", None, 11..20);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn flag_values() {
|
||||
let mut tk = Tokenizer::new("-flag=value --flag2=value2 -flag3=value3=more");
|
||||
assert_flag(tk.next(), "flag", Some("value"), 0..11);
|
||||
assert_flag(tk.next(), "flag2", Some("value2"), 12..26);
|
||||
assert_flag(tk.next(), "flag3", Some("value3=more"), 27..45);
|
||||
}
|
||||
|
||||
fn assert_word(tk: Option<Token>, s: &str, quoted: bool, span: Range<usize>) {
|
||||
assert_eq!(
|
||||
tk,
|
||||
Some(Token::Keyword {
|
||||
value: s.to_string(),
|
||||
quoted: quoted,
|
||||
span: span
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
fn assert_flag(tk: Option<Token>, name: &str, value: Option<&str>, span: Range<usize>) {
|
||||
assert_eq!(
|
||||
tk,
|
||||
Some(Token::Flag {
|
||||
name: name.to_string(),
|
||||
span: span,
|
||||
value: value.map(|x| x.to_string())
|
||||
})
|
||||
);
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue