didier/didier/didier.py

122 lines
4.5 KiB
Python
Raw Normal View History

2022-06-15 01:56:18 +02:00
import os
import sys
import traceback
2022-06-15 01:56:18 +02:00
import discord
2022-06-21 23:58:21 +02:00
from discord import Message
from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession
import settings
from database.engine import DBSession
from didier.utils.discord.prefix import get_prefix
class Didier(commands.Bot):
"""DIDIER <3"""
2022-06-16 00:34:27 +02:00
initial_extensions: tuple[str, ...] = ()
2022-06-15 01:56:18 +02:00
def __init__(self):
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
status = discord.Status.online
intents = discord.Intents(
guilds=True,
members=True,
message_content=True,
emojis=True,
messages=True,
reactions=True,
)
super().__init__(
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
)
2022-06-15 01:56:18 +02:00
async def setup_hook(self) -> None:
"""Hook called once the bot is initialised"""
# Load extensions
await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs")
# Sync application commands to the test guild
for guild in settings.DISCORD_TEST_GUILDS:
guild_object = discord.Object(id=guild)
self.tree.copy_global_to(guild=guild_object)
await self.tree.sync(guild=guild_object)
2022-06-15 01:56:18 +02:00
@property
def db_session(self) -> AsyncSession:
"""Obtain a database session"""
return DBSession()
2022-06-15 01:56:18 +02:00
async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others"""
2022-06-15 01:56:18 +02:00
for extension in self.initial_extensions:
await self.load_extension(f"didier.{extension}")
2022-06-15 01:56:18 +02:00
async def _load_directory_extensions(self, path: str):
"""Load all extensions in a given directory"""
2022-06-15 01:56:18 +02:00
load_path = path.removeprefix("./").replace("/", ".")
parent_path = load_path.removeprefix("didier.")
2022-06-15 01:56:18 +02:00
# Check every file in the directory
2022-06-15 01:56:18 +02:00
for file in os.listdir(path):
# Construct a path that includes all parent packages in order to
# Allow checking against initial extensions more easily
full_name = parent_path + file
# Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__)
# Also ignore the files that we have already loaded previously
if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions):
2022-06-15 01:56:18 +02:00
await self.load_extension(f"{load_path}.{file[:-3]}")
elif os.path.isdir(new_path := f"{path}/{file}"):
await self._load_directory_extensions(new_path)
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
"""Fetch a message from a reference"""
# Message is in the cache, return it
if reference.cached_message is not None:
return reference.cached_message
# For older messages: fetch them from the API
channel = self.get_channel(reference.channel_id)
return await channel.fetch_message(reference.message_id)
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
2022-06-21 23:58:21 +02:00
async def on_message(self, message: Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if self._try_invoke_custom_command(message):
return
await self.process_commands(message)
async def _try_invoke_custom_command(self, message: Message) -> bool:
"""Check if the message tries to invoke a custom command
If it does, send the reply associated with it
"""
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to
# check the logs all the time
if settings.SANDBOX:
print(traceback.format_exc(), file=sys.stderr)
return