rustproxy: initial commit

This commit is contained in:
Ske 2022-06-22 19:28:48 +02:00
parent b47694edc1
commit 9d90de45a6
23 changed files with 4686 additions and 0 deletions

3
rustproxy/.gitignore vendored Normal file
View file

@ -0,0 +1,3 @@
/target
.idea/
config.toml

2625
rustproxy/Cargo.lock generated Normal file

File diff suppressed because it is too large Load diff

5
rustproxy/Cargo.toml Normal file
View file

@ -0,0 +1,5 @@
[workspace]
members = [
"pk_bot",
"pk_command_parser"
]

View file

@ -0,0 +1,7 @@
token = "put your token here"
database = "postgres://postgres:password@localhost/postgres"
redis_addr = "redis://127.0.0.1/"
redis_gateway_queue_addr = "redis://127.0.0.1/"
shard_count = 1

View file

@ -0,0 +1,35 @@
[package]
name = "pk_bot"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
anyhow = "1.0.57"
async-trait = "0.1.56"
chrono = "0.4.19"
config = { version = "0.13.1", default-features = false, features = ["toml"] }
dashmap = "5.3.4"
futures = "0.3.21"
moka = { version = "0.8.5", features = ["future"] }
once_cell = "1.12.0"
redis = { version = "0.21.5", features = ["connection-manager", "tokio-comp"] }
regex = "1.5.6"
serde = "1.0.137"
serde_json = "1.0.81"
smallvec = "1.8.0"
smol_str = "0.1.23"
sqlx = { version = "0.6.0", features = ["runtime-tokio-native-tls", "postgres", "chrono", "macros"] }
thiserror = "1.0.31"
tokio = { version = "1.19.1", features = ["full"] }
tracing = "0.1.34"
tracing-subscriber = "0.3.11"
twilight-gateway = "0.11.0"
twilight-gateway-queue = "0.11.0"
twilight-http = "0.11.0"
twilight-model = "0.11.0"
twilight-util = { version = "0.11.0", features = ["builder", "permission-calculator"] }
[target.'cfg(not(target_env = "msvc"))'.dependencies]
tikv-jemallocator = "0.5"

View file

@ -0,0 +1,265 @@
use std::sync::{Arc, RwLock};
use dashmap::mapref::one::Ref;
use dashmap::DashMap;
use smol_str::SmolStr;
use twilight_model::channel::permission_overwrite::PermissionOverwrite;
use twilight_model::channel::{Channel, ChannelType};
use twilight_model::gateway::event::Event;
use twilight_model::gateway::payload::incoming::ThreadListSync;
use twilight_model::guild::{Guild, PartialMember, Permissions, PremiumTier, Role};
use twilight_model::id::marker::{ChannelMarker, GuildMarker, RoleMarker, UserMarker};
use twilight_model::id::Id;
use twilight_util::permission_calculator::PermissionCalculator;
const DM_PERMISSIONS: Permissions = Permissions::VIEW_CHANNEL
.union(Permissions::SEND_MESSAGES)
.union(Permissions::READ_MESSAGE_HISTORY)
.union(Permissions::ADD_REACTIONS)
.union(Permissions::ATTACH_FILES)
.union(Permissions::EMBED_LINKS)
.union(Permissions::USE_EXTERNAL_EMOJIS);
#[derive(Debug, Clone)]
pub struct CachedGuild {
owner_id: u64,
_premium_tier: PremiumTier,
}
#[derive(Debug, Clone)]
pub struct CachedChannel {
_name: Option<SmolStr>, // stores strings 22 characters or less inline, which is a large portion of channel names
_parent_id: Option<Id<ChannelMarker>>,
guild_id: Option<Id<GuildMarker>>,
kind: ChannelType,
overwrites: Vec<PermissionOverwrite>,
}
#[derive(Debug, Clone)]
pub struct CachedRole {
permissions: Permissions,
_mentionable: bool,
}
#[derive(Debug, Clone)]
pub struct CachedBotMember {
roles: Vec<Id<RoleMarker>>,
}
pub struct DiscordCache {
bot_user: Arc<RwLock<Option<Id<UserMarker>>>>,
guilds: DashMap<u64, CachedGuild>,
channels: DashMap<u64, CachedChannel>,
roles: DashMap<u64, CachedRole>,
bot_members: DashMap<u64, CachedBotMember>,
}
impl DiscordCache {
pub fn new() -> DiscordCache {
DiscordCache {
bot_user: Arc::new(RwLock::new(None)),
guilds: DashMap::new(),
channels: DashMap::new(),
roles: DashMap::new(),
bot_members: DashMap::new(),
}
}
pub fn handle_event(&self, event: &Event) {
match event {
Event::Ready(ref r) => {
let mut bot_user = self.bot_user.write().unwrap();
*bot_user = Some(r.user.id);
}
Event::GuildCreate(ref g) => self.update_guild(g),
Event::ChannelCreate(ref ch) => self.update_channel(ch),
Event::ChannelUpdate(ref ch) => self.update_channel(ch),
Event::ChannelDelete(ref ch) => self.delete_channel(ch.id),
Event::ThreadCreate(ref ch) => self.update_channel(ch),
Event::ThreadUpdate(ref ch) => self.update_channel(ch),
Event::ThreadDelete(ref ch) => self.delete_channel(ch.id),
Event::ThreadListSync(ref ts) => self.update_threads(ts),
Event::RoleCreate(ref r) => self.update_role(&r.role),
Event::RoleUpdate(ref r) => self.update_role(&r.role),
Event::RoleDelete(ref r) => self.delete_role(r.role_id),
Event::MemberUpdate(ref member) => {
let current_user = self.bot_user_id();
if Some(member.user.id) == current_user {
self.update_bot_member(member.guild_id, &member.roles);
}
}
_ => {}
}
}
fn update_guild(&self, guild: &Guild) {
for channel in &guild.channels {
self.update_channel(channel);
}
for role in &guild.roles {
self.update_role(role);
}
let current_user = self.bot_user_id();
for member in &guild.members {
if Some(member.user.id) == current_user {
self.update_bot_member(member.guild_id, &member.roles);
}
}
self.guilds.insert(
guild.id.get(),
CachedGuild {
owner_id: guild.owner_id.get(),
_premium_tier: guild.premium_tier,
},
);
}
fn bot_user_id(&self) -> Option<Id<UserMarker>> {
*self.bot_user.read().unwrap()
}
fn update_channel(&self, channel: &Channel) {
self.channels.insert(
channel.id.get(),
CachedChannel {
_name: channel.name.as_deref().map(SmolStr::new),
_parent_id: channel.parent_id,
guild_id: channel.guild_id,
kind: channel.kind,
overwrites: channel.permission_overwrites.clone().unwrap_or_default(),
},
);
}
fn delete_channel(&self, id: Id<ChannelMarker>) {
self.channels.remove(&id.get());
}
fn update_role(&self, role: &Role) {
self.roles.insert(
role.id.get(),
CachedRole {
permissions: role.permissions,
_mentionable: role.mentionable,
},
);
}
fn delete_role(&self, id: Id<RoleMarker>) {
self.roles.remove(&id.get());
}
fn update_threads(&self, evt: &ThreadListSync) {
for thread in &evt.threads {
self.update_channel(thread);
}
}
fn update_bot_member(&self, guild_id: Id<GuildMarker>, roles: &[Id<RoleMarker>]) {
self.bot_members.insert(
guild_id.get(),
CachedBotMember {
roles: roles.to_vec(),
},
);
}
pub fn get_guild(&self, guild_id: Id<GuildMarker>) -> anyhow::Result<Ref<u64, CachedGuild>> {
self.guilds
.get(&guild_id.get())
.ok_or_else(|| anyhow::anyhow!("could not find guild in cache: {}", guild_id))
}
pub fn get_channel(
&self,
channel_id: Id<ChannelMarker>,
) -> anyhow::Result<Ref<u64, CachedChannel>> {
self.channels
.get(&channel_id.get())
.ok_or_else(|| anyhow::anyhow!("could not find channel in cache: {}", channel_id))
}
pub fn get_role(&self, role_id: Id<RoleMarker>) -> anyhow::Result<Ref<u64, CachedRole>> {
self.roles
.get(&role_id.get())
.ok_or_else(|| anyhow::anyhow!("could not find role in cache: {}", role_id))
}
pub fn get_bot_member(
&self,
guild_id: Id<GuildMarker>,
) -> anyhow::Result<Ref<u64, CachedBotMember>> {
self.bot_members.get(&guild_id.get()).ok_or_else(|| {
anyhow::anyhow!("could not find bot member in cache for guild: {}", guild_id)
})
}
fn calculate_permissions_in(
&self,
channel_id: Id<ChannelMarker>,
user_id: Id<UserMarker>,
roles: &[Id<RoleMarker>],
) -> anyhow::Result<Permissions> {
let channel = self.get_channel(channel_id)?;
if let Some(guild_id) = channel.guild_id {
let guild = self.get_guild(guild_id)?;
let everyone_role = self.get_role(guild_id.cast())?;
let mut member_roles = Vec::with_capacity(roles.len());
for role_id in roles {
let role = self.get_role(*role_id)?;
member_roles.push((role_id.cast(), role.permissions));
}
let calc = PermissionCalculator::new(
guild_id,
user_id,
everyone_role.permissions,
&member_roles,
)
.owner_id(Id::new(guild.owner_id));
Ok(calc.in_channel(channel.kind, &channel.overwrites))
} else {
Ok(DM_PERMISSIONS)
}
}
pub fn member_permissions(
&self,
channel_id: Id<ChannelMarker>,
user_id: Id<UserMarker>,
member: Option<&PartialMember>,
) -> anyhow::Result<Permissions> {
if let Some(member) = member {
self.calculate_permissions_in(channel_id, user_id, &member.roles)
} else {
// this should just be dm perms, probably?
self.calculate_permissions_in(channel_id, user_id, &[])
}
}
pub fn bot_permissions(&self, channel_id: Id<ChannelMarker>) -> anyhow::Result<Permissions> {
let channel = self.get_channel(channel_id)?;
if let Some(guild_id) = channel.guild_id {
let member = self.get_bot_member(guild_id)?;
let user_id = self
.bot_user_id()
.ok_or_else(|| anyhow::anyhow!("haven't received bot user id yet"))?;
self.calculate_permissions_in(channel_id, user_id, &member.roles)
} else {
Ok(DM_PERMISSIONS)
}
}
pub fn channel_type(&self, channel_id: Id<ChannelMarker>) -> anyhow::Result<ChannelType> {
let channel = self.get_channel(channel_id)?;
Ok(channel.kind)
}
}

