diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index e06a4ff..afd41ce 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -34,7 +34,8 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo raise NoResultFoundException # Check if the alias exists (either as an alias or as a name) - if await get_command(session, alias) is not None: + alias_instance = await get_command(session, alias) + if alias_instance is not None: raise DuplicateInsertException alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance) @@ -46,7 +47,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: """Try to get a command out of a message""" - # Search lowercase & without spaces + # Search lowercase & without spaces, and strip the prefix message = clean_name(message) return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message)) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 1e2d6b3..af50633 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -1,14 +1,9 @@ from typing import Optional import discord -from discord import app_commands from discord.ext import commands -from database.crud import custom_commands -from database.exceptions.constraints import DuplicateInsertException -from database.exceptions.not_found import NoResultFoundException from didier import Didier -from didier.data.modals.custom_commands import CreateCustomCommand class Owner(commands.Cog): @@ -16,76 +11,16 @@ class Owner(commands.Cog): client: Didier - # Slash groups - add_slash = app_commands.Group(name="add", description="Add something new to the database") - def __init__(self, client: Didier): self.client = client - async def cog_check(self, ctx: commands.Context) -> bool: - """Global check for every command in this cog, so we don't have to add - is_owner() to every single command separately - """ - # pylint: disable=W0236 # Pylint thinks this can't be async, but it can - return await self.client.is_owner(ctx.author) - @commands.command(name="Sync") + @commands.is_owner() async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): """Sync all application-commands in Discord""" - if guild is not None: - self.client.tree.copy_global_to(guild=guild) - await self.client.tree.sync(guild=guild) - else: - self.client.tree.clear_commands(guild=None) - await self.client.tree.sync() - + await self.client.tree.sync(guild=guild) await ctx.message.add_reaction("🔄") - @commands.group(name="Add", case_insensitive=True, invoke_without_command=False) - async def add_msg(self, ctx: commands.Context): - """Command group for [add X] message commands""" - - @add_msg.command(name="Custom") - async def add_custom(self, ctx: commands.Context, name: str, *, response: str): - """Add a new custom command""" - async with self.client.db_session as session: - try: - await custom_commands.create_command(session, name, response) - await self.client.confirm_message(ctx.message) - except DuplicateInsertException: - await ctx.reply("Er bestaat al een commando met deze naam.") - await self.client.reject_message(ctx.message) - - @add_msg.command(name="Alias") - async def add_alias(self, ctx: commands.Context, command: str, alias: str): - """Add a new alias for a custom command""" - async with self.client.db_session as session: - try: - await custom_commands.create_alias(session, command, alias) - await self.client.confirm_message(ctx.message) - except NoResultFoundException: - await ctx.reply(f'Geen commando gevonden voor "{command}".') - await self.client.reject_message(ctx.message) - except DuplicateInsertException: - await ctx.reply("Er bestaat al een commando met deze naam.") - await self.client.reject_message(ctx.message) - - @add_slash.command(name="custom", description="Add a custom command") - async def add_custom_slash(self, interaction: discord.Interaction): - """Slash command to add a custom command""" - if not self.client.is_owner(interaction.user): - return interaction.response.send_message( - "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True - ) - - await interaction.response.defer(ephemeral=True) - modal = CreateCustomCommand() - await interaction.response.send_message(modal) - - @commands.group(name="Edit") - async def edit(self, ctx: commands.Context): - """Command group for [edit X] commands""" - async def setup(client: Didier): """Load the cog""" diff --git a/didier/data/modals/__init__.py b/didier/data/modals/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/didier/data/modals/custom_commands.py b/didier/data/modals/custom_commands.py deleted file mode 100644 index 18d9809..0000000 --- a/didier/data/modals/custom_commands.py +++ /dev/null @@ -1,20 +0,0 @@ -import traceback - -import discord - - -class CreateCustomCommand(discord.ui.Modal, title="Custom Command"): - """Modal shown to visually create custom commands""" - - name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command...") - - response: discord.ui.TextInput = discord.ui.TextInput( - label="Response", style=discord.TextStyle.long, placeholder="Response of the command...", max_length=2000 - ) - - async def on_submit(self, interaction: discord.Interaction) -> None: - await interaction.response.send_message("Submitted", ephemeral=True) - - async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore - await interaction.response.send_message("Errored", ephemeral=True) - traceback.print_tb(error.__traceback__) diff --git a/didier/didier.py b/didier/didier.py index a36e57c..b42dfcc 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -3,11 +3,11 @@ import sys import traceback import discord +from discord import Message from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession import settings -from database.crud import custom_commands from database.engine import DBSession from didier.utils.discord.prefix import get_prefix @@ -34,17 +34,24 @@ class Didier(commands.Bot): command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status ) - @property - def db_session(self) -> AsyncSession: - """Obtain a database session""" - return DBSession() - async def setup_hook(self) -> None: """Hook called once the bot is initialised""" # Load extensions await self._load_initial_extensions() await self._load_directory_extensions("didier/cogs") + # Sync application commands to the test guild + for guild in settings.DISCORD_TEST_GUILDS: + guild_object = discord.Object(id=guild) + + self.tree.copy_global_to(guild=guild_object) + await self.tree.sync(guild=guild_object) + + @property + def db_session(self) -> AsyncSession: + """Obtain a database session""" + return DBSession() + async def _load_initial_extensions(self): """Load all extensions that should be loaded before the others""" for extension in self.initial_extensions: @@ -78,19 +85,11 @@ class Didier(commands.Bot): channel = self.get_channel(reference.channel_id) return await channel.fetch_message(reference.message_id) - async def confirm_message(self, message: discord.Message): - """Add a checkmark to a message""" - await message.add_reaction("✅") - - async def reject_message(self, message: discord.Message): - """Add an X to a message""" - await message.add_reaction("❌") - async def on_ready(self): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) - async def on_message(self, message: discord.Message, /) -> None: + async def on_message(self, message: Message, /) -> None: """Event triggered when a message is sent""" # Ignore messages by bots if message.author.bot: @@ -101,32 +100,18 @@ class Didier(commands.Bot): await message.add_reaction(settings.DISCORD_BOOS_REACT) # Potential custom command - if await self._try_invoke_custom_command(message): + if self._try_invoke_custom_command(message): return await self.process_commands(message) - async def _try_invoke_custom_command(self, message: discord.Message) -> bool: + async def _try_invoke_custom_command(self, message: Message) -> bool: """Check if the message tries to invoke a custom command If it does, send the reply associated with it """ - # Doesn't start with the custom command prefix if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): return False - async with self.db_session as session: - # Remove the prefix - content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :] - command = await custom_commands.get_command(session, content) - - # Command found - if command is not None: - await message.reply(command.response, mention_author=False) - return True - - # Nothing found - return False - async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: """Event triggered when a regular command errors""" # If developing, print everything to stdout so you don't have to diff --git a/tests/conftest.py b/tests/conftest.py index c2c5e0a..a74ba4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ -import asyncio -from typing import AsyncGenerator, Generator +from typing import AsyncGenerator from unittest.mock import MagicMock import pytest @@ -12,14 +11,7 @@ from didier import Didier @pytest.fixture(scope="session") -def event_loop() -> Generator: - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -@pytest.fixture(scope="session") -def tables(event_loop): +def tables(): """Initialize a database before the tests, and then tear it down again Starts from an empty database and runs through all the migrations to check those as well while we're at it @@ -31,7 +23,7 @@ def tables(event_loop): @pytest.fixture -async def database_session(tables, event_loop) -> AsyncGenerator[AsyncSession, None]: +async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: """Fixture to create a session for every test Rollbacks the transaction afterwards so that the future tests start with a clean database """