Compare commits

...

5 Commits

Author SHA1 Message Date
stijndcl add9399944 Fix linting & typing 2022-06-22 02:09:16 +02:00
stijndcl 57e805e31c Try to fix async tests 2022-06-22 02:05:04 +02:00
stijndcl fc195e40b3 Fix syncing 2022-06-22 01:56:13 +02:00
stijndcl d8192cfa0a Adding custom commands & aliases 2022-06-22 00:49:00 +02:00
stijndcl efdc966611 Invoke custom commands 2022-06-22 00:22:26 +02:00
6 changed files with 131 additions and 24 deletions

View File

@ -34,8 +34,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
raise NoResultFoundException
# Check if the alias exists (either as an alias or as a name)
alias_instance = await get_command(session, alias)
if alias_instance is not None:
if await get_command(session, alias) is not None:
raise DuplicateInsertException
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
@ -47,7 +46,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message"""
# Search lowercase & without spaces, and strip the prefix
# Search lowercase & without spaces
message = clean_name(message)
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))

View File

@ -1,9 +1,14 @@
from typing import Optional
import discord
from discord import app_commands
from discord.ext import commands
from database.crud import custom_commands
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from didier import Didier
from didier.data.modals.custom_commands import CreateCustomCommand
class Owner(commands.Cog):
@ -11,16 +16,76 @@ class Owner(commands.Cog):
client: Didier
# Slash groups
add_slash = app_commands.Group(name="add", description="Add something new to the database")
def __init__(self, client: Didier):
self.client = client
async def cog_check(self, ctx: commands.Context) -> bool:
"""Global check for every command in this cog, so we don't have to add
is_owner() to every single command separately
"""
# pylint: disable=W0236 # Pylint thinks this can't be async, but it can
return await self.client.is_owner(ctx.author)
@commands.command(name="Sync")
@commands.is_owner()
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
"""Sync all application-commands in Discord"""
await self.client.tree.sync(guild=guild)
if guild is not None:
self.client.tree.copy_global_to(guild=guild)
await self.client.tree.sync(guild=guild)
else:
self.client.tree.clear_commands(guild=None)
await self.client.tree.sync()
await ctx.message.add_reaction("🔄")
@commands.group(name="Add", case_insensitive=True, invoke_without_command=False)
async def add_msg(self, ctx: commands.Context):
"""Command group for [add X] message commands"""
@add_msg.command(name="Custom")
async def add_custom(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command"""
async with self.client.db_session as session:
try:
await custom_commands.create_command(session, name, response)
await self.client.confirm_message(ctx.message)
except DuplicateInsertException:
await ctx.reply("Er bestaat al een commando met deze naam.")
await self.client.reject_message(ctx.message)
@add_msg.command(name="Alias")
async def add_alias(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command"""
async with self.client.db_session as session:
try:
await custom_commands.create_alias(session, command, alias)
await self.client.confirm_message(ctx.message)
except NoResultFoundException:
await ctx.reply(f'Geen commando gevonden voor "{command}".')
await self.client.reject_message(ctx.message)
except DuplicateInsertException:
await ctx.reply("Er bestaat al een commando met deze naam.")
await self.client.reject_message(ctx.message)
@add_slash.command(name="custom", description="Add a custom command")
async def add_custom_slash(self, interaction: discord.Interaction):
"""Slash command to add a custom command"""
if not self.client.is_owner(interaction.user):
return interaction.response.send_message(
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
)
await interaction.response.defer(ephemeral=True)
modal = CreateCustomCommand()
await interaction.response.send_message(modal)
@commands.group(name="Edit")
async def edit(self, ctx: commands.Context):
"""Command group for [edit X] commands"""
async def setup(client: Didier):
"""Load the cog"""

View File

View File

@ -0,0 +1,20 @@
import traceback
import discord
class CreateCustomCommand(discord.ui.Modal, title="Custom Command"):
"""Modal shown to visually create custom commands"""
name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command...")
response: discord.ui.TextInput = discord.ui.TextInput(
label="Response", style=discord.TextStyle.long, placeholder="Response of the command...", max_length=2000
)
async def on_submit(self, interaction: discord.Interaction) -> None:
await interaction.response.send_message("Submitted", ephemeral=True)
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore
await interaction.response.send_message("Errored", ephemeral=True)
traceback.print_tb(error.__traceback__)

View File

@ -3,11 +3,11 @@ import sys
import traceback
import discord
from discord import Message
from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession
import settings
from database.crud import custom_commands
from database.engine import DBSession
from didier.utils.discord.prefix import get_prefix
@ -34,24 +34,17 @@ class Didier(commands.Bot):
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
)
@property
def db_session(self) -> AsyncSession:
"""Obtain a database session"""
return DBSession()
async def setup_hook(self) -> None:
"""Hook called once the bot is initialised"""
# Load extensions
await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs")
# Sync application commands to the test guild
for guild in settings.DISCORD_TEST_GUILDS:
guild_object = discord.Object(id=guild)
self.tree.copy_global_to(guild=guild_object)
await self.tree.sync(guild=guild_object)
@property
def db_session(self) -> AsyncSession:
"""Obtain a database session"""
return DBSession()
async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others"""
for extension in self.initial_extensions:
@ -85,11 +78,19 @@ class Didier(commands.Bot):
channel = self.get_channel(reference.channel_id)
return await channel.fetch_message(reference.message_id)
async def confirm_message(self, message: discord.Message):
"""Add a checkmark to a message"""
await message.add_reaction("")
async def reject_message(self, message: discord.Message):
"""Add an X to a message"""
await message.add_reaction("")
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def on_message(self, message: Message, /) -> None:
async def on_message(self, message: discord.Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
@ -100,18 +101,32 @@ class Didier(commands.Bot):
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if self._try_invoke_custom_command(message):
if await self._try_invoke_custom_command(message):
return
await self.process_commands(message)
async def _try_invoke_custom_command(self, message: Message) -> bool:
async def _try_invoke_custom_command(self, message: discord.Message) -> bool:
"""Check if the message tries to invoke a custom command
If it does, send the reply associated with it
"""
# Doesn't start with the custom command prefix
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False
async with self.db_session as session:
# Remove the prefix
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
command = await custom_commands.get_command(session, content)
# Command found
if command is not None:
await message.reply(command.response, mention_author=False)
return True
# Nothing found
return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to

View File

@ -1,4 +1,5 @@
from typing import AsyncGenerator
import asyncio
from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock
import pytest
@ -11,7 +12,14 @@ from didier import Didier
@pytest.fixture(scope="session")
def tables():
def event_loop() -> Generator:
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
def tables(event_loop):
"""Initialize a database before the tests, and then tear it down again
Starts from an empty database and runs through all the migrations to check those as well
while we're at it
@ -23,7 +31,7 @@ def tables():
@pytest.fixture
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
async def database_session(tables, event_loop) -> AsyncGenerator[AsyncSession, None]:
"""Fixture to create a session for every test
Rollbacks the transaction afterwards so that the future tests start with a clean database
"""