From 000337107b7fd3c2a546d3e00143da25a86c7c9e Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 18:44:47 +0200 Subject: [PATCH 1/6] Parse publication time of notifications --- database/crud/ufora_announcements.py | 2 +- didier/cogs/tasks.py | 3 ++- didier/data/embeds/ufora/announcements.py | 17 +++++++++++++++-- didier/didier.py | 2 +- didier/utils/discord/__init__.py | 0 didier/utils/{ => discord}/prefix.py | 0 didier/utils/types/__init__.py | 0 didier/utils/types/datetime.py | 3 +++ didier/utils/types/string.py | 21 +++++++++++++++++++++ 9 files changed, 43 insertions(+), 5 deletions(-) create mode 100644 didier/utils/discord/__init__.py rename didier/utils/{ => discord}/prefix.py (100%) create mode 100644 didier/utils/types/__init__.py create mode 100644 didier/utils/types/datetime.py create mode 100644 didier/utils/types/string.py diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 0d8e15f..c4d67f6 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -13,7 +13,7 @@ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCou async def create_new_announcement( - session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime + session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime.datetime ) -> UforaAnnouncement: """Add a new announcement to the database""" new_announcement = UforaAnnouncement( diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index b56e426..8b366fa 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -13,7 +13,8 @@ class Tasks(commands.Cog): client: Didier - def __init__(self, client: Didier): # pylint: disable=no-member + def __init__(self, client: Didier): + # pylint: disable=no-member self.client = client # Only pull announcements if a token was provided diff --git a/didier/data/embeds/ufora/announcements.py b/didier/data/embeds/ufora/announcements.py index 71b3191..52d5849 100644 --- a/didier/data/embeds/ufora/announcements.py +++ b/didier/data/embeds/ufora/announcements.py @@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import ufora_announcements as crud from database.models import UforaCourse +from didier.utils.types.datetime import int_to_weekday +from didier.utils.types.string import leading @dataclass @@ -88,8 +90,19 @@ class UforaNotification: def _get_published(self) -> str: """Get a formatted string that represents when this announcement was published""" - # TODO - return "Placeholder :) TODO make the functions to format this" + return ( + f"{int_to_weekday(self.published_dt.weekday())} " + f"{leading('0', str(self.published_dt.day))}" + "/" + f"{leading('0', str(self.published_dt.month))}" + "/" + f"{self.published_dt.year} " + f"om {leading('0', str(self.published_dt.hour))}" + ":" + f"{leading('0', str(self.published_dt.minute))}" + ":" + f"{leading('0', str(self.published_dt.second))}" + ) def parse_ids(url: str) -> Optional[tuple[int, int]]: diff --git a/didier/didier.py b/didier/didier.py index 0b5b457..fa2eb24 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -8,7 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.engine import DBSession -from didier.utils.prefix import get_prefix +from didier.utils.discord.prefix import get_prefix class Didier(commands.Bot): diff --git a/didier/utils/discord/__init__.py b/didier/utils/discord/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/prefix.py b/didier/utils/discord/prefix.py similarity index 100% rename from didier/utils/prefix.py rename to didier/utils/discord/prefix.py diff --git a/didier/utils/types/__init__.py b/didier/utils/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py new file mode 100644 index 0000000..6e5c88e --- /dev/null +++ b/didier/utils/types/datetime.py @@ -0,0 +1,3 @@ +def int_to_weekday(number: int) -> str: + """Get the Dutch name of a weekday from the number""" + return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] diff --git a/didier/utils/types/string.py b/didier/utils/types/string.py new file mode 100644 index 0000000..773890a --- /dev/null +++ b/didier/utils/types/string.py @@ -0,0 +1,21 @@ +from typing import Optional + + +def leading(character: str, string: str, target_length: Optional[int] = 2) -> str: + """Add a leading [character] to [string] to make it length [target_length] + Pass None to target length to always do it, no matter the length + """ + # Cast to string just in case + string = str(string) + + # Add no matter what + if target_length is None: + return character + string + + # String is already long enough + if len(string) >= target_length: + return string + + frequency = (target_length - len(string)) // len(character) + + return (frequency * character) + string From 5a76cbd2ec68148ff3b9609dee4aa3cae5622974 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 18:50:00 +0200 Subject: [PATCH 2/6] Fix mypy error --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index bfc1ced..5ec9b79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,7 @@ line-length = 120 plugins = [ "sqlalchemy.ext.mypy.plugin" ] +namespace_packages = true [[tool.mypy.overrides]] module = ["discord.*", "feedparser.*", "markdownify.*"] ignore_missing_imports = true From 868cd392c34d9031edbfc055d4b18c612e1ef121 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 18:58:33 +0200 Subject: [PATCH 3/6] Fix mypy error --- didier/cogs/tasks.py | 2 +- pyproject.toml | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 8b366fa..29e9530 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,6 +1,6 @@ import traceback -from discord.ext import commands, tasks +from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error import settings from database.crud.ufora_announcements import remove_old_announcements diff --git a/pyproject.toml b/pyproject.toml index 5ec9b79..bfc1ced 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,7 +5,6 @@ line-length = 120 plugins = [ "sqlalchemy.ext.mypy.plugin" ] -namespace_packages = true [[tool.mypy.overrides]] module = ["discord.*", "feedparser.*", "markdownify.*"] ignore_missing_imports = true From 5c2c62c6c49477e868bb863c46a9bece08458da0 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 20:30:11 +0200 Subject: [PATCH 4/6] Add sync command, clean up db sessions --- didier/cogs/owner.py | 27 +++++++++++++++++++ didier/cogs/tasks.py | 12 +++++---- didier/utils/discord/checks/__init__.py | 0 .../utils/discord/checks/message_commands.py | 0 4 files changed, 34 insertions(+), 5 deletions(-) create mode 100644 didier/cogs/owner.py create mode 100644 didier/utils/discord/checks/__init__.py create mode 100644 didier/utils/discord/checks/message_commands.py diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py new file mode 100644 index 0000000..af50633 --- /dev/null +++ b/didier/cogs/owner.py @@ -0,0 +1,27 @@ +from typing import Optional + +import discord +from discord.ext import commands + +from didier import Didier + + +class Owner(commands.Cog): + """Cog for owner-only commands""" + + client: Didier + + def __init__(self, client: Didier): + self.client = client + + @commands.command(name="Sync") + @commands.is_owner() + async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): + """Sync all application-commands in Discord""" + await self.client.tree.sync(guild=guild) + await ctx.message.add_reaction("🔄") + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(Owner(client)) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 29e9530..c37b8b8 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -29,11 +29,12 @@ class Tasks(commands.Cog): if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: return - announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) - announcements = await fetch_ufora_announcements(self.client.db_session) + async with self.client.db_session as session: + announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) + announcements = await fetch_ufora_announcements(session) - for announcement in announcements: - await announcements_channel.send(embed=announcement.to_embed()) + for announcement in announcements: + await announcements_channel.send(embed=announcement.to_embed()) @pull_ufora_announcements.before_loop async def _before_ufora_announcements(self): @@ -48,7 +49,8 @@ class Tasks(commands.Cog): @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" - await remove_old_announcements(self.client.db_session) + async with self.client.db_session as session: + await remove_old_announcements(session) @remove_old_ufora_announcements.before_loop async def _before_remove_old_ufora_announcements(self): diff --git a/didier/utils/discord/checks/__init__.py b/didier/utils/discord/checks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/discord/checks/message_commands.py b/didier/utils/discord/checks/message_commands.py new file mode 100644 index 0000000..e69de29 From 53f58eb743120a1e272f7f131ecf2beb662e7f47 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 21:06:11 +0200 Subject: [PATCH 5/6] Write a few tests --- tests/conftest.py | 16 +++- tests/test_didier/__init__.py | 0 tests/test_didier/test_utils/__init__.py | 0 .../test_utils/test_discord/__init__.py | 0 .../test_utils/test_discord/test_prefix.py | 84 +++++++++++++++++++ tests/test_dummy.py | 2 - 6 files changed, 99 insertions(+), 3 deletions(-) create mode 100644 tests/test_didier/__init__.py create mode 100644 tests/test_didier/test_utils/__init__.py create mode 100644 tests/test_didier/test_utils/test_discord/__init__.py create mode 100644 tests/test_didier/test_utils/test_discord/test_prefix.py delete mode 100644 tests/test_dummy.py diff --git a/tests/conftest.py b/tests/conftest.py index 1ae9878..a74ba4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ -import os from typing import AsyncGenerator +from unittest.mock import MagicMock import pytest @@ -7,6 +7,7 @@ from alembic import command, config from sqlalchemy.ext.asyncio import AsyncSession from database.engine import engine +from didier import Didier @pytest.fixture(scope="session") @@ -38,3 +39,16 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: await transaction.rollback() await connection.close() + + +@pytest.fixture +def mock_client() -> Didier: + """Fixture to get a mock Didier instance + The mock uses 0 as the id + """ + mock_client = MagicMock() + mock_user = MagicMock() + mock_user.id = 0 + mock_client.user = mock_user + + return mock_client diff --git a/tests/test_didier/__init__.py b/tests/test_didier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/__init__.py b/tests/test_didier/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/test_discord/__init__.py b/tests/test_didier/test_utils/test_discord/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/test_discord/test_prefix.py b/tests/test_didier/test_utils/test_discord/test_prefix.py new file mode 100644 index 0000000..d49e8e3 --- /dev/null +++ b/tests/test_didier/test_utils/test_discord/test_prefix.py @@ -0,0 +1,84 @@ +from unittest.mock import MagicMock + +from didier import Didier +from didier.utils.discord.prefix import get_prefix + + +def test_get_prefix_didier(mock_client: Didier): + """Test the "didier" prefix""" + mock_message = MagicMock() + mock_message.content = "didier test" + assert get_prefix(mock_client, mock_message) == "didier " + + +def test_get_prefix_didier_cased(mock_client: Didier): + """Test the "didier" prefix with random casing""" + mock_message = MagicMock() + mock_message.content = "Didier test" + assert get_prefix(mock_client, mock_message) == "Didier " + + mock_message = MagicMock() + mock_message.content = "DIDIER test" + assert get_prefix(mock_client, mock_message) == "DIDIER " + + mock_message = MagicMock() + mock_message.content = "DiDiEr test" + assert get_prefix(mock_client, mock_message) == "DiDiEr " + + +def test_get_prefix_default(mock_client: Didier): + """Test the fallback prefix (used when nothing matched)""" + mock_message = MagicMock() + mock_message.content = "random message" + assert get_prefix(mock_client, mock_message) == "didier" + + +def test_get_prefix_big_d(mock_client: Didier): + """Test the "big d" prefix""" + mock_message = MagicMock() + mock_message.content = "big d test" + assert get_prefix(mock_client, mock_message) == "big d " + + +def test_get_prefix_big_d_cased(mock_client: Didier): + """Test the "big d" prefix with random casing""" + mock_message = MagicMock() + mock_message.content = "Big d test" + assert get_prefix(mock_client, mock_message) == "Big d " + + mock_message = MagicMock() + mock_message.content = "Big D test" + assert get_prefix(mock_client, mock_message) == "Big D " + + mock_message = MagicMock() + mock_message.content = "BIG D test" + assert get_prefix(mock_client, mock_message) == "BIG D " + + +def test_get_prefix_mention_username(mock_client: Didier): + """Test the @mention prefix when mentioned by username""" + mock_message = MagicMock() + prefix = f"<@{mock_client.user.id}> " + mock_message.content = f"{prefix}test" + + assert get_prefix(mock_client, mock_message) == prefix + + +def test_get_prefix_mention_nickname(mock_client: Didier): + """Test the @mention prefix when mentioned by server nickname""" + mock_message = MagicMock() + prefix = f"<@!{mock_client.user.id}> " + mock_message.content = f"{prefix}test" + + assert get_prefix(mock_client, mock_message) == prefix + + +def test_get_prefix_whitespace(mock_client: Didier): + """Test that variable whitespace doesn't matter""" + mock_message = MagicMock() + mock_message.content = "didiertest" + assert get_prefix(mock_client, mock_message) == "didier" + + mock_message = MagicMock() + mock_message.content = "didier test" + assert get_prefix(mock_client, mock_message) == "didier " diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 569bcda..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_dummy(tables): - assert True From fd57b5a79b954237af08136d71aa8108ed6d2318 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 21 Jun 2022 23:58:21 +0200 Subject: [PATCH 6/6] Crud & tests for custom commands --- .../b2d511552a1f_add_custom_commands.py | 57 +++++++++++ database/crud/custom_commands.py | 68 +++++++++++++ database/exceptions/__init__.py | 0 database/exceptions/constraints.py | 2 + database/exceptions/not_found.py | 2 + database/models.py | 28 ++++++ didier/didier.py | 24 +++++ settings.py | 2 + tests/test_database/__init__.py | 0 tests/test_database/test_crud/__init__.py | 0 .../test_crud/test_custom_commands.py | 98 +++++++++++++++++++ 11 files changed, 281 insertions(+) create mode 100644 alembic/versions/b2d511552a1f_add_custom_commands.py create mode 100644 database/crud/custom_commands.py create mode 100644 database/exceptions/__init__.py create mode 100644 database/exceptions/constraints.py create mode 100644 database/exceptions/not_found.py create mode 100644 tests/test_database/__init__.py create mode 100644 tests/test_database/test_crud/__init__.py create mode 100644 tests/test_database/test_crud/test_custom_commands.py diff --git a/alembic/versions/b2d511552a1f_add_custom_commands.py b/alembic/versions/b2d511552a1f_add_custom_commands.py new file mode 100644 index 0000000..83b004a --- /dev/null +++ b/alembic/versions/b2d511552a1f_add_custom_commands.py @@ -0,0 +1,57 @@ +"""Add custom commands + +Revision ID: b2d511552a1f +Revises: 4ec79dd5b191 +Create Date: 2022-06-21 22:10:05.590846 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'b2d511552a1f' +down_revision = '4ec79dd5b191' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('custom_commands', + sa.Column('command_id', sa.Integer(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('indexed_name', sa.Text(), nullable=False), + sa.Column('response', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('command_id'), + sa.UniqueConstraint('name') + ) + with op.batch_alter_table('custom_commands', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False) + + op.create_table('custom_command_aliases', + sa.Column('alias_id', sa.Integer(), nullable=False), + sa.Column('alias', sa.Text(), nullable=False), + sa.Column('indexed_alias', sa.Text(), nullable=False), + sa.Column('command_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ), + sa.PrimaryKeyConstraint('alias_id'), + sa.UniqueConstraint('alias') + ) + with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias')) + + op.drop_table('custom_command_aliases') + with op.batch_alter_table('custom_commands', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name')) + + op.drop_table('custom_commands') + # ### end Alembic commands ### diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py new file mode 100644 index 0000000..afd41ce --- /dev/null +++ b/database/crud/custom_commands.py @@ -0,0 +1,68 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.exceptions.constraints import DuplicateInsertException +from database.exceptions.not_found import NoResultFoundException +from database.models import CustomCommand, CustomCommandAlias + + +def clean_name(name: str) -> str: + """Convert a name to lowercase & remove spaces to allow easier matching""" + return name.lower().replace(" ", "") + + +async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand: + """Create a new custom command""" + # Check if command or alias already exists + command = await get_command(session, name) + if command is not None: + raise DuplicateInsertException + + command = CustomCommand(name=name, indexed_name=clean_name(name), response=response) + session.add(command) + await session.commit() + return command + + +async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias: + """Create an alias for a command""" + # Check if the command exists + command_instance = await get_command(session, command) + if command_instance is None: + raise NoResultFoundException + + # Check if the alias exists (either as an alias or as a name) + alias_instance = await get_command(session, alias) + if alias_instance is not None: + raise DuplicateInsertException + + alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance) + session.add(alias_instance) + await session.commit() + + return alias_instance + + +async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command out of a message""" + # Search lowercase & without spaces, and strip the prefix + message = clean_name(message) + return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message)) + + +async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command by its name""" + statement = select(CustomCommand).where(CustomCommand.indexed_name == message) + return (await session.execute(statement)).scalar_one_or_none() + + +async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command by its alias""" + statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message) + alias = (await session.execute(statement)).scalar_one_or_none() + if alias is None: + return None + + return alias.command diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/exceptions/constraints.py b/database/exceptions/constraints.py new file mode 100644 index 0000000..1e78089 --- /dev/null +++ b/database/exceptions/constraints.py @@ -0,0 +1,2 @@ +class DuplicateInsertException(Exception): + """Exception raised when a value already exists""" diff --git a/database/exceptions/not_found.py b/database/exceptions/not_found.py new file mode 100644 index 0000000..eccfaa5 --- /dev/null +++ b/database/exceptions/not_found.py @@ -0,0 +1,2 @@ +class NoResultFoundException(Exception): + """Exception raised when nothing was found""" diff --git a/database/models.py b/database/models.py index 4414326..1c28d37 100644 --- a/database/models.py +++ b/database/models.py @@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() +class CustomCommand(Base): + """Custom commands to fill the hole Dyno couldn't""" + + __tablename__ = "custom_commands" + + command_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + indexed_name: str = Column(Text, nullable=False, index=True) + response: str = Column(Text, nullable=False) + + aliases: list[CustomCommandAlias] = relationship( + "CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" + ) + + +class CustomCommandAlias(Base): + """Aliases for custom commands""" + + __tablename__ = "custom_command_aliases" + + alias_id: int = Column(Integer, primary_key=True) + alias: str = Column(Text, nullable=False, unique=True) + indexed_alias: str = Column(Text, nullable=False, index=True) + command_id: int = Column(Integer, ForeignKey("custom_commands.command_id")) + + command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") + + class UforaCourse(Base): """A course on Ufora""" diff --git a/didier/didier.py b/didier/didier.py index fa2eb24..b42dfcc 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -3,6 +3,7 @@ import sys import traceback import discord +from discord import Message from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession @@ -88,6 +89,29 @@ class Didier(commands.Bot): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) + async def on_message(self, message: Message, /) -> None: + """Event triggered when a message is sent""" + # Ignore messages by bots + if message.author.bot: + return + + # Boos react to people that say Dider + if "dider" in message.content.lower() and message.author.id != self.user.id: + await message.add_reaction(settings.DISCORD_BOOS_REACT) + + # Potential custom command + if self._try_invoke_custom_command(message): + return + + await self.process_commands(message) + + async def _try_invoke_custom_command(self, message: Message) -> bool: + """Check if the message tries to invoke a custom command + If it does, send the reply associated with it + """ + if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): + return False + async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: """Event triggered when a regular command errors""" # If developing, print everything to stdout so you don't have to diff --git a/settings.py b/settings.py index ee0c448..797b052 100644 --- a/settings.py +++ b/settings.py @@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN") DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.") DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int) +DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>") +DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?") UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) """API Keys""" diff --git a/tests/test_database/__init__.py b/tests/test_database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/__init__.py b/tests/test_database/test_crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py new file mode 100644 index 0000000..5f41983 --- /dev/null +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -0,0 +1,98 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import custom_commands as crud +from database.exceptions.constraints import DuplicateInsertException +from database.models import CustomCommand, CustomCommandAlias + + +async def test_create_command_non_existing(database_session: AsyncSession): + """Test creating a new command when it doesn't exist yet""" + await crud.create_command(database_session, "name", "response") + + commands = (await database_session.execute(select(CustomCommand))).scalars().all() + assert len(commands) == 1 + assert commands[0].name == "name" + + +async def test_create_command_duplicate_name(database_session: AsyncSession): + """Test creating a command when the name already exists""" + await crud.create_command(database_session, "name", "response") + + with pytest.raises(DuplicateInsertException): + await crud.create_command(database_session, "name", "other response") + + +async def test_create_command_name_is_alias(database_session: AsyncSession): + """Test creating a command when the name is taken by an alias""" + await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, "name", "n") + + with pytest.raises(DuplicateInsertException): + await crud.create_command(database_session, "n", "other response") + + +async def test_create_alias_non_existing(database_session: AsyncSession): + """Test creating an alias when the name is still free""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "n") + + await database_session.refresh(command) + assert len(command.aliases) == 1 + assert command.aliases[0].alias == "n" + + +async def test_create_alias_duplicate(database_session: AsyncSession): + """Test creating an alias when another alias already has this name""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "n") + + with pytest.raises(DuplicateInsertException): + await crud.create_alias(database_session, command.name, "n") + + +async def test_create_alias_is_command(database_session: AsyncSession): + """Test creating an alias when the name is taken by a command""" + await crud.create_command(database_session, "n", "response") + command = await crud.create_command(database_session, "name", "response") + + with pytest.raises(DuplicateInsertException): + await crud.create_alias(database_session, command.name, "n") + + +async def test_create_alias_match_by_alias(database_session: AsyncSession): + """Test creating an alias for a command when matching the name to another alias""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "a1") + alias = await crud.create_alias(database_session, "a1", "a2") + assert alias.command == command + + +async def test_get_command_by_name_exists(database_session: AsyncSession): + """Test getting a command by name""" + await crud.create_command(database_session, "name", "response") + command = await crud.get_command(database_session, "name") + assert command is not None + + +async def test_get_command_by_cleaned_name(database_session: AsyncSession): + """Test getting a command by the cleaned version of the name""" + command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response") + found = await crud.get_command(database_session, "capitalizednamewithspaces") + assert command == found + + +async def test_get_command_by_alias(database_session: AsyncSession): + """Test getting a command by an alias""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "a1") + await crud.create_alias(database_session, command.name, "a2") + + found = await crud.get_command(database_session, "a1") + assert command == found + + +async def test_get_command_non_existing(database_session: AsyncSession): + """Test getting a command when it doesn't exist""" + assert await crud.get_command(database_session, "name") is None