View file

@ -0,0 +1,22 @@
use config::{Config, Environment, File, FileFormat};
use serde::Deserialize;
#[derive(Deserialize, Debug)]
pub struct BotConfig {
pub token: String,
pub max_concurrency: Option<u64>,
pub database: String,
pub redis_addr: Option<String>,
pub redis_gateway_queue_addr: Option<String>,
pub shard_count: Option<u64>,
}
// todo: should this be a once_cell::Lazy global const or something
pub fn load_config() -> anyhow::Result<BotConfig> {
let builder = Config::builder()
.add_source(Environment::default())
.add_source(File::new("config", FileFormat::Toml));
Ok(builder.build()?.try_deserialize()?)
}

174
rustproxy/pk_bot/src/db.rs Normal file
View file

@ -0,0 +1,174 @@
use std::str::FromStr;
use crate::{
config::BotConfig,
model::{PKMember, PKMemberGuild, PKMessage, PKSystem, PKSystemGuild},
};
use chrono::{DateTime, Utc};
use sqlx::{
postgres::{PgConnectOptions, PgPoolOptions},
ConnectOptions, FromRow, PgPool,
};
use tracing::info;
#[derive(FromRow, Debug, Default)]
pub struct MessageContext {
// being defensive with these values - we need to be explicit with Option<T>
// when the database might return null, and some of these don't have proper default values set
// most of the Option<T>s can probably get removed with a few changes to the db function
pub system_id: Option<i32>,
pub is_deleting: Option<bool>,
pub in_blacklist: Option<bool>,
pub in_log_blacklist: Option<bool>,
pub proxy_enabled: Option<bool>,
pub last_switch: Option<i32>,
pub last_switch_members: Option<Vec<i32>>,
pub last_switch_timestamp: Option<DateTime<Utc>>,
pub system_tag: Option<String>,
pub system_guild_tag: Option<String>,
pub tag_enabled: Option<bool>,
pub system_avatar: Option<String>,
pub allow_autoproxy: Option<bool>,
pub latch_timeout: Option<i32>,
}
pub async fn get_message_context(
pool: &PgPool,
account_id: i64,
guild_id: i64,
channel_id: i64,
) -> anyhow::Result<MessageContext> {
Ok(sqlx::query_as("select * from message_context($1, $2, $3)")
.bind(account_id)
.bind(guild_id)
.bind(channel_id)
.fetch_one(pool)
.await?)
}
#[derive(FromRow, Debug, Clone)]
pub struct ProxyTagEntry {
pub prefix: String,
pub suffix: String,
pub member_id: i32,
}
impl From<(&str, &str, i32)> for ProxyTagEntry {
fn from((prefix, suffix, member_id): (&str, &str, i32)) -> Self {
ProxyTagEntry {
prefix: prefix.to_string(),
suffix: suffix.to_string(),
member_id,
}
}
}
pub async fn get_proxy_tags(pool: &PgPool, system_id: i32) -> anyhow::Result<Vec<ProxyTagEntry>> {
Ok(sqlx::query_as("select coalesce((i.tags).prefix, '') as prefix, coalesce((i.tags).suffix, '') as suffix, member_id from (select unnest(proxy_tags) as tags, id as member_id from members where system = $1) as i;")
.bind(system_id)
.fetch_all(pool)
.await?)
}
#[repr(i32)]
#[derive(sqlx::Type, Debug, Copy, Clone)]
pub enum AutoproxyMode {
Off = 1,
Front = 2,
Latch = 3,
Member = 4,
}
#[derive(FromRow, Debug, Clone)]
pub struct AutoproxyState {
pub autoproxy_mode: AutoproxyMode,
pub autoproxy_member: Option<i32>,
pub last_latch_timestamp: Option<DateTime<Utc>>,
}
pub async fn get_autoproxy_state(
pool: &PgPool,
system_id: i32,
guild_id: i64,
channel_id: i64,
) -> anyhow::Result<Option<AutoproxyState>> {
Ok(sqlx::query_as(
"select * from autoproxy where system = $1 and guild_id = $2 and channel_id = $3;",
)
.bind(system_id)
.bind(guild_id)
.bind(channel_id)
.fetch_optional(pool)
.await?)
}
pub async fn get_system_by_id(pool: &PgPool, system_id: i32) -> anyhow::Result<Option<PKSystem>> {
Ok(sqlx::query_as("select * from systems where id = $1")
.bind(system_id)
.fetch_optional(pool)
.await?)
}
pub async fn get_member_by_id(pool: &PgPool, member_id: i32) -> anyhow::Result<Option<PKMember>> {
Ok(sqlx::query_as("select * from members where id = $1")
.bind(member_id)
.fetch_optional(pool)
.await?)
}
pub async fn get_system_guild(
pool: &PgPool,
system_id: i32,
guild_id: i64,
) -> anyhow::Result<Option<PKSystemGuild>> {
Ok(
sqlx::query_as("select * from system_guild where system = $1 and guild = $2")
.bind(system_id)
.bind(guild_id)
.fetch_optional(pool)
.await?,
)
}
pub async fn get_member_guild(
pool: &PgPool,
member_id: i32,
guild_id: i64,
) -> anyhow::Result<Option<PKMemberGuild>> {
Ok(
sqlx::query_as("select * from member_guild where member = $1 and guild = $2")
.bind(member_id)
.bind(guild_id)
.fetch_optional(pool)
.await?,
)
}
pub async fn insert_message(pool: &PgPool, message: PKMessage) -> anyhow::Result<()> {
sqlx::query("insert into messages (mid, guild, channel, member, sender, original_mid) values ($1, $2, $3, $4, $5, $6)")
.bind(message.mid)
.bind(message.guild)
.bind(message.channel)
.bind(message.member_id)
.bind(message.sender)
.bind(message.original_mid)
.execute(pool)
.await?;
Ok(())
}
pub async fn init_db(config: &BotConfig) -> anyhow::Result<PgPool> {
info!("connecting to database");
let options = PgConnectOptions::from_str(&config.database)
.unwrap()
.disable_statement_logging()
.clone();
let pool = PgPoolOptions::new()
.max_connections(32)
.connect_with(options)
.await?;
Ok(pool)
}

View file

@ -0,0 +1,63 @@
use crate::config::BotConfig;
use crate::redis;
use std::{env, sync::Arc};
use tracing::info;
use twilight_gateway::{
cluster::{Events, ShardScheme},
Cluster, EventTypeFlags, Intents,
};
use twilight_http::Client;
pub async fn init_gateway(
http: Arc<Client>,
config: &BotConfig,
) -> anyhow::Result<(Arc<Cluster>, Events)> {
let mut builder = Cluster::builder(
config.token.clone(),
Intents::GUILDS
| Intents::DIRECT_MESSAGES
| Intents::GUILD_MESSAGES
| Intents::MESSAGE_CONTENT,
);
builder = builder.http_client(http);
builder = builder.event_types(EventTypeFlags::all());
if let Some(scheme) = get_shard_scheme(config)? {
info!("using shard scheme: {:?}", scheme);
builder = builder.shard_scheme(scheme);
}
if let Some(queue) = redis::init_gateway_queue(config).await? {
info!("using redis gateway queue");
builder = builder.queue(Arc::new(queue));
}
let (cluster, events) = builder.build().await?;
let cluster = Arc::new(cluster);
let cluster_spawn = Arc::clone(&cluster);
tokio::spawn(async move {
info!("starting shards...");
cluster_spawn.up().await;
});
Ok((cluster, events))
}
fn get_cluster_id() -> anyhow::Result<u64> {
Ok(env::var("NOMAD_ALLOC_INDEX")
.unwrap_or_else(|_| "0".to_string())
.parse::<u64>()?)
}
fn get_shard_scheme(config: &BotConfig) -> anyhow::Result<Option<ShardScheme>> {
let shard_count = config.shard_count.unwrap_or(1);
let scheme = if shard_count >= 16 {
let cluster_id = get_cluster_id()?;
let first_shard_id = 16 * cluster_id;
let shard_range = first_shard_id..first_shard_id + 16;
Some(ShardScheme::try_from((shard_range, shard_count))?)
} else {
None
};
Ok(scheme)
}

View file

@ -0,0 +1,109 @@
#[cfg(not(target_env = "msvc"))]
use tikv_jemallocator::Jemalloc;
#[cfg(not(target_env = "msvc"))]
#[global_allocator]
static GLOBAL: Jemalloc = Jemalloc;
use crate::cache::DiscordCache;
use crate::redis::RedisEventProxy;
use futures::StreamExt;
use sqlx::PgPool;
use std::sync::Arc;
use tracing::{error, info};
use twilight_gateway::Event;
use twilight_http::Client as HttpClient;
mod cache;
mod config;
mod db;
mod gateway;
mod model;
mod proxy;
mod redis;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().init();
let config = config::load_config()?;
info!("loaded config: {:?}", config);
let pool = db::init_db(&config).await?;
let http = Arc::new(HttpClient::new(config.token.clone()));
let (_cluster, mut events) = gateway::init_gateway(Arc::clone(&http), &config).await?;
let cache = Arc::new(DiscordCache::new());
let redis = redis::init_event_proxy(&config).await?;
while let Some((shard_id, event)) = events.next().await {
let http = Arc::clone(&http);
let pool = pool.clone();
let cache = Arc::clone(&cache);
let redis = redis.clone();
tokio::spawn(async move {
cache.handle_event(&event);
let res = handle_event(shard_id, event, http, pool, cache, redis).await;
if let Err(e) = res {
error!("error handling event: {:?}", e);
}
});
}
Ok(())
}
async fn handle_event(
shard_id: u64,
event: Event,
http: Arc<HttpClient>,
pool: PgPool,
cache: Arc<DiscordCache>,
mut redis: RedisEventProxy,
) -> anyhow::Result<()> {
match event {
Event::MessageCreate(msg) => {
if msg.content.starts_with("pk;") || msg.content.starts_with("pk!") {
redis
.send_event_parsed(shard_id, Event::MessageCreate(msg))
.await?;
return Ok(());
}
let channel_type = cache.channel_type(msg.channel_id)?;
let ctx = db::get_message_context(
&pool,
msg.author.id.get() as i64,
msg.guild_id.map(|x| x.get()).unwrap_or_default() as i64,
msg.channel_id.get() as i64,
)
.await?;
let _member_permissions =
cache.member_permissions(msg.channel_id, msg.author.id, msg.member.as_ref())?;
let bot_permissions = cache.bot_permissions(msg.channel_id)?;
match proxy::check_preconditions(&msg, channel_type, bot_permissions, &ctx) {
Ok(_) => {
info!("attempting to proxy");
proxy::do_proxy(&http, &pool, &msg, &ctx).await?;
}
Err(reason) => {
info!("skipping proxy because: {}", reason)
}
}
}
Event::ShardConnected(_) => {
info!("connected on shard {}", shard_id);
}
Event::ShardPayload(payload) => {
redis.send_event_raw(shard_id, &payload.bytes).await?;
}
// Other events here...
_ => {}
}
Ok(())
}

View file

@ -0,0 +1,49 @@
use sqlx::FromRow;
#[derive(FromRow, Debug, Clone)]
pub struct PKSystem {
pub id: i32,
pub name: Option<String>,
pub tag: Option<String>,
pub avatar_url: Option<String>,
}
#[derive(FromRow, Debug, Clone)]
pub struct PKMember {
pub id: i32,
pub system: i32,
pub name: String,
pub color: Option<String>,
pub avatar_url: Option<String>,
pub display_name: Option<String>,
pub pronouns: Option<String>,
pub description: Option<String>,
}
#[derive(FromRow, Debug, Clone)]
pub struct PKMessage {
pub mid: i64,
pub guild: Option<i64>,
pub channel: i64,
pub member_id: i32,
pub sender: i64,
pub original_mid: Option<i64>,
}
#[derive(FromRow, Debug, Clone)]
pub struct PKSystemGuild {
pub system: i32,
pub guild: i64,
pub proxy_enabled: bool,
pub tag: Option<String>,
pub tag_enabled: bool,
}
#[derive(FromRow, Debug, Clone)]
pub struct PKMemberGuild {
pub member: i32,
pub guild: i64,
pub display_name: Option<String>,
pub avatar_url: Option<String>,
}

