mirror of https://github.com/stijndcl/didier
				
				
				
			Editing of custom commands, add posix flags
							parent
							
								
									257eae6fa7
								
							
						
					
					
						commit
						d6a560851b
					
				|  | @ -65,3 +65,23 @@ async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[ | ||||||
|         return None |         return None | ||||||
| 
 | 
 | ||||||
|     return alias.command |     return alias.command | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def edit_command( | ||||||
|  |     session: AsyncSession, original_name: str, new_name: Optional[str] = None, new_response: Optional[str] = None | ||||||
|  | ) -> CustomCommand: | ||||||
|  |     """Edit an existing command""" | ||||||
|  |     # Check if the command exists | ||||||
|  |     command = await get_command(session, original_name) | ||||||
|  |     if command is None: | ||||||
|  |         raise NoResultFoundException | ||||||
|  | 
 | ||||||
|  |     if new_name is not None: | ||||||
|  |         command.name = new_name | ||||||
|  |     if new_response is not None: | ||||||
|  |         command.response = new_response | ||||||
|  | 
 | ||||||
|  |     session.add(command) | ||||||
|  |     await session.commit() | ||||||
|  | 
 | ||||||
|  |     return command | ||||||
|  |  | ||||||
|  | @ -8,7 +8,8 @@ from database.crud import custom_commands | ||||||
| from database.exceptions.constraints import DuplicateInsertException | from database.exceptions.constraints import DuplicateInsertException | ||||||
| from database.exceptions.not_found import NoResultFoundException | from database.exceptions.not_found import NoResultFoundException | ||||||
| from didier import Didier | from didier import Didier | ||||||
| from didier.data.modals.custom_commands import CreateCustomCommand | from didier.data.modals.custom_commands import CreateCustomCommand, EditCustomCommand | ||||||
|  | from didier.data.flags.owner import EditCustomFlags | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| class Owner(commands.Cog): | class Owner(commands.Cog): | ||||||
|  | @ -18,6 +19,7 @@ class Owner(commands.Cog): | ||||||
| 
 | 
 | ||||||
|     # Slash groups |     # Slash groups | ||||||
|     add_slash = app_commands.Group(name="add", description="Add something new to the database") |     add_slash = app_commands.Group(name="add", description="Add something new to the database") | ||||||
|  |     edit_slash = app_commands.Group(name="edit", description="Edit an existing database entry") | ||||||
| 
 | 
 | ||||||
|     def __init__(self, client: Didier): |     def __init__(self, client: Didier): | ||||||
|         self.client = client |         self.client = client | ||||||
|  | @ -29,6 +31,11 @@ class Owner(commands.Cog): | ||||||
|         # pylint: disable=W0236 # Pylint thinks this can't be async, but it can |         # pylint: disable=W0236 # Pylint thinks this can't be async, but it can | ||||||
|         return await self.client.is_owner(ctx.author) |         return await self.client.is_owner(ctx.author) | ||||||
| 
 | 
 | ||||||
|  |     @commands.command(name="Error") | ||||||
|  |     async def _error(self, ctx: commands.Context): | ||||||
|  |         """Raise an exception for debugging purposes""" | ||||||
|  |         raise Exception("Debug") | ||||||
|  | 
 | ||||||
|     @commands.command(name="Sync") |     @commands.command(name="Sync") | ||||||
|     async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): |     async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): | ||||||
|         """Sync all application-commands in Discord""" |         """Sync all application-commands in Discord""" | ||||||
|  | @ -77,14 +84,43 @@ class Owner(commands.Cog): | ||||||
|                 "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True |                 "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True | ||||||
|             ) |             ) | ||||||
| 
 | 
 | ||||||
|         # await interaction.response.defer(ephemeral=True) |         modal = CreateCustomCommand(self.client) | ||||||
|         modal = CreateCustomCommand() |  | ||||||
|         await interaction.response.send_modal(modal) |         await interaction.response.send_modal(modal) | ||||||
| 
 | 
 | ||||||
|     @commands.group(name="Edit") |     @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) | ||||||
|     async def edit(self, ctx: commands.Context): |     async def edit_msg(self, ctx: commands.Context): | ||||||
|         """Command group for [edit X] commands""" |         """Command group for [edit X] commands""" | ||||||
| 
 | 
 | ||||||
