Crud & tests for custom commands

pull/115/head
stijndcl 2022-06-21 23:58:21 +02:00
parent 53f58eb743
commit fd57b5a79b
11 changed files with 281 additions and 0 deletions

View File

@ -0,0 +1,57 @@
"""Add custom commands
Revision ID: b2d511552a1f
Revises: 4ec79dd5b191
Create Date: 2022-06-21 22:10:05.590846
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'b2d511552a1f'
down_revision = '4ec79dd5b191'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('custom_commands',
sa.Column('command_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('indexed_name', sa.Text(), nullable=False),
sa.Column('response', sa.Text(), nullable=False),
sa.PrimaryKeyConstraint('command_id'),
sa.UniqueConstraint('name')
)
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False)
op.create_table('custom_command_aliases',
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias', sa.Text(), nullable=False),
sa.Column('indexed_alias', sa.Text(), nullable=False),
sa.Column('command_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ),
sa.PrimaryKeyConstraint('alias_id'),
sa.UniqueConstraint('alias')
)
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias'))
op.drop_table('custom_command_aliases')
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name'))
op.drop_table('custom_commands')
# ### end Alembic commands ###

View File

@ -0,0 +1,68 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from database.models import CustomCommand, CustomCommandAlias
def clean_name(name: str) -> str:
"""Convert a name to lowercase & remove spaces to allow easier matching"""
return name.lower().replace(" ", "")
async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand:
"""Create a new custom command"""
# Check if command or alias already exists
command = await get_command(session, name)
if command is not None:
raise DuplicateInsertException
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
session.add(command)
await session.commit()
return command
async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias:
"""Create an alias for a command"""
# Check if the command exists
command_instance = await get_command(session, command)
if command_instance is None:
raise NoResultFoundException
# Check if the alias exists (either as an alias or as a name)
alias_instance = await get_command(session, alias)
if alias_instance is not None:
raise DuplicateInsertException
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
session.add(alias_instance)
await session.commit()
return alias_instance
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message"""
# Search lowercase & without spaces, and strip the prefix
message = clean_name(message)
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its name"""
statement = select(CustomCommand).where(CustomCommand.indexed_name == message)
return (await session.execute(statement)).scalar_one_or_none()
async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its alias"""
statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message)
alias = (await session.execute(statement)).scalar_one_or_none()
if alias is None:
return None
return alias.command

View File

View File

@ -0,0 +1,2 @@
class DuplicateInsertException(Exception):
"""Exception raised when a value already exists"""

View File

@ -0,0 +1,2 @@
class NoResultFoundException(Exception):
"""Exception raised when nothing was found"""

View File

@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't"""
__tablename__ = "custom_commands"
command_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
indexed_name: str = Column(Text, nullable=False, index=True)
response: str = Column(Text, nullable=False)
aliases: list[CustomCommandAlias] = relationship(
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
)
class CustomCommandAlias(Base):
"""Aliases for custom commands"""
__tablename__ = "custom_command_aliases"
alias_id: int = Column(Integer, primary_key=True)
alias: str = Column(Text, nullable=False, unique=True)
indexed_alias: str = Column(Text, nullable=False, index=True)
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
class UforaCourse(Base):
"""A course on Ufora"""

View File

@ -3,6 +3,7 @@ import sys
import traceback
import discord
from discord import Message
from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession
@ -88,6 +89,29 @@ class Didier(commands.Bot):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def on_message(self, message: Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if self._try_invoke_custom_command(message):
return
await self.process_commands(message)
async def _try_invoke_custom_command(self, message: Message) -> bool:
"""Check if the message tries to invoke a custom command
If it does, send the reply associated with it
"""
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to

View File

@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN")
DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY")
DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.")
DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int)
DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?")
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
"""API Keys"""

View File

View File

@ -0,0 +1,98 @@
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import custom_commands as crud
from database.exceptions.constraints import DuplicateInsertException
from database.models import CustomCommand, CustomCommandAlias
async def test_create_command_non_existing(database_session: AsyncSession):
"""Test creating a new command when it doesn't exist yet"""
await crud.create_command(database_session, "name", "response")
commands = (await database_session.execute(select(CustomCommand))).scalars().all()
assert len(commands) == 1
assert commands[0].name == "name"
async def test_create_command_duplicate_name(database_session: AsyncSession):
"""Test creating a command when the name already exists"""
await crud.create_command(database_session, "name", "response")
with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "name", "other response")
async def test_create_command_name_is_alias(database_session: AsyncSession):
"""Test creating a command when the name is taken by an alias"""
await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, "name", "n")
with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "n", "other response")
async def test_create_alias_non_existing(database_session: AsyncSession):
"""Test creating an alias when the name is still free"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "n")
await database_session.refresh(command)
assert len(command.aliases) == 1
assert command.aliases[0].alias == "n"
async def test_create_alias_duplicate(database_session: AsyncSession):
"""Test creating an alias when another alias already has this name"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "n")
with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n")
async def test_create_alias_is_command(database_session: AsyncSession):
"""Test creating an alias when the name is taken by a command"""
await crud.create_command(database_session, "n", "response")
command = await crud.create_command(database_session, "name", "response")
with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n")
async def test_create_alias_match_by_alias(database_session: AsyncSession):
"""Test creating an alias for a command when matching the name to another alias"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "a1")
alias = await crud.create_alias(database_session, "a1", "a2")
assert alias.command == command
async def test_get_command_by_name_exists(database_session: AsyncSession):
"""Test getting a command by name"""
await crud.create_command(database_session, "name", "response")
command = await crud.get_command(database_session, "name")
assert command is not None
async def test_get_command_by_cleaned_name(database_session: AsyncSession):
"""Test getting a command by the cleaned version of the name"""
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response")
found = await crud.get_command(database_session, "capitalizednamewithspaces")
assert command == found
async def test_get_command_by_alias(database_session: AsyncSession):
"""Test getting a command by an alias"""
command = await crud.create_command(database_session, "name", "response")
await crud.create_alias(database_session, command.name, "a1")
await crud.create_alias(database_session, command.name, "a2")
found = await crud.get_command(database_session, "a1")
assert command == found
async def test_get_command_non_existing(database_session: AsyncSession):
"""Test getting a command when it doesn't exist"""
assert await crud.get_command(database_session, "name") is None