diff --git a/didier/cogs/school.py b/didier/cogs/school.py new file mode 100644 index 0000000..9f9cadf --- /dev/null +++ b/didier/cogs/school.py @@ -0,0 +1,63 @@ +from typing import Optional + +import discord +from discord import app_commands +from discord.ext import commands + +from didier import Didier + + +class School(commands.Cog): + """School-related commands""" + + client: Didier + + # Context-menu references + _pin_ctx_menu: app_commands.ContextMenu + + def __init__(self, client: Didier): + self.client = client + + self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx) + self.client.tree.add_command(self._pin_ctx_menu) + + async def cog_unload(self) -> None: + """Remove the commands when the cog is unloaded""" + self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) + + @commands.command(name="Pin", usage="[Message]") + async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None): + """Pin a message in the current channel""" + # If no message was passed, allow replying to the message that should be pinned + if message is None and ctx.message.reference is not None: + message = await self.client.resolve_message(ctx.message.reference) + + # Didn't fix it, sad + if message is None: + return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10) + + if message.is_system(): + return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.") + + await message.pin(reason=f"Didier Pin door {ctx.author.display_name}") + await message.add_reaction("📌") + + async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message): + """Pin a message in the current channel""" + # Is already pinned + if message.pinned: + return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True) + + if message.is_system(): + return await interaction.response.send_message( + "Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True + ) + + await message.pin(reason=f"Didier Pin door {interaction.user.display_name}") + await message.add_reaction("📌") + return await interaction.response.send_message("📌", ephemeral=True) + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(School(client)) diff --git a/didier/cogs/test_cog.py b/didier/cogs/test_cog.py deleted file mode 100644 index 093aaf7..0000000 --- a/didier/cogs/test_cog.py +++ /dev/null @@ -1,14 +0,0 @@ -from discord.ext import commands - -from didier import Didier - - -class TestCog(commands.Cog): - client: Didier - - def __init__(self, client: Didier): - self.client = client - - -async def setup(client: Didier): - await client.add_cog(TestCog(client)) diff --git a/didier/didier.py b/didier/didier.py index f4b36c5..dc08e53 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -1,4 +1,7 @@ import os +import sys +import traceback +from typing import Union, Optional import discord from discord.ext import commands @@ -33,29 +36,78 @@ class Didier(commands.Bot): async def setup_hook(self) -> None: """Hook called once the bot is initialised""" - await self._load_initial_cogs() - await self._load_directory_cogs("didier/cogs") + # Load extensions + await self._load_initial_extensions() + await self._load_directory_extensions("didier/cogs") + + # Sync application commands to the test guild + for guild in settings.DISCORD_TEST_GUILDS: + guild_object = discord.Object(id=guild) + + self.tree.copy_global_to(guild=guild_object) + await self.tree.sync(guild=guild_object) @property def db_session(self) -> AsyncSession: """Obtain a database session""" return DBSession() + async def _load_initial_extensions(self): + """Load all extensions that should be loaded before the others""" + for extension in self.initial_extensions: + await self.load_extension(f"didier.{extension}") + + async def _load_directory_extensions(self, path: str): + """Load all extensions in a given directory""" + load_path = path.removeprefix("./").replace("/", ".") + parent_path = load_path.removeprefix("didier.") + + # Check every file in the directory + for file in os.listdir(path): + # Construct a path that includes all parent packages in order to + # Allow checking against initial extensions more easily + full_name = parent_path + file + + # Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__) + # Also ignore the files that we have already loaded previously + if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions): + await self.load_extension(f"{load_path}.{file[:-3]}") + elif os.path.isdir(new_path := f"{path}/{file}"): + await self._load_directory_extensions(new_path) + + async def respond( + self, + context: Union[commands.Context, discord.Interaction], + message: str, + mention_author: bool = False, + ephemeral: bool = True, + embeds: Optional[list[discord.Embed]] = None, + ): + """Function to respond to both a normal message and an interaction""" + if isinstance(context, commands.Context): + return await context.reply(message, mention_author=mention_author, embeds=embeds) + + if isinstance(context, discord.Interaction): + return await context.response.send_message(message, ephemeral=ephemeral, embeds=embeds) + + async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: + """Fetch a message from a reference""" + # Message is in the cache, return it + if reference.cached_message is not None: + return reference.cached_message + + # For older messages: fetch them from the API + channel = self.get_channel(reference.channel_id) + return await channel.fetch_message(reference.message_id) + async def on_ready(self): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) - async def _load_initial_cogs(self): - """Load all cogs""" - for extension in self.initial_extensions: - await self.load_extension(f"didier.cogs.{extension}") - - async def _load_directory_cogs(self, path: str): - """Load all cogs in a given directory""" - load_path = path.removeprefix("./").replace("/", ".") - - for file in os.listdir(path): - if file.endswith(".py") and not file.startswith("_") and not file.startswith(self.initial_extensions): - await self.load_extension(f"{load_path}.{file[:-3]}") - elif os.path.isdir(new_path := f"{path}/{file}"): - await self._load_directory_cogs(new_path) + async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: + """Event triggered when a regular command errors""" + # If developing, print everything to stdout so you don't have to + # check the logs all the time + if settings.SANDBOX: + print(traceback.format_exc(), file=sys.stderr) + return diff --git a/didier/utils/prefix.py b/didier/utils/prefix.py index 53e5a3e..a096920 100644 --- a/didier/utils/prefix.py +++ b/didier/utils/prefix.py @@ -11,7 +11,7 @@ def get_prefix(client: commands.Bot, message: Message) -> str: This is done dynamically to allow variable amounts of whitespace, and through regexes to allow case-insensitivity among other things. """ - mention = f"<@!{client.user.id}>" + mention = f"<@!?{client.user.id}>" regex = r"^({})\s*" # Check which prefix was used