Compare commits

..

6 Commits

Author SHA1 Message Date
stijndcl fd57b5a79b Crud & tests for custom commands 2022-06-21 23:58:21 +02:00
stijndcl 53f58eb743 Write a few tests 2022-06-21 21:06:11 +02:00
stijndcl 5c2c62c6c4 Add sync command, clean up db sessions 2022-06-21 20:30:11 +02:00
stijndcl 868cd392c3 Fix mypy error 2022-06-21 18:58:33 +02:00
stijndcl 5a76cbd2ec Fix mypy error 2022-06-21 18:50:00 +02:00
stijndcl 000337107b Parse publication time of notifications 2022-06-21 18:44:47 +02:00
28 changed files with 458 additions and 14 deletions

View File

@ -0,0 +1,57 @@
"""Add custom commands
Revision ID: b2d511552a1f
Revises: 4ec79dd5b191
Create Date: 2022-06-21 22:10:05.590846
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b2d511552a1f'
down_revision = '4ec79dd5b191'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('custom_commands',
sa.Column('command_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('indexed_name', sa.Text(), nullable=False),
sa.Column('response', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('command_id'),
sa.UniqueConstraint('name')
)
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False)
op.create_table('custom_command_aliases',
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias', sa.Text(), nullable=False),
sa.Column('indexed_alias', sa.Text(), nullable=False),
sa.Column('command_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ),
sa.PrimaryKeyConstraint('alias_id'),
sa.UniqueConstraint('alias')
)
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias'))
op.drop_table('custom_command_aliases')
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name'))
op.drop_table('custom_commands')
# ### end Alembic commands ###

View File

@ -0,0 +1,68 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from database.models import CustomCommand, CustomCommandAlias
def clean_name(name: str) -> str:
"""Convert a name to lowercase & remove spaces to allow easier matching"""
return name.lower().replace(" ", "")
async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand:
"""Create a new custom command"""
# Check if command or alias already exists
command = await get_command(session, name)
if command is not None:
raise DuplicateInsertException
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
session.add(command)
await session.commit()
return command
async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias:
"""Create an alias for a command"""
# Check if the command exists
command_instance = await get_command(session, command)
if command_instance is None:
raise NoResultFoundException
# Check if the alias exists (either as an alias or as a name)
alias_instance = await get_command(session, alias)
if alias_instance is not None:
raise DuplicateInsertException
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
session.add(alias_instance)
await session.commit()
return alias_instance
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message"""
# Search lowercase & without spaces, and strip the prefix
message = clean_name(message)
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its name"""
statement = select(CustomCommand).where(CustomCommand.indexed_name == message)
return (await session.execute(statement)).scalar_one_or_none()
async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its alias"""
statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message)
alias = (await session.execute(statement)).scalar_one_or_none()
if alias is None:
return None
return alias.command

View File

@ -13,7 +13,7 @@ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCou
async def create_new_announcement(
session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime
session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime.datetime
) -> UforaAnnouncement:
"""Add a new announcement to the database"""
new_announcement = UforaAnnouncement(

View File

View File

@ -0,0 +1,2 @@
class DuplicateInsertException(Exception):
"""Exception raised when a value already exists"""

View File

@ -0,0 +1,2 @@
class NoResultFoundException(Exception):
"""Exception raised when nothing was found"""

View File

@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't"""
__tablename__ = "custom_commands"
command_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
indexed_name: str = Column(Text, nullable=False, index=True)
response: str = Column(Text, nullable=False)
aliases: list[CustomCommandAlias] = relationship(
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
)
class CustomCommandAlias(Base):
"""Aliases for custom commands"""
__tablename__ = "custom_command_aliases"
alias_id: int = Column(Integer, primary_key=True)
alias: str = Column(Text, nullable=False, unique=True)
indexed_alias: str = Column(Text, nullable=False, index=True)
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
class UforaCourse(Base):
"""A course on Ufora"""

View File

@ -0,0 +1,27 @@
from typing import Optional
import discord
from discord.ext import commands
from didier import Didier
class Owner(commands.Cog):
"""Cog for owner-only commands"""
client: Didier
def __init__(self, client: Didier):
self.client = client
@commands.command(name="Sync")
@commands.is_owner()
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
"""Sync all application-commands in Discord"""
await self.client.tree.sync(guild=guild)
await ctx.message.add_reaction("🔄")
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Owner(client))

View File

