Merge branch 'rewrite-port'

This commit is contained in:
Ske 2018-11-08 16:43:09 +01:00
commit a72a7c3de9
13 changed files with 369 additions and 450 deletions

View file

@ -1,131 +1,134 @@
import asyncio
import json
import logging
import os
import time
import asyncpg
import sys
import traceback
from datetime import datetime
import asyncio
import os
import logging
import discord
import traceback
from pluralkit import db
from pluralkit.bot import channel_logger, commands, proxy, embeds
from pluralkit.bot import commands, proxy, channel_logger, embeds
logging.basicConfig(level=logging.INFO, format="[%(asctime)s] [%(name)s] [%(levelname)s] %(message)s")
# logging.getLogger("pluralkit").setLevel(logging.DEBUG)
def connect_to_database() -> asyncpg.pool.Pool:
username = os.environ["DATABASE_USER"]
password = os.environ["DATABASE_PASS"]
name = os.environ["DATABASE_NAME"]
host = os.environ["DATABASE_HOST"]
port = os.environ["DATABASE_PORT"]
class PluralKitBot:
def __init__(self, token):
self.token = token
self.logger = logging.getLogger("pluralkit.bot")
if username is None or password is None or name is None or host is None or port is None:
print(
"Database credentials not specified. Please pass valid PostgreSQL database credentials in the DATABASE_[USER|PASS|NAME|HOST|PORT] environment variable.",
file=sys.stderr)
sys.exit(1)
self.client = discord.Client()
self.client.event(self.on_error)
self.client.event(self.on_ready)
self.client.event(self.on_message)
self.client.event(self.on_socket_raw_receive)
try:
port = int(port)
except ValueError:
print("Please pass a valid integer as the DATABASE_PORT environment variable.", file=sys.stderr)
sys.exit(1)
self.channel_logger = channel_logger.ChannelLogger(self.client)
return asyncio.get_event_loop().run_until_complete(db.connect(
username=username,
password=password,
database=name,
host=host,
port=port
))
self.proxy = proxy.Proxy(self.client, token, self.channel_logger)
async def on_error(self, evt, *args, **kwargs):
self.logger.exception("Error while handling event {} with arguments {}:".format(evt, args))
def run():
pool = connect_to_database()
async def on_ready(self):
self.logger.info("Connected to Discord.")
self.logger.info("- Account: {}#{}".format(self.client.user.name, self.client.user.discriminator))
self.logger.info("- User ID: {}".format(self.client.user.id))
self.logger.info("- {} servers".format(len(self.client.servers)))
async def create_tables():
async with pool.acquire() as conn:
await db.create_tables(conn)
# Set playing message
# TODO: change this when merging rewrite-port branch, kwarg game -> activity
await self.client.change_presence(game=discord.Game(name="pk;help"))
asyncio.get_event_loop().run_until_complete(create_tables())
async def on_message(self, message):
# Ignore bot messages
client = discord.Client()
logger = channel_logger.ChannelLogger(client)
@client.event
async def on_ready():
print("PluralKit started.")
print("User: {}#{} (ID: {})".format(client.user.name, client.user.discriminator, client.user.id))
print("{} servers".format(len(client.guilds)))
print("{} shards".format(client.shard_count or 1))
await client.change_presence(activity=discord.Game(name="pk;help"))
@client.event
async def on_message(message: discord.Message):
# Ignore messages from bots
if message.author.bot:
return
try:
if await self.handle_command_dispatch(message):
# Grab a database connection from the pool
async with pool.acquire() as conn:
# First pass: do command handling
did_run_command = await commands.command_dispatch(client, message, conn)
if did_run_command:
return
if await self.handle_proxy_dispatch(message):
return
except Exception:
await self.log_error_in_channel(message)
# Second pass: do proxy matching
await proxy.try_proxy_message(conn, message, logger)
async def on_socket_raw_receive(self, msg):
# Since on_reaction_add is buggy (only works for messages the bot's already cached, ie. no old messages)
# we parse socket data manually for the reaction add event
if isinstance(msg, str):
try:
msg_data = json.loads(msg)
if msg_data.get("t") == "MESSAGE_REACTION_ADD":
evt_data = msg_data.get("d")
if evt_data:
user_id = evt_data["user_id"]
message_id = evt_data["message_id"]
emoji = evt_data["emoji"]["name"]
@client.event
async def on_raw_message_delete(payload: discord.RawMessageDeleteEvent):
async with pool.acquire() as conn:
await proxy.handle_deleted_message(conn, client, payload.message_id, None, logger)
async with self.pool.acquire() as conn:
await self.proxy.handle_reaction(conn, user_id, message_id, emoji)
elif msg_data.get("t") == "MESSAGE_DELETE":
evt_data = msg_data.get("d")
if evt_data:
message_id = evt_data["id"]
async with self.pool.acquire() as conn:
await self.proxy.handle_deletion(conn, message_id)
except ValueError:
pass
@client.event
async def on_raw_bulk_message_delete(payload: discord.RawBulkMessageDeleteEvent):
async with pool.acquire() as conn:
for message_id in payload.message_ids:
await proxy.handle_deleted_message(conn, client, message_id, None, logger)
async def handle_command_dispatch(self, message):
async with self.pool.acquire() as conn:
result = await commands.command_dispatch(self.client, message, conn)
return result
@client.event
async def on_raw_reaction_add(payload: discord.RawReactionActionEvent):
if payload.emoji.name == "\u274c": # Red X
async with pool.acquire() as conn:
await proxy.try_delete_by_reaction(conn, client, payload.message_id, payload.user_id, logger)
async def handle_proxy_dispatch(self, message):
# Try doing proxy parsing
async with self.pool.acquire() as conn:
return await self.proxy.try_proxy_message(conn, message)
async def log_error_in_channel(self, message):
channel_id = os.environ["LOG_CHANNEL"]
if not channel_id:
@client.event
async def on_error(event_name, *args, **kwargs):
log_channel_id = os.environ["LOG_CHANNEL"]
if not log_channel_id:
return
channel = self.client.get_channel(channel_id)
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.server.id if message.server else None,
message.channel.id
)
log_channel = client.get_channel(int(log_channel_id))
await self.client.send_message(channel, "```python\n{}```".format(traceback.format_exc()), embed=embed)
async def run(self):
try:
self.logger.info("Connecting to database...")
self.pool = await db.connect(
os.environ["DATABASE_USER"],
os.environ["DATABASE_PASS"],
os.environ["DATABASE_NAME"],
os.environ["DATABASE_HOST"],
int(os.environ["DATABASE_PORT"])
# If this is a message event, we can attach additional information in an event
# ie. username, channel, content, etc
if args and isinstance(args[0], discord.Message):
message: discord.Message = args[0]
embed = embeds.exception_log(
message.content,
message.author.name,
message.author.discriminator,
message.author.id,
message.guild.id if message.guild else None,
message.channel.id
)
else:
# If not, just post the string itself
embed = None
self.logger.info("Attempting to create tables...")
async with self.pool.acquire() as conn:
await db.create_tables(conn)
traceback_str = "```python\n{}```".format(traceback.format_exc())
await log_channel.send(content=traceback_str, embed=embed)
self.logger.info("Connecting to Discord...")
await self.client.start(self.token)
finally:
self.logger.info("Logging out from Discord...")
await self.client.logout()
bot_token = os.environ["TOKEN"]
if not bot_token:
print("No token specified. Please pass a valid Discord bot token in the TOKEN environment variable.",
file=sys.stderr)
sys.exit(1)
client.run(bot_token)