View file

@ -0,0 +1,31 @@
use crate::db::{AutoproxyMode, AutoproxyState, MessageContext};
pub fn resolve_autoproxy_member(
ctx: &MessageContext,
state: &AutoproxyState,
content: &str,
) -> Option<i32> {
if !ctx.allow_autoproxy.unwrap_or(true) {
return None;
}
if is_escape(content) {
return None;
}
let first_fronter = ctx.last_switch_members.iter().flatten().cloned().next();
match (state.autoproxy_mode, state.autoproxy_member, first_fronter) {
(AutoproxyMode::Latch, Some(m), _) => Some(m),
(AutoproxyMode::Member, Some(m), _) => Some(m),
(AutoproxyMode::Front, _, Some(f)) => Some(f),
_ => None,
}
}
fn _is_unlatch(content: &str) -> bool {
content.starts_with("\\\\") || content.starts_with("\\\u{200b}\\")
}
fn is_escape(content: &str) -> bool {
content.starts_with('\\')
}

View file

@ -0,0 +1,227 @@
use self::{post_proxy::ProxyResult, webhook::WebhookExecuteRequest};
use crate::db::{self, MessageContext};
use sqlx::PgPool;
use thiserror::Error;
use twilight_http::Client;
use twilight_model::{
channel::{message::MessageType, ChannelType, Message},
guild::Permissions,
user::User,
};
mod autoproxy;
mod post_proxy;
mod profile;
mod reply;
mod tags;
mod webhook;
#[derive(Error, Debug)]
// use this for pk;proxycheck or something
pub enum PreconditionFailure {
#[error("bot is missing permissions (has: {has:?}, needs: {needs:?})")]
BotMissingPermission {
has: Permissions,
needs: Permissions,
},
#[error("invalid channel type {0:?}")]
InvalidChannelType(ChannelType),
#[error("invalid message type {0:?}")]
InvalidMessageType(MessageType),
#[error("user is bot")]
UserIsBot,
#[error("user is webhook")]
UserIsWebhook,
#[error("user is discord system")]
UserIsDiscordSystem,
#[error("user has no system")]
UserHasNoSystem,
#[error("proxy disabled for system")]
ProxyDisabledForSystem,
#[error("proxy disabled in channel")]
ProxyDisabledInChannel,
#[error("message contains activity")]
MessageContainsActivity,
#[error("message contains sticker")]
MessageContainsSticker,
#[error("message is empty and has no attachments")]
MessageIsEmpty,
}
// todo: the parameters here are nasty, refactor+put this code somewhere else maybe
pub fn check_preconditions(
msg: &Message,
channel_type: ChannelType,
bot_permissions: Permissions,
ctx: &MessageContext,
) -> Result<(), PreconditionFailure> {
let required_permissions =
Permissions::SEND_MESSAGES | Permissions::MANAGE_WEBHOOKS | Permissions::MANAGE_MESSAGES;
if !bot_permissions.contains(required_permissions) {
return Err(PreconditionFailure::BotMissingPermission {
has: bot_permissions,
needs: required_permissions,
});
}
match channel_type {
ChannelType::GuildText
| ChannelType::GuildNews
| ChannelType::GuildPrivateThread
| ChannelType::GuildPublicThread
| ChannelType::GuildNewsThread => Ok(()),
wrong_type => Err(PreconditionFailure::InvalidChannelType(wrong_type)),
}?;
match msg.kind {
MessageType::Regular | MessageType::Reply => Ok(()),
wrong_type => Err(PreconditionFailure::InvalidMessageType(wrong_type)),
}?;
match msg {
Message {
author: User {
system: Some(true), ..
},
..
} => Err(PreconditionFailure::UserIsDiscordSystem),
Message {
author: User { bot: true, .. },
..
} => Err(PreconditionFailure::UserIsBot),
Message {
webhook_id: Some(_),
..
} => Err(PreconditionFailure::UserIsWebhook),
Message {
activity: Some(_), ..
} => Err(PreconditionFailure::MessageContainsActivity),
Message {
sticker_items: s, ..
} if !s.is_empty() => Err(PreconditionFailure::MessageContainsSticker),
Message {
content: c,
attachments: a,
..
} if c.trim().is_empty() && a.is_empty() => Err(PreconditionFailure::MessageIsEmpty),
_ => Ok(()),
}?;
match ctx {
MessageContext {
system_id: None, ..
} => Err(PreconditionFailure::UserHasNoSystem),
MessageContext {
in_blacklist: Some(true),
..
} => Err(PreconditionFailure::ProxyDisabledInChannel),
MessageContext {
proxy_enabled: Some(false),
..
} => Err(PreconditionFailure::ProxyDisabledForSystem),
_ => Ok(()),
}?;
Ok(())
}
struct ProxyMatchResult {
member_id: i32,
inner_content: String,
_tags: Option<(String, String)>, // todo: need this for keepproxy
}
async fn match_tags_or_autoproxy(
pool: &PgPool,
msg: &Message,
ctx: &MessageContext,
) -> anyhow::Result<Option<ProxyMatchResult>> {
let guild_id = msg.guild_id.ok_or_else(|| anyhow::anyhow!("no guild id"))?;
let system_id = ctx.system_id.ok_or_else(|| anyhow::anyhow!("no system"))?;
let tags = db::get_proxy_tags(pool, system_id).await?;
let ap_state = db::get_autoproxy_state(
pool,
system_id,
guild_id.get() as i64,
0, // all autoproxy has channel id 0? o.o
)
.await?;
let tag_match = tags::match_proxy_tags(&tags, &msg.content);
if let Some(tag_match) = tag_match {
return Ok(Some(ProxyMatchResult {
member_id: tag_match.member_id,
inner_content: tag_match.inner_content,
_tags: Some(tag_match.tags),
}));
}
if let Some(ap_state) = ap_state {
let res = autoproxy::resolve_autoproxy_member(ctx, &ap_state, &msg.content);
if let Some(member_id) = res {
return Ok(Some(ProxyMatchResult {
inner_content: msg.content.clone(),
member_id,
_tags: None,
}));
}
}
Ok(None)
}
// todo: this shouldn't depend on a Message object (for reproxy/proxy command/etc)
pub async fn do_proxy(
http: &Client,
pool: &PgPool,
msg: &Message,
ctx: &MessageContext,
) -> anyhow::Result<()> {
let guild_id = msg.guild_id.ok_or_else(|| anyhow::anyhow!("no guild id"))?;
let system_id = ctx.system_id.ok_or_else(|| anyhow::anyhow!("no system"))?;
// todo: unlatch check/exec should probably go in here somewhere
let proxy_match = match_tags_or_autoproxy(pool, msg, ctx).await?;
if let Some(result) = proxy_match {
let profile =
profile::fetch_proxy_profile(pool, guild_id.get(), system_id, result.member_id).await?;
let webhook_req = WebhookExecuteRequest {
channel_id: msg.channel_id.get(),
avatar_url: profile.avatar_url().map(|s| s.to_string()),
content: Some(result.inner_content.clone()),
username: profile.formatted_name(),
embed: msg
.referenced_message
.as_deref()
.map(|msg| reply::create_reply_embed(guild_id, msg))
.transpose()?,
};
let webhook_res = webhook::execute_webhook(http, &webhook_req).await?;
let proxy_res = ProxyResult {
channel_id: msg.channel_id.get(),
guild_id: guild_id.get(),
member_id: result.member_id,
original_message_id: msg.id.get(),
proxy_message_id: webhook_res.message_id,
sender: msg.author.id.get(),
};
post_proxy::handle_post_proxy(http, pool, &proxy_res).await?;
}
Ok(())
}

View file

@ -0,0 +1,66 @@
use crate::db;
use crate::model::PKMessage;
use futures::TryFutureExt;
use sqlx::PgPool;
use tracing::error;
use twilight_http::Client;
#[derive(Debug, Clone)]
pub struct ProxyResult {
pub guild_id: u64,
pub channel_id: u64,
pub proxy_message_id: u64,
pub original_message_id: u64,
pub sender: u64,
pub member_id: i32,
}
pub async fn handle_post_proxy(
http: &Client,
pool: &PgPool,
res: &ProxyResult,
) -> anyhow::Result<()> {
// todo: log channel
let _ = futures::join!(
delete_original_message(http, res)
.inspect_err(|e| error!("error deleting original message: {}", e)),
insert_message_in_db(pool, res)
.inspect_err(|e| error!("error deleting original message: {}", e))
);
Ok(())
}
async fn delete_original_message(http: &Client, res: &ProxyResult) -> anyhow::Result<()> {
// todo: sleep some amount
// (do we still need to do that or did discord fix that client bug?)
http.delete_message(
res.channel_id.try_into().unwrap(),
res.original_message_id.try_into().unwrap(),
)
.exec()
.await?;
Ok(())
}
async fn insert_message_in_db(pool: &PgPool, res: &ProxyResult) -> anyhow::Result<()> {
db::insert_message(
pool,
PKMessage {
mid: res.proxy_message_id as i64,
original_mid: Some(res.original_message_id as i64),
sender: res.sender as i64,
guild: Some(res.guild_id as i64),
channel: res.channel_id as i64,
member_id: res.member_id,
},
)
.await?;
Ok(())
}
#[cfg(test)]
mod tests {}

View file

@ -0,0 +1,90 @@
use futures::try_join;
use sqlx::PgPool;
use crate::{
db,
model::{PKMember, PKMemberGuild, PKSystem, PKSystemGuild},
};
// inventing the term "proxy profile" here to describe the info needed to work out the webhook name+avatar
// arbitrary choice to put the source models in the struct and logic in methods, could just as well have had a function to do the math and put the results in a struct
pub struct ProxyProfile {
system: PKSystem,
member: PKMember,
system_guild: Option<PKSystemGuild>,
member_guild: Option<PKMemberGuild>,
}
impl ProxyProfile {
pub fn name(&self) -> &str {
let member_name = &self.member.name;
let display_name = self.member.display_name.as_deref();
let server_name = self
.member_guild
.as_ref()
.and_then(|x| x.display_name.as_deref());
server_name.or(display_name).unwrap_or(member_name)
}
pub fn avatar_url(&self) -> Option<&str> {
let system_avatar = self.system.avatar_url.as_deref();
let member_avatar = self.member.avatar_url.as_deref();
let server_avatar = self
.member_guild
.as_ref()
.and_then(|x| x.avatar_url.as_deref());
server_avatar.or(member_avatar).or(system_avatar)
}
pub fn tag(&self) -> Option<&str> {
let server_tag = self.system_guild.as_ref().and_then(|x| x.tag.as_deref());
let system_tag = self.system.tag.as_deref();
server_tag.or(system_tag)
}
pub fn formatted_name(&self) -> String {
let mut name = if let Some(tag) = self.tag() {
format!("{} {}", self.name(), tag)
} else {
self.name().to_string()
};
if name.len() == 1 {
name.push('\u{17b5}');
}
name
}
}
pub async fn fetch_proxy_profile(
pool: &PgPool,
guild_id: u64,
system_id: i32,
member_id: i32,
) -> anyhow::Result<ProxyProfile> {
// todo: this should be a db view with joins
// this is all the info that proxy_members returned, so a single-member version of that could work nicely
let system = db::get_system_by_id(pool, system_id);
let member = db::get_member_by_id(pool, member_id);
let system_guild = db::get_system_guild(pool, system_id, guild_id as i64);
let member_guild = db::get_member_guild(pool, member_id, guild_id as i64);
let (system, member, system_guild, member_guild) =
try_join!(system, member, system_guild, member_guild)?;
let system = system.ok_or_else(|| anyhow::anyhow!("could not find system"))?;
let member = member.ok_or_else(|| anyhow::anyhow!("could not find member"))?;
Ok(ProxyProfile {
system,
member,
system_guild,
member_guild,
})
}
#[cfg(test)]
mod tests {
// todo: this code is gonna be easy to unit test so we should do that
}

View file

@ -0,0 +1,57 @@
use twilight_http::Client;
use twilight_model::{
channel::{embed::Embed, Message},
id::{
marker::{GuildMarker, UserMarker},
Id,
},
util::ImageHash,
};
use twilight_util::builder::embed::{EmbedAuthorBuilder, EmbedBuilder, ImageSource};
async fn _fetch_additional_reply_info(_http: &Client) -> anyhow::Result<()> {
Ok(())
}
pub fn create_reply_embed(
guild_id: Id<GuildMarker>,
replied_to: &Message,
) -> anyhow::Result<Embed> {
// todo: guild avatars, guild nicknames
// probably put this in fetch_additional_reply_info
let author = {
let icon = replied_to
.author
.avatar
.map(|hash| get_avatar_url(replied_to.author.id, hash))
.and_then(|url| ImageSource::url(url).ok());
let mut builder = EmbedAuthorBuilder::new(replied_to.author.name.clone());
if let Some(icon) = icon {
builder = builder.icon_url(icon);
};
builder.build()
};
let content = {
let jump_link = format!(
"https://discord.com/channels/{}/{}/{}",
guild_id, replied_to.channel_id, replied_to.id
);
let content = format!("**[Reply to:]({})** ", jump_link);
// todo: properly add truncated content (including handling links/spoilers/etc)
content
};
let builder = EmbedBuilder::new().description(content).author(author);
Ok(builder.build())
}
fn get_avatar_url(user_id: Id<UserMarker>, hash: ImageHash) -> String {
format!(
"https://cdn.discordapp.com/avatars/{}/{}.png",
user_id, hash
)
}

View file

