mirror of
https://github.com/PluralKit/PluralKit.git
synced 2026-02-12 16:50:10 +00:00
Major command handling refactor
This commit is contained in:
parent
0869f94cdf
commit
f067485e88
15 changed files with 463 additions and 355 deletions
|
|
@ -1,84 +1,111 @@
|
|||
import logging
|
||||
from collections import namedtuple
|
||||
|
||||
import asyncpg
|
||||
import discord
|
||||
import logging
|
||||
import re
|
||||
from typing import Tuple, Optional
|
||||
|
||||
import pluralkit
|
||||
from pluralkit import db
|
||||
from pluralkit.bot import utils, embeds
|
||||
from pluralkit import db, System, Member
|
||||
from pluralkit.bot import embeds, utils
|
||||
|
||||
logger = logging.getLogger("pluralkit.bot.commands")
|
||||
|
||||
command_list = {}
|
||||
|
||||
class NoSystemRegistered(Exception):
|
||||
pass
|
||||
def next_arg(arg_string: str) -> Tuple[str, Optional[str]]:
|
||||
if arg_string.startswith("\""):
|
||||
end_quote = arg_string.find("\"", start=1)
|
||||
if end_quote > 0:
|
||||
return arg_string[1:end_quote], arg_string[end_quote + 1:].strip()
|
||||
else:
|
||||
return arg_string[1:], None
|
||||
|
||||
class CommandContext(namedtuple("CommandContext", ["client", "conn", "message", "system"])):
|
||||
client: discord.Client
|
||||
conn: asyncpg.Connection
|
||||
message: discord.Message
|
||||
system: pluralkit.System
|
||||
next_space = arg_string.find(" ")
|
||||
if next_space >= 0:
|
||||
return arg_string[:next_space].strip(), arg_string[next_space:].strip()
|
||||
else:
|
||||
return arg_string.strip(), None
|
||||
|
||||
async def reply(self, message=None, embed=None):
|
||||
return await self.client.send_message(self.message.channel, message, embed=embed)
|
||||
|
||||
class MemberCommandContext(namedtuple("MemberCommandContext", CommandContext._fields + ("member",)), CommandContext):
|
||||
client: discord.Client
|
||||
conn: asyncpg.Connection
|
||||
message: discord.Message
|
||||
system: pluralkit.System
|
||||
member: pluralkit.Member
|
||||
class CommandResponse:
|
||||
def to_embed(self):
|
||||
pass
|
||||
|
||||
class CommandEntry(namedtuple("CommandEntry", ["command", "function", "usage", "description", "category"])):
|
||||
pass
|
||||
|
||||
def command(cmd, usage=None, description=None, category=None, system_required=True):
|
||||
def wrap(func):
|
||||
async def wrapper(client, conn, message, args):
|
||||
system = await db.get_system_by_account(conn, message.author.id)
|
||||
class CommandSuccess(CommandResponse):
|
||||
def __init__(self, text):
|
||||
self.text = text
|
||||
|
||||
if system_required and system is None:
|
||||
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one."))
|
||||
return
|
||||
|
||||
ctx = CommandContext(client=client, conn=conn, message=message, system=system)
|
||||
try:
|
||||
res = await func(ctx, args)
|
||||
def to_embed(self):
|
||||
return embeds.success("\u2705 " + self.text)
|
||||
|
||||
if res:
|
||||
embed = res if isinstance(res, discord.Embed) else utils.make_default_embed(res)
|
||||
await client.send_message(message.channel, embed=embed)
|
||||
except NoSystemRegistered:
|
||||
await client.send_message(message.channel, embed=utils.make_error_embed("No system registered to this account. Use `pk;system new` to register one."))
|
||||
except Exception:
|
||||
logger.exception("Exception while handling command {} (args={}, system={})".format(cmd, args, system.hid if system else "(none)"))
|
||||
|
||||
# Put command in map
|
||||
command_list[cmd] = CommandEntry(command=cmd, function=wrapper, usage=usage, description=description, category=category)
|
||||
return wrapper
|
||||
return wrap
|
||||
class CommandError(Exception, CommandResponse):
|
||||
def __init__(self, embed: str, help: Tuple[str, str] = None):
|
||||
self.text = embed
|
||||
self.help = help
|
||||
|
||||
def member_command(cmd, usage=None, description=None, category=None, system_only=True):
|
||||
def wrap(func):
|
||||
async def wrapper(ctx: CommandContext, args):
|
||||
# Return if no member param
|
||||
if len(args) == 0:
|
||||
return embeds.error("You must pass a member name or ID.")
|
||||
def to_embed(self):
|
||||
return embeds.error("\u274c " + self.text, self.help)
|
||||
|
||||
# System is allowed to be none if not system_only
|
||||
system_id = ctx.system.id if ctx.system else None
|
||||
# And find member by key
|
||||
member = await utils.get_member_fuzzy(ctx.conn, system_id=system_id, key=args[0], system_only=system_only)
|
||||
|
||||
if member is None:
|
||||
return embeds.error("Can't find member \"{}\".".format(args[0]))
|
||||
class CommandContext:
|
||||
def __init__(self, client: discord.Client, message: discord.Message, conn, args: str):
|
||||
self.client = client
|
||||
self.message = message
|
||||
self.conn = conn
|
||||
self.args = args
|
||||
|
||||
async def get_system(self) -> Optional[System]:
|
||||
return await db.get_system_by_account(self.conn, self.message.author.id)
|
||||
|
||||
async def ensure_system(self) -> System:
|
||||
system = await self.get_system()
|
||||
|
||||
if not system:
|
||||
raise CommandError(
|
||||
embeds.error("No system registered to this account. Use `pk;system new` to register one."))
|
||||
|
||||
return system
|
||||
|
||||
def has_next(self) -> bool:
|
||||
return bool(self.args)
|
||||
|
||||
def pop_str(self, error: CommandError = None) -> str:
|
||||
if not self.args:
|
||||
if error:
|
||||
raise error
|
||||
return None
|
||||
|
||||
popped, self.args = next_arg(self.args)
|
||||
return popped
|
||||
|
||||
async def pop_system(self, error: CommandError = None) -> System:
|
||||
name = self.pop_str(error)
|
||||
system = await utils.get_system_fuzzy(self.conn, self.client, name)
|
||||
|
||||
if not system:
|
||||
raise CommandError("Unable to find system '{}'.".format(name))
|
||||
|
||||
return system
|
||||
|
||||
async def pop_member(self, error: CommandError = None, system_only: bool = True) -> Member:
|
||||
name = self.pop_str(error)
|
||||
|
||||
if system_only:
|
||||
system = await self.ensure_system()
|
||||
else:
|
||||
system = await self.get_system()
|
||||
|
||||
member = await utils.get_member_fuzzy(self.conn, system.id if system else None, name, system_only)
|
||||
if not member:
|
||||
raise CommandError("Unable to find member '{}'{}.".format(name, " in your system" if system_only else ""))
|
||||
|
||||
return member
|
||||
|
||||
def remaining(self):
|
||||
return self.args
|
||||
|
||||
async def reply(self, content=None, embed=None):
|
||||
return await self.client.send_message(self.message.channel, content=content, embed=embed)
|
||||
|
||||
ctx = MemberCommandContext(client=ctx.client, conn=ctx.conn, message=ctx.message, system=ctx.system, member=member)
|
||||
return await func(ctx, args[1:])
|
||||
return command(cmd=cmd, usage="<name|id> {}".format(usage or ""), description=description, category=category, system_required=False)(wrapper)
|
||||
return wrap
|
||||
|
||||
import pluralkit.bot.commands.import_commands
|
||||
import pluralkit.bot.commands.member_commands
|
||||
|
|
@ -87,3 +114,69 @@ import pluralkit.bot.commands.misc_commands
|
|||
import pluralkit.bot.commands.mod_commands
|
||||
import pluralkit.bot.commands.switch_commands
|
||||
import pluralkit.bot.commands.system_commands
|
||||
|
||||
|
||||
async def run_command(ctx: CommandContext, func):
|
||||
try:
|
||||
result = await func(ctx)
|
||||
if isinstance(result, CommandResponse):
|
||||
await ctx.reply(embed=result.to_embed())
|
||||
except CommandError as e:
|
||||
await ctx.reply(embed=e.to_embed())
|
||||
except Exception:
|
||||
logger.exception("Exception while dispatching command")
|
||||
|
||||
|
||||
async def command_dispatch(client: discord.Client, message: discord.Message, conn) -> bool:
|
||||
prefix = "^pk(;|!)"
|
||||
commands = [
|
||||
(r"system (new|register|create|init)", system_commands.new_system),
|
||||
(r"system set", system_commands.system_set),
|
||||
(r"system link", system_commands.system_link),
|
||||
(r"system unlink", system_commands.system_unlink),
|
||||
(r"system fronter", system_commands.system_fronter),
|
||||
(r"system fronthistory", system_commands.system_fronthistory),
|
||||
(r"system (delete|remove|destroy|erase)", system_commands.system_delete),
|
||||
(r"system frontpercent(age)?", system_commands.system_frontpercent),
|
||||
(r"system", system_commands.system_info),
|
||||
|
||||
(r"import tupperware", import_commands.import_tupperware),
|
||||
|
||||
(r"member (new|create|add|register)", member_commands.new_member),
|
||||
(r"member set", member_commands.member_set),
|
||||
(r"member proxy", member_commands.member_proxy),
|
||||
(r"member (delete|remove|destroy|erase)", member_commands.member_delete),
|
||||
(r"member", member_commands.member_info),
|
||||
|
||||
(r"message", message_commands.message_info),
|
||||
|
||||
(r"mod log", mod_commands.set_log),
|
||||
|
||||
(r"invite", misc_commands.invite_link),
|
||||
(r"export", misc_commands.export),
|
||||
|
||||
(r"help", misc_commands.show_help),
|
||||
|
||||
(r"switch move", switch_commands.switch_move),
|
||||
(r"switch out", switch_commands.switch_out),
|
||||
(r"switch", switch_commands.switch_member)
|
||||
]
|
||||
|
||||
for pattern, func in commands:
|
||||
regex = re.compile(prefix + pattern, re.IGNORECASE)
|
||||
|
||||
cmd = message.content
|
||||
match = regex.match(cmd)
|
||||
if match:
|
||||
remaining_string = cmd[match.span()[1]:].strip()
|
||||
|
||||
ctx = CommandContext(
|
||||
client=client,
|
||||
message=message,
|
||||
conn=conn,
|
||||
args=remaining_string
|
||||
)
|
||||
|
||||
await run_command(ctx, func)
|
||||
return True
|
||||
return False
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue