Compare commits

..

No commits in common. "add939994481cdac714efd443b11bb876f6284ae" and "fd57b5a79b954237af08136d71aa8108ed6d2318" have entirely different histories.

6 changed files with 24 additions and 131 deletions

View File

@ -34,7 +34,8 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
raise NoResultFoundException raise NoResultFoundException
# Check if the alias exists (either as an alias or as a name) # Check if the alias exists (either as an alias or as a name)
if await get_command(session, alias) is not None: alias_instance = await get_command(session, alias)
if alias_instance is not None:
raise DuplicateInsertException raise DuplicateInsertException
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance) alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
@ -46,7 +47,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message""" """Try to get a command out of a message"""
# Search lowercase & without spaces # Search lowercase & without spaces, and strip the prefix
message = clean_name(message) message = clean_name(message)
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message)) return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))

View File

@ -1,14 +1,9 @@
from typing import Optional from typing import Optional
import discord import discord
from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.crud import custom_commands
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from didier import Didier from didier import Didier
from didier.data.modals.custom_commands import CreateCustomCommand
class Owner(commands.Cog): class Owner(commands.Cog):
@ -16,76 +11,16 @@ class Owner(commands.Cog):
client: Didier client: Didier
# Slash groups
add_slash = app_commands.Group(name="add", description="Add something new to the database")
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
async def cog_check(self, ctx: commands.Context) -> bool:
"""Global check for every command in this cog, so we don't have to add
is_owner() to every single command separately
"""
# pylint: disable=W0236 # Pylint thinks this can't be async, but it can
return await self.client.is_owner(ctx.author)
@commands.command(name="Sync") @commands.command(name="Sync")
@commands.is_owner()
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
"""Sync all application-commands in Discord""" """Sync all application-commands in Discord"""
if guild is not None: await self.client.tree.sync(guild=guild)
self.client.tree.copy_global_to(guild=guild)
await self.client.tree.sync(guild=guild)
else:
self.client.tree.clear_commands(guild=None)
await self.client.tree.sync()
await ctx.message.add_reaction("🔄") await ctx.message.add_reaction("🔄")
@commands.group(name="Add", case_insensitive=True, invoke_without_command=False)
async def add_msg(self, ctx: commands.Context):
"""Command group for [add X] message commands"""
@add_msg.command(name="Custom")
async def add_custom(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command"""
async with self.client.db_session as session:
try:
await custom_commands.create_command(session, name, response)
await self.client.confirm_message(ctx.message)
except DuplicateInsertException:
await ctx.reply("Er bestaat al een commando met deze naam.")
await self.client.reject_message(ctx.message)
@add_msg.command(name="Alias")
async def add_alias(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command"""
async with self.client.db_session as session:
try:
await custom_commands.create_alias(session, command, alias)
await self.client.confirm_message(ctx.message)
except NoResultFoundException:
await ctx.reply(f'Geen commando gevonden voor "{command}".')
await self.client.reject_message(ctx.message)
except DuplicateInsertException:
await ctx.reply("Er bestaat al een commando met deze naam.")
await self.client.reject_message(ctx.message)
@add_slash.command(name="custom", description="Add a custom command")
async def add_custom_slash(self, interaction: discord.Interaction):
"""Slash command to add a custom command"""
if not self.client.is_owner(interaction.user):
return interaction.response.send_message(
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
)
await interaction.response.defer(ephemeral=True)
modal = CreateCustomCommand()
await interaction.response.send_message(modal)
@commands.group(name="Edit")
async def edit(self, ctx: commands.Context):
"""Command group for [edit X] commands"""
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""

View File

@ -1,20 +0,0 @@
import traceback
import discord
class CreateCustomCommand(discord.ui.Modal, title="Custom Command"):
"""Modal shown to visually create custom commands"""
name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command...")
response: discord.ui.TextInput = discord.ui.TextInput(
label="Response", style=discord.TextStyle.long, placeholder="Response of the command...", max_length=2000
)
async def on_submit(self, interaction: discord.Interaction) -> None:
await interaction.response.send_message("Submitted", ephemeral=True)
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore
await interaction.response.send_message("Errored", ephemeral=True)
traceback.print_tb(error.__traceback__)

View File

@ -3,11 +3,11 @@ import sys
import traceback import traceback
import discord import discord
from discord import Message
from discord.ext import commands from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import custom_commands
from database.engine import DBSession from database.engine import DBSession
from didier.utils.discord.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
@ -34,17 +34,24 @@ class Didier(commands.Bot):
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
) )
@property
def db_session(self) -> AsyncSession:
"""Obtain a database session"""
return DBSession()
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""Hook called once the bot is initialised""" """Hook called once the bot is initialised"""
# Load extensions # Load extensions
await self._load_initial_extensions() await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs") await self._load_directory_extensions("didier/cogs")
# Sync application commands to the test guild
for guild in settings.DISCORD_TEST_GUILDS:
guild_object = discord.Object(id=guild)
self.tree.copy_global_to(guild=guild_object)
await self.tree.sync(guild=guild_object)
@property
def db_session(self) -> AsyncSession:
"""Obtain a database session"""
return DBSession()
async def _load_initial_extensions(self): async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others""" """Load all extensions that should be loaded before the others"""
for extension in self.initial_extensions: for extension in self.initial_extensions:
@ -78,19 +85,11 @@ class Didier(commands.Bot):
channel = self.get_channel(reference.channel_id) channel = self.get_channel(reference.channel_id)
return await channel.fetch_message(reference.message_id) return await channel.fetch_message(reference.message_id)
async def confirm_message(self, message: discord.Message):
"""Add a checkmark to a message"""
await message.add_reaction("")
async def reject_message(self, message: discord.Message):
"""Add an X to a message"""
await message.add_reaction("")
async def on_ready(self): async def on_ready(self):
"""Event triggered when the bot is ready""" """Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE) print(settings.DISCORD_READY_MESSAGE)
async def on_message(self, message: discord.Message, /) -> None: async def on_message(self, message: Message, /) -> None:
"""Event triggered when a message is sent""" """Event triggered when a message is sent"""
# Ignore messages by bots # Ignore messages by bots
if message.author.bot: if message.author.bot:
@ -101,32 +100,18 @@ class Didier(commands.Bot):
await message.add_reaction(settings.DISCORD_BOOS_REACT) await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command # Potential custom command
if await self._try_invoke_custom_command(message): if self._try_invoke_custom_command(message):
return return
await self.process_commands(message) await self.process_commands(message)
async def _try_invoke_custom_command(self, message: discord.Message) -> bool: async def _try_invoke_custom_command(self, message: Message) -> bool:
"""Check if the message tries to invoke a custom command """Check if the message tries to invoke a custom command
If it does, send the reply associated with it If it does, send the reply associated with it
""" """
# Doesn't start with the custom command prefix
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False return False
async with self.db_session as session:
# Remove the prefix
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
command = await custom_commands.get_command(session, content)
# Command found
if command is not None:
await message.reply(command.response, mention_author=False)
return True
# Nothing found
return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors""" """Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to # If developing, print everything to stdout so you don't have to

View File

@ -1,5 +1,4 @@
import asyncio from typing import AsyncGenerator
from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
@ -12,14 +11,7 @@ from didier import Didier
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def event_loop() -> Generator: def tables():
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@pytest.fixture(scope="session")
def tables(event_loop):
"""Initialize a database before the tests, and then tear it down again """Initialize a database before the tests, and then tear it down again
Starts from an empty database and runs through all the migrations to check those as well Starts from an empty database and runs through all the migrations to check those as well
while we're at it while we're at it
@ -31,7 +23,7 @@ def tables(event_loop):
@pytest.fixture @pytest.fixture
async def database_session(tables, event_loop) -> AsyncGenerator[AsyncSession, None]: async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
"""Fixture to create a session for every test """Fixture to create a session for every test
Rollbacks the transaction afterwards so that the future tests start with a clean database Rollbacks the transaction afterwards so that the future tests start with a clean database
""" """