@ -0,0 +1,93 @@
use crate::db::ProxyTagEntry;
use tracing::info;
#[derive(Debug)]
pub struct ProxyTagMatch {
pub inner_content: String,
pub tags: (String, String),
pub member_id: i32,
}
pub fn match_proxy_tags(tags: &[ProxyTagEntry], content: &str) -> Option<ProxyTagMatch> {
let content = content.trim();
let mut sorted_entries = tags.to_vec();
sorted_entries.sort_by_key(|x| -((x.prefix.len() + x.suffix.len()) as i32));
for entry in sorted_entries {
let is_tag_match = content.starts_with(&entry.prefix) && content.ends_with(&entry.suffix);
info!(
"prefix: {}, suffix: {}, content: {}, is_match: {}",
entry.prefix, entry.suffix, content, is_tag_match
);
// todo: extract leading mentions
// todo: allow empty matches only if we're proxying an attachment
// todo: properly handle <>s etc, there's some regex stuff in there i don't entirely understand
// todo: there's some weird edge cases with various unicode control characters and emoji joiners and whatever, should figure that out + unit test it
if is_tag_match {
let inner_content =
&content[entry.prefix.len()..(content.len() - entry.suffix.len())].trim();
return Some(ProxyTagMatch {
inner_content: inner_content.to_string(),
tags: (entry.prefix, entry.suffix),
member_id: entry.member_id,
});
}
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_match() {
let tags = vec![
("[", "]", 0).into(),
("[[", "]]", 1).into(),
("P:", "", 2).into(),
("", "-P", 3).into(),
("+ ", "", 4).into(),
];
assert_no_match(&tags, "hello world");
assert_match(&tags, "[hello world]", "hello world", 0);
assert_match(&tags, "[ hello world ]", "hello world", 0);
assert_match(&tags, " [ hello world ] ", "hello world", 0);
assert_match(&tags, "[\nhello\n]", "hello", 0);
assert_match(&tags, "[\nhello\nworld\n]", "hello\nworld", 0);
assert_match(&tags, "[[text]]", "text", 1);
assert_match(&tags, "[text]]", "text]", 0);
assert_match(&tags, "[[[text]]]", "[text]", 1);
assert_match(&tags, "P:text", "text", 2);
assert_match(&tags, "text -P", "text", 3);
assert_match(&tags, "+ hello", "hello", 4);
assert_no_match(&tags, "+hello"); // (prefix contains trailing space)
assert_match(&tags, "[]", "", 0);
// edge case: the c# implementation currently does what the commented out test does
// *if* the message doesn't have an attachment. not sure if we should mirror this here.
// assert_match(&tags, "[[]]", "[]", 0);
assert_match(&tags, "[[]]", "", 1);
}
fn assert_match(tags: &[ProxyTagEntry], message: &str, inner: &str, member: i32) {
let res = match_proxy_tags(tags, message);
assert_eq!(
res.as_ref()
.map(|x| (x.inner_content.as_str(), x.member_id)),
Some((inner, member))
);
}
fn assert_no_match(tags: &[ProxyTagEntry], message: &str) {
let res = match_proxy_tags(tags, message);
assert!(res.is_none());
}
}

View file

@ -0,0 +1,143 @@
use moka::future::Cache;
use once_cell::sync::Lazy;
use tracing::info;
use twilight_http::Client;
use twilight_model::channel::{embed::Embed, Webhook, WebhookType};
// space for 1 million is probably way overkill, this is a LRU cache so it's okay to evict occasionally
static WEBHOOK_CACHE: Lazy<Cache<u64, CachedWebhook>> = Lazy::new(|| Cache::new(1024 * 1024));
const WEBHOOK_NAME: &str = "PluralKit Proxy Webhook";
#[derive(Debug, Clone)]
pub struct CachedWebhook {
pub id: u64,
pub token: String,
}
pub async fn get_webhook_cached(http: &Client, channel_id: u64) -> anyhow::Result<CachedWebhook> {
let res = WEBHOOK_CACHE
.try_get_with(channel_id, fetch_or_create_pk_webhook(http, channel_id))
.await;
// todo: what happens if fetch_or_create_pk_webhook errors? i think moka handles it properly and just retries
// but i'm not entiiiirely sure
// https://docs.rs/moka/0.8.5/moka/future/struct.Cache.html#method.try_get_with
// error is Arc<Error> here and it's hard to convert that into an owned ref so we just make a new error lmao
res.map_err(|_e| anyhow::anyhow!(
"could not fetch webhook: {}", _e
))
}
async fn fetch_or_create_pk_webhook(
http: &Client,
channel_id: u64,
) -> anyhow::Result<CachedWebhook> {
match fetch_pk_webhook(http, channel_id).await? {
Some(hook) => Ok(hook),
None => create_pk_webhook(http, channel_id).await,
}
}
async fn fetch_pk_webhook(http: &Client, channel_id: u64) -> anyhow::Result<Option<CachedWebhook>> {
info!("cache miss, fetching webhook for channel {}", channel_id);
let webhooks = http
.channel_webhooks(channel_id.try_into().unwrap())
.exec()
.await?
.models()
.await?;
webhooks
.iter()
.find(|wh| is_proxy_webhook(wh))
.map(|x| {
let token = x
.token
.as_ref()
.map(|x| x.to_string())
.ok_or_else(|| anyhow::anyhow!("webhook should contain token"));
token.map(|token| CachedWebhook {
id: x.id.get(),
token,
})
})
.transpose()
}
async fn create_pk_webhook(http: &Client, channel_id: u64) -> anyhow::Result<CachedWebhook> {
let response = http
.create_webhook(channel_id.try_into().unwrap(), WEBHOOK_NAME)?
.exec()
.await?;
// todo: error handling here
let val = response.model().await?;
Ok(CachedWebhook {
id: val.id.get(),
token: val
.token
.ok_or_else(|| anyhow::anyhow!("webhook should contain token"))?,
})
}
fn is_proxy_webhook(wh: &Webhook) -> bool {
wh.kind == WebhookType::Incoming
&& wh.token.is_some()
&& wh.name.as_deref() == Some(WEBHOOK_NAME)
}
#[derive(Debug)]
pub struct WebhookExecuteRequest {
pub channel_id: u64,
pub username: String,
pub avatar_url: Option<String>,
pub content: Option<String>,
pub embed: Option<Embed>,
}
#[derive(Debug)]
pub struct WebhookExecuteResult {
pub message_id: u64,
}
pub async fn execute_webhook(
http: &Client,
req: &WebhookExecuteRequest,
) -> anyhow::Result<WebhookExecuteResult> {
let webhook = get_webhook_cached(http, req.channel_id).await?;
let mut request = http
.execute_webhook(webhook.id.try_into().unwrap(), &webhook.token)
.username(&req.username)?;
if let Some(ref content) = req.content {
request = request.content(content)?;
}
if let Some(ref avatar_url) = req.avatar_url {
request = request.avatar_url(avatar_url);
}
let mut embeds = Vec::new();
if let Some(ref embed) = req.embed {
embeds.push(embed.clone());
request = request.embeds(&embeds)?;
}
// todo: handle error if webhook was deleted, should invalidate and retry
let result = request.wait().exec().await?;
let model = result.model().await?;
if model.channel_id != req.channel_id {
// it's possible for someone to "redirect" a webhook to another channel
// and the only way we find out is when we send a message.
// if this has happened remove it from cache and refetch later
WEBHOOK_CACHE.invalidate(&req.channel_id).await;
}
Ok(WebhookExecuteResult {
message_id: model.id.get(),
})
}

View file

@ -0,0 +1,168 @@
use crate::config::BotConfig;
use once_cell::sync::Lazy;
use redis::aio::ConnectionManager;
use std::fmt::Debug;
use std::time::Duration;
use tracing::{error, info};
use twilight_gateway::Event;
use twilight_gateway_queue::Queue;
use twilight_model::gateway::event::{DispatchEvent, GatewayEvent, GatewayEventDeserializer};
#[derive(Clone)]
pub struct RedisEventProxy {
// todo: don't know if i want this struct to have responsibility for ignoring calls if redis is disabled
inner: Option<ConnectionManager>,
}
// events that should be sent in the raw handler
// does not include message create/update since we want first dibs on those and pass them on later
static ALLOWED_EVENTS: Lazy<Vec<&'static str>> = Lazy::new(|| {
vec![
"INTERACTION_CREATE",
"MESSAGE_DELETE",
"MESSAGE_DELETE_BULK",
"MESSAGE_REACTION_ADD",
"READY",
"GUILD_CREATE",
"GUILD_UPDATE",
"GUILD_DELETE",
"GUILD_ROLE_CREATE",
"GUILD_ROLE_UPDATE",
"GUILD_ROLE_DELETE",
"CHANNEL_CREATE",
"CHANNEL_UPDATE",
"CHANNEL_DELETE",
"THREAD_CREATE",
"THREAD_UPDATE",
"THREAD_DELETE",
"THREAD_LIST_SYNC",
]
});
impl RedisEventProxy {
async fn send_event_raw_inner(&mut self, shard_id: u64, payload: &[u8]) -> anyhow::Result<()> {
if let Some(ref mut redis) = self.inner {
info!(shard_id = shard_id, "publishing event");
let key = format!("evt-{}", shard_id);
redis::cmd("PUBLISH")
.arg(&key[..])
.arg(payload)
.query_async(redis)
.await?;
}
Ok(())
}
pub async fn send_event_raw(&mut self, shard_id: u64, payload: &[u8]) -> anyhow::Result<()> {
let payload_str = std::str::from_utf8(payload)?;
if let Some(deser) = GatewayEventDeserializer::from_json(payload_str) {
if let Some(event_type) = deser.event_type_ref() {
if ALLOWED_EVENTS.contains(&event_type) {
self.send_event_raw_inner(shard_id, payload).await?;
}
}
}
Ok(())
}
pub async fn send_event_parsed(&mut self, shard_id: u64, evt: Event) -> anyhow::Result<()> {
info!(shard_id, "sending parsed: {:?}", evt.kind());
let dispatch_event = DispatchEvent::try_from(evt)?;
let gateway_event = GatewayEvent::Dispatch(0, Box::new(dispatch_event));
let buf = serde_json::to_vec(&gateway_event)?;
self.send_event_raw_inner(shard_id, &buf).await?;
Ok(())
}
}
pub async fn connect_to_redis(addr: &str) -> anyhow::Result<ConnectionManager> {
let client = redis::Client::open(addr)?;
info!("connecting to redis at {}...", addr);
Ok(ConnectionManager::new(client).await?)
}
pub async fn init_event_proxy(config: &BotConfig) -> anyhow::Result<RedisEventProxy> {
let mgr = if let Some(redis_addr) = &config.redis_addr {
Some(connect_to_redis(redis_addr).await?)
} else {
info!("no redis address specified, skipping");
None
};
Ok(RedisEventProxy { inner: mgr })
}
#[derive(Clone)]
pub struct RedisQueue {
pub redis: ConnectionManager,
pub concurrency: u64,
}
impl Debug for RedisQueue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RedisQueue")
.field("concurrency", &self.concurrency)
.finish()
}
}
impl Queue for RedisQueue {
fn request<'a>(
&'a self,
shard_id: [u64; 2],
) -> std::pin::Pin<Box<dyn futures::Future<Output = ()> + Send + 'a>> {
Box::pin(request_inner(
self.redis.clone(),
self.concurrency,
*shard_id.first().unwrap(),
))
}
}
async fn request_inner(mut client: ConnectionManager, concurrency: u64, shard_id: u64) {
let bucket = shard_id % concurrency;
let key = format!("pluralkit:identify:{}", bucket);
// SET bucket 1 EX 6 NX = write a key expiring after 6 seconds if there's not already one
let mut cmd = redis::cmd("SET");
cmd.arg(key).arg("1").arg("EX").arg(6i8).arg("NX");
info!(shard_id, bucket, "waiting for allowance...");
loop {
let done = cmd
.clone()
.query_async::<_, Option<String>>(&mut client)
.await;
match done {
Ok(Some(_)) => {
info!(shard_id, bucket, "got allowance!");
return;
}
Ok(None) => {
// not allowed yet, waiting
}
Err(e) => {
error!(shard_id, bucket, "error getting shard allowance: {}", e)
}
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
}
pub async fn init_gateway_queue(config: &BotConfig) -> anyhow::Result<Option<RedisQueue>> {
let queue = if let Some(ref addr) = config.redis_gateway_queue_addr {
let redis = connect_to_redis(addr).await?;
let concurrency = config.max_concurrency.unwrap_or(1);
Some(RedisQueue { redis, concurrency })
} else {
None
};
Ok(queue)
}

View file

@ -0,0 +1,9 @@
[package]
name = "pk_command_parser"
version = "0.1.0"
edition = "2021"
[dependencies]
anyhow = "1.0.57"
matches = "0.1.9"
slab = "0.4.6"

View file

@ -0,0 +1,2 @@
mod matcher;
mod tokenizer;

View file

@ -0,0 +1,177 @@
use std::ops::Range;
use crate::tokenizer::{Token, Tokenizer};
#[derive(Debug)]
pub enum Segment {
Word(Vec<String>),
Parameter { name: String, optional: bool },
}
#[derive(Debug)]
pub struct Pattern {
segments: Vec<Segment>,
}
#[derive(Debug, PartialEq)]
pub struct ParameterMatch {
name: String,
value: String,
span: Range<usize>,
}
#[derive(Debug, PartialEq)]
pub struct FlagMatch {
name: String,
value: Option<String>,
}
#[derive(Debug, PartialEq)]
pub struct MatchResult {
parameters: Vec<ParameterMatch>,
flags: Vec<FlagMatch>,
remainder: Option<String>,
}
pub fn does_match(s: &str, pat: &Pattern) -> Option<MatchResult> {
let mut flags = Vec::new();
let mut parameters = Vec::new();
let mut remainder_pos = None;
let mut segments = pat.segments.iter().peekable();
let mut tokenizer = Tokenizer::new(s);
// loop until we find a keyword token
while let Some(token) = tokenizer.next() {
match token {
Token::Flag { name, .. } => {
// flags are set aside
flags.push(FlagMatch { name, value: None });
}
Token::Keyword {
value,
quoted: _,
span,
} => {
let mut next_segment = segments.next();
match next_segment {
Some(Segment::Word(options)) => {
// keyword doesn't match? definitely not a match then
if !matches_word(&value, &options) {
return None;
}
}
Some(Segment::Parameter { name, optional }) => {
// for an optional parameter, check the next token instead and consume
if let Some(Segment::Word(options)) = segments.peek() {
if *optional {
if !matches_word(&value, &options) {
return None;
}
segments.next();
continue;
}
}
// set parameter aside for later
parameters.push(ParameterMatch {
name: name.clone(),
span: span,
value: value.clone(),
});
}
None => {
// out of segments to match, but we already consumed the next token
// so set position aside for remainder and exit
remainder_pos = Some(span.start);
break;
}
}
}
}
}
Some(MatchResult {
parameters,
flags,
remainder: remainder_pos.map(|x| s[x..].to_string()),
})
}
fn matches_word(word: &str, options: &[String]) -> bool {
options.iter().any(|o| o.eq_ignore_ascii_case(word))
}
struct PatternBuilder {
segments: Vec<Segment>,
}
impl PatternBuilder {
fn new() -> PatternBuilder {
PatternBuilder {
segments: Vec::new(),
}
}
fn word(mut self, options: &[&str]) -> PatternBuilder {
self.segments.push(Segment::Word(
options.iter().map(|x| x.to_string()).collect(),
));
self
}
fn param(mut self, name: &str) -> PatternBuilder {
self.segments.push(Segment::Parameter {
name: name.to_string(),
optional: false,
});
self
}
fn param_opt(mut self, name: &str) -> PatternBuilder {
self.segments.push(Segment::Parameter {
name: name.to_string(),
optional: true,
});
self
}
fn build(self) -> Pattern {
Pattern {
segments: self.segments,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hi() {
let pat = PatternBuilder::new()
.word(&["m", "member"])
.param("member_ref")
.word(&["desc", "d"])
.build();
assert_eq!(
does_match("member \"Hello World\" -raw desc More text goes here", &pat),
Some(MatchResult {
parameters: vec![ParameterMatch {
name: "member_ref".to_string(),
value: "Hello World".to_string(),
span: 7..20
}],
flags: vec![FlagMatch {
name: "raw".to_string(),
value: None
}],
remainder: Some("More text goes here".to_string())
})
);
}
}

View file

@ -0,0 +1,266 @@
use std::ops::Range;
#[derive(Debug, Clone, PartialEq)]
pub enum Token {
Keyword {
value: String,
quoted: bool,
span: Range<usize>,
},
Flag {
name: String,
value: Option<String>,
span: Range<usize>,
},
}
pub struct Tokenizer<'a> {
inner: &'a str,
words: WordSpanIterator<'a>,
}
impl<'a> Tokenizer<'a> {
pub fn new(s: &str) -> Tokenizer {
Tokenizer {
inner: s,
words: WordSpanIterator::new(s),
}
}
fn next_token(&mut self) -> Option<Token> {
self.words.next().map(|span| {
let word = &self.inner[span.clone()];
if let Some((inner, quoted_span)) = self.try_read_quoted_token(span.clone()) {
Token::Keyword {
value: inner.to_string(),
quoted: true,
span: quoted_span,
}
} else if word.starts_with("-") {
let flag_name = word.trim_start_matches('-');
let (flag_name, flag_value) = match flag_name.split_once('=') {
Some((flag_name, flag_value)) => {
(flag_name.to_string(), Some(flag_value.to_string()))
}
None => (flag_name.to_string(), None),
};
Token::Flag {
name: flag_name,
value: flag_value,
span,
}
} else {
Token::Keyword {
value: word.to_string(),
quoted: false,
span: span,
}
}
})
}
fn try_read_quoted_token(&mut self, span: Range<usize>) -> Option<(&str, Range<usize>)> {
let start_pos = span.start;
find_quote_pair(&self.inner[span.start..]).map(|(left_quote, right_quotes)| {
let mut word_span = span;
word_span.start += left_quote.len_utf8();
// effectively do-while-let but rust doesn't have that :/
loop {
let end_word = &self.inner[word_span.clone()];
for right_quote in right_quotes {
if end_word.ends_with(*right_quote) {
let end_pos = word_span.end;
let inner_span =
(start_pos + left_quote.len_utf8())..(end_pos - right_quote.len_utf8());
let inner_str = &self.inner[inner_span];
return (inner_str, start_pos..end_pos);
}
}
if let Some(next_word_span) = self.words.next() {
word_span = next_word_span;
} else {
break;
}
}
(&self.inner[start_pos..], start_pos..self.inner.len())
})
}
}
impl<'a> Iterator for Tokenizer<'a> {
type Item = Token;
fn next(&mut self) -> Option<Self::Item> {
self.next_token()
}
}
struct WordSpanIterator<'a> {
iter: std::str::SplitInclusive<'a, fn(char) -> bool>,
pos: usize,
}
impl<'a> WordSpanIterator<'a> {
fn new(s: &'a str) -> WordSpanIterator<'a> {
WordSpanIterator {
iter: s.split_inclusive(char::is_whitespace),
pos: 0,
}
}
}
impl<'a> Iterator for WordSpanIterator<'a> {
type Item = Range<usize>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(word) = self.iter.next() {
let word_start = self.pos;
self.pos += word.len();
let trimmed = word.trim_end();
if word.trim_end().len() > 0 {
let trimmed_span = word_start..(word_start + trimmed.len());
return Some(trimmed_span);
}
}
None
}
}
fn find_quote_pair(s: &str) -> Option<(char, &'static [char])> {
s.chars()
.next()
.and_then(|c| matching_quotes(c).map(|x| (c, x)))
}
fn matching_quotes(c: char) -> Option<&'static [char]> {
match c {
// Basic
'"' => Some(&['"']),
'\'' => Some(&['\'']),
// "Smart quotes"
// Specifically ignore the left/right status of the quotes and match any combination of them
// Left string also includes "low" quotes to allow for the low-high style used in some locales
'\u{201c}' | '\u{201d}' | '\u{201f}' | '\u{201e}' => {
Some(&['\u{201c}', '\u{201d}', '\u{201f}'])
} // double
'\u{2018}' | '\u{2019}' | '\u{201b}' | '\u{201a}' => {
Some(&['\u{2018}', '\u{2019}', '\u{201b}'])
} // single
// Chevrons (normal and "fullwidth" variants)
'\u{00ab}' | '\u{300a}' => Some(&['\u{00bb}', '\u{300b}']), // double chevrons, pointing away (<<text>>)
'\u{00bb}' | '\u{300b}' => Some(&['\u{00aa}', '\u{300a}']), // double chevrons, pointing together (>>text<<)
'\u{2039}' | '\u{3008}' => Some(&['\u{203a}', '\u{3009}']), // single chevrons, pointing away (<text>)
'\u{203a}' | '\u{3009}' => Some(&['\u{2039}', '\u{3008}']), // single chevrons, pointing together (>text<)
// Other
'\u{300c}' | '\u{300e}' => Some(&['\u{300d}', '\u{300f}']), // corner brackets (Japanese/Chinese)
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_words() {
let s = "hello world abcdefg";
let mut tk = Tokenizer::new(s);
assert_word(tk.next(), "hello", false, 0..5);
assert_word(tk.next(), "world", false, 6..11);
assert_word(tk.next(), "abcdefg", false, 12..19);
}
#[test]
fn ignore_whitespace() {
// U+2003 EM SPACE is 3 utf-8 bytes
let s = " lotsa \u{2003} spaces \t and \t\n\t stuff \n";
let mut tk = Tokenizer::new(s);
assert_word(tk.next(), "lotsa", false, 4..9);
assert_word(tk.next(), "spaces", false, 19..25);
assert_word(tk.next(), "and", false, 29..32);
assert_word(tk.next(), "stuff", false, 37..42);
}
#[test]
fn quoted_words() {
let mut tk = Tokenizer::new("hello \"in double quotes\" 'and single quotes'");
assert_word(tk.next(), "hello", false, 0..5);
assert_word(tk.next(), "in double quotes", true, 6..24);
assert_word(tk.next(), "and single quotes", true, 25..44);
let mut tk = Tokenizer::new("\"quote at start of\" string");
assert_word(tk.next(), "quote at start of", true, 0..19);
assert_word(tk.next(), "string", false, 20..26);
let mut tk = Tokenizer::new("\"\n include whitespace\nin quotes\n\"");
assert_word(
tk.next(),
"\n include whitespace\nin quotes\n",
true,
0..34,
);
let mut tk = Tokenizer::new("'it's 5 o'clock' said o'brian");
assert_word(tk.next(), "it's 5 o'clock", true, 0..16);
assert_word(tk.next(), "said", false, 17..21);
assert_word(tk.next(), "o'brian", false, 22..29);
}
#[test]
fn flags() {
let mut tk = Tokenizer::new("word -flag and-word");
assert_word(tk.next(), "word", false, 0..4);
assert_flag(tk.next(), "flag", None, 5..10);
assert_word(tk.next(), "and-word", false, 11..19);
let mut tk = Tokenizer::new("-lots --of ---dashes");
assert_flag(tk.next(), "lots", None, 0..5);
assert_flag(tk.next(), "of", None, 6..10);
assert_flag(tk.next(), "dashes", None, 11..20);
}
#[test]
fn flag_values() {
let mut tk = Tokenizer::new("-flag=value --flag2=value2 -flag3=value3=more");
assert_flag(tk.next(), "flag", Some("value"), 0..11);
assert_flag(tk.next(), "flag2", Some("value2"), 12..26);
assert_flag(tk.next(), "flag3", Some("value3=more"), 27..45);
}
fn assert_word(tk: Option<Token>, s: &str, quoted: bool, span: Range<usize>) {
assert_eq!(
tk,
Some(Token::Keyword {
value: s.to_string(),
quoted: quoted,
span: span
})
);
}
fn assert_flag(tk: Option<Token>, name: &str, value: Option<&str>, span: Range<usize>) {
assert_eq!(
tk,
Some(Token::Flag {
name: name.to_string(),
span: span,
value: value.map(|x| x.to_string())
})
);
}
}