@ -1,6 +1,6 @@
import traceback
from discord.ext import commands, tasks
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
import settings
from database.crud.ufora_announcements import remove_old_announcements
@ -13,7 +13,8 @@ class Tasks(commands.Cog):
client: Didier
def __init__(self, client: Didier): # pylint: disable=no-member
def __init__(self, client: Didier):
# pylint: disable=no-member
self.client = client
# Only pull announcements if a token was provided
@ -28,11 +29,12 @@ class Tasks(commands.Cog):
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
return
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
announcements = await fetch_ufora_announcements(self.client.db_session)
async with self.client.db_session as session:
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
announcements = await fetch_ufora_announcements(session)
for announcement in announcements:
await announcements_channel.send(embed=announcement.to_embed())
for announcement in announcements:
await announcements_channel.send(embed=announcement.to_embed())
@pull_ufora_announcements.before_loop
async def _before_ufora_announcements(self):
@ -47,7 +49,8 @@ class Tasks(commands.Cog):
@tasks.loop(hours=24)
async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day"""
await remove_old_announcements(self.client.db_session)
async with self.client.db_session as session:
await remove_old_announcements(session)
@remove_old_ufora_announcements.before_loop
async def _before_remove_old_ufora_announcements(self):

View File

@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
import settings
from database.crud import ufora_announcements as crud
from database.models import UforaCourse
from didier.utils.types.datetime import int_to_weekday
from didier.utils.types.string import leading
@dataclass
@ -88,8 +90,19 @@ class UforaNotification:
def _get_published(self) -> str:
"""Get a formatted string that represents when this announcement was published"""
# TODO
return "Placeholder :) TODO make the functions to format this"
return (
f"{int_to_weekday(self.published_dt.weekday())} "
f"{leading('0', str(self.published_dt.day))}"
"/"
f"{leading('0', str(self.published_dt.month))}"
"/"
f"{self.published_dt.year} "
f"om {leading('0', str(self.published_dt.hour))}"
":"
f"{leading('0', str(self.published_dt.minute))}"
":"
f"{leading('0', str(self.published_dt.second))}"
)
def parse_ids(url: str) -> Optional[tuple[int, int]]:

View File

@ -3,12 +3,13 @@ import sys
import traceback
import discord
from discord import Message
from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession
import settings
from database.engine import DBSession
from didier.utils.prefix import get_prefix
from didier.utils.discord.prefix import get_prefix
class Didier(commands.Bot):
@ -88,6 +89,29 @@ class Didier(commands.Bot):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def on_message(self, message: Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if self._try_invoke_custom_command(message):
return
await self.process_commands(message)
async def _try_invoke_custom_command(self, message: Message) -> bool:
"""Check if the message tries to invoke a custom command
If it does, send the reply associated with it
"""
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to

View File

View File

View File

@ -0,0 +1,3 @@
def int_to_weekday(number: int) -> str:
"""Get the Dutch name of a weekday from the number"""
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]

View File

@ -0,0 +1,21 @@
from typing import Optional
def leading(character: str, string: str, target_length: Optional[int] = 2) -> str:
"""Add a leading [character] to [string] to make it length [target_length]
Pass None to target length to always do it, no matter the length
"""
# Cast to string just in case
string = str(string)
# Add no matter what
if target_length is None:
return character + string
# String is already long enough
if len(string) >= target_length:
return string
frequency = (target_length - len(string)) // len(character)
return (frequency * character) + string

View File

@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN")
DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY")
DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.")
DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int)
DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?")
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
"""API Keys"""

View File

@ -1,5 +1,5 @@
import os
from typing import AsyncGenerator
from unittest.mock import MagicMock
import pytest
@ -7,6 +7,7 @@ from alembic import command, config
from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine
from didier import Didier
@pytest.fixture(scope="session")
@ -38,3 +39,16 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
await transaction.rollback()
await connection.close()
@pytest.fixture
def mock_client() -> Didier:
"""Fixture to get a mock Didier instance
The mock uses 0 as the id
"""
mock_client = MagicMock()
mock_user = MagicMock()
mock_user.id = 0
mock_client.user = mock_user
return mock_client

View File

View File

@ -0,0 +1,98 @@
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import custom_commands as crud
from database.exceptions.constraints import DuplicateInsertException
from database.models import CustomCommand, CustomCommandAlias
async def test_create_command_non_existing(database_session: AsyncSession):
"""Test creating a new command when it doesn't exist yet"""
await crud.create_command(database_session, "name", "response")
commands = (await database_session.execute(select(CustomCommand))).scalars().all()
assert len(commands) == 1
assert commands[0].name == "name"
async def test_create_command_duplicate_name(database_session: AsyncSession):
"""Test creating a command when the name already exists"""
await crud.create_command(database_session, "name", "response")
with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "name", "other response")
async def test_create_command_name_is_alias(database_session: AsyncSession):
"""Test creating a command when the name is taken by an alias"""
await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, "name", "n")
with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "n", "other response")
async def test_create_alias_non_existing(database_session: AsyncSession):
"""Test creating an alias when the name is still free"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "n")
await database_session.refresh(command)
assert len(command.aliases) == 1
assert command.aliases[0].alias == "n"
async def test_create_alias_duplicate(database_session: AsyncSession):
"""Test creating an alias when another alias already has this name"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "n")
with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n")
async def test_create_alias_is_command(database_session: AsyncSession):
"""Test creating an alias when the name is taken by a command"""
await crud.create_command(database_session, "n", "response")
command = await crud.create_command(database_session, "name", "response")
with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n")
async def test_create_alias_match_by_alias(database_session: AsyncSession):
"""Test creating an alias for a command when matching the name to another alias"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "a1")
alias = await crud.create_alias(database_session, "a1", "a2")
assert alias.command == command
async def test_get_command_by_name_exists(database_session: AsyncSession):
"""Test getting a command by name"""
await crud.create_command(database_session, "name", "response")
command = await crud.get_command(database_session, "name")
assert command is not None
async def test_get_command_by_cleaned_name(database_session: AsyncSession):
"""Test getting a command by the cleaned version of the name"""
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response")
found = await crud.get_command(database_session, "capitalizednamewithspaces")
assert command == found
async def test_get_command_by_alias(database_session: AsyncSession):
"""Test getting a command by an alias"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "a1")
await crud.create_alias(database_session, command.name, "a2")
found = await crud.get_command(database_session, "a1")
assert command == found
async def test_get_command_non_existing(database_session: AsyncSession):
"""Test getting a command when it doesn't exist"""
assert await crud.get_command(database_session, "name") is None

