mirror of https://github.com/stijndcl/didier
Crud & tests for custom commands
parent
53f58eb743
commit
fd57b5a79b
|
@ -0,0 +1,57 @@
|
||||||
|
"""Add custom commands
|
||||||
|
|
||||||
|
Revision ID: b2d511552a1f
|
||||||
|
Revises: 4ec79dd5b191
|
||||||
|
Create Date: 2022-06-21 22:10:05.590846
|
||||||
|
|
||||||
|
"""
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = 'b2d511552a1f'
|
||||||
|
down_revision = '4ec79dd5b191'
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table('custom_commands',
|
||||||
|
sa.Column('command_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('name', sa.Text(), nullable=False),
|
||||||
|
sa.Column('indexed_name', sa.Text(), nullable=False),
|
||||||
|
sa.Column('response', sa.Text(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint('command_id'),
|
||||||
|
sa.UniqueConstraint('name')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
|
||||||
|
batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False)
|
||||||
|
|
||||||
|
op.create_table('custom_command_aliases',
|
||||||
|
sa.Column('alias_id', sa.Integer(), nullable=False),
|
||||||
|
sa.Column('alias', sa.Text(), nullable=False),
|
||||||
|
sa.Column('indexed_alias', sa.Text(), nullable=False),
|
||||||
|
sa.Column('command_id', sa.Integer(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ),
|
||||||
|
sa.PrimaryKeyConstraint('alias_id'),
|
||||||
|
sa.UniqueConstraint('alias')
|
||||||
|
)
|
||||||
|
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
|
||||||
|
batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False)
|
||||||
|
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias'))
|
||||||
|
|
||||||
|
op.drop_table('custom_command_aliases')
|
||||||
|
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
|
||||||
|
batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name'))
|
||||||
|
|
||||||
|
op.drop_table('custom_commands')
|
||||||
|
# ### end Alembic commands ###
|
|
@ -0,0 +1,68 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
|
from database.models import CustomCommand, CustomCommandAlias
|
||||||
|
|
||||||
|
|
||||||
|
def clean_name(name: str) -> str:
|
||||||
|
"""Convert a name to lowercase & remove spaces to allow easier matching"""
|
||||||
|
return name.lower().replace(" ", "")
|
||||||
|
|
||||||
|
|
||||||
|
async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand:
|
||||||
|
"""Create a new custom command"""
|
||||||
|
# Check if command or alias already exists
|
||||||
|
command = await get_command(session, name)
|
||||||
|
if command is not None:
|
||||||
|
raise DuplicateInsertException
|
||||||
|
|
||||||
|
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
|
||||||
|
session.add(command)
|
||||||
|
await session.commit()
|
||||||
|
return command
|
||||||
|
|
||||||
|
|
||||||
|
async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias:
|
||||||
|
"""Create an alias for a command"""
|
||||||
|
# Check if the command exists
|
||||||
|
command_instance = await get_command(session, command)
|
||||||
|
if command_instance is None:
|
||||||
|
raise NoResultFoundException
|
||||||
|
|
||||||
|
# Check if the alias exists (either as an alias or as a name)
|
||||||
|
alias_instance = await get_command(session, alias)
|
||||||
|
if alias_instance is not None:
|
||||||
|
raise DuplicateInsertException
|
||||||
|
|
||||||
|
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
|
||||||
|
session.add(alias_instance)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
return alias_instance
|
||||||
|
|
||||||
|
|
||||||
|
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||||
|
"""Try to get a command out of a message"""
|
||||||
|
# Search lowercase & without spaces, and strip the prefix
|
||||||
|
message = clean_name(message)
|
||||||
|
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
|
||||||
|
|
||||||
|
|
||||||
|
async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||||
|
"""Try to get a command by its name"""
|
||||||
|
statement = select(CustomCommand).where(CustomCommand.indexed_name == message)
|
||||||
|
return (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||||
|
"""Try to get a command by its alias"""
|
||||||
|
statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message)
|
||||||
|
alias = (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
if alias is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return alias.command
|
|
@ -0,0 +1,2 @@
|
||||||
|
class DuplicateInsertException(Exception):
|
||||||
|
"""Exception raised when a value already exists"""
|
|
@ -0,0 +1,2 @@
|
||||||
|
class NoResultFoundException(Exception):
|
||||||
|
"""Exception raised when nothing was found"""
|
|
@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
class CustomCommand(Base):
|
||||||
|
"""Custom commands to fill the hole Dyno couldn't"""
|
||||||
|
|
||||||
|
__tablename__ = "custom_commands"
|
||||||
|
|
||||||
|
command_id: int = Column(Integer, primary_key=True)
|
||||||
|
name: str = Column(Text, nullable=False, unique=True)
|
||||||
|
indexed_name: str = Column(Text, nullable=False, index=True)
|
||||||
|
response: str = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
aliases: list[CustomCommandAlias] = relationship(
|
||||||
|
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CustomCommandAlias(Base):
|
||||||
|
"""Aliases for custom commands"""
|
||||||
|
|
||||||
|
__tablename__ = "custom_command_aliases"
|
||||||
|
|
||||||
|
alias_id: int = Column(Integer, primary_key=True)
|
||||||
|
alias: str = Column(Text, nullable=False, unique=True)
|
||||||
|
indexed_alias: str = Column(Text, nullable=False, index=True)
|
||||||
|
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
|
||||||
|
|
||||||
|
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
|
||||||
|
|
||||||
|
|
||||||
class UforaCourse(Base):
|
class UforaCourse(Base):
|
||||||
"""A course on Ufora"""
|
"""A course on Ufora"""
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
from discord import Message
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
@ -88,6 +89,29 @@ class Didier(commands.Bot):
|
||||||
"""Event triggered when the bot is ready"""
|
"""Event triggered when the bot is ready"""
|
||||||
print(settings.DISCORD_READY_MESSAGE)
|
print(settings.DISCORD_READY_MESSAGE)
|
||||||
|
|
||||||
|
async def on_message(self, message: Message, /) -> None:
|
||||||
|
"""Event triggered when a message is sent"""
|
||||||
|
# Ignore messages by bots
|
||||||
|
if message.author.bot:
|
||||||
|
return
|
||||||
|
|
||||||
|
# Boos react to people that say Dider
|
||||||
|
if "dider" in message.content.lower() and message.author.id != self.user.id:
|
||||||
|
await message.add_reaction(settings.DISCORD_BOOS_REACT)
|
||||||
|
|
||||||
|
# Potential custom command
|
||||||
|
if self._try_invoke_custom_command(message):
|
||||||
|
return
|
||||||
|
|
||||||
|
await self.process_commands(message)
|
||||||
|
|
||||||
|
async def _try_invoke_custom_command(self, message: Message) -> bool:
|
||||||
|
"""Check if the message tries to invoke a custom command
|
||||||
|
If it does, send the reply associated with it
|
||||||
|
"""
|
||||||
|
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
|
||||||
|
return False
|
||||||
|
|
||||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||||
"""Event triggered when a regular command errors"""
|
"""Event triggered when a regular command errors"""
|
||||||
# If developing, print everything to stdout so you don't have to
|
# If developing, print everything to stdout so you don't have to
|
||||||
|
|
|
@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN")
|
||||||
DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY")
|
DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY")
|
||||||
DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.")
|
DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.")
|
||||||
DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int)
|
DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int)
|
||||||
|
DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>")
|
||||||
|
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?")
|
||||||
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
||||||
|
|
||||||
"""API Keys"""
|
"""API Keys"""
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import custom_commands as crud
|
||||||
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
|
from database.models import CustomCommand, CustomCommandAlias
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_command_non_existing(database_session: AsyncSession):
|
||||||
|
"""Test creating a new command when it doesn't exist yet"""
|
||||||
|
await crud.create_command(database_session, "name", "response")
|
||||||
|
|
||||||
|
commands = (await database_session.execute(select(CustomCommand))).scalars().all()
|
||||||
|
assert len(commands) == 1
|
||||||
|
assert commands[0].name == "name"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_command_duplicate_name(database_session: AsyncSession):
|
||||||
|
"""Test creating a command when the name already exists"""
|
||||||
|
await crud.create_command(database_session, "name", "response")
|
||||||
|
|
||||||
|
with pytest.raises(DuplicateInsertException):
|
||||||
|
await crud.create_command(database_session, "name", "other response")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_command_name_is_alias(database_session: AsyncSession):
|
||||||
|
"""Test creating a command when the name is taken by an alias"""
|
||||||
|
await crud.create_command(database_session, "name", "response")
|
||||||
|
await crud.create_alias(database_session, "name", "n")
|
||||||
|
|
||||||
|
with pytest.raises(DuplicateInsertException):
|
||||||
|
await crud.create_command(database_session, "n", "other response")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_alias_non_existing(database_session: AsyncSession):
|
||||||
|
"""Test creating an alias when the name is still free"""
|
||||||
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
await crud.create_alias(database_session, command.name, "n")
|
||||||
|
|
||||||
|
await database_session.refresh(command)
|
||||||
|
assert len(command.aliases) == 1
|
||||||
|
assert command.aliases[0].alias == "n"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_alias_duplicate(database_session: AsyncSession):
|
||||||
|
"""Test creating an alias when another alias already has this name"""
|
||||||
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
await crud.create_alias(database_session, command.name, "n")
|
||||||
|
|
||||||
|
with pytest.raises(DuplicateInsertException):
|
||||||
|
await crud.create_alias(database_session, command.name, "n")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_alias_is_command(database_session: AsyncSession):
|
||||||
|
"""Test creating an alias when the name is taken by a command"""
|
||||||
|
await crud.create_command(database_session, "n", "response")
|
||||||
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
|
||||||
|
with pytest.raises(DuplicateInsertException):
|
||||||
|
await crud.create_alias(database_session, command.name, "n")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_create_alias_match_by_alias(database_session: AsyncSession):
|
||||||
|
"""Test creating an alias for a command when matching the name to another alias"""
|
||||||
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
await crud.create_alias(database_session, command.name, "a1")
|
||||||
|
alias = await crud.create_alias(database_session, "a1", "a2")
|
||||||
|
assert alias.command == command
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_command_by_name_exists(database_session: AsyncSession):
|
||||||
|
"""Test getting a command by name"""
|
||||||
|
await crud.create_command(database_session, "name", "response")
|
||||||
|
command = await crud.get_command(database_session, "name")
|
||||||
|
assert command is not None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_command_by_cleaned_name(database_session: AsyncSession):
|
||||||
|
"""Test getting a command by the cleaned version of the name"""
|
||||||
|
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response")
|
||||||
|
found = await crud.get_command(database_session, "capitalizednamewithspaces")
|
||||||
|
assert command == found
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_command_by_alias(database_session: AsyncSession):
|
||||||
|
"""Test getting a command by an alias"""
|
||||||
|
command = await crud.create_command(database_session, "name", "response")
|
||||||
|
await crud.create_alias(database_session, command.name, "a1")
|
||||||
|
await crud.create_alias(database_session, command.name, "a2")
|
||||||
|
|
||||||
|
found = await crud.get_command(database_session, "a1")
|
||||||
|
assert command == found
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_command_non_existing(database_session: AsyncSession):
|
||||||
|
"""Test getting a command when it doesn't exist"""
|
||||||
|
assert await crud.get_command(database_session, "name") is None
|
Loading…
Reference in New Issue