Compare commits

...

5 Commits

Author SHA1 Message Date
stijndcl b54aed24e8 Log errors in Discord channels 2022-07-24 21:35:38 +02:00
stijndcl 2e3b4823d0 Update sync 2022-07-24 18:30:03 +02:00
stijndcl 424399b88a Translate to english 2022-07-24 17:09:42 +02:00
stijndcl edc6343e12 Create owner-guild-only commands, make sync a bit fancier 2022-07-24 16:39:27 +02:00
stijndcl 0834a4ccbc Write tests for tasks 2022-07-24 01:49:52 +02:00
18 changed files with 281 additions and 53 deletions

View File

@ -23,6 +23,8 @@ extend-ignore =
D401, D401,
# Whitespace before ":" # Whitespace before ":"
E203, E203,
# Standard pseudo-random generators are not suitable for security/cryptographic purposes.
S311,
# Don't require docstrings when overriding a method, # Don't require docstrings when overriding a method,
# the base method should have a docstring but the rest not # the base method should have a docstring but the rest not
ignore-decorators=overrides ignore-decorators=overrides

View File

@ -42,4 +42,4 @@ async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> lis
months = extract("month", Birthday.birthday) months = extract("month", Birthday.birthday)
statement = select(Birthday).where((days == day.day) & (months == day.month)) statement = select(Birthday).where((days == day.day) & (months == day.month))
return list((await session.execute(statement)).scalars()) return list((await session.execute(statement)).scalars().all())

View File

@ -49,6 +49,12 @@ class Discord(commands.Cog):
await birthdays.add_birthday(session, ctx.author.id, date) await birthdays.add_birthday(session, ctx.author.id, date)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
@commands.command(name="Join", usage="[Thread]")
async def join(self, ctx: commands.Context, thread: discord.Thread):
"""Make Didier join a thread"""
if thread.me is not None:
return await ctx.reply()
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""

View File

@ -15,11 +15,14 @@ class Other(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Woord]") @commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Term]")
async def define(self, ctx: commands.Context, *, query: str): async def define(self, ctx: commands.Context, *, query: str):
"""Look up the definition of a word on the Urban Dictionary""" """Look up the definition of a word on the Urban Dictionary"""
async with ctx.typing(): async with ctx.typing():
definitions = await urban_dictionary.lookup(self.client.http_session, query) status_code, definitions = await urban_dictionary.lookup(self.client.http_session, query)
if not definitions:
return await ctx.reply(f"Something went wrong (status {status_code})")
await ctx.reply(embed=definitions[0].to_embed(), mention_author=False) await ctx.reply(embed=definitions[0].to_embed(), mention_author=False)
@commands.hybrid_command(name="google", description="Google search", usage="[Query]") @commands.hybrid_command(name="google", description="Google search", usage="[Query]")

View File

