Compare commits

..

6 Commits

Author SHA1 Message Date
stijndcl fd57b5a79b Crud & tests for custom commands 2022-06-21 23:58:21 +02:00
stijndcl 53f58eb743 Write a few tests 2022-06-21 21:06:11 +02:00
stijndcl 5c2c62c6c4 Add sync command, clean up db sessions 2022-06-21 20:30:11 +02:00
stijndcl 868cd392c3 Fix mypy error 2022-06-21 18:58:33 +02:00
stijndcl 5a76cbd2ec Fix mypy error 2022-06-21 18:50:00 +02:00
stijndcl 000337107b Parse publication time of notifications 2022-06-21 18:44:47 +02:00
28 changed files with 458 additions and 14 deletions

View File

@ -0,0 +1,57 @@
"""Add custom commands
Revision ID: b2d511552a1f
Revises: 4ec79dd5b191
Create Date: 2022-06-21 22:10:05.590846
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b2d511552a1f'
down_revision = '4ec79dd5b191'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('custom_commands',
sa.Column('command_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('indexed_name', sa.Text(), nullable=False),
sa.Column('response', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('command_id'),
sa.UniqueConstraint('name')
)
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False)
op.create_table('custom_command_aliases',
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias', sa.Text(), nullable=False),
sa.Column('indexed_alias', sa.Text(), nullable=False),
sa.Column('command_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ),
sa.PrimaryKeyConstraint('alias_id'),
sa.UniqueConstraint('alias')
)
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias'))
op.drop_table('custom_command_aliases')
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name'))
op.drop_table('custom_commands')
# ### end Alembic commands ###

View File

@ -0,0 +1,68 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from database.models import CustomCommand, CustomCommandAlias
def clean_name(name: str) -> str:
"""Convert a name to lowercase & remove spaces to allow easier matching"""
return name.lower().replace(" ", "")
async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand:
"""Create a new custom command"""
# Check if command or alias already exists
command = await get_command(session, name)
if command is not None:
raise DuplicateInsertException
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
session.add(command)
await session.commit()
return command
async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias:
"""Create an alias for a command"""
# Check if the command exists
command_instance = await get_command(session, command)
if command_instance is None:
raise NoResultFoundException
# Check if the alias exists (either as an alias or as a name)
alias_instance = await get_command(session, alias)
if alias_instance is not None:
raise DuplicateInsertException
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
session.add(alias_instance)
await session.commit()
return alias_instance
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message"""
# Search lowercase & without spaces, and strip the prefix
message = clean_name(message)
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its name"""
statement = select(CustomCommand).where(CustomCommand.indexed_name == message)
return (await session.execute(statement)).scalar_one_or_none()
async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its alias"""
statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message)
alias = (await session.execute(statement)).scalar_one_or_none()
if alias is None:
return None
return alias.command

View File

@ -13,7 +13,7 @@ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCou
async def create_new_announcement( async def create_new_announcement(
session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime.datetime
) -> UforaAnnouncement: ) -> UforaAnnouncement:
"""Add a new announcement to the database""" """Add a new announcement to the database"""
new_announcement = UforaAnnouncement( new_announcement = UforaAnnouncement(

View File

View File

@ -0,0 +1,2 @@
class DuplicateInsertException(Exception):
"""Exception raised when a value already exists"""

View File

@ -0,0 +1,2 @@
class NoResultFoundException(Exception):
"""Exception raised when nothing was found"""

View File

@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base() Base = declarative_base()
class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't"""
__tablename__ = "custom_commands"
command_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
indexed_name: str = Column(Text, nullable=False, index=True)
response: str = Column(Text, nullable=False)
aliases: list[CustomCommandAlias] = relationship(
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
)
class CustomCommandAlias(Base):
"""Aliases for custom commands"""
__tablename__ = "custom_command_aliases"
alias_id: int = Column(Integer, primary_key=True)
alias: str = Column(Text, nullable=False, unique=True)
indexed_alias: str = Column(Text, nullable=False, index=True)
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
class UforaCourse(Base): class UforaCourse(Base):
"""A course on Ufora""" """A course on Ufora"""

View File

@ -0,0 +1,27 @@
from typing import Optional
import discord
from discord.ext import commands
from didier import Didier
class Owner(commands.Cog):
"""Cog for owner-only commands"""
client: Didier
def __init__(self, client: Didier):
self.client = client
@commands.command(name="Sync")
@commands.is_owner()
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
"""Sync all application-commands in Discord"""
await self.client.tree.sync(guild=guild)
await ctx.message.add_reaction("🔄")
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Owner(client))

View File

@ -1,6 +1,6 @@
import traceback import traceback
from discord.ext import commands, tasks from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
import settings import settings
from database.crud.ufora_announcements import remove_old_announcements from database.crud.ufora_announcements import remove_old_announcements
@ -13,7 +13,8 @@ class Tasks(commands.Cog):
client: Didier client: Didier
def __init__(self, client: Didier): # pylint: disable=no-member def __init__(self, client: Didier):
# pylint: disable=no-member
self.client = client self.client = client
# Only pull announcements if a token was provided # Only pull announcements if a token was provided
@ -28,11 +29,12 @@ class Tasks(commands.Cog):
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
return return
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) async with self.client.db_session as session:
announcements = await fetch_ufora_announcements(self.client.db_session) announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
announcements = await fetch_ufora_announcements(session)
for announcement in announcements: for announcement in announcements:
await announcements_channel.send(embed=announcement.to_embed()) await announcements_channel.send(embed=announcement.to_embed())
@pull_ufora_announcements.before_loop @pull_ufora_announcements.before_loop
async def _before_ufora_announcements(self): async def _before_ufora_announcements(self):
@ -47,7 +49,8 @@ class Tasks(commands.Cog):
@tasks.loop(hours=24) @tasks.loop(hours=24)
async def remove_old_ufora_announcements(self): async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day""" """Remove all announcements that are over 1 week old, once per day"""
await remove_old_announcements(self.client.db_session) async with self.client.db_session as session:
await remove_old_announcements(session)
@remove_old_ufora_announcements.before_loop @remove_old_ufora_announcements.before_loop
async def _before_remove_old_ufora_announcements(self): async def _before_remove_old_ufora_announcements(self):

View File

@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud
from database.models import UforaCourse from database.models import UforaCourse
from didier.utils.types.datetime import int_to_weekday
from didier.utils.types.string import leading
@dataclass @dataclass
@ -88,8 +90,19 @@ class UforaNotification:
def _get_published(self) -> str: def _get_published(self) -> str:
"""Get a formatted string that represents when this announcement was published""" """Get a formatted string that represents when this announcement was published"""
# TODO return (
return "Placeholder :) TODO make the functions to format this" f"{int_to_weekday(self.published_dt.weekday())} "
f"{leading('0', str(self.published_dt.day))}"
"/"
f"{leading('0', str(self.published_dt.month))}"
"/"
f"{self.published_dt.year} "
f"om {leading('0', str(self.published_dt.hour))}"
":"
f"{leading('0', str(self.published_dt.minute))}"
":"
f"{leading('0', str(self.published_dt.second))}"
)
def parse_ids(url: str) -> Optional[tuple[int, int]]: def parse_ids(url: str) -> Optional[tuple[int, int]]:

View File

@ -3,12 +3,13 @@ import sys
import traceback import traceback
import discord import discord
from discord import Message
from discord.ext import commands from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.engine import DBSession from database.engine import DBSession
from didier.utils.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
class Didier(commands.Bot): class Didier(commands.Bot):
@ -88,6 +89,29 @@ class Didier(commands.Bot):
"""Event triggered when the bot is ready""" """Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE) print(settings.DISCORD_READY_MESSAGE)
async def on_message(self, message: Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if self._try_invoke_custom_command(message):
return
await self.process_commands(message)
async def _try_invoke_custom_command(self, message: Message) -> bool:
"""Check if the message tries to invoke a custom command
If it does, send the reply associated with it
"""
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors""" """Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to # If developing, print everything to stdout so you don't have to

View File

View File

View File

@ -0,0 +1,3 @@
def int_to_weekday(number: int) -> str:
"""Get the Dutch name of a weekday from the number"""
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]

View File

@ -0,0 +1,21 @@
from typing import Optional
def leading(character: str, string: str, target_length: Optional[int] = 2) -> str:
"""Add a leading [character] to [string] to make it length [target_length]
Pass None to target length to always do it, no matter the length
"""
# Cast to string just in case
string = str(string)
# Add no matter what
if target_length is None:
return character + string
# String is already long enough
if len(string) >= target_length:
return string
frequency = (target_length - len(string)) // len(character)
return (frequency * character) + string

View File

@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN")
DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY")
DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.") DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.")
DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int) DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int)
DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?")
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
"""API Keys""" """API Keys"""

View File

@ -1,5 +1,5 @@
import os
from typing import AsyncGenerator from typing import AsyncGenerator
from unittest.mock import MagicMock
import pytest import pytest
@ -7,6 +7,7 @@ from alembic import command, config
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine from database.engine import engine
from didier import Didier
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -38,3 +39,16 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
await transaction.rollback() await transaction.rollback()
await connection.close() await connection.close()
@pytest.fixture
def mock_client() -> Didier:
"""Fixture to get a mock Didier instance
The mock uses 0 as the id
"""
mock_client = MagicMock()
mock_user = MagicMock()
mock_user.id = 0
mock_client.user = mock_user
return mock_client

View File

View File

@ -0,0 +1,98 @@
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import custom_commands as crud
from database.exceptions.constraints import DuplicateInsertException
from database.models import CustomCommand, CustomCommandAlias
async def test_create_command_non_existing(database_session: AsyncSession):
"""Test creating a new command when it doesn't exist yet"""
await crud.create_command(database_session, "name", "response")
commands = (await database_session.execute(select(CustomCommand))).scalars().all()
assert len(commands) == 1
assert commands[0].name == "name"
async def test_create_command_duplicate_name(database_session: AsyncSession):
"""Test creating a command when the name already exists"""
await crud.create_command(database_session, "name", "response")
with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "name", "other response")
async def test_create_command_name_is_alias(database_session: AsyncSession):
"""Test creating a command when the name is taken by an alias"""
await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, "name", "n")
with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "n", "other response")
async def test_create_alias_non_existing(database_session: AsyncSession):
"""Test creating an alias when the name is still free"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "n")
await database_session.refresh(command)
assert len(command.aliases) == 1
assert command.aliases[0].alias == "n"
async def test_create_alias_duplicate(database_session: AsyncSession):
"""Test creating an alias when another alias already has this name"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "n")
with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n")
async def test_create_alias_is_command(database_session: AsyncSession):
"""Test creating an alias when the name is taken by a command"""
await crud.create_command(database_session, "n", "response")
command = await crud.create_command(database_session, "name", "response")
with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n")
async def test_create_alias_match_by_alias(database_session: AsyncSession):
"""Test creating an alias for a command when matching the name to another alias"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "a1")
alias = await crud.create_alias(database_session, "a1", "a2")
assert alias.command == command
async def test_get_command_by_name_exists(database_session: AsyncSession):
"""Test getting a command by name"""
await crud.create_command(database_session, "name", "response")
command = await crud.get_command(database_session, "name")
assert command is not None
async def test_get_command_by_cleaned_name(database_session: AsyncSession):
"""Test getting a command by the cleaned version of the name"""
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response")
found = await crud.get_command(database_session, "capitalizednamewithspaces")
assert command == found
async def test_get_command_by_alias(database_session: AsyncSession):
"""Test getting a command by an alias"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "a1")
await crud.create_alias(database_session, command.name, "a2")
found = await crud.get_command(database_session, "a1")
assert command == found
async def test_get_command_non_existing(database_session: AsyncSession):
"""Test getting a command when it doesn't exist"""
assert await crud.get_command(database_session, "name") is None

View File

View File

@ -0,0 +1,84 @@
from unittest.mock import MagicMock
from didier import Didier
from didier.utils.discord.prefix import get_prefix
def test_get_prefix_didier(mock_client: Didier):
"""Test the "didier" prefix"""
mock_message = MagicMock()
mock_message.content = "didier test"
assert get_prefix(mock_client, mock_message) == "didier "
def test_get_prefix_didier_cased(mock_client: Didier):
"""Test the "didier" prefix with random casing"""
mock_message = MagicMock()
mock_message.content = "Didier test"
assert get_prefix(mock_client, mock_message) == "Didier "
mock_message = MagicMock()
mock_message.content = "DIDIER test"
assert get_prefix(mock_client, mock_message) == "DIDIER "
mock_message = MagicMock()
mock_message.content = "DiDiEr test"
assert get_prefix(mock_client, mock_message) == "DiDiEr "
def test_get_prefix_default(mock_client: Didier):
"""Test the fallback prefix (used when nothing matched)"""
mock_message = MagicMock()
mock_message.content = "random message"
assert get_prefix(mock_client, mock_message) == "didier"
def test_get_prefix_big_d(mock_client: Didier):
"""Test the "big d" prefix"""
mock_message = MagicMock()
mock_message.content = "big d test"
assert get_prefix(mock_client, mock_message) == "big d "
def test_get_prefix_big_d_cased(mock_client: Didier):
"""Test the "big d" prefix with random casing"""
mock_message = MagicMock()
mock_message.content = "Big d test"
assert get_prefix(mock_client, mock_message) == "Big d "
mock_message = MagicMock()
mock_message.content = "Big D test"
assert get_prefix(mock_client, mock_message) == "Big D "
mock_message = MagicMock()
mock_message.content = "BIG D test"
assert get_prefix(mock_client, mock_message) == "BIG D "
def test_get_prefix_mention_username(mock_client: Didier):
"""Test the @mention prefix when mentioned by username"""
mock_message = MagicMock()
prefix = f"<@{mock_client.user.id}> "
mock_message.content = f"{prefix}test"
assert get_prefix(mock_client, mock_message) == prefix
def test_get_prefix_mention_nickname(mock_client: Didier):
"""Test the @mention prefix when mentioned by server nickname"""
mock_message = MagicMock()
prefix = f"<@!{mock_client.user.id}> "
mock_message.content = f"{prefix}test"
assert get_prefix(mock_client, mock_message) == prefix
def test_get_prefix_whitespace(mock_client: Didier):
"""Test that variable whitespace doesn't matter"""
mock_message = MagicMock()
mock_message.content = "didiertest"
assert get_prefix(mock_client, mock_message) == "didier"
mock_message = MagicMock()
mock_message.content = "didier test"
assert get_prefix(mock_client, mock_message) == "didier "

View File

@ -1,2 +0,0 @@
def test_dummy(tables):
assert True