From 0834a4ccbcb0567adde04aef36af9f1e00085842 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sun, 24 Jul 2022 01:49:52 +0200 Subject: [PATCH 1/5] Write tests for tasks --- database/crud/birthdays.py | 2 +- tests/test_database/test_crud/test_tasks.py | 63 +++++++++++++++++++++ 2 files changed, 64 insertions(+), 1 deletion(-) create mode 100644 tests/test_database/test_crud/test_tasks.py diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 054d4c5..df59dfc 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -42,4 +42,4 @@ async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> lis months = extract("month", Birthday.birthday) statement = select(Birthday).where((days == day.day) & (months == day.month)) - return list((await session.execute(statement)).scalars()) + return list((await session.execute(statement)).scalars().all()) diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py new file mode 100644 index 0000000..e1e4f97 --- /dev/null +++ b/tests/test_database/test_crud/test_tasks.py @@ -0,0 +1,63 @@ +import datetime + +import pytest +from freezegun import freeze_time +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import tasks as crud +from database.enums import TaskType +from database.models import Task + + +@pytest.fixture +def task_type() -> TaskType: + """Fixture to use the same TaskType in every test""" + return TaskType.BIRTHDAYS + + +@pytest.fixture +async def task(database_session: AsyncSession, task_type: TaskType) -> Task: + """Fixture to create a task""" + task = Task(task=task_type) + database_session.add(task) + await database_session.commit() + return task + + +async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType): + """Test getting a task by its enum type when it exists""" + result = await crud.get_task_by_enum(database_session, task_type) + assert result is not None + assert result == task + + +async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType): + """Test getting a task by its enum type when it doesn't exist""" + result = await crud.get_task_by_enum(database_session, task_type) + assert result is None + + +@freeze_time("2022/07/24") +async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType): + """Test setting the execution time of an existing task""" + await database_session.refresh(task) + assert task.previous_run is None + + await crud.set_last_task_execution_time(database_session, task_type) + await database_session.refresh(task) + assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) + + +@freeze_time("2022/07/24") +async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType): + """Test setting the execution time of a non-existing task""" + statement = select(Task).where(Task.task == task_type) + results = list((await database_session.execute(statement)).scalars().all()) + assert len(results) == 0 + + await crud.set_last_task_execution_time(database_session, task_type) + results = list((await database_session.execute(statement)).scalars().all()) + assert len(results) == 1 + task = results[0] + assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) From edc6343e12e1b11c6e566c188e1f55c394dd1ba7 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sun, 24 Jul 2022 16:39:27 +0200 Subject: [PATCH 2/5] Create owner-guild-only commands, make sync a bit fancier --- didier/cogs/other.py | 5 +++- didier/cogs/owner.py | 36 ++++++++++++++++++++-------- didier/data/apis/urban_dictionary.py | 9 +++++-- didier/utils/discord/flags/owner.py | 9 ++++++- settings.py | 1 + 5 files changed, 46 insertions(+), 14 deletions(-) diff --git a/didier/cogs/other.py b/didier/cogs/other.py index 04175ed..2a3da2a 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -19,7 +19,10 @@ class Other(commands.Cog): async def define(self, ctx: commands.Context, *, query: str): """Look up the definition of a word on the Urban Dictionary""" async with ctx.typing(): - definitions = await urban_dictionary.lookup(self.client.http_session, query) + status_code, definitions = await urban_dictionary.lookup(self.client.http_session, query) + if not definitions: + return await ctx.reply(f"Something went wrong (status {status_code})") + await ctx.reply(embed=definitions[0].to_embed(), mention_author=False) @commands.hybrid_command(name="google", description="Google search", usage="[Query]") diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 9a93fdf..3977b49 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -4,11 +4,12 @@ import discord from discord import app_commands from discord.ext import commands +import settings from database.crud import custom_commands from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier -from didier.utils.discord.flags.owner import EditCustomFlags +from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand @@ -18,8 +19,18 @@ class Owner(commands.Cog): client: Didier # Slash groups - add_slash = app_commands.Group(name="add", description="Add something new to the database") - edit_slash = app_commands.Group(name="edit", description="Edit an existing database entry") + add_slash = app_commands.Group( + name="add", + description="Add something new to the database", + guild_ids=settings.DISCORD_OWNER_GUILDS, + guild_only=True, + ) + edit_slash = app_commands.Group( + name="edit", + description="Edit an existing database entry", + guild_ids=settings.DISCORD_OWNER_GUILDS, + guild_only=True, + ) def __init__(self, client: Didier): self.client = client @@ -31,16 +42,21 @@ class Owner(commands.Cog): """ return await self.client.is_owner(ctx.author) - @commands.command(name="Error") - async def _error(self, ctx: commands.Context): + @commands.command(name="Error", aliases=["Raise"]) + async def _error(self, ctx: commands.Context, message: str = "Debug"): """Raise an exception for debugging purposes""" - raise Exception("Debug") + raise Exception(message) @commands.command(name="Sync") - async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None): + async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None, *, flags: SyncOptionFlags): """Sync all application-commands in Discord""" if guild is not None: - self.client.tree.copy_global_to(guild=guild) + if flags.clear: + self.client.tree.clear_commands(guild=guild) + + if flags.copy_globals: + self.client.tree.copy_global_to(guild=guild) + await self.client.tree.sync(guild=guild) else: await self.client.tree.sync() @@ -52,7 +68,7 @@ class Owner(commands.Cog): """Command group for [add X] message commands""" @add_msg.command(name="Custom") - async def add_custom(self, ctx: commands.Context, name: str, *, response: str): + async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): """Add a new custom command""" async with self.client.db_session as session: try: @@ -63,7 +79,7 @@ class Owner(commands.Cog): await self.client.reject_message(ctx.message) @add_msg.command(name="Alias") - async def add_alias(self, ctx: commands.Context, command: str, alias: str): + async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" async with self.client.db_session as session: try: diff --git a/didier/data/apis/urban_dictionary.py b/didier/data/apis/urban_dictionary.py index 40381ea..e0ebf15 100644 --- a/didier/data/apis/urban_dictionary.py +++ b/didier/data/apis/urban_dictionary.py @@ -1,3 +1,5 @@ +from http import HTTPStatus + from aiohttp import ClientSession from didier.data.embeds.urban_dictionary import Definition @@ -8,10 +10,13 @@ __all__ = ["lookup", "PER_PAGE"] PER_PAGE = 10 -async def lookup(http_session: ClientSession, query: str) -> list[Definition]: +async def lookup(http_session: ClientSession, query: str) -> tuple[int, list[Definition]]: """Fetch the Urban Dictionary definitions for a given word""" url = "https://api.urbandictionary.com/v0/define" async with http_session.get(url, params={"term": query}) as response: + if response.status != HTTPStatus.OK: + return response.status, [] + response_json = await response.json() - return list(map(Definition.parse_obj, response_json["list"])) + return 200, list(map(Definition.parse_obj, response_json["list"])) diff --git a/didier/utils/discord/flags/owner.py b/didier/utils/discord/flags/owner.py index 282957c..b12fb8d 100644 --- a/didier/utils/discord/flags/owner.py +++ b/didier/utils/discord/flags/owner.py @@ -2,7 +2,7 @@ from typing import Optional from didier.utils.discord.flags import PosixFlags -__all__ = ["EditCustomFlags"] +__all__ = ["EditCustomFlags", "SyncOptionFlags"] class EditCustomFlags(PosixFlags): @@ -10,3 +10,10 @@ class EditCustomFlags(PosixFlags): name: Optional[str] = None response: Optional[str] = None + + +class SyncOptionFlags(PosixFlags): + """Flags for the sync command""" + + clear: bool = False + copy_globals: bool = False diff --git a/settings.py b/settings.py index d03b48b..cf20b95 100644 --- a/settings.py +++ b/settings.py @@ -44,6 +44,7 @@ DISCORD_TOKEN: str = env.str("DISCORD_TOKEN") DISCORD_READY_MESSAGE: str = env.str("DISCORD_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didier Dinks.") DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int) +DISCORD_OWNER_GUILDS: Optional[list[int]] = env.list("DISCORD_OWNER_GUILDS", [], subcast=int) or None DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>") DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?") BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None) From 424399b88a78c37b2d06cd503279278a4e8f1502 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sun, 24 Jul 2022 17:09:42 +0200 Subject: [PATCH 3/5] Translate to english --- .flake8 | 2 ++ didier/cogs/other.py | 2 +- didier/cogs/owner.py | 16 ++++++---------- didier/cogs/school.py | 21 ++++++++++++--------- didier/cogs/tasks.py | 16 +++++++++++----- didier/data/embeds/google/google_search.py | 4 ++-- didier/data/embeds/urban_dictionary.py | 8 ++++---- didier/utils/discord/flags/school.py | 12 ++++++++++++ pyproject.toml | 3 ++- 9 files changed, 52 insertions(+), 32 deletions(-) create mode 100644 didier/utils/discord/flags/school.py diff --git a/.flake8 b/.flake8 index 1707912..7e28a5f 100644 --- a/.flake8 +++ b/.flake8 @@ -23,6 +23,8 @@ extend-ignore = D401, # Whitespace before ":" E203, + # Standard pseudo-random generators are not suitable for security/cryptographic purposes. + S311, # Don't require docstrings when overriding a method, # the base method should have a docstring but the rest not ignore-decorators=overrides diff --git a/didier/cogs/other.py b/didier/cogs/other.py index 2a3da2a..de642a7 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -15,7 +15,7 @@ class Other(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Woord]") + @commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Term]") async def define(self, ctx: commands.Context, *, query: str): """Look up the definition of a word on the Urban Dictionary""" async with ctx.typing(): diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 3977b49..9f1e3c6 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -75,7 +75,7 @@ class Owner(commands.Cog): await custom_commands.create_command(session, name, response) await self.client.confirm_message(ctx.message) except DuplicateInsertException: - await ctx.reply("Er bestaat al een commando met deze naam.") + await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) @add_msg.command(name="Alias") @@ -86,19 +86,17 @@ class Owner(commands.Cog): await custom_commands.create_alias(session, command, alias) await self.client.confirm_message(ctx.message) except NoResultFoundException: - await ctx.reply(f'Geen commando gevonden voor "{command}".') + await ctx.reply(f"No command found matching `{command}`.") await self.client.reject_message(ctx.message) except DuplicateInsertException: - await ctx.reply("Er bestaat al een commando met deze naam.") + await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) @add_slash.command(name="custom", description="Add a custom command") async def add_custom_slash(self, interaction: discord.Interaction): """Slash command to add a custom command""" if not await self.client.is_owner(interaction.user): - return interaction.response.send_message( - "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True - ) + return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True) modal = CreateCustomCommand(self.client) await interaction.response.send_modal(modal) @@ -107,9 +105,7 @@ class Owner(commands.Cog): async def add_dad_joke_slash(self, interaction: discord.Interaction): """Slash command to add a dad joke""" if not await self.client.is_owner(interaction.user): - return interaction.response.send_message( - "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True - ) + return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True) modal = AddDadJoke(self.client) await interaction.response.send_modal(modal) @@ -126,7 +122,7 @@ class Owner(commands.Cog): await custom_commands.edit_command(session, command, flags.name, flags.response) return await self.client.confirm_message(ctx.message) except NoResultFoundException: - await ctx.reply(f"Geen commando gevonden voor ``{command}``.") + await ctx.reply(f"No command found matching ``{command}``.") return await self.client.reject_message(ctx.message) @edit_slash.command(name="custom", description="Edit a custom command") diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 716c59d..ee13a0c 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -6,7 +6,7 @@ from discord.ext import commands from database.crud import ufora_courses from didier import Didier -from didier.data import constants +from didier.utils.discord.flags.school import StudyGuideFlags class School(commands.Cog): @@ -36,43 +36,46 @@ class School(commands.Cog): # Didn't fix it, sad if message is None: - return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10) + return await ctx.reply("Found no message to pin.", delete_after=10) + + if message.pinned: + return await ctx.reply("This message is already pinned.", delete_after=10) if message.is_system(): return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.") - await message.pin(reason=f"Didier Pin door {ctx.author.display_name}") + await message.pin(reason=f"Didier Pin by {ctx.author.display_name}") await message.add_reaction("📌") async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message): """Pin a message in the current channel""" # Is already pinned if message.pinned: - return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True) + return await interaction.response.send_message("This message is already pinned.", ephemeral=True) if message.is_system(): return await interaction.response.send_message( "Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True ) - await message.pin(reason=f"Didier Pin door {interaction.user.display_name}") + await message.pin(reason=f"Didier Pin by {interaction.user.display_name}") await message.add_reaction("📌") return await interaction.response.send_message("📌", ephemeral=True) @commands.hybrid_command( - name="fiche", description="Stuurt de link naar de studiefiche voor [Vak]", aliases=["guide", "studiefiche"] + name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"] ) @app_commands.describe(course="vak") - async def study_guide(self, ctx: commands.Context, course: str): + async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): """Create links to study guides""" async with self.client.db_session as session: ufora_course = await ufora_courses.get_course_by_name(session, course) if ufora_course is None: - return await ctx.reply(f"Geen vak gevonden voor ``{course}``", ephemeral=True) + return await ctx.reply(f"Found no course matching ``{course}``", ephemeral=True) return await ctx.reply( - f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{constants.CURRENT_YEAR}", + f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{flags.year}", mention_author=False, ) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 88bcb6e..3764331 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,4 +1,5 @@ import datetime +import random import traceback from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error @@ -18,6 +19,10 @@ DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) +# TODO more messages? +BIRTHDAY_MESSAGES = ["Gelukkige verjaardag {mention}!", "Happy birthday {mention}!"] + + class Tasks(commands.Cog): """Task loops that run periodically @@ -52,12 +57,12 @@ class Tasks(commands.Cog): """ raise NotImplementedError() - @tasks_group.command(name="Force", case_insensitive=True) + @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") async def force_task(self, ctx: commands.Context, name: str): - """Command to force-run a task without waiting for the run time""" + """Command to force-run a task without waiting for the specified run time""" name = name.lower() if name not in self._tasks: - return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False) + return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False) task = self._tasks[name] await task() @@ -76,8 +81,8 @@ class Tasks(commands.Cog): for birthday in birthdays: user = self.client.get_user(birthday.user_id) - # TODO more messages? - await channel.send(f"Gelukkig verjaardag {user.mention}!") + + await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention)) @check_birthdays.before_loop async def _before_check_birthdays(self): @@ -114,6 +119,7 @@ class Tasks(commands.Cog): async def _on_tasks_error(self, error: BaseException): """Error handler for all tasks""" print("".join(traceback.format_exception(type(error), error, error.__traceback__))) + self.client.dispatch("task_error") async def setup(client: Didier): diff --git a/didier/data/embeds/google/google_search.py b/didier/data/embeds/google/google_search.py index 9a8eefc..7d31859 100644 --- a/didier/data/embeds/google/google_search.py +++ b/didier/data/embeds/google/google_search.py @@ -24,11 +24,11 @@ class GoogleSearch(EmbedBaseModel): # Empty embed if not self.data.results: - embed.description = "Geen resultaten gevonden" + embed.description = "Found no results" return embed # Error embed - embed.description = f"Status {self.data.status_code}" + embed.description = f"Something went wrong (status {self.data.status_code})" return embed diff --git a/didier/data/embeds/urban_dictionary.py b/didier/data/embeds/urban_dictionary.py index 14086bb..4f42378 100644 --- a/didier/data/embeds/urban_dictionary.py +++ b/didier/data/embeds/urban_dictionary.py @@ -50,10 +50,10 @@ class Definition(EmbedPydantic): embed = discord.Embed(colour=colours.urban_dictionary_green()) embed.set_author(name="Urban Dictionary") - embed.add_field(name="Woord", value=self.word, inline=True) - embed.add_field(name="Auteur", value=self.author, inline=True) - embed.add_field(name="Definitie", value=self.definition, inline=False) - embed.add_field(name="Voorbeeld", value=self.example or "\u200B", inline=False) + embed.add_field(name="Term", value=self.word, inline=True) + embed.add_field(name="Author", value=self.author, inline=True) + embed.add_field(name="Definition", value=self.definition, inline=False) + embed.add_field(name="Example", value=self.example or "\u200B", inline=False) embed.add_field( name="Rating", value=f"{self.ratio}% ({self.thumbs_up}/{self.thumbs_up + self.thumbs_down})", inline=True ) diff --git a/didier/utils/discord/flags/school.py b/didier/utils/discord/flags/school.py new file mode 100644 index 0000000..f2c4713 --- /dev/null +++ b/didier/utils/discord/flags/school.py @@ -0,0 +1,12 @@ +from typing import Optional + +from didier.data import constants +from didier.utils.discord.flags import PosixFlags + +__all__ = ["StudyGuideFlags"] + + +class StudyGuideFlags(PosixFlags): + """Flags for the study guide command""" + + year: Optional[int] = constants.CURRENT_YEAR diff --git a/pyproject.toml b/pyproject.toml index aa1f485..f59d25e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,8 @@ omit = [ "./didier/data/embeds/*", "./didier/data/flags/*", "./didier/utils/discord/colours.py", - "./didier/utils/discord/constants.py" + "./didier/utils/discord/constants.py", + "./didier/utils/discord/flags/*", ] [tool.isort] From 2e3b4823d0113558f58fc099e321c41df0a0d733 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sun, 24 Jul 2022 18:30:03 +0200 Subject: [PATCH 4/5] Update sync --- didier/cogs/owner.py | 17 +++++++++++++++-- didier/utils/discord/flags/owner.py | 4 +++- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 9f1e3c6..c251b73 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Literal, Optional import discord from discord import app_commands @@ -48,8 +48,21 @@ class Owner(commands.Cog): raise Exception(message) @commands.command(name="Sync") - async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None, *, flags: SyncOptionFlags): + async def sync( + self, + ctx: commands.Context, + guild: Optional[discord.Guild] = None, + symbol: Optional[Literal["."]] = None, + *, + flags: SyncOptionFlags, + ): """Sync all application-commands in Discord""" + # Allow using "." to specify the current guild + # When passing flags, and no guild was specified, default to the current guild as well + # because these don't work on global syncs + if guild is None and (symbol == "." or flags.clear or flags.copy_globals): + guild = ctx.guild + if guild is not None: if flags.clear: self.client.tree.clear_commands(guild=guild) diff --git a/didier/utils/discord/flags/owner.py b/didier/utils/discord/flags/owner.py index b12fb8d..1fc8562 100644 --- a/didier/utils/discord/flags/owner.py +++ b/didier/utils/discord/flags/owner.py @@ -1,5 +1,7 @@ from typing import Optional +from discord.ext import commands + from didier.utils.discord.flags import PosixFlags __all__ = ["EditCustomFlags", "SyncOptionFlags"] @@ -16,4 +18,4 @@ class SyncOptionFlags(PosixFlags): """Flags for the sync command""" clear: bool = False - copy_globals: bool = False + copy_globals: bool = commands.flag(aliases=["copy_global", "copy"], default=False) From b54aed24e86952a7e151c7cd99bd7e51ea5d9db5 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sun, 24 Jul 2022 21:35:38 +0200 Subject: [PATCH 5/5] Log errors in Discord channels --- didier/cogs/discord.py | 6 ++++ didier/cogs/owner.py | 2 +- didier/data/embeds/error_embed.py | 47 +++++++++++++++++++++++++++ didier/didier.py | 53 ++++++++++++++++++++++++++++--- readme.md | 2 ++ 5 files changed, 104 insertions(+), 6 deletions(-) create mode 100644 didier/data/embeds/error_embed.py diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index a73ead3..db9ae7d 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -49,6 +49,12 @@ class Discord(commands.Cog): await birthdays.add_birthday(session, ctx.author.id, date) await self.client.confirm_message(ctx.message) + @commands.command(name="Join", usage="[Thread]") + async def join(self, ctx: commands.Context, thread: discord.Thread): + """Make Didier join a thread""" + if thread.me is not None: + return await ctx.reply() + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index c251b73..30090df 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -43,7 +43,7 @@ class Owner(commands.Cog): return await self.client.is_owner(ctx.author) @commands.command(name="Error", aliases=["Raise"]) - async def _error(self, ctx: commands.Context, message: str = "Debug"): + async def _error(self, ctx: commands.Context, *, message: str = "Debug"): """Raise an exception for debugging purposes""" raise Exception(message) diff --git a/didier/data/embeds/error_embed.py b/didier/data/embeds/error_embed.py new file mode 100644 index 0000000..0dc1b80 --- /dev/null +++ b/didier/data/embeds/error_embed.py @@ -0,0 +1,47 @@ +import traceback + +import discord +from discord.ext import commands + +from didier.utils.discord.constants import Limits +from didier.utils.types.string import abbreviate + +__all__ = ["create_error_embed"] + + +def _get_traceback(exception: Exception) -> str: + """Get a proper representation of the exception""" + tb = traceback.format_exception(type(exception), exception, exception.__traceback__) + error_string = "" + for line in tb: + # Don't add endless tracebacks + if line.strip().startswith("The above exception was the direct cause of"): + break + + # Escape Discord markdown formatting + error_string += line.replace(r"*", r"\*").replace(r"_", r"\_") + if line.strip(): + error_string += "\n" + + return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH) + + +def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed: + """Create an embed for the traceback of an exception""" + description = _get_traceback(exception) + + if ctx.guild is None: + origin = "DM" + else: + origin = f"{ctx.channel.mention} ({ctx.guild.name})" + + invocation = f"{ctx.author.display_name} in {origin}" + + embed = discord.Embed(colour=discord.Colour.red()) + embed.set_author(name="Error") + embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True) + embed.add_field(name="Context", value=invocation, inline=True) + embed.add_field(name="Exception", value=abbreviate(str(exception), Limits.EMBED_FIELD_VALUE_LENGTH), inline=False) + embed.add_field(name="Traceback", value=description, inline=False) + + return embed diff --git a/didier/didier.py b/didier/didier.py index 2f07372..d5745c2 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -10,6 +10,7 @@ import settings from database.crud import custom_commands from database.engine import DBSession from database.utils.caches import CacheManager +from didier.data.embeds.error_embed import create_error_embed from didier.utils.discord.prefix import get_prefix __all__ = ["Didier"] @@ -139,6 +140,9 @@ class Didier(commands.Bot): await self.process_commands(message) + # TODO easter eggs + # TODO stats + async def _try_invoke_custom_command(self, message: discord.Message) -> bool: """Check if the message tries to invoke a custom command @@ -162,11 +166,50 @@ class Didier(commands.Bot): # Nothing found return False - async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None: - """Event triggered when a regular command errors""" - # Print everything to the logs/stderr - await super().on_command_error(context, exception) + async def on_thread_create(self, thread: discord.Thread): + """Event triggered when a new thread is created""" + await thread.join() - # If developing, do nothing special + async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /): + """Event triggered when a regular command errors""" + # If working locally, print everything to your console if settings.SANDBOX: + await super().on_command_error(ctx, exception) return + + # If commands have their own error handler, let it handle the error instead + if hasattr(ctx.command, "on_error"): + return + + # Ignore exceptions that aren't important + if isinstance( + exception, + ( + commands.CommandNotFound, + commands.CheckFailure, + commands.TooManyArguments, + ), + ): + return + + # Print everything that we care about to the logs/stderr + await super().on_command_error(ctx, exception) + + if isinstance(exception, commands.MessageNotFound): + return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10) + + if isinstance( + exception, + ( + commands.BadArgument, + commands.MissingRequiredArgument, + commands.UnexpectedQuoteError, + commands.ExpectedClosingQuoteError, + ), + ): + return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10) + + if settings.ERRORS_CHANNEL is not None: + embed = create_error_embed(ctx, exception) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) diff --git a/readme.md b/readme.md index dae8347..1060271 100644 --- a/readme.md +++ b/readme.md @@ -61,3 +61,5 @@ black flake8 mypy ``` + +It's also convenient to have code-formatting happen automatically on-save. The [`Black documentation`](https://black.readthedocs.io/en/stable/integrations/editors.html) explains how to set this up for different types of editors.