diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index afd41ce..e06a4ff 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -34,8 +34,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo raise NoResultFoundException # Check if the alias exists (either as an alias or as a name) - alias_instance = await get_command(session, alias) - if alias_instance is not None: + if await get_command(session, alias) is not None: raise DuplicateInsertException alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance) @@ -47,7 +46,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: """Try to get a command out of a message""" - # Search lowercase & without spaces, and strip the prefix + # Search lowercase & without spaces message = clean_name(message) return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message)) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index af50633..1e2d6b3 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -1,9 +1,14 @@ from typing import Optional import discord +from discord import app_commands from discord.ext import commands +from database.crud import custom_commands +from database.exceptions.constraints import DuplicateInsertException +from database.exceptions.not_found import NoResultFoundException from didier import Didier +from didier.data.modals.custom_commands import CreateCustomCommand class Owner(commands.Cog): @@ -11,16 +16,76 @@ class Owner(commands.Cog): client: Didier + # Slash groups + add_slash = app_commands.Group(name="add", description="Add something new to the database") + def __init__(self, client: Didier): self.client = client + async def cog_check(self, ctx: commands.Context) -> bool: + """Global check for every command in this cog, so we don't have to add + is_owner() to every single command separately + """ + # pylint: disable=W0236 # Pylint thinks this can't be async, but it can + return await self.client.is_owner(ctx.author) + @commands.command(name="Sync") - @commands.is_owner() async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): """Sync all application-commands in Discord""" - await self.client.tree.sync(guild=guild) + if guild is not None: + self.client.tree.copy_global_to(guild=guild) + await self.client.tree.sync(guild=guild) + else: + self.client.tree.clear_commands(guild=None) + await self.client.tree.sync() + await ctx.message.add_reaction("🔄") + @commands.group(name="Add", case_insensitive=True, invoke_without_command=False) + async def add_msg(self, ctx: commands.Context): + """Command group for [add X] message commands""" + + @add_msg.command(name="Custom") + async def add_custom(self, ctx: commands.Context, name: str, *, response: str): + """Add a new custom command""" + async with self.client.db_session as session: + try: + await custom_commands.create_command(session, name, response) + await self.client.confirm_message(ctx.message) + except DuplicateInsertException: + await ctx.reply("Er bestaat al een commando met deze naam.") + await self.client.reject_message(ctx.message) + + @add_msg.command(name="Alias") + async def add_alias(self, ctx: commands.Context, command: str, alias: str): + """Add a new alias for a custom command""" + async with self.client.db_session as session: + try: + await custom_commands.create_alias(session, command, alias) + await self.client.confirm_message(ctx.message) + except NoResultFoundException: + await ctx.reply(f'Geen commando gevonden voor "{command}".') + await self.client.reject_message(ctx.message) + except DuplicateInsertException: + await ctx.reply("Er bestaat al een commando met deze naam.") + await self.client.reject_message(ctx.message) + + @add_slash.command(name="custom", description="Add a custom command") + async def add_custom_slash(self, interaction: discord.Interaction): + """Slash command to add a custom command""" + if not self.client.is_owner(interaction.user): + return interaction.response.send_message( + "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True + ) + + await interaction.response.defer(ephemeral=True) + modal = CreateCustomCommand() + await interaction.response.send_message(modal) + + @commands.group(name="Edit") + async def edit(self, ctx: commands.Context): + """Command group for [edit X] commands""" + async def setup(client: Didier): """Load the cog""" diff --git a/didier/data/modals/__init__.py b/didier/data/modals/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/data/modals/custom_commands.py b/didier/data/modals/custom_commands.py new file mode 100644 index 0000000..18d9809 --- /dev/null +++ b/didier/data/modals/custom_commands.py @@ -0,0 +1,20 @@ +import traceback + +import discord + + +class CreateCustomCommand(discord.ui.Modal, title="Custom Command"): + """Modal shown to visually create custom commands""" + + name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command...") + + response: discord.ui.TextInput = discord.ui.TextInput( + label="Response", style=discord.TextStyle.long, placeholder="Response of the command...", max_length=2000 + ) + + async def on_submit(self, interaction: discord.Interaction) -> None: + await interaction.response.send_message("Submitted", ephemeral=True) + + async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore + await interaction.response.send_message("Errored", ephemeral=True) + traceback.print_tb(error.__traceback__) diff --git a/didier/didier.py b/didier/didier.py index b42dfcc..a36e57c 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -3,11 +3,11 @@ import sys import traceback import discord -from discord import Message from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession import settings +from database.crud import custom_commands from database.engine import DBSession from didier.utils.discord.prefix import get_prefix @@ -34,24 +34,17 @@ class Didier(commands.Bot): command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status ) + @property + def db_session(self) -> AsyncSession: + """Obtain a database session""" + return DBSession() + async def setup_hook(self) -> None: """Hook called once the bot is initialised""" # Load extensions await self._load_initial_extensions() await self._load_directory_extensions("didier/cogs") - # Sync application commands to the test guild - for guild in settings.DISCORD_TEST_GUILDS: - guild_object = discord.Object(id=guild) - - self.tree.copy_global_to(guild=guild_object) - await self.tree.sync(guild=guild_object) - - @property - def db_session(self) -> AsyncSession: - """Obtain a database session""" - return DBSession() - async def _load_initial_extensions(self): """Load all extensions that should be loaded before the others""" for extension in self.initial_extensions: @@ -85,11 +78,19 @@ class Didier(commands.Bot): channel = self.get_channel(reference.channel_id) return await channel.fetch_message(reference.message_id) + async def confirm_message(self, message: discord.Message): + """Add a checkmark to a message""" + await message.add_reaction("✅") + + async def reject_message(self, message: discord.Message): + """Add an X to a message""" + await message.add_reaction("❌") + async def on_ready(self): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) - async def on_message(self, message: Message, /) -> None: + async def on_message(self, message: discord.Message, /) -> None: """Event triggered when a message is sent""" # Ignore messages by bots if message.author.bot: @@ -100,18 +101,32 @@ class Didier(commands.Bot): await message.add_reaction(settings.DISCORD_BOOS_REACT) # Potential custom command - if self._try_invoke_custom_command(message): + if await self._try_invoke_custom_command(message): return await self.process_commands(message) - async def _try_invoke_custom_command(self, message: Message) -> bool: + async def _try_invoke_custom_command(self, message: discord.Message) -> bool: """Check if the message tries to invoke a custom command If it does, send the reply associated with it """ + # Doesn't start with the custom command prefix if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): return False + async with self.db_session as session: + # Remove the prefix + content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :] + command = await custom_commands.get_command(session, content) + + # Command found + if command is not None: + await message.reply(command.response, mention_author=False) + return True + + # Nothing found + return False + async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: """Event triggered when a regular command errors""" # If developing, print everything to stdout so you don't have to diff --git a/tests/conftest.py b/tests/conftest.py index a74ba4c..c2c5e0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ -from typing import AsyncGenerator +import asyncio +from typing import AsyncGenerator, Generator from unittest.mock import MagicMock import pytest @@ -11,7 +12,14 @@ from didier import Didier @pytest.fixture(scope="session") -def tables(): +def event_loop() -> Generator: + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +@pytest.fixture(scope="session") +def tables(event_loop): """Initialize a database before the tests, and then tear it down again Starts from an empty database and runs through all the migrations to check those as well while we're at it @@ -23,7 +31,7 @@ def tables(): @pytest.fixture -async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: +async def database_session(tables, event_loop) -> AsyncGenerator[AsyncSession, None]: """Fixture to create a session for every test Rollbacks the transaction afterwards so that the future tests start with a clean database """