mirror of https://github.com/stijndcl/didier
Editing of custom commands, add posix flags
parent
257eae6fa7
commit
d6a560851b
|
@ -65,3 +65,23 @@ async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return alias.command
|
return alias.command
|
||||||
|
|
||||||
|
|
||||||
|
async def edit_command(
|
||||||
|
session: AsyncSession, original_name: str, new_name: Optional[str] = None, new_response: Optional[str] = None
|
||||||
|
) -> CustomCommand:
|
||||||
|
"""Edit an existing command"""
|
||||||
|
# Check if the command exists
|
||||||
|
command = await get_command(session, original_name)
|
||||||
|
if command is None:
|
||||||
|
raise NoResultFoundException
|
||||||
|
|
||||||
|
if new_name is not None:
|
||||||
|
command.name = new_name
|
||||||
|
if new_response is not None:
|
||||||
|
command.response = new_response
|
||||||
|
|
||||||
|
session.add(command)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return command
|
||||||
|
|
|
@ -8,7 +8,8 @@ from database.crud import custom_commands
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.modals.custom_commands import CreateCustomCommand
|
from didier.data.modals.custom_commands import CreateCustomCommand, EditCustomCommand
|
||||||
|
from didier.data.flags.owner import EditCustomFlags
|
||||||
|
|
||||||
|
|
||||||
class Owner(commands.Cog):
|
class Owner(commands.Cog):
|
||||||
|
@ -18,6 +19,7 @@ class Owner(commands.Cog):
|
||||||
|
|
||||||
# Slash groups
|
# Slash groups
|
||||||
add_slash = app_commands.Group(name="add", description="Add something new to the database")
|
add_slash = app_commands.Group(name="add", description="Add something new to the database")
|
||||||
|
edit_slash = app_commands.Group(name="edit", description="Edit an existing database entry")
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
@ -29,6 +31,11 @@ class Owner(commands.Cog):
|
||||||
# pylint: disable=W0236 # Pylint thinks this can't be async, but it can
|
# pylint: disable=W0236 # Pylint thinks this can't be async, but it can
|
||||||
return await self.client.is_owner(ctx.author)
|
return await self.client.is_owner(ctx.author)
|
||||||
|
|
||||||
|
@commands.command(name="Error")
|
||||||
|
async def _error(self, ctx: commands.Context):
|
||||||
|
"""Raise an exception for debugging purposes"""
|
||||||
|
raise Exception("Debug")
|
||||||
|
|
||||||
@commands.command(name="Sync")
|
@commands.command(name="Sync")
|
||||||
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
|
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
|
||||||
"""Sync all application-commands in Discord"""
|
"""Sync all application-commands in Discord"""
|
||||||
|
@ -77,14 +84,43 @@ class Owner(commands.Cog):
|
||||||
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
||||||
)
|
)
|
||||||
|
|
||||||
# await interaction.response.defer(ephemeral=True)
|
modal = CreateCustomCommand(self.client)
|
||||||
modal = CreateCustomCommand()
|
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
@commands.group(name="Edit")
|
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
|
||||||
async def edit(self, ctx: commands.Context):
|
async def edit_msg(self, ctx: commands.Context):
|
||||||
"""Command group for [edit X] commands"""
|
"""Command group for [edit X] commands"""
|
||||||
|
|
||||||
|
@edit_msg.command(name="Custom")
|
||||||
|
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
|
||||||
|
"""Edit an existing custom command"""
|
||||||
|
async with self.client.db_session as session:
|
||||||
|
try:
|
||||||
|
await custom_commands.edit_command(session, command, flags.name, flags.response)
|
||||||
|
return await self.client.confirm_message(ctx.message)
|
||||||
|
except NoResultFoundException:
|
||||||
|
await ctx.reply(f"Geen commando gevonden voor ``{command}``.")
|
||||||
|
return await self.client.reject_message(ctx.message)
|
||||||
|
|
||||||
|
@edit_slash.command(name="custom", description="Edit a custom command")
|
||||||
|
@app_commands.describe(command="The name of the command to edit")
|
||||||
|
async def edit_custom_slash(self, interaction: discord.Interaction, command: str):
|
||||||
|
"""Slash command to edit a custom command"""
|
||||||
|
if not await self.client.is_owner(interaction.user):
|
||||||
|
return interaction.response.send_message(
|
||||||
|
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async with self.client.db_session as session:
|
||||||
|
_command = await custom_commands.get_command(session, command)
|
||||||
|
if _command is None:
|
||||||
|
return await interaction.response.send_message(
|
||||||
|
f"Geen commando gevonden voor ``{command}``.", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
modal = EditCustomCommand(self.client, _command.name, _command.response)
|
||||||
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
"""Load the cog"""
|
"""Load the cog"""
|
||||||
|
|
|
@ -0,0 +1 @@
|
||||||
|
from .posix import PosixFlags
|
|
@ -0,0 +1,10 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from didier.data.flags import PosixFlags
|
||||||
|
|
||||||
|
|
||||||
|
class EditCustomFlags(PosixFlags):
|
||||||
|
"""Flags for the edit custom command"""
|
||||||
|
|
||||||
|
name: Optional[str] = None
|
||||||
|
response: Optional[str] = None
|
|
@ -0,0 +1,14 @@
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
|
||||||
|
class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"):
|
||||||
|
"""Base class to add POSIX-like flags to commands
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
>>> class Flags(PosixFlags):
|
||||||
|
>>> name: str
|
||||||
|
>>> async def command(ctx, *, flags: Flags):
|
||||||
|
>>> ...
|
||||||
|
This can now be called in Discord as
|
||||||
|
command --name here-be-name
|
||||||
|
"""
|
|
@ -2,19 +2,66 @@ import traceback
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
|
from database.crud.custom_commands import create_command, edit_command
|
||||||
|
from didier import Didier
|
||||||
|
|
||||||
class CreateCustomCommand(discord.ui.Modal, title="Custom Command"):
|
|
||||||
"""Modal shown to visually create custom commands"""
|
|
||||||
|
|
||||||
name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command")
|
class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
|
||||||
|
"""Modal to create new custom commands"""
|
||||||
|
|
||||||
|
name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Didier")
|
||||||
|
|
||||||
response: discord.ui.TextInput = discord.ui.TextInput(
|
response: discord.ui.TextInput = discord.ui.TextInput(
|
||||||
label="Response", style=discord.TextStyle.long, placeholder="Response of the command", max_length=2000
|
label="Response", style=discord.TextStyle.long, placeholder="Hmm?", max_length=2000
|
||||||
)
|
)
|
||||||
|
|
||||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
client: Didier
|
||||||
await interaction.response.send_message("Submitted", ephemeral=True)
|
|
||||||
|
|
||||||
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore
|
def __init__(self, client: Didier, *args, **kwargs):
|
||||||
await interaction.response.send_message("Errored", ephemeral=True)
|
super().__init__(*args, **kwargs)
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
async def on_submit(self, interaction: discord.Interaction):
|
||||||
|
async with self.client.db_session as session:
|
||||||
|
command = await create_command(session, self.name.value, self.response.value)
|
||||||
|
|
||||||
|
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
|
||||||
|
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||||
|
await interaction.response.send_message("Something went wrong.", ephemeral=True)
|
||||||
|
traceback.print_tb(error.__traceback__)
|
||||||
|
|
||||||
|
|
||||||
|
class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
|
||||||
|
"""Modal to edit an existing custom command
|
||||||
|
Fills in the current values as defaults
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: discord.ui.TextInput
|
||||||
|
response: discord.ui.TextInput
|
||||||
|
|
||||||
|
original_name: str
|
||||||
|
|
||||||
|
client: Didier
|
||||||
|
|
||||||
|
def __init__(self, client: Didier, name: str, response: str, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.original_name = name
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
self.name = self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name))
|
||||||
|
self.response = self.add_item(
|
||||||
|
discord.ui.TextInput(
|
||||||
|
label="Response", placeholder="Hmm?", default=response, style=discord.TextStyle.long, max_length=2000
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def on_submit(self, interaction: discord.Interaction):
|
||||||
|
async with self.client.db_session as session:
|
||||||
|
await edit_command(session, self.original_name, self.name.value, self.response.value)
|
||||||
|
|
||||||
|
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)
|
||||||
|
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||||
|
await interaction.response.send_message("Something went wrong.", ephemeral=True)
|
||||||
traceback.print_tb(error.__traceback__)
|
traceback.print_tb(error.__traceback__)
|
||||||
|
|
|
@ -129,8 +129,9 @@ class Didier(commands.Bot):
|
||||||
|
|
||||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||||
"""Event triggered when a regular command errors"""
|
"""Event triggered when a regular command errors"""
|
||||||
# If developing, print everything to stdout so you don't have to
|
# Print everything to the logs/stderr
|
||||||
# check the logs all the time
|
await super().on_command_error(context, exception)
|
||||||
|
|
||||||
|
# If developing, do nothing special
|
||||||
if settings.SANDBOX:
|
if settings.SANDBOX:
|
||||||
print(traceback.format_exc(), file=sys.stderr)
|
|
||||||
return
|
return
|
||||||
|
|
20
main.py
20
main.py
|
@ -1,4 +1,5 @@
|
||||||
import logging
|
import logging
|
||||||
|
import sys
|
||||||
from logging.handlers import RotatingFileHandler
|
from logging.handlers import RotatingFileHandler
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
@ -18,15 +19,26 @@ def setup_logging():
|
||||||
"""Configure custom loggers"""
|
"""Configure custom loggers"""
|
||||||
max_log_size = 32 * 1024 * 1024
|
max_log_size = 32 * 1024 * 1024
|
||||||
|
|
||||||
|
# Configure Didier handler
|
||||||
didier_log = logging.getLogger("didier")
|
didier_log = logging.getLogger("didier")
|
||||||
|
|
||||||
handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5)
|
didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5)
|
||||||
handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))
|
didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))
|
||||||
|
|
||||||
didier_log.addHandler(handler)
|
didier_log.addHandler(didier_handler)
|
||||||
didier_log.setLevel(logging.INFO)
|
didier_log.setLevel(logging.INFO)
|
||||||
|
|
||||||
logging.getLogger("discord").setLevel(logging.ERROR)
|
# Configure discord handler
|
||||||
|
discord_log = logging.getLogger("discord")
|
||||||
|
|
||||||
|
# Make dev print to stderr instead, so you don't have to watch the file
|
||||||
|
if settings.SANDBOX:
|
||||||
|
discord_handler = logging.StreamHandler(sys.stderr)
|
||||||
|
else:
|
||||||
|
discord_handler = RotatingFileHandler("discord.log", mode="a", maxBytes=max_log_size, backupCount=5)
|
||||||
|
|
||||||
|
discord_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s] %(name)s: %(message)s"))
|
||||||
|
discord_log.addHandler(discord_handler)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
|
@ -4,7 +4,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import custom_commands as crud
|
from database.crud import custom_commands as crud
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
from database.models import CustomCommand, CustomCommandAlias
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
|
from database.models import CustomCommand
|
||||||
|
|
||||||
|
|
||||||
async def test_create_command_non_existing(database_session: AsyncSession):
|
async def test_create_command_non_existing(database_session: AsyncSession):
|
||||||
|
@ -33,7 +34,7 @@ async def test_create_command_name_is_alias(database_session: AsyncSession):
|
||||||
await crud.create_command(database_session, "n", "other response")
|
await crud.create_command(database_session, "n", "other response")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias_non_existing(database_session: AsyncSession):
|
async def test_create_alias(database_session: AsyncSession):
|
||||||
"""Test creating an alias when the name is still free"""
|
"""Test creating an alias when the name is still free"""
|
||||||
command = await crud.create_command(database_session, "name", "response")
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
await crud.create_alias(database_session, command.name, "n")
|
await crud.create_alias(database_session, command.name, "n")
|
||||||
|
@ -43,6 +44,12 @@ async def test_create_alias_non_existing(database_session: AsyncSession):
|
||||||
assert command.aliases[0].alias == "n"
|
assert command.aliases[0].alias == "n"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_alias_non_existing(database_session: AsyncSession):
|
||||||
|
"""Test creating an alias when the command doesn't exist"""
|
||||||
|
with pytest.raises(NoResultFoundException):
|
||||||
|
await crud.create_alias(database_session, "name", "alias")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias_duplicate(database_session: AsyncSession):
|
async def test_create_alias_duplicate(database_session: AsyncSession):
|
||||||
"""Test creating an alias when another alias already has this name"""
|
"""Test creating an alias when another alias already has this name"""
|
||||||
command = await crud.create_command(database_session, "name", "response")
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
@ -96,3 +103,17 @@ async def test_get_command_by_alias(database_session: AsyncSession):
|
||||||
async def test_get_command_non_existing(database_session: AsyncSession):
|
async def test_get_command_non_existing(database_session: AsyncSession):
|
||||||
"""Test getting a command when it doesn't exist"""
|
"""Test getting a command when it doesn't exist"""
|
||||||
assert await crud.get_command(database_session, "name") is None
|
assert await crud.get_command(database_session, "name") is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_edit_command(database_session: AsyncSession):
|
||||||
|
"""Test editing an existing command"""
|
||||||
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
await crud.edit_command(database_session, command.name, "new name", "new response")
|
||||||
|
assert command.name == "new name"
|
||||||
|
assert command.response == "new response"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_edit_command_non_existing(database_session: AsyncSession):
|
||||||
|
"""Test editing a command that doesn't exist"""
|
||||||
|
with pytest.raises(NoResultFoundException):
|
||||||
|
await crud.edit_command(database_session, "name", "n", "r")
|
||||||
|
|
Loading…
Reference in New Issue