|  |     @edit_msg.command(name="Custom") | ||||||
|  |     async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): | ||||||
|  |         """Edit an existing custom command""" | ||||||
|  |         async with self.client.db_session as session: | ||||||
|  |             try: | ||||||
|  |                 await custom_commands.edit_command(session, command, flags.name, flags.response) | ||||||
|  |                 return await self.client.confirm_message(ctx.message) | ||||||
|  |             except NoResultFoundException: | ||||||
|  |                 await ctx.reply(f"Geen commando gevonden voor ``{command}``.") | ||||||
|  |                 return await self.client.reject_message(ctx.message) | ||||||
|  | 
 | ||||||
|  |     @edit_slash.command(name="custom", description="Edit a custom command") | ||||||
|  |     @app_commands.describe(command="The name of the command to edit") | ||||||
|  |     async def edit_custom_slash(self, interaction: discord.Interaction, command: str): | ||||||
|  |         """Slash command to edit a custom command""" | ||||||
|  |         if not await self.client.is_owner(interaction.user): | ||||||
|  |             return interaction.response.send_message( | ||||||
|  |                 "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True | ||||||
|  |             ) | ||||||
|  | 
 | ||||||
|  |         async with self.client.db_session as session: | ||||||
|  |             _command = await custom_commands.get_command(session, command) | ||||||
|  |             if _command is None: | ||||||
|  |                 return await interaction.response.send_message( | ||||||
|  |                     f"Geen commando gevonden voor ``{command}``.", ephemeral=True | ||||||
|  |                 ) | ||||||
|  | 
 | ||||||
|  |             modal = EditCustomCommand(self.client, _command.name, _command.response) | ||||||
|  |             await interaction.response.send_modal(modal) | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| async def setup(client: Didier): | async def setup(client: Didier): | ||||||
|     """Load the cog""" |     """Load the cog""" | ||||||
|  |  | ||||||
|  | @ -0,0 +1 @@ | ||||||
|  | from .posix import PosixFlags | ||||||
|  | @ -0,0 +1,10 @@ | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | from didier.data.flags import PosixFlags | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EditCustomFlags(PosixFlags): | ||||||
|  |     """Flags for the edit custom command""" | ||||||
|  | 
 | ||||||
|  |     name: Optional[str] = None | ||||||
|  |     response: Optional[str] = None | ||||||
|  | @ -0,0 +1,14 @@ | ||||||
|  | from discord.ext import commands | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): | ||||||
|  |     """Base class to add POSIX-like flags to commands | ||||||
|  | 
 | ||||||
|  |     Example usage: | ||||||
|  |         >>> class Flags(PosixFlags): | ||||||
|  |         >>>     name: str | ||||||
|  |         >>> async def command(ctx, *, flags: Flags): | ||||||
|  |         >>>     ... | ||||||
|  |     This can now be called in Discord as | ||||||
|  |     command --name here-be-name | ||||||
|  |     """ | ||||||
|  | @ -2,19 +2,66 @@ import traceback | ||||||
| 
 | 
 | ||||||
| import discord | import discord | ||||||
| 
 | 
 | ||||||
|  | from database.crud.custom_commands import create_command, edit_command | ||||||
|  | from didier import Didier | ||||||
| 
 | 
 | ||||||
| class CreateCustomCommand(discord.ui.Modal, title="Custom Command"): |  | ||||||
|     """Modal shown to visually create custom commands""" |  | ||||||
| 
 | 
 | ||||||
|     name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command") | class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"): | ||||||
|  |     """Modal to create new custom commands""" | ||||||
|  | 
 | ||||||
|  |     name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Didier") | ||||||
| 
 | 
 | ||||||
