diff --git a/alembic/versions/b2d511552a1f_add_custom_commands.py b/alembic/versions/b2d511552a1f_add_custom_commands.py new file mode 100644 index 0000000..83b004a --- /dev/null +++ b/alembic/versions/b2d511552a1f_add_custom_commands.py @@ -0,0 +1,57 @@ +"""Add custom commands + +Revision ID: b2d511552a1f +Revises: 4ec79dd5b191 +Create Date: 2022-06-21 22:10:05.590846 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'b2d511552a1f' +down_revision = '4ec79dd5b191' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('custom_commands', + sa.Column('command_id', sa.Integer(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('indexed_name', sa.Text(), nullable=False), + sa.Column('response', sa.Text(), nullable=False), + sa.PrimaryKeyConstraint('command_id'), + sa.UniqueConstraint('name') + ) + with op.batch_alter_table('custom_commands', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False) + + op.create_table('custom_command_aliases', + sa.Column('alias_id', sa.Integer(), nullable=False), + sa.Column('alias', sa.Text(), nullable=False), + sa.Column('indexed_alias', sa.Text(), nullable=False), + sa.Column('command_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ), + sa.PrimaryKeyConstraint('alias_id'), + sa.UniqueConstraint('alias') + ) + with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op: + batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias')) + + op.drop_table('custom_command_aliases') + with op.batch_alter_table('custom_commands', schema=None) as batch_op: + batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name')) + + op.drop_table('custom_commands') + # ### end Alembic commands ### diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py new file mode 100644 index 0000000..afd41ce --- /dev/null +++ b/database/crud/custom_commands.py @@ -0,0 +1,68 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.exceptions.constraints import DuplicateInsertException +from database.exceptions.not_found import NoResultFoundException +from database.models import CustomCommand, CustomCommandAlias + + +def clean_name(name: str) -> str: + """Convert a name to lowercase & remove spaces to allow easier matching""" + return name.lower().replace(" ", "") + + +async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand: + """Create a new custom command""" + # Check if command or alias already exists + command = await get_command(session, name) + if command is not None: + raise DuplicateInsertException + + command = CustomCommand(name=name, indexed_name=clean_name(name), response=response) + session.add(command) + await session.commit() + return command + + +async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias: + """Create an alias for a command""" + # Check if the command exists + command_instance = await get_command(session, command) + if command_instance is None: + raise NoResultFoundException + + # Check if the alias exists (either as an alias or as a name) + alias_instance = await get_command(session, alias) + if alias_instance is not None: + raise DuplicateInsertException + + alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance) + session.add(alias_instance) + await session.commit() + + return alias_instance + + +async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command out of a message""" + # Search lowercase & without spaces, and strip the prefix + message = clean_name(message) + return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message)) + + +async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command by its name""" + statement = select(CustomCommand).where(CustomCommand.indexed_name == message) + return (await session.execute(statement)).scalar_one_or_none() + + +async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]: + """Try to get a command by its alias""" + statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message) + alias = (await session.execute(statement)).scalar_one_or_none() + if alias is None: + return None + + return alias.command diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 0d8e15f..c4d67f6 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -13,7 +13,7 @@ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCou async def create_new_announcement( - session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime + session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime.datetime ) -> UforaAnnouncement: """Add a new announcement to the database""" new_announcement = UforaAnnouncement( diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/exceptions/constraints.py b/database/exceptions/constraints.py new file mode 100644 index 0000000..1e78089 --- /dev/null +++ b/database/exceptions/constraints.py @@ -0,0 +1,2 @@ +class DuplicateInsertException(Exception): + """Exception raised when a value already exists""" diff --git a/database/exceptions/not_found.py b/database/exceptions/not_found.py new file mode 100644 index 0000000..eccfaa5 --- /dev/null +++ b/database/exceptions/not_found.py @@ -0,0 +1,2 @@ +class NoResultFoundException(Exception): + """Exception raised when nothing was found""" diff --git a/database/models.py b/database/models.py index 4414326..1c28d37 100644 --- a/database/models.py +++ b/database/models.py @@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() +class CustomCommand(Base): + """Custom commands to fill the hole Dyno couldn't""" + + __tablename__ = "custom_commands" + + command_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + indexed_name: str = Column(Text, nullable=False, index=True) + response: str = Column(Text, nullable=False) + + aliases: list[CustomCommandAlias] = relationship( + "CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" + ) + + +class CustomCommandAlias(Base): + """Aliases for custom commands""" + + __tablename__ = "custom_command_aliases" + + alias_id: int = Column(Integer, primary_key=True) + alias: str = Column(Text, nullable=False, unique=True) + indexed_alias: str = Column(Text, nullable=False, index=True) + command_id: int = Column(Integer, ForeignKey("custom_commands.command_id")) + + command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") + + class UforaCourse(Base): """A course on Ufora""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py new file mode 100644 index 0000000..af50633 --- /dev/null +++ b/didier/cogs/owner.py @@ -0,0 +1,27 @@ +from typing import Optional + +import discord +from discord.ext import commands + +from didier import Didier + + +class Owner(commands.Cog): + """Cog for owner-only commands""" + + client: Didier + + def __init__(self, client: Didier): + self.client = client + + @commands.command(name="Sync") + @commands.is_owner() + async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): + """Sync all application-commands in Discord""" + await self.client.tree.sync(guild=guild) + await ctx.message.add_reaction("🔄") + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(Owner(client)) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index b56e426..c37b8b8 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,6 +1,6 @@ import traceback -from discord.ext import commands, tasks +from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error import settings from database.crud.ufora_announcements import remove_old_announcements @@ -13,7 +13,8 @@ class Tasks(commands.Cog): client: Didier - def __init__(self, client: Didier): # pylint: disable=no-member + def __init__(self, client: Didier): + # pylint: disable=no-member self.client = client # Only pull announcements if a token was provided @@ -28,11 +29,12 @@ class Tasks(commands.Cog): if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: return - announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) - announcements = await fetch_ufora_announcements(self.client.db_session) + async with self.client.db_session as session: + announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) + announcements = await fetch_ufora_announcements(session) - for announcement in announcements: - await announcements_channel.send(embed=announcement.to_embed()) + for announcement in announcements: + await announcements_channel.send(embed=announcement.to_embed()) @pull_ufora_announcements.before_loop async def _before_ufora_announcements(self): @@ -47,7 +49,8 @@ class Tasks(commands.Cog): @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" - await remove_old_announcements(self.client.db_session) + async with self.client.db_session as session: + await remove_old_announcements(session) @remove_old_ufora_announcements.before_loop async def _before_remove_old_ufora_announcements(self): diff --git a/didier/data/embeds/ufora/announcements.py b/didier/data/embeds/ufora/announcements.py index 71b3191..52d5849 100644 --- a/didier/data/embeds/ufora/announcements.py +++ b/didier/data/embeds/ufora/announcements.py @@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import ufora_announcements as crud from database.models import UforaCourse +from didier.utils.types.datetime import int_to_weekday +from didier.utils.types.string import leading @dataclass @@ -88,8 +90,19 @@ class UforaNotification: def _get_published(self) -> str: """Get a formatted string that represents when this announcement was published""" - # TODO - return "Placeholder :) TODO make the functions to format this" + return ( + f"{int_to_weekday(self.published_dt.weekday())} " + f"{leading('0', str(self.published_dt.day))}" + "/" + f"{leading('0', str(self.published_dt.month))}" + "/" + f"{self.published_dt.year} " + f"om {leading('0', str(self.published_dt.hour))}" + ":" + f"{leading('0', str(self.published_dt.minute))}" + ":" + f"{leading('0', str(self.published_dt.second))}" + ) def parse_ids(url: str) -> Optional[tuple[int, int]]: diff --git a/didier/didier.py b/didier/didier.py index 0b5b457..b42dfcc 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -3,12 +3,13 @@ import sys import traceback import discord +from discord import Message from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession import settings from database.engine import DBSession -from didier.utils.prefix import get_prefix +from didier.utils.discord.prefix import get_prefix class Didier(commands.Bot): @@ -88,6 +89,29 @@ class Didier(commands.Bot): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) + async def on_message(self, message: Message, /) -> None: + """Event triggered when a message is sent""" + # Ignore messages by bots + if message.author.bot: + return + + # Boos react to people that say Dider + if "dider" in message.content.lower() and message.author.id != self.user.id: + await message.add_reaction(settings.DISCORD_BOOS_REACT) + + # Potential custom command + if self._try_invoke_custom_command(message): + return + + await self.process_commands(message) + + async def _try_invoke_custom_command(self, message: Message) -> bool: + """Check if the message tries to invoke a custom command + If it does, send the reply associated with it + """ + if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): + return False + async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: """Event triggered when a regular command errors""" # If developing, print everything to stdout so you don't have to diff --git a/didier/utils/discord/__init__.py b/didier/utils/discord/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/discord/checks/__init__.py b/didier/utils/discord/checks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/discord/checks/message_commands.py b/didier/utils/discord/checks/message_commands.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/prefix.py b/didier/utils/discord/prefix.py similarity index 100% rename from didier/utils/prefix.py rename to didier/utils/discord/prefix.py diff --git a/didier/utils/types/__init__.py b/didier/utils/types/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py new file mode 100644 index 0000000..6e5c88e --- /dev/null +++ b/didier/utils/types/datetime.py @@ -0,0 +1,3 @@ +def int_to_weekday(number: int) -> str: + """Get the Dutch name of a weekday from the number""" + return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] diff --git a/didier/utils/types/string.py b/didier/utils/types/string.py new file mode 100644 index 0000000..773890a --- /dev/null +++ b/didier/utils/types/string.py @@ -0,0 +1,21 @@ +from typing import Optional + + +def leading(character: str, string: str, target_length: Optional[int] = 2) -> str: + """Add a leading [character] to [string] to make it length [target_length] + Pass None to target length to always do it, no matter the length + """ + # Cast to string just in case + string = str(string) + + # Add no matter what + if target_length is None: + return character + string + + # String is already long enough + if len(string) >= target_length: + return string + + frequency = (target_length - len(string)) // len(character) + + return (frequency * character) + string diff --git a/settings.py b/settings.py index ee0c448..797b052 100644 --- a/settings.py +++ b/settings.py @@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN") DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.") DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int) +DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>") +DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?") UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) """API Keys""" diff --git a/tests/conftest.py b/tests/conftest.py index 1ae9878..a74ba4c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,5 @@ -import os from typing import AsyncGenerator +from unittest.mock import MagicMock import pytest @@ -7,6 +7,7 @@ from alembic import command, config from sqlalchemy.ext.asyncio import AsyncSession from database.engine import engine +from didier import Didier @pytest.fixture(scope="session") @@ -38,3 +39,16 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: await transaction.rollback() await connection.close() + + +@pytest.fixture +def mock_client() -> Didier: + """Fixture to get a mock Didier instance + The mock uses 0 as the id + """ + mock_client = MagicMock() + mock_user = MagicMock() + mock_user.id = 0 + mock_client.user = mock_user + + return mock_client diff --git a/tests/test_database/__init__.py b/tests/test_database/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/__init__.py b/tests/test_database/test_crud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py new file mode 100644 index 0000000..5f41983 --- /dev/null +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -0,0 +1,98 @@ +import pytest +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import custom_commands as crud +from database.exceptions.constraints import DuplicateInsertException +from database.models import CustomCommand, CustomCommandAlias + + +async def test_create_command_non_existing(database_session: AsyncSession): + """Test creating a new command when it doesn't exist yet""" + await crud.create_command(database_session, "name", "response") + + commands = (await database_session.execute(select(CustomCommand))).scalars().all() + assert len(commands) == 1 + assert commands[0].name == "name" + + +async def test_create_command_duplicate_name(database_session: AsyncSession): + """Test creating a command when the name already exists""" + await crud.create_command(database_session, "name", "response") + + with pytest.raises(DuplicateInsertException): + await crud.create_command(database_session, "name", "other response") + + +async def test_create_command_name_is_alias(database_session: AsyncSession): + """Test creating a command when the name is taken by an alias""" + await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, "name", "n") + + with pytest.raises(DuplicateInsertException): + await crud.create_command(database_session, "n", "other response") + + +async def test_create_alias_non_existing(database_session: AsyncSession): + """Test creating an alias when the name is still free""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "n") + + await database_session.refresh(command) + assert len(command.aliases) == 1 + assert command.aliases[0].alias == "n" + + +async def test_create_alias_duplicate(database_session: AsyncSession): + """Test creating an alias when another alias already has this name""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "n") + + with pytest.raises(DuplicateInsertException): + await crud.create_alias(database_session, command.name, "n") + + +async def test_create_alias_is_command(database_session: AsyncSession): + """Test creating an alias when the name is taken by a command""" + await crud.create_command(database_session, "n", "response") + command = await crud.create_command(database_session, "name", "response") + + with pytest.raises(DuplicateInsertException): + await crud.create_alias(database_session, command.name, "n") + + +async def test_create_alias_match_by_alias(database_session: AsyncSession): + """Test creating an alias for a command when matching the name to another alias""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "a1") + alias = await crud.create_alias(database_session, "a1", "a2") + assert alias.command == command + + +async def test_get_command_by_name_exists(database_session: AsyncSession): + """Test getting a command by name""" + await crud.create_command(database_session, "name", "response") + command = await crud.get_command(database_session, "name") + assert command is not None + + +async def test_get_command_by_cleaned_name(database_session: AsyncSession): + """Test getting a command by the cleaned version of the name""" + command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response") + found = await crud.get_command(database_session, "capitalizednamewithspaces") + assert command == found + + +async def test_get_command_by_alias(database_session: AsyncSession): + """Test getting a command by an alias""" + command = await crud.create_command(database_session, "name", "response") + await crud.create_alias(database_session, command.name, "a1") + await crud.create_alias(database_session, command.name, "a2") + + found = await crud.get_command(database_session, "a1") + assert command == found + + +async def test_get_command_non_existing(database_session: AsyncSession): + """Test getting a command when it doesn't exist""" + assert await crud.get_command(database_session, "name") is None diff --git a/tests/test_didier/__init__.py b/tests/test_didier/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/__init__.py b/tests/test_didier/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/test_discord/__init__.py b/tests/test_didier/test_utils/test_discord/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/test_discord/test_prefix.py b/tests/test_didier/test_utils/test_discord/test_prefix.py new file mode 100644 index 0000000..d49e8e3 --- /dev/null +++ b/tests/test_didier/test_utils/test_discord/test_prefix.py @@ -0,0 +1,84 @@ +from unittest.mock import MagicMock + +from didier import Didier +from didier.utils.discord.prefix import get_prefix + + +def test_get_prefix_didier(mock_client: Didier): + """Test the "didier" prefix""" + mock_message = MagicMock() + mock_message.content = "didier test" + assert get_prefix(mock_client, mock_message) == "didier " + + +def test_get_prefix_didier_cased(mock_client: Didier): + """Test the "didier" prefix with random casing""" + mock_message = MagicMock() + mock_message.content = "Didier test" + assert get_prefix(mock_client, mock_message) == "Didier " + + mock_message = MagicMock() + mock_message.content = "DIDIER test" + assert get_prefix(mock_client, mock_message) == "DIDIER " + + mock_message = MagicMock() + mock_message.content = "DiDiEr test" + assert get_prefix(mock_client, mock_message) == "DiDiEr " + + +def test_get_prefix_default(mock_client: Didier): + """Test the fallback prefix (used when nothing matched)""" + mock_message = MagicMock() + mock_message.content = "random message" + assert get_prefix(mock_client, mock_message) == "didier" + + +def test_get_prefix_big_d(mock_client: Didier): + """Test the "big d" prefix""" + mock_message = MagicMock() + mock_message.content = "big d test" + assert get_prefix(mock_client, mock_message) == "big d " + + +def test_get_prefix_big_d_cased(mock_client: Didier): + """Test the "big d" prefix with random casing""" + mock_message = MagicMock() + mock_message.content = "Big d test" + assert get_prefix(mock_client, mock_message) == "Big d " + + mock_message = MagicMock() + mock_message.content = "Big D test" + assert get_prefix(mock_client, mock_message) == "Big D " + + mock_message = MagicMock() + mock_message.content = "BIG D test" + assert get_prefix(mock_client, mock_message) == "BIG D " + + +def test_get_prefix_mention_username(mock_client: Didier): + """Test the @mention prefix when mentioned by username""" + mock_message = MagicMock() + prefix = f"<@{mock_client.user.id}> " + mock_message.content = f"{prefix}test" + + assert get_prefix(mock_client, mock_message) == prefix + + +def test_get_prefix_mention_nickname(mock_client: Didier): + """Test the @mention prefix when mentioned by server nickname""" + mock_message = MagicMock() + prefix = f"<@!{mock_client.user.id}> " + mock_message.content = f"{prefix}test" + + assert get_prefix(mock_client, mock_message) == prefix + + +def test_get_prefix_whitespace(mock_client: Didier): + """Test that variable whitespace doesn't matter""" + mock_message = MagicMock() + mock_message.content = "didiertest" + assert get_prefix(mock_client, mock_message) == "didier" + + mock_message = MagicMock() + mock_message.content = "didier test" + assert get_prefix(mock_client, mock_message) == "didier " diff --git a/tests/test_dummy.py b/tests/test_dummy.py deleted file mode 100644 index 569bcda..0000000 --- a/tests/test_dummy.py +++ /dev/null @@ -1,2 +0,0 @@ -def test_dummy(tables): - assert True