@ -1,14 +1,15 @@
from typing import Optional from typing import Literal, Optional
import discord import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
import settings
from database.crud import custom_commands from database.crud import custom_commands
from database.exceptions.constraints import DuplicateInsertException from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from didier import Didier from didier import Didier
from didier.utils.discord.flags.owner import EditCustomFlags from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand
@ -18,8 +19,18 @@ class Owner(commands.Cog):
client: Didier client: Didier
# Slash groups # Slash groups
add_slash = app_commands.Group(name="add", description="Add something new to the database") add_slash = app_commands.Group(
edit_slash = app_commands.Group(name="edit", description="Edit an existing database entry") name="add",
description="Add something new to the database",
guild_ids=settings.DISCORD_OWNER_GUILDS,
guild_only=True,
)
edit_slash = app_commands.Group(
name="edit",
description="Edit an existing database entry",
guild_ids=settings.DISCORD_OWNER_GUILDS,
guild_only=True,
)
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@ -31,16 +42,34 @@ class Owner(commands.Cog):
""" """
return await self.client.is_owner(ctx.author) return await self.client.is_owner(ctx.author)
@commands.command(name="Error") @commands.command(name="Error", aliases=["Raise"])
async def _error(self, ctx: commands.Context): async def _error(self, ctx: commands.Context, *, message: str = "Debug"):
"""Raise an exception for debugging purposes""" """Raise an exception for debugging purposes"""
raise Exception("Debug") raise Exception(message)
@commands.command(name="Sync") @commands.command(name="Sync")
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): async def sync(
self,
ctx: commands.Context,
guild: Optional[discord.Guild] = None,
symbol: Optional[Literal["."]] = None,
*,
flags: SyncOptionFlags,
):
"""Sync all application-commands in Discord""" """Sync all application-commands in Discord"""
# Allow using "." to specify the current guild
# When passing flags, and no guild was specified, default to the current guild as well
# because these don't work on global syncs
if guild is None and (symbol == "." or flags.clear or flags.copy_globals):
guild = ctx.guild
if guild is not None: if guild is not None:
self.client.tree.copy_global_to(guild=guild) if flags.clear:
self.client.tree.clear_commands(guild=guild)
if flags.copy_globals:
self.client.tree.copy_global_to(guild=guild)
await self.client.tree.sync(guild=guild) await self.client.tree.sync(guild=guild)
else: else:
await self.client.tree.sync() await self.client.tree.sync()
@ -52,37 +81,35 @@ class Owner(commands.Cog):
"""Command group for [add X] message commands""" """Command group for [add X] message commands"""
@add_msg.command(name="Custom") @add_msg.command(name="Custom")
async def add_custom(self, ctx: commands.Context, name: str, *, response: str): async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command""" """Add a new custom command"""
async with self.client.db_session as session: async with self.client.db_session as session:
try: try:
await custom_commands.create_command(session, name, response) await custom_commands.create_command(session, name, response)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
except DuplicateInsertException: except DuplicateInsertException:
await ctx.reply("Er bestaat al een commando met deze naam.") await ctx.reply("There is already a command with this name.")
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@add_msg.command(name="Alias") @add_msg.command(name="Alias")
async def add_alias(self, ctx: commands.Context, command: str, alias: str): async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command""" """Add a new alias for a custom command"""
async with self.client.db_session as session: async with self.client.db_session as session:
try: try:
await custom_commands.create_alias(session, command, alias) await custom_commands.create_alias(session, command, alias)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
except NoResultFoundException: except NoResultFoundException:
await ctx.reply(f'Geen commando gevonden voor "{command}".') await ctx.reply(f"No command found matching `{command}`.")
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
except DuplicateInsertException: except DuplicateInsertException:
await ctx.reply("Er bestaat al een commando met deze naam.") await ctx.reply("There is already a command with this name.")
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@add_slash.command(name="custom", description="Add a custom command") @add_slash.command(name="custom", description="Add a custom command")
async def add_custom_slash(self, interaction: discord.Interaction): async def add_custom_slash(self, interaction: discord.Interaction):
"""Slash command to add a custom command""" """Slash command to add a custom command"""
if not await self.client.is_owner(interaction.user): if not await self.client.is_owner(interaction.user):
return interaction.response.send_message( return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True)
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
)
modal = CreateCustomCommand(self.client) modal = CreateCustomCommand(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@ -91,9 +118,7 @@ class Owner(commands.Cog):
async def add_dad_joke_slash(self, interaction: discord.Interaction): async def add_dad_joke_slash(self, interaction: discord.Interaction):
"""Slash command to add a dad joke""" """Slash command to add a dad joke"""
if not await self.client.is_owner(interaction.user): if not await self.client.is_owner(interaction.user):
return interaction.response.send_message( return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True)
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
)
modal = AddDadJoke(self.client) modal = AddDadJoke(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@ -110,7 +135,7 @@ class Owner(commands.Cog):
await custom_commands.edit_command(session, command, flags.name, flags.response) await custom_commands.edit_command(session, command, flags.name, flags.response)
return await self.client.confirm_message(ctx.message) return await self.client.confirm_message(ctx.message)
except NoResultFoundException: except NoResultFoundException:
await ctx.reply(f"Geen commando gevonden voor ``{command}``.") await ctx.reply(f"No command found matching ``{command}``.")
return await self.client.reject_message(ctx.message) return await self.client.reject_message(ctx.message)
@edit_slash.command(name="custom", description="Edit a custom command") @edit_slash.command(name="custom", description="Edit a custom command")

View File

@ -6,7 +6,7 @@ from discord.ext import commands
from database.crud import ufora_courses from database.crud import ufora_courses
from didier import Didier from didier import Didier
from didier.data import constants from didier.utils.discord.flags.school import StudyGuideFlags
class School(commands.Cog): class School(commands.Cog):
@ -36,43 +36,46 @@ class School(commands.Cog):
# Didn't fix it, sad # Didn't fix it, sad
if message is None: if message is None:
return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10) return await ctx.reply("Found no message to pin.", delete_after=10)
if message.pinned:
return await ctx.reply("This message is already pinned.", delete_after=10)
if message.is_system(): if message.is_system():
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.") return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
await message.pin(reason=f"Didier Pin door {ctx.author.display_name}") await message.pin(reason=f"Didier Pin by {ctx.author.display_name}")
await message.add_reaction("📌") await message.add_reaction("📌")
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message): async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
"""Pin a message in the current channel""" """Pin a message in the current channel"""
# Is already pinned # Is already pinned
if message.pinned: if message.pinned:
return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True) return await interaction.response.send_message("This message is already pinned.", ephemeral=True)
if message.is_system(): if message.is_system():
return await interaction.response.send_message( return await interaction.response.send_message(
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True "Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
) )
await message.pin(reason=f"Didier Pin door {interaction.user.display_name}") await message.pin(reason=f"Didier Pin by {interaction.user.display_name}")
await message.add_reaction("📌") await message.add_reaction("📌")
return await interaction.response.send_message("📌", ephemeral=True) return await interaction.response.send_message("📌", ephemeral=True)
@commands.hybrid_command( @commands.hybrid_command(
name="fiche", description="Stuurt de link naar de studiefiche voor [Vak]", aliases=["guide", "studiefiche"] name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"]
) )
@app_commands.describe(course="vak") @app_commands.describe(course="vak")
async def study_guide(self, ctx: commands.Context, course: str): async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
"""Create links to study guides""" """Create links to study guides"""
async with self.client.db_session as session: async with self.client.db_session as session:
ufora_course = await ufora_courses.get_course_by_name(session, course) ufora_course = await ufora_courses.get_course_by_name(session, course)
if ufora_course is None: if ufora_course is None:
return await ctx.reply(f"Geen vak gevonden voor ``{course}``", ephemeral=True) return await ctx.reply(f"Found no course matching ``{course}``", ephemeral=True)
return await ctx.reply( return await ctx.reply(
f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{constants.CURRENT_YEAR}", f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{flags.year}",
mention_author=False, mention_author=False,
) )

View File

@ -1,4 +1,5 @@
import datetime import datetime
import random
import traceback import traceback
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
@ -18,6 +19,10 @@ DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
# TODO more messages?
BIRTHDAY_MESSAGES = ["Gelukkige verjaardag {mention}!", "Happy birthday {mention}!"]
class Tasks(commands.Cog): class Tasks(commands.Cog):
"""Task loops that run periodically """Task loops that run periodically
@ -52,12 +57,12 @@ class Tasks(commands.Cog):
""" """
raise NotImplementedError() raise NotImplementedError()
@tasks_group.command(name="Force", case_insensitive=True) @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]")
async def force_task(self, ctx: commands.Context, name: str): async def force_task(self, ctx: commands.Context, name: str):
"""Command to force-run a task without waiting for the run time""" """Command to force-run a task without waiting for the specified run time"""
name = name.lower() name = name.lower()
if name not in self._tasks: if name not in self._tasks:
return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False) return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False)
task = self._tasks[name] task = self._tasks[name]
await task() await task()
@ -76,8 +81,8 @@ class Tasks(commands.Cog):
for birthday in birthdays: for birthday in birthdays:
user = self.client.get_user(birthday.user_id) user = self.client.get_user(birthday.user_id)
# TODO more messages?
await channel.send(f"Gelukkig verjaardag {user.mention}!") await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention))
@check_birthdays.before_loop @check_birthdays.before_loop
async def _before_check_birthdays(self): async def _before_check_birthdays(self):
@ -114,6 +119,7 @@ class Tasks(commands.Cog):
async def _on_tasks_error(self, error: BaseException): async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks""" """Error handler for all tasks"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__))) print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
self.client.dispatch("task_error")
async def setup(client: Didier): async def setup(client: Didier):

View File

@ -1,3 +1,5 @@
from http import HTTPStatus
from aiohttp import ClientSession from aiohttp import ClientSession
from didier.data.embeds.urban_dictionary import Definition from didier.data.embeds.urban_dictionary import Definition
@ -8,10 +10,13 @@ __all__ = ["lookup", "PER_PAGE"]
PER_PAGE = 10 PER_PAGE = 10
async def lookup(http_session: ClientSession, query: str) -> list[Definition]: async def lookup(http_session: ClientSession, query: str) -> tuple[int, list[Definition]]:
"""Fetch the Urban Dictionary definitions for a given word""" """Fetch the Urban Dictionary definitions for a given word"""
url = "https://api.urbandictionary.com/v0/define" url = "https://api.urbandictionary.com/v0/define"
async with http_session.get(url, params={"term": query}) as response: async with http_session.get(url, params={"term": query}) as response:
if response.status != HTTPStatus.OK:
return response.status, []
response_json = await response.json() response_json = await response.json()
return list(map(Definition.parse_obj, response_json["list"])) return 200, list(map(Definition.parse_obj, response_json["list"]))

View File

@ -0,0 +1,47 @@
import traceback
import discord
from discord.ext import commands
from didier.utils.discord.constants import Limits
from didier.utils.types.string import abbreviate
__all__ = ["create_error_embed"]
def _get_traceback(exception: Exception) -> str:
"""Get a proper representation of the exception"""
tb = traceback.format_exception(type(exception), exception, exception.__traceback__)
error_string = ""
for line in tb:
# Don't add endless tracebacks
if line.strip().startswith("The above exception was the direct cause of"):
break
# Escape Discord markdown formatting
error_string += line.replace(r"*", r"\*").replace(r"_", r"\_")
if line.strip():
error_string += "\n"
return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH)
def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed:
"""Create an embed for the traceback of an exception"""
description = _get_traceback(exception)
if ctx.guild is None:
origin = "DM"
else:
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
invocation = f"{ctx.author.display_name} in {origin}"
embed = discord.Embed(colour=discord.Colour.red())
embed.set_author(name="Error")
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
embed.add_field(name="Context", value=invocation, inline=True)
embed.add_field(name="Exception", value=abbreviate(str(exception), Limits.EMBED_FIELD_VALUE_LENGTH), inline=False)
embed.add_field(name="Traceback", value=description, inline=False)
return embed

View File

@ -24,11 +24,11 @@ class GoogleSearch(EmbedBaseModel):
# Empty embed # Empty embed
if not self.data.results: if not self.data.results:
embed.description = "Geen resultaten gevonden" embed.description = "Found no results"
return embed return embed
# Error embed # Error embed
embed.description = f"Status {self.data.status_code}" embed.description = f"Something went wrong (status {self.data.status_code})"
return embed return embed

View File

@ -50,10 +50,10 @@ class Definition(EmbedPydantic):
embed = discord.Embed(colour=colours.urban_dictionary_green()) embed = discord.Embed(colour=colours.urban_dictionary_green())
embed.set_author(name="Urban Dictionary") embed.set_author(name="Urban Dictionary")
embed.add_field(name="Woord", value=self.word, inline=True) embed.add_field(name="Term", value=self.word, inline=True)
embed.add_field(name="Auteur", value=self.author, inline=True) embed.add_field(name="Author", value=self.author, inline=True)
embed.add_field(name="Definitie", value=self.definition, inline=False) embed.add_field(name="Definition", value=self.definition, inline=False)
embed.add_field(name="Voorbeeld", value=self.example or "\u200B", inline=False) embed.add_field(name="Example", value=self.example or "\u200B", inline=False)
embed.add_field( embed.add_field(
name="Rating", value=f"{self.ratio}% ({self.thumbs_up}/{self.thumbs_up + self.thumbs_down})", inline=True name="Rating", value=f"{self.ratio}% ({self.thumbs_up}/{self.thumbs_up + self.thumbs_down})", inline=True
) )

View File

@ -10,6 +10,7 @@ import settings
from database.crud import custom_commands from database.crud import custom_commands
from database.engine import DBSession from database.engine import DBSession
from database.utils.caches import CacheManager from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed
from didier.utils.discord.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
__all__ = ["Didier"] __all__ = ["Didier"]
@ -139,6 +140,9 @@ class Didier(commands.Bot):
await self.process_commands(message) await self.process_commands(message)
# TODO easter eggs
# TODO stats
async def _try_invoke_custom_command(self, message: discord.Message) -> bool: async def _try_invoke_custom_command(self, message: discord.Message) -> bool:
"""Check if the message tries to invoke a custom command """Check if the message tries to invoke a custom command
@ -162,11 +166,50 @@ class Didier(commands.Bot):
# Nothing found # Nothing found
return False return False
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: async def on_thread_create(self, thread: discord.Thread):
"""Event triggered when a regular command errors""" """Event triggered when a new thread is created"""
# Print everything to the logs/stderr await thread.join()
await super().on_command_error(context, exception)
# If developing, do nothing special async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /):
"""Event triggered when a regular command errors"""
# If working locally, print everything to your console
if settings.SANDBOX: if settings.SANDBOX:
await super().on_command_error(ctx, exception)
return return
# If commands have their own error handler, let it handle the error instead
if hasattr(ctx.command, "on_error"):
return
# Ignore exceptions that aren't important
if isinstance(
exception,
(
commands.CommandNotFound,
commands.CheckFailure,
commands.TooManyArguments,
),
):
return
# Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception)
if isinstance(exception, commands.MessageNotFound):
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
if isinstance(
exception,
(
commands.BadArgument,
commands.MissingRequiredArgument,
commands.UnexpectedQuoteError,
commands.ExpectedClosingQuoteError,
),
):
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
if settings.ERRORS_CHANNEL is not None:
embed = create_error_embed(ctx, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed)

View File

@ -1,8 +1,10 @@
from typing import Optional from typing import Optional
from discord.ext import commands
from didier.utils.discord.flags import PosixFlags from didier.utils.discord.flags import PosixFlags
__all__ = ["EditCustomFlags"] __all__ = ["EditCustomFlags", "SyncOptionFlags"]
class EditCustomFlags(PosixFlags): class EditCustomFlags(PosixFlags):
@ -10,3 +12,10 @@ class EditCustomFlags(PosixFlags):
name: Optional[str] = None name: Optional[str] = None
response: Optional[str] = None response: Optional[str] = None
class SyncOptionFlags(PosixFlags):
"""Flags for the sync command"""
clear: bool = False
copy_globals: bool = commands.flag(aliases=["copy_global", "copy"], default=False)

View File

@ -0,0 +1,12 @@
from typing import Optional
from didier.data import constants
from didier.utils.discord.flags import PosixFlags
__all__ = ["StudyGuideFlags"]
class StudyGuideFlags(PosixFlags):
"""Flags for the study guide command"""
year: Optional[int] = constants.CURRENT_YEAR

View File

@ -18,7 +18,8 @@ omit = [
"./didier/data/embeds/*", "./didier/data/embeds/*",
"./didier/data/flags/*", "./didier/data/flags/*",
"./didier/utils/discord/colours.py", "./didier/utils/discord/colours.py",
"./didier/utils/discord/constants.py" "./didier/utils/discord/constants.py",
"./didier/utils/discord/flags/*",
] ]
[tool.isort] [tool.isort]

View File

@ -61,3 +61,5 @@ black
flake8 flake8
mypy mypy
``` ```
It's also convenient to have code-formatting happen automatically on-save. The [`Black documentation`](https://black.readthedocs.io/en/stable/integrations/editors.html) explains how to set this up for different types of editors.

View File

@ -44,6 +44,7 @@ DISCORD_TOKEN: str = env.str("DISCORD_TOKEN")
DISCORD_READY_MESSAGE: str = env.str("DISCORD_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_READY_MESSAGE: str = env.str("DISCORD_READY_MESSAGE", "I'M READY I'M READY I'M READY")
DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didier Dinks.") DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didier Dinks.")
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int) DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
DISCORD_OWNER_GUILDS: Optional[list[int]] = env.list("DISCORD_OWNER_GUILDS", [], subcast=int) or None
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>") DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?") DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None) BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)

View File

@ -0,0 +1,63 @@
import datetime
import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import tasks as crud
from database.enums import TaskType
from database.models import Task
@pytest.fixture
def task_type() -> TaskType:
"""Fixture to use the same TaskType in every test"""
return TaskType.BIRTHDAYS
@pytest.fixture
async def task(database_session: AsyncSession, task_type: TaskType) -> Task:
"""Fixture to create a task"""
task = Task(task=task_type)
database_session.add(task)
await database_session.commit()
return task
async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType):
"""Test getting a task by its enum type when it exists"""
result = await crud.get_task_by_enum(database_session, task_type)
assert result is not None
assert result == task
async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType):
"""Test getting a task by its enum type when it doesn't exist"""
result = await crud.get_task_by_enum(database_session, task_type)
assert result is None
@freeze_time("2022/07/24")
async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType):
"""Test setting the execution time of an existing task"""
await database_session.refresh(task)
assert task.previous_run is None
await crud.set_last_task_execution_time(database_session, task_type)
await database_session.refresh(task)
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)
@freeze_time("2022/07/24")
async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType):
"""Test setting the execution time of a non-existing task"""
statement = select(Task).where(Task.task == task_type)
results = list((await database_session.execute(statement)).scalars().all())
assert len(results) == 0
await crud.set_last_task_execution_time(database_session, task_type)
results = list((await database_session.execute(statement)).scalars().all())
assert len(results) == 1
task = results[0]
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)