View File

View File

@ -0,0 +1,84 @@
from unittest.mock import MagicMock
from didier import Didier
from didier.utils.discord.prefix import get_prefix
def test_get_prefix_didier(mock_client: Didier):
"""Test the "didier" prefix"""
mock_message = MagicMock()
mock_message.content = "didier test"
assert get_prefix(mock_client, mock_message) == "didier "
def test_get_prefix_didier_cased(mock_client: Didier):
"""Test the "didier" prefix with random casing"""
mock_message = MagicMock()
mock_message.content = "Didier test"
assert get_prefix(mock_client, mock_message) == "Didier "
mock_message = MagicMock()
mock_message.content = "DIDIER test"
assert get_prefix(mock_client, mock_message) == "DIDIER "
mock_message = MagicMock()
mock_message.content = "DiDiEr test"
assert get_prefix(mock_client, mock_message) == "DiDiEr "
def test_get_prefix_default(mock_client: Didier):
"""Test the fallback prefix (used when nothing matched)"""
mock_message = MagicMock()
mock_message.content = "random message"
assert get_prefix(mock_client, mock_message) == "didier"
def test_get_prefix_big_d(mock_client: Didier):
"""Test the "big d" prefix"""
mock_message = MagicMock()
mock_message.content = "big d test"
assert get_prefix(mock_client, mock_message) == "big d "
def test_get_prefix_big_d_cased(mock_client: Didier):
"""Test the "big d" prefix with random casing"""
mock_message = MagicMock()
mock_message.content = "Big d test"
assert get_prefix(mock_client, mock_message) == "Big d "
mock_message = MagicMock()
mock_message.content = "Big D test"
assert get_prefix(mock_client, mock_message) == "Big D "
mock_message = MagicMock()
mock_message.content = "BIG D test"
assert get_prefix(mock_client, mock_message) == "BIG D "
def test_get_prefix_mention_username(mock_client: Didier):
"""Test the @mention prefix when mentioned by username"""
mock_message = MagicMock()
prefix = f"<@{mock_client.user.id}> "
mock_message.content = f"{prefix}test"
assert get_prefix(mock_client, mock_message) == prefix
def test_get_prefix_mention_nickname(mock_client: Didier):
"""Test the @mention prefix when mentioned by server nickname"""
mock_message = MagicMock()
prefix = f"<@!{mock_client.user.id}> "
mock_message.content = f"{prefix}test"
assert get_prefix(mock_client, mock_message) == prefix
def test_get_prefix_whitespace(mock_client: Didier):
"""Test that variable whitespace doesn't matter"""
mock_message = MagicMock()
mock_message.content = "didiertest"
assert get_prefix(mock_client, mock_message) == "didier"
mock_message = MagicMock()
mock_message.content = "didier test"
assert get_prefix(mock_client, mock_message) == "didier "

View File

@ -1,2 +0,0 @@
def test_dummy(tables):
assert True