mirror of https://github.com/stijndcl/didier
Compare commits
6 Commits
d75831f848
...
fd57b5a79b
| Author | SHA1 | Date |
|---|---|---|
|
|
fd57b5a79b | |
|
|
53f58eb743 | |
|
|
5c2c62c6c4 | |
|
|
868cd392c3 | |
|
|
5a76cbd2ec | |
|
|
000337107b |
|
|
@ -0,0 +1,57 @@
|
|||
"""Add custom commands
|
||||
|
||||
Revision ID: b2d511552a1f
|
||||
Revises: 4ec79dd5b191
|
||||
Create Date: 2022-06-21 22:10:05.590846
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'b2d511552a1f'
|
||||
down_revision = '4ec79dd5b191'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('custom_commands',
|
||||
sa.Column('command_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('indexed_name', sa.Text(), nullable=False),
|
||||
sa.Column('response', sa.Text(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('command_id'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_custom_commands_indexed_name'), ['indexed_name'], unique=False)
|
||||
|
||||
op.create_table('custom_command_aliases',
|
||||
sa.Column('alias_id', sa.Integer(), nullable=False),
|
||||
sa.Column('alias', sa.Text(), nullable=False),
|
||||
sa.Column('indexed_alias', sa.Text(), nullable=False),
|
||||
sa.Column('command_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['command_id'], ['custom_commands.command_id'], ),
|
||||
sa.PrimaryKeyConstraint('alias_id'),
|
||||
sa.UniqueConstraint('alias')
|
||||
)
|
||||
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
|
||||
batch_op.create_index(batch_op.f('ix_custom_command_aliases_indexed_alias'), ['indexed_alias'], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('custom_command_aliases', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_custom_command_aliases_indexed_alias'))
|
||||
|
||||
op.drop_table('custom_command_aliases')
|
||||
with op.batch_alter_table('custom_commands', schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f('ix_custom_commands_indexed_name'))
|
||||
|
||||
op.drop_table('custom_commands')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,68 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.models import CustomCommand, CustomCommandAlias
|
||||
|
||||
|
||||
def clean_name(name: str) -> str:
|
||||
"""Convert a name to lowercase & remove spaces to allow easier matching"""
|
||||
return name.lower().replace(" ", "")
|
||||
|
||||
|
||||
async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand:
|
||||
"""Create a new custom command"""
|
||||
# Check if command or alias already exists
|
||||
command = await get_command(session, name)
|
||||
if command is not None:
|
||||
raise DuplicateInsertException
|
||||
|
||||
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
|
||||
session.add(command)
|
||||
await session.commit()
|
||||
return command
|
||||
|
||||
|
||||
async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias:
|
||||
"""Create an alias for a command"""
|
||||
# Check if the command exists
|
||||
command_instance = await get_command(session, command)
|
||||
if command_instance is None:
|
||||
raise NoResultFoundException
|
||||
|
||||
# Check if the alias exists (either as an alias or as a name)
|
||||
alias_instance = await get_command(session, alias)
|
||||
if alias_instance is not None:
|
||||
raise DuplicateInsertException
|
||||
|
||||
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
|
||||
session.add(alias_instance)
|
||||
await session.commit()
|
||||
|
||||
return alias_instance
|
||||
|
||||
|
||||
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
"""Try to get a command out of a message"""
|
||||
# Search lowercase & without spaces, and strip the prefix
|
||||
message = clean_name(message)
|
||||
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
|
||||
|
||||
|
||||
async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
"""Try to get a command by its name"""
|
||||
statement = select(CustomCommand).where(CustomCommand.indexed_name == message)
|
||||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
"""Try to get a command by its alias"""
|
||||
statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message)
|
||||
alias = (await session.execute(statement)).scalar_one_or_none()
|
||||
if alias is None:
|
||||
return None
|
||||
|
||||
return alias.command
|
||||
|
|
@ -13,7 +13,7 @@ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCou
|
|||
|
||||
|
||||
async def create_new_announcement(
|
||||
session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime
|
||||
session: AsyncSession, announcement_id: int, course: UforaCourse, publication_date: datetime.datetime
|
||||
) -> UforaAnnouncement:
|
||||
"""Add a new announcement to the database"""
|
||||
new_announcement = UforaAnnouncement(
|
||||
|
|
|
|||
|
|
@ -0,0 +1,2 @@
|
|||
class DuplicateInsertException(Exception):
|
||||
"""Exception raised when a value already exists"""
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
class NoResultFoundException(Exception):
|
||||
"""Exception raised when nothing was found"""
|
||||
|
|
@ -8,6 +8,34 @@ from sqlalchemy.orm import declarative_base, relationship
|
|||
Base = declarative_base()
|
||||
|
||||
|
||||
class CustomCommand(Base):
|
||||
"""Custom commands to fill the hole Dyno couldn't"""
|
||||
|
||||
__tablename__ = "custom_commands"
|
||||
|
||||
command_id: int = Column(Integer, primary_key=True)
|
||||
name: str = Column(Text, nullable=False, unique=True)
|
||||
indexed_name: str = Column(Text, nullable=False, index=True)
|
||||
response: str = Column(Text, nullable=False)
|
||||
|
||||
aliases: list[CustomCommandAlias] = relationship(
|
||||
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
class CustomCommandAlias(Base):
|
||||
"""Aliases for custom commands"""
|
||||
|
||||
__tablename__ = "custom_command_aliases"
|
||||
|
||||
alias_id: int = Column(Integer, primary_key=True)
|
||||
alias: str = Column(Text, nullable=False, unique=True)
|
||||
indexed_alias: str = Column(Text, nullable=False, index=True)
|
||||
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
|
||||
|
||||
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class UforaCourse(Base):
|
||||
"""A course on Ufora"""
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,27 @@
|
|||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from didier import Didier
|
||||
|
||||
|
||||
class Owner(commands.Cog):
|
||||
"""Cog for owner-only commands"""
|
||||
|
||||
client: Didier
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.command(name="Sync")
|
||||
@commands.is_owner()
|
||||
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
|
||||
"""Sync all application-commands in Discord"""
|
||||
await self.client.tree.sync(guild=guild)
|
||||
await ctx.message.add_reaction("🔄")
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(Owner(client))
|
||||
|
|
@ -1,6 +1,6 @@
|
|||
import traceback
|
||||
|
||||
from discord.ext import commands, tasks
|
||||
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
|
||||
|
||||
import settings
|
||||
from database.crud.ufora_announcements import remove_old_announcements
|
||||
|
|
@ -13,7 +13,8 @@ class Tasks(commands.Cog):
|
|||
|
||||
client: Didier
|
||||
|
||||
def __init__(self, client: Didier): # pylint: disable=no-member
|
||||
def __init__(self, client: Didier):
|
||||
# pylint: disable=no-member
|
||||
self.client = client
|
||||
|
||||
# Only pull announcements if a token was provided
|
||||
|
|
@ -28,11 +29,12 @@ class Tasks(commands.Cog):
|
|||
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
||||
return
|
||||
|
||||
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
|
||||
announcements = await fetch_ufora_announcements(self.client.db_session)
|
||||
async with self.client.db_session as session:
|
||||
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
|
||||
announcements = await fetch_ufora_announcements(session)
|
||||
|
||||
for announcement in announcements:
|
||||
await announcements_channel.send(embed=announcement.to_embed())
|
||||
for announcement in announcements:
|
||||
await announcements_channel.send(embed=announcement.to_embed())
|
||||
|
||||
@pull_ufora_announcements.before_loop
|
||||
async def _before_ufora_announcements(self):
|
||||
|
|
@ -47,7 +49,8 @@ class Tasks(commands.Cog):
|
|||
@tasks.loop(hours=24)
|
||||
async def remove_old_ufora_announcements(self):
|
||||
"""Remove all announcements that are over 1 week old, once per day"""
|
||||
await remove_old_announcements(self.client.db_session)
|
||||
async with self.client.db_session as session:
|
||||
await remove_old_announcements(session)
|
||||
|
||||
@remove_old_ufora_announcements.before_loop
|
||||
async def _before_remove_old_ufora_announcements(self):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
import settings
|
||||
from database.crud import ufora_announcements as crud
|
||||
from database.models import UforaCourse
|
||||
from didier.utils.types.datetime import int_to_weekday
|
||||
from didier.utils.types.string import leading
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -88,8 +90,19 @@ class UforaNotification:
|
|||
|
||||
def _get_published(self) -> str:
|
||||
"""Get a formatted string that represents when this announcement was published"""
|
||||
# TODO
|
||||
return "Placeholder :) TODO make the functions to format this"
|
||||
return (
|
||||
f"{int_to_weekday(self.published_dt.weekday())} "
|
||||
f"{leading('0', str(self.published_dt.day))}"
|
||||
"/"
|
||||
f"{leading('0', str(self.published_dt.month))}"
|
||||
"/"
|
||||
f"{self.published_dt.year} "
|
||||
f"om {leading('0', str(self.published_dt.hour))}"
|
||||
":"
|
||||
f"{leading('0', str(self.published_dt.minute))}"
|
||||
":"
|
||||
f"{leading('0', str(self.published_dt.second))}"
|
||||
)
|
||||
|
||||
|
||||
def parse_ids(url: str) -> Optional[tuple[int, int]]:
|
||||
|
|
|
|||
|
|
@ -3,12 +3,13 @@ import sys
|
|||
import traceback
|
||||
|
||||
import discord
|
||||
from discord import Message
|
||||
from discord.ext import commands
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import settings
|
||||
from database.engine import DBSession
|
||||
from didier.utils.prefix import get_prefix
|
||||
from didier.utils.discord.prefix import get_prefix
|
||||
|
||||
|
||||
class Didier(commands.Bot):
|
||||
|
|
@ -88,6 +89,29 @@ class Didier(commands.Bot):
|
|||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
||||
async def on_message(self, message: Message, /) -> None:
|
||||
"""Event triggered when a message is sent"""
|
||||
# Ignore messages by bots
|
||||
if message.author.bot:
|
||||
return
|
||||
|
||||
# Boos react to people that say Dider
|
||||
if "dider" in message.content.lower() and message.author.id != self.user.id:
|
||||
await message.add_reaction(settings.DISCORD_BOOS_REACT)
|
||||
|
||||
# Potential custom command
|
||||
if self._try_invoke_custom_command(message):
|
||||
return
|
||||
|
||||
await self.process_commands(message)
|
||||
|
||||
async def _try_invoke_custom_command(self, message: Message) -> bool:
|
||||
"""Check if the message tries to invoke a custom command
|
||||
If it does, send the reply associated with it
|
||||
"""
|
||||
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
|
||||
return False
|
||||
|
||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||
"""Event triggered when a regular command errors"""
|
||||
# If developing, print everything to stdout so you don't have to
|
||||
|
|
|
|||
|
|
@ -0,0 +1,3 @@
|
|||
def int_to_weekday(number: int) -> str:
|
||||
"""Get the Dutch name of a weekday from the number"""
|
||||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
||||
|
|
@ -0,0 +1,21 @@
|
|||
from typing import Optional
|
||||
|
||||
|
||||
def leading(character: str, string: str, target_length: Optional[int] = 2) -> str:
|
||||
"""Add a leading [character] to [string] to make it length [target_length]
|
||||
Pass None to target length to always do it, no matter the length
|
||||
"""
|
||||
# Cast to string just in case
|
||||
string = str(string)
|
||||
|
||||
# Add no matter what
|
||||
if target_length is None:
|
||||
return character + string
|
||||
|
||||
# String is already long enough
|
||||
if len(string) >= target_length:
|
||||
return string
|
||||
|
||||
frequency = (target_length - len(string)) // len(character)
|
||||
|
||||
return (frequency * character) + string
|
||||
|
|
@ -24,6 +24,8 @@ DISCORD_TOKEN: str = env.str("DISC_TOKEN")
|
|||
DISCORD_READY_MESSAGE: str = env.str("DISC_READY_MESSAGE", "I'M READY I'M READY I'M READY")
|
||||
DISCORD_STATUS_MESSAGE: str = env.str("DISC_STATUS_MESSAGE", "with your Didier Dinks.")
|
||||
DISCORD_TEST_GUILDS: list[int] = env.list("DISC_TEST_GUILDS", [], subcast=int)
|
||||
DISCORD_BOOS_REACT: str = env.str("DISC_BOOS_REACT", "<:boos:629603785840263179>")
|
||||
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISC_CUSTOM_COMMAND_PREFIX", "?")
|
||||
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
||||
|
||||
"""API Keys"""
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
import os
|
||||
from typing import AsyncGenerator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
|
@ -7,6 +7,7 @@ from alembic import command, config
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.engine import engine
|
||||
from didier import Didier
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
|
@ -38,3 +39,16 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
|||
await transaction.rollback()
|
||||
|
||||
await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client() -> Didier:
|
||||
"""Fixture to get a mock Didier instance
|
||||
The mock uses 0 as the id
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
mock_user = MagicMock()
|
||||
mock_user.id = 0
|
||||
mock_client.user = mock_user
|
||||
|
||||
return mock_client
|
||||
|
|
|
|||
|
|
@ -0,0 +1,98 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import custom_commands as crud
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.models import CustomCommand, CustomCommandAlias
|
||||
|
||||
|
||||
async def test_create_command_non_existing(database_session: AsyncSession):
|
||||
"""Test creating a new command when it doesn't exist yet"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
|
||||
commands = (await database_session.execute(select(CustomCommand))).scalars().all()
|
||||
assert len(commands) == 1
|
||||
assert commands[0].name == "name"
|
||||
|
||||
|
||||
async def test_create_command_duplicate_name(database_session: AsyncSession):
|
||||
"""Test creating a command when the name already exists"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_command(database_session, "name", "other response")
|
||||
|
||||
|
||||
async def test_create_command_name_is_alias(database_session: AsyncSession):
|
||||
"""Test creating a command when the name is taken by an alias"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, "name", "n")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_command(database_session, "n", "other response")
|
||||
|
||||
|
||||
async def test_create_alias_non_existing(database_session: AsyncSession):
|
||||
"""Test creating an alias when the name is still free"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
|
||||
await database_session.refresh(command)
|
||||
assert len(command.aliases) == 1
|
||||
assert command.aliases[0].alias == "n"
|
||||
|
||||
|
||||
async def test_create_alias_duplicate(database_session: AsyncSession):
|
||||
"""Test creating an alias when another alias already has this name"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
|
||||
|
||||
async def test_create_alias_is_command(database_session: AsyncSession):
|
||||
"""Test creating an alias when the name is taken by a command"""
|
||||
await crud.create_command(database_session, "n", "response")
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
|
||||
|
||||
async def test_create_alias_match_by_alias(database_session: AsyncSession):
|
||||
"""Test creating an alias for a command when matching the name to another alias"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "a1")
|
||||
alias = await crud.create_alias(database_session, "a1", "a2")
|
||||
assert alias.command == command
|
||||
|
||||
|
||||
async def test_get_command_by_name_exists(database_session: AsyncSession):
|
||||
"""Test getting a command by name"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
command = await crud.get_command(database_session, "name")
|
||||
assert command is not None
|
||||
|
||||
|
||||
async def test_get_command_by_cleaned_name(database_session: AsyncSession):
|
||||
"""Test getting a command by the cleaned version of the name"""
|
||||
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response")
|
||||
found = await crud.get_command(database_session, "capitalizednamewithspaces")
|
||||
assert command == found
|
||||
|
||||
|
||||
async def test_get_command_by_alias(database_session: AsyncSession):
|
||||
"""Test getting a command by an alias"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "a1")
|
||||
await crud.create_alias(database_session, command.name, "a2")
|
||||
|
||||
found = await crud.get_command(database_session, "a1")
|
||||
assert command == found
|
||||
|
||||
|
||||
async def test_get_command_non_existing(database_session: AsyncSession):
|
||||
"""Test getting a command when it doesn't exist"""
|
||||
assert await crud.get_command(database_session, "name") is None
|
||||
|
|
@ -0,0 +1,84 @@
|
|||
from unittest.mock import MagicMock
|
||||
|
||||
from didier import Didier
|
||||
from didier.utils.discord.prefix import get_prefix
|
||||
|
||||
|
||||
def test_get_prefix_didier(mock_client: Didier):
|
||||
"""Test the "didier" prefix"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "didier test"
|
||||
assert get_prefix(mock_client, mock_message) == "didier "
|
||||
|
||||
|
||||
def test_get_prefix_didier_cased(mock_client: Didier):
|
||||
"""Test the "didier" prefix with random casing"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Didier test"
|
||||
assert get_prefix(mock_client, mock_message) == "Didier "
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "DIDIER test"
|
||||
assert get_prefix(mock_client, mock_message) == "DIDIER "
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "DiDiEr test"
|
||||
assert get_prefix(mock_client, mock_message) == "DiDiEr "
|
||||
|
||||
|
||||
def test_get_prefix_default(mock_client: Didier):
|
||||
"""Test the fallback prefix (used when nothing matched)"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "random message"
|
||||
assert get_prefix(mock_client, mock_message) == "didier"
|
||||
|
||||
|
||||
def test_get_prefix_big_d(mock_client: Didier):
|
||||
"""Test the "big d" prefix"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "big d test"
|
||||
assert get_prefix(mock_client, mock_message) == "big d "
|
||||
|
||||
|
||||
def test_get_prefix_big_d_cased(mock_client: Didier):
|
||||
"""Test the "big d" prefix with random casing"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Big d test"
|
||||
assert get_prefix(mock_client, mock_message) == "Big d "
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Big D test"
|
||||
assert get_prefix(mock_client, mock_message) == "Big D "
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "BIG D test"
|
||||
assert get_prefix(mock_client, mock_message) == "BIG D "
|
||||
|
||||
|
||||
def test_get_prefix_mention_username(mock_client: Didier):
|
||||
"""Test the @mention prefix when mentioned by username"""
|
||||
mock_message = MagicMock()
|
||||
prefix = f"<@{mock_client.user.id}> "
|
||||
mock_message.content = f"{prefix}test"
|
||||
|
||||
assert get_prefix(mock_client, mock_message) == prefix
|
||||
|
||||
|
||||
def test_get_prefix_mention_nickname(mock_client: Didier):
|
||||
"""Test the @mention prefix when mentioned by server nickname"""
|
||||
mock_message = MagicMock()
|
||||
prefix = f"<@!{mock_client.user.id}> "
|
||||
mock_message.content = f"{prefix}test"
|
||||
|
||||
assert get_prefix(mock_client, mock_message) == prefix
|
||||
|
||||
|
||||
def test_get_prefix_whitespace(mock_client: Didier):
|
||||
"""Test that variable whitespace doesn't matter"""
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "didiertest"
|
||||
assert get_prefix(mock_client, mock_message) == "didier"
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "didier test"
|
||||
assert get_prefix(mock_client, mock_message) == "didier "
|
||||
|
|
@ -1,2 +0,0 @@
|
|||
def test_dummy(tables):
|
||||
assert True
|
||||
Loading…
Reference in New Issue