From fd57b5a79b954237af08136d71aa8108ed6d2318 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 23:58:21 +0200 Subject: [PATCH] Crud & tests for custom commands --- .../b2d511552a1f_add_custom_commands.py | 57 +++++++++++ database/crud/custom_commands.py | 68 +++++++++++++ database/exceptions/__init__.py | 0 database/exceptions/constraints.py | 2 + database/exceptions/not_found.py | 2 + database/models.py | 28 ++++++ didier/didier.py | 24 +++++ settings.py | 2 + tests/test_database/__init__.py | 0 tests/test_database/test_crud/__init__.py | 0 .../test_crud/test_custom_commands.py | 98 +++++++++++++++++++ 11 files changed, 281 insertions(+) create mode 100644 alembic/versions/b2d511552a1f_add_custom_commands.py create mode 100644 database/crud/custom_commands.py create mode 100644 database/exceptions/__init__.py create mode 100644 database/exceptions/constraints.py create mode 100644 database/exceptions/not_found.py create mode 100644 tests/test_database/__init__.py create mode 100644 tests/test_database/test_crud/__init__.py create mode 100644 tests/test_database/test_crud/test_custom_commands.py diff --git a/alembic/versions/b2d511552a1f_add_custom_commands.py b/alembic/versions/b2d511552a1f_add_custom_commands.py new file mode 100644 index 0000000..83b004a --- /dev/null +++ b/alembic/versions/b2d511552a1f_add_custom_commands.py @@ -0,0 +1,57 @@ +"""Add custom commands + +Revision ID: b2d511552a1f +Revises: 4ec79dd5b191 +Create Date: 2022-06-21 22:10:05.590846 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'b2d511552a1f' +down_revision = '4ec79dd5b191' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('custom_commands', + sa.Column('command_id', sa.Integer(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('indexed_name', sa.Text(), nullable=False), + sa.Column('response', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('command_id'), + sa.UniqueConstraint('name') + ) + with op.batch_alter_table('custom_commands', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False) + + op.create_table('custom_command_aliases', + sa.Column('alias_id', sa.Integer(), nullable=False), + sa.Column('alias', sa.Text(), nullable=False), + sa.Column('indexed_alias', sa.Text(), nullable=False), + sa.Column('command_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ), + sa.PrimaryKeyConstraint('alias_id'), + sa.UniqueConstraint('alias') + ) + with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias')) + + op.drop_table('custom_command_aliases') + with op.batch_alter_table('custom_commands', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name')) + + op.drop_table('custom_commands') + # ### end Alembic commands ### diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py new file mode 100644 index 0000000..afd41ce --- /dev/null +++ b/database/crud/custom_commands.py @@ -0,0 +1,68 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.exceptions.constraints import DuplicateInsertException +from database.exceptions.not_found import NoResultFoundException +from database.models import CustomCommand, CustomCommandAlias + + +def clean_name(name: str) -> str: + """Convert a name to lowercase & remove spaces to allow easier matching""" + return name.lower().replace(" ", "") + + +async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand: + """Create a new custom command""" + # Check if command or alias already exists + command = await get_command(session, name) + if command is not None: + raise DuplicateInsertException + + command = CustomCommand(name=name, indexed_name=clean_name(name), response=response) + session.add(command) + await session.commit() + return command + + +async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias: + """Create an alias for a command""" + # Check if the command exists + command_instance = await get_command(session, command) + if command_instance is None: + raise NoResultFoundException + + # Check if the alias exists (either as an alias or as a name) + alias_instance = await get_command(session, alias) + if alias_instance is not None: + raise DuplicateInsertException + + alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance) + session.add(alias_instance) + await session.commit() + + return alias_instance + + +async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command out of a message""" + # Search lowercase & without spaces, and strip the prefix + message = clean_name(message) + return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message)) + + +async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command by its name""" + statement = select(CustomCommand).where(CustomCommand.indexed_name == message) + return (await session.execute(statement)).scalar_one_or_none() + + +async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command by its alias""" + statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message) + alias = (await session.execute(statement)).scalar_one_or_none() + if alias is None: + return None + + return alias.command diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/exceptions/constraints.py b/database/exceptions/constraints.py new file mode 100644 index 0000000..1e78089 --- /dev/null +++ b/database/exceptions/constraints.py @@ -0,0 +1,2 @@ +class DuplicateInsertException(Exception): + """Exception raised when a value already exists""" diff --git a/database/exceptions/not_found.py b/database/exceptions/not_found.py new file mode 100644 index 0000000..eccfaa5 --- /dev/null +++ b/database/exceptions/not_found.py @@ -0,0 +1,2 @@ +class NoResultFoundException(Exception): + """Exception raised when nothing was found""" diff --git a/database/models.py b/database/models.py index 4414326..1c28d37 100644 --- a/database/models.py +++ b/database/models.py @@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() +class CustomCommand(Base): + """Custom commands to fill the hole Dyno couldn't""" + + __tablename__ = "custom_commands" + + command_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + indexed_name: str = Column(Text, nullable=False, index=True) + response: str = Column(Text, nullable=False) + + aliases: list[CustomCommandAlias] = relationship( + "CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" + ) + + +class CustomCommandAlias(Base): + """Aliases for custom commands""" + + __tablename__ = "custom_command_aliases" + + alias_id: int = Column(Integer, primary_key=True) + alias: str = Column(Text, nullable=False, unique=True) + indexed_alias: str = Column(Text, nullable=False, index=True) + command_id: int = Column(Integer, ForeignKey("custom_commands.command_id")) + + command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") + + class UforaCourse(Base): """A course on Ufora""" diff --git a/didier/didier.py b/didier/didier.py index fa2eb24..b42dfcc 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -3,6 +3,7 @@ import sys import traceback import discord +from discord import Message from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession @@ -88,6 +89,29 @@ class Didier(commands.Bot): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) + async def on_message(self, message: Message, /) -> None: + """Event triggered when a message is sent""" + # Ignore messages by bots + if message.author.bot: + return + + # Boos react to people that say Dider + if "dider" in message.content.lower() and message.author.id != self.user.id: + await message.add_reaction(settings.DISCORD_BOOS_REACT) + + # Potential custom command + if self._try_invoke_custom_command(message): + return + + await self.process_commands(message) + + async def _try_invoke_custom_command(self, message: Message) -> bool: + """Check if the message tries to invoke a custom command + If it does, send the reply associated with it + """ + if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): + return False + async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: """Event triggered when a regular command errors""" # If developing, print everything to stdout so you don't have to diff --git a/settings.py b/settings.py index ee0c448..797b052 100644 --- a/settings.py +++ b/settings.py @@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN") DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.") DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int) +DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>") +DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?") UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) """API Keys""" diff --git a/tests/test_database/__init__.py b/tests/test_database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/__init__.py b/tests/test_database/test_crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py new file mode 100644 index 0000000..5f41983 --- /dev/null +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -0,0 +1,98 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import custom_commands as crud +from database.exceptions.constraints import DuplicateInsertException +from database.models import CustomCommand, CustomCommandAlias + + +async def test_create_command_non_existing(database_session: AsyncSession): + """Test creating a new command when it doesn't exist yet""" + await crud.create_command(database_session, "name", "response") + + commands = (await database_session.execute(select(CustomCommand))).scalars().all() + assert len(commands) == 1 + assert commands[0].name == "name" + + +async def test_create_command_duplicate_name(database_session: AsyncSession): + """Test creating a command when the name already exists""" + await crud.create_command(database_session, "name", "response") + + with pytest.raises(DuplicateInsertException): + await crud.create_command(database_session, "name", "other response") + + +async def test_create_command_name_is_alias(database_session: AsyncSession): + """Test creating a command when the name is taken by an alias""" + await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, "name", "n") + + with pytest.raises(DuplicateInsertException): + await crud.create_command(database_session, "n", "other response") + + +async def test_create_alias_non_existing(database_session: AsyncSession): + """Test creating an alias when the name is still free""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "n") + + await database_session.refresh(command) + assert len(command.aliases) == 1 + assert command.aliases[0].alias == "n" + + +async def test_create_alias_duplicate(database_session: AsyncSession): + """Test creating an alias when another alias already has this name""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "n") + + with pytest.raises(DuplicateInsertException): + await crud.create_alias(database_session, command.name, "n") + + +async def test_create_alias_is_command(database_session: AsyncSession): + """Test creating an alias when the name is taken by a command""" + await crud.create_command(database_session, "n", "response") + command = await crud.create_command(database_session, "name", "response") + + with pytest.raises(DuplicateInsertException): + await crud.create_alias(database_session, command.name, "n") + + +async def test_create_alias_match_by_alias(database_session: AsyncSession): + """Test creating an alias for a command when matching the name to another alias""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "a1") + alias = await crud.create_alias(database_session, "a1", "a2") + assert alias.command == command + + +async def test_get_command_by_name_exists(database_session: AsyncSession): + """Test getting a command by name""" + await crud.create_command(database_session, "name", "response") + command = await crud.get_command(database_session, "name") + assert command is not None + + +async def test_get_command_by_cleaned_name(database_session: AsyncSession): + """Test getting a command by the cleaned version of the name""" + command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response") + found = await crud.get_command(database_session, "capitalizednamewithspaces") + assert command == found + + +async def test_get_command_by_alias(database_session: AsyncSession): + """Test getting a command by an alias""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "a1") + await crud.create_alias(database_session, command.name, "a2") + + found = await crud.get_command(database_session, "a1") + assert command == found + + +async def test_get_command_non_existing(database_session: AsyncSession): + """Test getting a command when it doesn't exist""" + assert await crud.get_command(database_session, "name") is None