|     response: discord.ui.TextInput = discord.ui.TextInput( |     response: discord.ui.TextInput = discord.ui.TextInput( | ||||||
|         label="Response", style=discord.TextStyle.long, placeholder="Response of the command", max_length=2000 |         label="Response", style=discord.TextStyle.long, placeholder="Hmm?", max_length=2000 | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     async def on_submit(self, interaction: discord.Interaction) -> None: |     client: Didier | ||||||
|         await interaction.response.send_message("Submitted", ephemeral=True) |  | ||||||
| 
 | 
 | ||||||
|     async def on_error(self, interaction: discord.Interaction, error: Exception) -> None:  # type: ignore |     def __init__(self, client: Didier, *args, **kwargs): | ||||||
|         await interaction.response.send_message("Errored", ephemeral=True) |         super().__init__(*args, **kwargs) | ||||||
|  |         self.client = client | ||||||
|  | 
 | ||||||
|  |     async def on_submit(self, interaction: discord.Interaction): | ||||||
|  |         async with self.client.db_session as session: | ||||||
|  |             command = await create_command(session, self.name.value, self.response.value) | ||||||
|  | 
 | ||||||
|  |         await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) | ||||||
|  | 
 | ||||||
|  |     async def on_error(self, interaction: discord.Interaction, error: Exception):  # type: ignore | ||||||
|  |         await interaction.response.send_message("Something went wrong.", ephemeral=True) | ||||||
|  |         traceback.print_tb(error.__traceback__) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): | ||||||
|  |     """Modal to edit an existing custom command | ||||||
|  |     Fills in the current values as defaults | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     name: discord.ui.TextInput | ||||||
|  |     response: discord.ui.TextInput | ||||||
|  | 
 | ||||||
|  |     original_name: str | ||||||
|  | 
 | ||||||
|  |     client: Didier | ||||||
|  | 
 | ||||||
|  |     def __init__(self, client: Didier, name: str, response: str, *args, **kwargs): | ||||||
|  |         super().__init__(*args, **kwargs) | ||||||
|  |         self.original_name = name | ||||||
|  |         self.client = client | ||||||
|  | 
 | ||||||
|  |         self.name = self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name)) | ||||||
|  |         self.response = self.add_item( | ||||||
|  |             discord.ui.TextInput( | ||||||
|  |                 label="Response", placeholder="Hmm?", default=response, style=discord.TextStyle.long, max_length=2000 | ||||||
|  |             ) | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     async def on_submit(self, interaction: discord.Interaction): | ||||||
|  |         async with self.client.db_session as session: | ||||||
|  |             await edit_command(session, self.original_name, self.name.value, self.response.value) | ||||||
|  | 
 | ||||||
|  |         await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) | ||||||
|  | 
 | ||||||
|  |     async def on_error(self, interaction: discord.Interaction, error: Exception):  # type: ignore | ||||||
|  |         await interaction.response.send_message("Something went wrong.", ephemeral=True) | ||||||
|         traceback.print_tb(error.__traceback__) |         traceback.print_tb(error.__traceback__) | ||||||
|  |  | ||||||
|  | @ -129,8 +129,9 @@ class Didier(commands.Bot): | ||||||
| 
 | 
 | ||||||
|     async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: |     async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: | ||||||
|         """Event triggered when a regular command errors""" |         """Event triggered when a regular command errors""" | ||||||
|         # If developing, print everything to stdout so you don't have to |         # Print everything to the logs/stderr | ||||||
|         # check the logs all the time |         await super().on_command_error(context, exception) | ||||||
|  | 
 | ||||||
|  |         # If developing, do nothing special | ||||||
|         if settings.SANDBOX: |         if settings.SANDBOX: | ||||||
|             print(traceback.format_exc(), file=sys.stderr) |  | ||||||
|             return |             return | ||||||
|  |  | ||||||
							
								
								
									
										20
									
								
								main.py
								
								
								
								
							
							
						
						
									
										20
									
								
								main.py
								
								
								
								
							|  | @ -1,4 +1,5 @@ | ||||||
| import logging | import logging | ||||||
|  | import sys | ||||||
| from logging.handlers import RotatingFileHandler | from logging.handlers import RotatingFileHandler | ||||||
| 
 | 
 | ||||||
| import asyncio | import asyncio | ||||||
|  | @ -18,15 +19,26 @@ def setup_logging(): | ||||||
|     """Configure custom loggers""" |     """Configure custom loggers""" | ||||||
|     max_log_size = 32 * 1024 * 1024 |     max_log_size = 32 * 1024 * 1024 | ||||||
| 
 | 
 | ||||||
|  |     # Configure Didier handler | ||||||
|     didier_log = logging.getLogger("didier") |     didier_log = logging.getLogger("didier") | ||||||
| 
 | 
 | ||||||
|     handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) |     didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) | ||||||
|     handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s")) |     didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s")) | ||||||
| 
 | 
 | ||||||
|     didier_log.addHandler(handler) |     didier_log.addHandler(didier_handler) | ||||||
|     didier_log.setLevel(logging.INFO) |     didier_log.setLevel(logging.INFO) | ||||||
| 
 | 
 | ||||||
|     logging.getLogger("discord").setLevel(logging.ERROR) |     # Configure discord handler | ||||||
|  |     discord_log = logging.getLogger("discord") | ||||||
|  | 
 | ||||||
|  |     # Make dev print to stderr instead, so you don't have to watch the file | ||||||
|  |     if settings.SANDBOX: | ||||||
|  |         discord_handler = logging.StreamHandler(sys.stderr) | ||||||
|  |     else: | ||||||
|  |         discord_handler = RotatingFileHandler("discord.log", mode="a", maxBytes=max_log_size, backupCount=5) | ||||||
|  | 
 | ||||||
|  |     discord_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] %(name)s: %(message)s")) | ||||||
|  |     discord_log.addHandler(discord_handler) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def main(): | async def main(): | ||||||
|  |  | ||||||
|  | @ -4,7 +4,8 @@ from sqlalchemy.ext.asyncio import AsyncSession | ||||||
| 
 | 
 | ||||||
| from database.crud import custom_commands as crud | from database.crud import custom_commands as crud | ||||||
| from database.exceptions.constraints import DuplicateInsertException | from database.exceptions.constraints import DuplicateInsertException | ||||||
| from database.models import CustomCommand, CustomCommandAlias | from database.exceptions.not_found import NoResultFoundException | ||||||
|  | from database.models import CustomCommand | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_create_command_non_existing(database_session: AsyncSession): | async def test_create_command_non_existing(database_session: AsyncSession): | ||||||
|  | @ -33,7 +34,7 @@ async def test_create_command_name_is_alias(database_session: AsyncSession): | ||||||
|         await crud.create_command(database_session, "n", "other response") |         await crud.create_command(database_session, "n", "other response") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_create_alias_non_existing(database_session: AsyncSession): | async def test_create_alias(database_session: AsyncSession): | ||||||
|     """Test creating an alias when the name is still free""" |     """Test creating an alias when the name is still free""" | ||||||
|     command = await crud.create_command(database_session, "name", "response") |     command = await crud.create_command(database_session, "name", "response") | ||||||
|     await crud.create_alias(database_session, command.name, "n") |     await crud.create_alias(database_session, command.name, "n") | ||||||
|  | @ -43,6 +44,12 @@ async def test_create_alias_non_existing(database_session: AsyncSession): | ||||||
|     assert command.aliases[0].alias == "n" |     assert command.aliases[0].alias == "n" | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | async def test_create_alias_non_existing(database_session: AsyncSession): | ||||||
|  |     """Test creating an alias when the command doesn't exist""" | ||||||
|  |     with pytest.raises(NoResultFoundException): | ||||||
|  |         await crud.create_alias(database_session, "name", "alias") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| async def test_create_alias_duplicate(database_session: AsyncSession): | async def test_create_alias_duplicate(database_session: AsyncSession): | ||||||
|     """Test creating an alias when another alias already has this name""" |     """Test creating an alias when another alias already has this name""" | ||||||
|     command = await crud.create_command(database_session, "name", "response") |     command = await crud.create_command(database_session, "name", "response") | ||||||
|  | @ -96,3 +103,17 @@ async def test_get_command_by_alias(database_session: AsyncSession): | ||||||
| async def test_get_command_non_existing(database_session: AsyncSession): | async def test_get_command_non_existing(database_session: AsyncSession): | ||||||
|     """Test getting a command when it doesn't exist""" |     """Test getting a command when it doesn't exist""" | ||||||
|     assert await crud.get_command(database_session, "name") is None |     assert await crud.get_command(database_session, "name") is None | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def test_edit_command(database_session: AsyncSession): | ||||||
|  |     """Test editing an existing command""" | ||||||
|  |     command = await crud.create_command(database_session, "name", "response") | ||||||
|  |     await crud.edit_command(database_session, command.name, "new name", "new response") | ||||||
|  |     assert command.name == "new name" | ||||||
|  |     assert command.response == "new response" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def test_edit_command_non_existing(database_session: AsyncSession): | ||||||
|  |     """Test editing a command that doesn't exist""" | ||||||
|  |     with pytest.raises(NoResultFoundException): | ||||||
|  |         await crud.edit_command(database_session, "name", "n", "r") | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue