diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index e06a4ff..4f5dd2f 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -65,3 +65,23 @@ async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[ return None return alias.command + + +async def edit_command( + session: AsyncSession, original_name: str, new_name: Optional[str] = None, new_response: Optional[str] = None +) -> CustomCommand: + """Edit an existing command""" + # Check if the command exists + command = await get_command(session, original_name) + if command is None: + raise NoResultFoundException + + if new_name is not None: + command.name = new_name + if new_response is not None: + command.response = new_response + + session.add(command) + await session.commit() + + return command diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 91fe500..0bde065 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -8,7 +8,8 @@ from database.crud import custom_commands from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier -from didier.data.modals.custom_commands import CreateCustomCommand +from didier.data.modals.custom_commands import CreateCustomCommand, EditCustomCommand +from didier.data.flags.owner import EditCustomFlags class Owner(commands.Cog): @@ -18,6 +19,7 @@ class Owner(commands.Cog): # Slash groups add_slash = app_commands.Group(name="add", description="Add something new to the database") + edit_slash = app_commands.Group(name="edit", description="Edit an existing database entry") def __init__(self, client: Didier): self.client = client @@ -29,6 +31,11 @@ class Owner(commands.Cog): # pylint: disable=W0236 # Pylint thinks this can't be async, but it can return await self.client.is_owner(ctx.author) + @commands.command(name="Error") + async def _error(self, ctx: commands.Context): + """Raise an exception for debugging purposes""" + raise Exception("Debug") + @commands.command(name="Sync") async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): """Sync all application-commands in Discord""" @@ -77,14 +84,43 @@ class Owner(commands.Cog): "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True ) - # await interaction.response.defer(ephemeral=True) - modal = CreateCustomCommand() + modal = CreateCustomCommand(self.client) await interaction.response.send_modal(modal) - @commands.group(name="Edit") - async def edit(self, ctx: commands.Context): + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) + async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" + @edit_msg.command(name="Custom") + async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): + """Edit an existing custom command""" + async with self.client.db_session as session: + try: + await custom_commands.edit_command(session, command, flags.name, flags.response) + return await self.client.confirm_message(ctx.message) + except NoResultFoundException: + await ctx.reply(f"Geen commando gevonden voor ``{command}``.") + return await self.client.reject_message(ctx.message) + + @edit_slash.command(name="custom", description="Edit a custom command") + @app_commands.describe(command="The name of the command to edit") + async def edit_custom_slash(self, interaction: discord.Interaction, command: str): + """Slash command to edit a custom command""" + if not await self.client.is_owner(interaction.user): + return interaction.response.send_message( + "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True + ) + + async with self.client.db_session as session: + _command = await custom_commands.get_command(session, command) + if _command is None: + return await interaction.response.send_message( + f"Geen commando gevonden voor ``{command}``.", ephemeral=True + ) + + modal = EditCustomCommand(self.client, _command.name, _command.response) + await interaction.response.send_modal(modal) + async def setup(client: Didier): """Load the cog""" diff --git a/didier/data/flags/__init__.py b/didier/data/flags/__init__.py new file mode 100644 index 0000000..1ef3b46 --- /dev/null +++ b/didier/data/flags/__init__.py @@ -0,0 +1 @@ +from .posix import PosixFlags diff --git a/didier/data/flags/owner.py b/didier/data/flags/owner.py new file mode 100644 index 0000000..e3e9789 --- /dev/null +++ b/didier/data/flags/owner.py @@ -0,0 +1,10 @@ +from typing import Optional + +from didier.data.flags import PosixFlags + + +class EditCustomFlags(PosixFlags): + """Flags for the edit custom command""" + + name: Optional[str] = None + response: Optional[str] = None diff --git a/didier/data/flags/posix.py b/didier/data/flags/posix.py new file mode 100644 index 0000000..582fd4e --- /dev/null +++ b/didier/data/flags/posix.py @@ -0,0 +1,14 @@ +from discord.ext import commands + + +class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): + """Base class to add POSIX-like flags to commands + + Example usage: + >>> class Flags(PosixFlags): + >>> name: str + >>> async def command(ctx, *, flags: Flags): + >>> ... + This can now be called in Discord as + command --name here-be-name + """ diff --git a/didier/data/modals/custom_commands.py b/didier/data/modals/custom_commands.py index 35f7f74..90f378c 100644 --- a/didier/data/modals/custom_commands.py +++ b/didier/data/modals/custom_commands.py @@ -2,19 +2,66 @@ import traceback import discord +from database.crud.custom_commands import create_command, edit_command +from didier import Didier -class CreateCustomCommand(discord.ui.Modal, title="Custom Command"): - """Modal shown to visually create custom commands""" - name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command") +class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"): + """Modal to create new custom commands""" + + name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Didier") response: discord.ui.TextInput = discord.ui.TextInput( - label="Response", style=discord.TextStyle.long, placeholder="Response of the command", max_length=2000 + label="Response", style=discord.TextStyle.long, placeholder="Hmm?", max_length=2000 ) - async def on_submit(self, interaction: discord.Interaction) -> None: - await interaction.response.send_message("Submitted", ephemeral=True) + client: Didier - async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore - await interaction.response.send_message("Errored", ephemeral=True) + def __init__(self, client: Didier, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + + async def on_submit(self, interaction: discord.Interaction): + async with self.client.db_session as session: + command = await create_command(session, self.name.value, self.response.value) + + await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) + + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__) + + +class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): + """Modal to edit an existing custom command + Fills in the current values as defaults + """ + + name: discord.ui.TextInput + response: discord.ui.TextInput + + original_name: str + + client: Didier + + def __init__(self, client: Didier, name: str, response: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.original_name = name + self.client = client + + self.name = self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name)) + self.response = self.add_item( + discord.ui.TextInput( + label="Response", placeholder="Hmm?", default=response, style=discord.TextStyle.long, max_length=2000 + ) + ) + + async def on_submit(self, interaction: discord.Interaction): + async with self.client.db_session as session: + await edit_command(session, self.original_name, self.name.value, self.response.value) + + await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) + + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/didier.py b/didier/didier.py index a36e57c..5cefe9f 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -129,8 +129,9 @@ class Didier(commands.Bot): async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: """Event triggered when a regular command errors""" - # If developing, print everything to stdout so you don't have to - # check the logs all the time + # Print everything to the logs/stderr + await super().on_command_error(context, exception) + + # If developing, do nothing special if settings.SANDBOX: - print(traceback.format_exc(), file=sys.stderr) return diff --git a/main.py b/main.py index 8c8168b..34a2cf3 100644 --- a/main.py +++ b/main.py @@ -1,4 +1,5 @@ import logging +import sys from logging.handlers import RotatingFileHandler import asyncio @@ -18,15 +19,26 @@ def setup_logging(): """Configure custom loggers""" max_log_size = 32 * 1024 * 1024 + # Configure Didier handler didier_log = logging.getLogger("didier") - handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) - handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s")) + didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) + didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s")) - didier_log.addHandler(handler) + didier_log.addHandler(didier_handler) didier_log.setLevel(logging.INFO) - logging.getLogger("discord").setLevel(logging.ERROR) + # Configure discord handler + discord_log = logging.getLogger("discord") + + # Make dev print to stderr instead, so you don't have to watch the file + if settings.SANDBOX: + discord_handler = logging.StreamHandler(sys.stderr) + else: + discord_handler = RotatingFileHandler("discord.log", mode="a", maxBytes=max_log_size, backupCount=5) + + discord_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] %(name)s: %(message)s")) + discord_log.addHandler(discord_handler) async def main(): diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index 5f41983..a5c4092 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -4,7 +4,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException -from database.models import CustomCommand, CustomCommandAlias +from database.exceptions.not_found import NoResultFoundException +from database.models import CustomCommand async def test_create_command_non_existing(database_session: AsyncSession): @@ -33,7 +34,7 @@ async def test_create_command_name_is_alias(database_session: AsyncSession): await crud.create_command(database_session, "n", "other response") -async def test_create_alias_non_existing(database_session: AsyncSession): +async def test_create_alias(database_session: AsyncSession): """Test creating an alias when the name is still free""" command = await crud.create_command(database_session, "name", "response") await crud.create_alias(database_session, command.name, "n") @@ -43,6 +44,12 @@ async def test_create_alias_non_existing(database_session: AsyncSession): assert command.aliases[0].alias == "n" +async def test_create_alias_non_existing(database_session: AsyncSession): + """Test creating an alias when the command doesn't exist""" + with pytest.raises(NoResultFoundException): + await crud.create_alias(database_session, "name", "alias") + + async def test_create_alias_duplicate(database_session: AsyncSession): """Test creating an alias when another alias already has this name""" command = await crud.create_command(database_session, "name", "response") @@ -96,3 +103,17 @@ async def test_get_command_by_alias(database_session: AsyncSession): async def test_get_command_non_existing(database_session: AsyncSession): """Test getting a command when it doesn't exist""" assert await crud.get_command(database_session, "name") is None + + +async def test_edit_command(database_session: AsyncSession): + """Test editing an existing command""" + command = await crud.create_command(database_session, "name", "response") + await crud.edit_command(database_session, command.name, "new name", "new response") + assert command.name == "new name" + assert command.response == "new response" + + +async def test_edit_command_non_existing(database_session: AsyncSession): + """Test editing a command that doesn't exist""" + with pytest.raises(NoResultFoundException): + await crud.edit_command(database_session, "name", "n", "r")