diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c68c2d4..33624a9 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -42,7 +42,7 @@ jobs: - name: Run Pytest run: pytest tests linting: - needs: [tests] + needs: [dependencies] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -62,7 +62,7 @@ jobs: - name: Linting run: pylint didier database typing: - needs: [tests] + needs: [dependencies] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 @@ -82,7 +82,7 @@ jobs: - name: Typing run: mypy didier database formatting: - needs: [tests] + needs: [dependencies] runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 diff --git a/didier/cogs/school.py b/didier/cogs/school.py new file mode 100644 index 0000000..9f9cadf --- /dev/null +++ b/didier/cogs/school.py @@ -0,0 +1,63 @@ +from typing import Optional + +import discord +from discord import app_commands +from discord.ext import commands + +from didier import Didier + + +class School(commands.Cog): + """School-related commands""" + + client: Didier + + # Context-menu references + _pin_ctx_menu: app_commands.ContextMenu + + def __init__(self, client: Didier): + self.client = client + + self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx) + self.client.tree.add_command(self._pin_ctx_menu) + + async def cog_unload(self) -> None: + """Remove the commands when the cog is unloaded""" + self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) + + @commands.command(name="Pin", usage="[Message]") + async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None): + """Pin a message in the current channel""" + # If no message was passed, allow replying to the message that should be pinned + if message is None and ctx.message.reference is not None: + message = await self.client.resolve_message(ctx.message.reference) + + # Didn't fix it, sad + if message is None: + return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10) + + if message.is_system(): + return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.") + + await message.pin(reason=f"Didier Pin door {ctx.author.display_name}") + await message.add_reaction("📌") + + async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message): + """Pin a message in the current channel""" + # Is already pinned + if message.pinned: + return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True) + + if message.is_system(): + return await interaction.response.send_message( + "Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True + ) + + await message.pin(reason=f"Didier Pin door {interaction.user.display_name}") + await message.add_reaction("📌") + return await interaction.response.send_message("📌", ephemeral=True) + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(School(client)) diff --git a/didier/cogs/test_cog.py b/didier/cogs/test_cog.py deleted file mode 100644 index 093aaf7..0000000 --- a/didier/cogs/test_cog.py +++ /dev/null @@ -1,14 +0,0 @@ -from discord.ext import commands - -from didier import Didier - - -class TestCog(commands.Cog): - client: Didier - - def __init__(self, client: Didier): - self.client = client - - -async def setup(client: Didier): - await client.add_cog(TestCog(client)) diff --git a/didier/didier.py b/didier/didier.py index f4b36c5..0b5b457 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -1,4 +1,6 @@ import os +import sys +import traceback import discord from discord.ext import commands @@ -12,7 +14,7 @@ from didier.utils.prefix import get_prefix class Didier(commands.Bot): """DIDIER <3""" - initial_extensions: tuple[str] = () + initial_extensions: tuple[str, ...] = () def __init__(self): activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE) @@ -33,29 +35,63 @@ class Didier(commands.Bot): async def setup_hook(self) -> None: """Hook called once the bot is initialised""" - await self._load_initial_cogs() - await self._load_directory_cogs("didier/cogs") + # Load extensions + await self._load_initial_extensions() + await self._load_directory_extensions("didier/cogs") + + # Sync application commands to the test guild + for guild in settings.DISCORD_TEST_GUILDS: + guild_object = discord.Object(id=guild) + + self.tree.copy_global_to(guild=guild_object) + await self.tree.sync(guild=guild_object) @property def db_session(self) -> AsyncSession: """Obtain a database session""" return DBSession() + async def _load_initial_extensions(self): + """Load all extensions that should be loaded before the others""" + for extension in self.initial_extensions: + await self.load_extension(f"didier.{extension}") + + async def _load_directory_extensions(self, path: str): + """Load all extensions in a given directory""" + load_path = path.removeprefix("./").replace("/", ".") + parent_path = load_path.removeprefix("didier.") + + # Check every file in the directory + for file in os.listdir(path): + # Construct a path that includes all parent packages in order to + # Allow checking against initial extensions more easily + full_name = parent_path + file + + # Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__) + # Also ignore the files that we have already loaded previously + if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions): + await self.load_extension(f"{load_path}.{file[:-3]}") + elif os.path.isdir(new_path := f"{path}/{file}"): + await self._load_directory_extensions(new_path) + + async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: + """Fetch a message from a reference""" + # Message is in the cache, return it + if reference.cached_message is not None: + return reference.cached_message + + # For older messages: fetch them from the API + channel = self.get_channel(reference.channel_id) + return await channel.fetch_message(reference.message_id) + async def on_ready(self): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) - async def _load_initial_cogs(self): - """Load all cogs""" - for extension in self.initial_extensions: - await self.load_extension(f"didier.cogs.{extension}") - - async def _load_directory_cogs(self, path: str): - """Load all cogs in a given directory""" - load_path = path.removeprefix("./").replace("/", ".") - - for file in os.listdir(path): - if file.endswith(".py") and not file.startswith("_") and not file.startswith(self.initial_extensions): - await self.load_extension(f"{load_path}.{file[:-3]}") - elif os.path.isdir(new_path := f"{path}/{file}"): - await self._load_directory_cogs(new_path) + async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: + """Event triggered when a regular command errors""" + # If developing, print everything to stdout so you don't have to + # check the logs all the time + if settings.SANDBOX: + print(traceback.format_exc(), file=sys.stderr) + return diff --git a/didier/utils/prefix.py b/didier/utils/prefix.py index 53e5a3e..a096920 100644 --- a/didier/utils/prefix.py +++ b/didier/utils/prefix.py @@ -11,7 +11,7 @@ def get_prefix(client: commands.Bot, message: Message) -> str: This is done dynamically to allow variable amounts of whitespace, and through regexes to allow case-insensitivity among other things. """ - mention = f"<@!{client.user.id}>" + mention = f"<@!?{client.user.id}>" regex = r"^({})\s*" # Check which prefix was used diff --git a/pyproject.toml b/pyproject.toml index 5ca139e..a5ef1e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,8 @@ plugins = [ [tool.pylint.master] disable = [ "missing-module-docstring", - "too-few-public-methods" + "too-few-public-methods", + "too-many-arguments" ] [tool.pylint.format]