diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index 925e645..d696e50 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[ if query is not None: statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%")) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]: diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index eac9c7e..efbc689 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo async def get_all_commands(session: AsyncSession) -> list[CustomCommand]: """Get a list of all commands""" statement = select(CustomCommand) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index a539518..78d623f 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -38,4 +38,4 @@ async def get_deadlines( statement = statement.where(Deadline.course_id == course.course_id) statement = statement.options(selectinload(Deadline.course)) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() diff --git a/database/crud/easter_eggs.py b/database/crud/easter_eggs.py index f4a55ed..d4c25d9 100644 --- a/database/crud/easter_eggs.py +++ b/database/crud/easter_eggs.py @@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"] async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: """Return a list of all easter eggs""" statement = select(EasterEgg) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() diff --git a/database/crud/events.py b/database/crud/events.py index e8dfac5..19887e6 100644 --- a/database/crud/events.py +++ b/database/crud/events.py @@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: """Get a list of all upcoming events""" statement = select(Event).where(Event.timestamp > now) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: diff --git a/database/crud/free_games.py b/database/crud/free_games.py index e0dab5b..b2d835d 100644 --- a/database/crud/free_games.py +++ b/database/crud/free_games.py @@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]): async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]: """Filter a list of game IDs down to the ones that aren't in the database yet""" statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids)) - matches: list[int] = list((await session.execute(statement)).scalars().all()) + matches: list[int] = (await session.execute(statement)).scalars().all() return list(set(game_ids).difference(matches)) diff --git a/database/crud/github.py b/database/crud/github.py index a352bae..0d32377 100644 --- a/database/crud/github.py +++ b/database/crud/github.py @@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]: """Get a user's GitHub links""" statement = select(GitHubLink).where(GitHubLink.user_id == user_id) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() diff --git a/database/crud/links.py b/database/crud/links.py index 20bcab3..495e0f3 100644 --- a/database/crud/links.py +++ b/database/crud/links.py @@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] async def get_all_links(session: AsyncSession) -> list[Link]: """Get a list of all links""" statement = select(Link) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def add_link(session: AsyncSession, name: str, url: str) -> Link: diff --git a/database/crud/memes.py b/database/crud/memes.py index b1ed1e0..ab288aa 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: """Get a list of all memes""" statement = select(MemeTemplate) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: diff --git a/database/crud/reminders.py b/database/crud/reminders.py index 78350e6..007a779 100644 --- a/database/crud/reminders.py +++ b/database/crud/reminders.py @@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"] async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]: """Get a list of all Reminders for a given category""" statement = select(Reminder).where(Reminder.category == category) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool: diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 06c2b58..688bcc7 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]: """Get all courses where announcements are enabled""" statement = select(UforaCourse).where(UforaCourse.log_announcements) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def create_new_announcement( diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index d4cf728..5374c07 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor # Search case-insensitively query = query.lower() - course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) - course_result = (await session.execute(course_statement)).scalars().first() - if course_result: - return course_result + statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) + result = (await session.execute(statement)).scalars().first() + if result: + return result - alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) - alias_result = (await session.execute(alias_statement)).scalars().first() - return alias_result.course if alias_result else None + statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) + result = (await session.execute(statement)).scalars().first() + return result.course if result else None diff --git a/database/utils/caches.py b/database/utils/caches.py index 2df9dac..248eb5f 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -69,7 +69,7 @@ class LinkCache(DatabaseCache): self.clear() all_links = await links.get_all_links(database_session) - self.data = list(map(lambda link: link.name, all_links)) + self.data = list(map(lambda l: l.name, all_links)) self.data.sort() self.data_transformed = list(map(str.lower, self.data)) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 4049654..709a461 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -25,7 +25,7 @@ class Currency(commands.Cog): super().__init__() self.client = client - @commands.command(name="award") # type: ignore[arg-type] + @commands.command(name="award") @commands.check(is_owner) async def award( self, @@ -49,9 +49,7 @@ class Currency(commands.Cog): bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue()) - - if ctx.author.avatar is not None: - embed.set_thumbnail(url=ctx.author.avatar.url) + embed.set_thumbnail(url=ctx.author.avatar.url) embed.add_field(name="Interest level", value=bank.interest_level) embed.add_field(name="Capacity level", value=bank.capacity_level) @@ -59,9 +57,7 @@ class Currency(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @bank.group( # type: ignore[arg-type] - name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True - ) + @bank.group(name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True) async def bank_upgrades(self, ctx: commands.Context): """List the upgrades you can buy & their prices.""" async with self.client.postgres_session as session: @@ -81,7 +77,7 @@ class Currency(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @bank_upgrades.command(name="capacity", aliases=["c"]) # type: ignore[arg-type] + @bank_upgrades.command(name="capacity", aliases=["c"]) async def bank_upgrade_capacity(self, ctx: commands.Context): """Upgrade the capacity level of your bank.""" async with self.client.postgres_session as session: @@ -92,7 +88,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @bank_upgrades.command(name="interest", aliases=["i"]) # type: ignore[arg-type] + @bank_upgrades.command(name="interest", aliases=["i"]) async def bank_upgrade_interest(self, ctx: commands.Context): """Upgrade the interest level of your bank.""" async with self.client.postgres_session as session: @@ -103,7 +99,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @bank_upgrades.command(name="rob", aliases=["r"]) # type: ignore[arg-type] + @bank_upgrades.command(name="rob", aliases=["r"]) async def bank_upgrade_rob(self, ctx: commands.Context): """Upgrade the rob level of your bank.""" async with self.client.postgres_session as session: @@ -114,7 +110,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @commands.hybrid_command(name="dinks") # type: ignore[arg-type] + @commands.hybrid_command(name="dinks") async def dinks(self, ctx: commands.Context): """Check your Didier Dinks.""" async with self.client.postgres_session as session: @@ -122,7 +118,7 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) - @commands.command(name="invest", aliases=["deposit", "dep"]) # type: ignore[arg-type] + @commands.command(name="invest", aliases=["deposit", "dep"]) async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]): """Invest `amount` Didier Dinks into your bank. @@ -148,7 +144,7 @@ class Currency(commands.Cog): f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False ) - @commands.hybrid_command(name="nightly") # type: ignore[arg-type] + @commands.hybrid_command(name="nightly") async def nightly(self, ctx: commands.Context): """Claim nightly Didier Dinks.""" async with self.client.postgres_session as session: diff --git a/didier/cogs/debug_cog.py b/didier/cogs/debug_cog.py index a0e4747..2d03b9f 100644 --- a/didier/cogs/debug_cog.py +++ b/didier/cogs/debug_cog.py @@ -13,7 +13,7 @@ class DebugCog(commands.Cog): self.client = client @overrides - async def cog_check(self, ctx: commands.Context) -> bool: # type:ignore[override] + async def cog_check(self, ctx: commands.Context) -> bool: return await self.client.is_owner(ctx.author) @commands.command(aliases=["Dev"]) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index fdfa05d..4d9b423 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, cast +from typing import Optional import discord from discord import app_commands @@ -17,7 +17,6 @@ from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar -from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.constants import Limits from didier.utils.timer import Timer from didier.utils.types.datetime import localize, str_to_date, tz_aware_now @@ -61,19 +60,9 @@ class Discord(commands.Cog): event = await events.get_event_by_id(session, event_id) if event is None: - return await self.client.log_error(f"Unable to find event with id {event_id}.", log_to_discord=True) + return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True) channel = self.client.get_channel(event.notification_channel) - if channel is None: - return await self.client.log_error( - f"Unable to fetch channel for event `#{event_id}` (id `{event.notification_channel}`)." - ) - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Channel for event `#{event_id}` (id `{event.notification_channel}`) is not messageable." - ) - human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M") embed = discord.Embed(title=event.name, colour=discord.Colour.blue()) @@ -92,7 +81,7 @@ class Discord(commands.Cog): self.client.loop.create_task(self.timer.update()) @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) - async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None): + async def birthday(self, ctx: commands.Context, user: discord.User = None): """Command to check the birthday of `user`. Not passing an argument for `user` will show yours instead. @@ -109,10 +98,8 @@ class Discord(commands.Cog): day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False) - @birthday.command(name="set", aliases=["config"]) # type: ignore[arg-type] - async def birthday_set( - self, ctx: commands.Context, day: str, user: Optional[Union[discord.User, discord.Member]] = None - ): + @birthday.command(name="set", aliases=["config"]) + async def birthday_set(self, ctx: commands.Context, day: str, user: Optional[discord.User] = None): """Set your birthday to `day`. Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`. @@ -126,9 +113,6 @@ class Discord(commands.Cog): if user is None: user = ctx.author - # Please Mypy - user = cast(Union[discord.User, discord.Member], user) - try: default_year = 2001 date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"]) @@ -157,7 +141,7 @@ class Discord(commands.Cog): """ # No label: shortcut to display bookmarks if label is None: - return await self.bookmark_search(ctx, query=None) # type: ignore[arg-type] + return await self.bookmark_search(ctx, query=None) async with self.client.postgres_session as session: result = expect( @@ -167,7 +151,7 @@ class Discord(commands.Cog): ) await ctx.reply(result.jump_url, mention_author=False) - @bookmark.command(name="create", aliases=["new"]) # type: ignore[arg-type] + @bookmark.command(name="create", aliases=["new"]) async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]): """Create a new bookmark for message `message` with label `label`. @@ -198,7 +182,7 @@ class Discord(commands.Cog): # Label isn't allowed return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) - @bookmark.command(name="delete", aliases=["rm"]) # type: ignore[arg-type] + @bookmark.command(name="delete", aliases=["rm"]) async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): """Delete the bookmark with id `bookmark_id`. @@ -223,7 +207,7 @@ class Discord(commands.Cog): return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) - @bookmark.command(name="search", aliases=["list", "ls"]) # type: ignore[arg-type] + @bookmark.command(name="search", aliases=["list", "ls"]) async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): """Search through the list of bookmarks. @@ -252,7 +236,7 @@ class Discord(commands.Cog): modal = CreateBookmark(self.client, message.jump_url) await interaction.response.send_modal(modal) - @commands.hybrid_command(name="events") # type: ignore[arg-type] + @commands.hybrid_command(name="events") @app_commands.rename(event_id="id") @app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.") async def events(self, ctx: commands.Context, event_id: Optional[int] = None): @@ -292,16 +276,16 @@ class Discord(commands.Cog): embed.add_field( name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True ) - - channel = self.client.get_channel(result_event.notification_channel) - if channel is not None and not isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - embed.add_field(name="Channel", value=channel.mention, inline=False) - + embed.add_field( + name="Channel", + value=self.client.get_channel(result_event.notification_channel).mention, + inline=False, + ) embed.description = result_event.description return await ctx.reply(embed=embed, mention_author=False) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) - async def github_group(self, ctx: commands.Context, user: Optional[Union[discord.User, discord.Member]] = None): + async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): """Show a user's GitHub links. If no user is provided, this shows your links instead. @@ -309,9 +293,6 @@ class Discord(commands.Cog): # Default to author user = user or ctx.author - # Please Mypy - user = cast(Union[discord.User, discord.Member], user) - embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") embed.set_author(name=user.display_name, icon_url=get_user_avatar(user)) @@ -343,7 +324,7 @@ class Discord(commands.Cog): return await ctx.reply(embed=embed, mention_author=False) - @github_group.command(name="add", aliases=["create", "insert"]) # type: ignore[arg-type] + @github_group.command(name="add", aliases=["create", "insert"]) async def github_add(self, ctx: commands.Context, link: str): """Add a new link into the database.""" # Remove wrapping which can be used to escape Discord embeds @@ -358,7 +339,7 @@ class Discord(commands.Cog): await self.client.confirm_message(ctx.message) return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False) - @github_group.command(name="delete", aliases=["del", "remove", "rm"]) # type: ignore[arg-type] + @github_group.command(name="delete", aliases=["del", "remove", "rm"]) async def github_delete(self, ctx: commands.Context, link_id: str): """Delete the link with it `link_id` from the database. @@ -430,7 +411,7 @@ class Discord(commands.Cog): await message.add_reaction("📌") return await interaction.response.send_message("📌", ephemeral=True) - @commands.hybrid_command(name="snipe") # type: ignore[arg-type] + @commands.hybrid_command(name="snipe") async def snipe(self, ctx: commands.Context): """Publicly shame people when they edit or delete one of their messages. @@ -439,7 +420,7 @@ class Discord(commands.Cog): if ctx.guild is None: return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True) - sniped_data = self.client.sniped.get(ctx.channel.id) + sniped_data = self.client.sniped.get(ctx.channel.id, None) if sniped_data is None: return await ctx.reply( "There's no one to make fun of in this channel.", mention_author=False, ephemeral=True diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 4ccfb2a..e824ab2 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -28,7 +28,7 @@ class Fun(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="clap") # type: ignore[arg-type] + @commands.hybrid_command(name="clap") async def clap(self, ctx: commands.Context, *, text: str): """Clap a message with emojis for extra dramatic effect""" chars = list(filter(lambda c: c in constants.EMOJI_MAP, text)) @@ -50,7 +50,10 @@ class Fun(commands.Cog): meme = await generate_meme(self.client.http_session, result, fields) return meme - @commands.hybrid_command(name="dadjoke", aliases=["dad", "dj"]) # type: ignore[arg-type] + @commands.hybrid_command( + name="dadjoke", + aliases=["dad", "dj"], + ) async def dad_joke(self, ctx: commands.Context): """Why does Yoda's code always crash? Because there is no try.""" async with self.client.postgres_session as session: @@ -80,13 +83,13 @@ class Fun(commands.Cog): return await self.memegen_ls_msg(ctx) if fields is None: - return await self.memegen_preview_msg(ctx, template) # type: ignore[arg-type] + return await self.memegen_preview_msg(ctx, template) async with ctx.typing(): meme = await self._do_generate_meme(template, shlex.split(fields)) return await ctx.reply(meme, mention_author=False) - @memegen_msg.command(name="list", aliases=["ls"]) # type: ignore[arg-type] + @memegen_msg.command(name="list", aliases=["ls"]) async def memegen_ls_msg(self, ctx: commands.Context): """Get a list of all available meme templates. @@ -97,14 +100,14 @@ class Fun(commands.Cog): await MemeSource(ctx, results).start() - @memegen_msg.command(name="preview", aliases=["p"]) # type: ignore[arg-type] + @memegen_msg.command(name="preview", aliases=["p"]) async def memegen_preview_msg(self, ctx: commands.Context, template: str): """Generate a preview for the meme template `template`, to see how the fields are structured.""" async with ctx.typing(): meme = await self._do_generate_meme(template, []) return await ctx.reply(meme, mention_author=False) - @memes_slash.command(name="generate") # type: ignore[arg-type] + @memes_slash.command(name="generate") async def memegen_slash(self, interaction: discord.Interaction, template: str): """Generate a meme.""" async with self.client.postgres_session as session: @@ -113,7 +116,7 @@ class Fun(commands.Cog): modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) - @memes_slash.command(name="preview") # type: ignore[arg-type] + @memes_slash.command(name="preview") @app_commands.describe(template="The meme template to use in the preview.") async def memegen_preview_slash(self, interaction: discord.Interaction, template: str): """Generate a preview for a meme, to see how the fields are structured.""" @@ -131,7 +134,7 @@ class Fun(commands.Cog): """Autocompletion for the 'template'-parameter""" return self.client.database_caches.memes.get_autocomplete_suggestions(current) - @app_commands.command() # type: ignore[arg-type] + @app_commands.command() @app_commands.describe(message="The text to convert.") async def mock(self, interaction: discord.Interaction, message: str): """Mock a message. @@ -155,7 +158,7 @@ class Fun(commands.Cog): return await interaction.followup.send(mock(message)) - @commands.hybrid_command(name="xkcd") # type: ignore[arg-type] + @commands.hybrid_command(name="xkcd") @app_commands.rename(comic_id="id") async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None): """Fetch comic `#id` from xkcd. diff --git a/didier/cogs/help.py b/didier/cogs/help.py index 57cad6d..459f802 100644 --- a/didier/cogs/help.py +++ b/didier/cogs/help.py @@ -159,9 +159,6 @@ class CustomHelpCommand(commands.MinimalHelpCommand): Code in codeblocks is ignored, as it is used to create examples. """ description = command.help - if description is None: - return "" - codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL) # Regex borrowed from https://stackoverflow.com/a/59843498/13568999 @@ -201,10 +198,13 @@ class CustomHelpCommand(commands.MinimalHelpCommand): return None - async def _filter_cogs(self, cogs: list[Optional[commands.Cog]]) -> list[commands.Cog]: + async def _filter_cogs(self, cogs: list[commands.Cog]) -> list[commands.Cog]: """Filter the list of cogs down to all those that the user can see""" - async def _predicate(cog: commands.Cog) -> bool: + async def _predicate(cog: Optional[commands.Cog]) -> bool: + if cog is None: + return False + # Remove cogs that we never want to see in the help page because they # don't contain commands, or shouldn't be visible at all if not cog.get_commands(): @@ -220,12 +220,12 @@ class CustomHelpCommand(commands.MinimalHelpCommand): return True # Filter list of cogs down - filtered_cogs = [cog for cog in cogs if cog is not None and await _predicate(cog)] + filtered_cogs = [cog for cog in cogs if await _predicate(cog)] return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name)) def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]: """Check if a command has flags""" - flag_param = command.params.get("flags") + flag_param = command.params.get("flags", None) if flag_param is None: return None diff --git a/didier/cogs/meta.py b/didier/cogs/meta.py index 861bf58..c330dbd 100644 --- a/didier/cogs/meta.py +++ b/didier/cogs/meta.py @@ -1,6 +1,6 @@ import inspect import os -from typing import Any, Optional, Union +from typing import Optional from discord.ext import commands @@ -76,24 +76,18 @@ class Meta(commands.Cog): if command_name is None: return await ctx.reply(repo_home, mention_author=False) - command: Optional[Union[commands.HelpCommand, commands.Command]] - src: Any - if command_name == "help": command = self.client.help_command - if command is None: - return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) - src = type(self.client.help_command) filename = inspect.getsourcefile(src) else: command = self.client.get_command(command_name) - if command is None: - return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) - src = command.callback.__code__ filename = src.co_filename + if command is None: + return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) + lines, first_line = inspect.getsourcelines(src) if filename is None: diff --git a/didier/cogs/other.py b/didier/cogs/other.py index a48cb5e..02c0095 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -22,7 +22,7 @@ class Other(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) # type: ignore[arg-type] + @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) async def covid(self, ctx: commands.Context, country: str = "Belgium"): """Show Covid-19 info for a specific country. @@ -43,7 +43,7 @@ class Other(commands.Cog): """Autocompletion for the 'country'-parameter""" return autocomplete_country(value)[:25] - @commands.hybrid_command( # type: ignore[arg-type] + @commands.hybrid_command( name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary" ) async def define(self, ctx: commands.Context, *, query: str): @@ -55,7 +55,7 @@ class Other(commands.Cog): mention_author=False, ) - @commands.hybrid_command(name="google", description="Google search") # type: ignore[arg-type] + @commands.hybrid_command(name="google", description="Google search") @app_commands.describe(query="Search query") async def google(self, ctx: commands.Context, *, query: str): """Show the Google search results for `query`. @@ -71,7 +71,7 @@ class Other(commands.Cog): embed = GoogleSearch(results).to_embed() await ctx.reply(embed=embed, mention_author=False) - @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") # type: ignore[arg-type] + @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") async def inspire(self, ctx: commands.Context): """Generate an [InspiroBot](https://inspirobot.me/) quote.""" async with ctx.typing(): @@ -82,7 +82,7 @@ class Other(commands.Cog): async with self.client.postgres_session as session: return await get_link_by_name(session, name.lower()) - @commands.command(name="Link", aliases=["Links"]) # type: ignore[arg-type] + @commands.command(name="Link", aliases=["Links"]) async def link_msg(self, ctx: commands.Context, name: str): """Get the link to the resource named `name`.""" link = await self._get_link(name) @@ -92,7 +92,7 @@ class Other(commands.Cog): target_message = await self.client.get_reply_target(ctx) await target_message.reply(link.url, mention_author=False) - @app_commands.command(name="link") # type: ignore[arg-type] + @app_commands.command(name="link") @app_commands.describe(name="The name of the resource") async def link_slash(self, interaction: discord.Interaction, name: str): """Get the link to something.""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 1f72eff..139f02c 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -42,7 +42,7 @@ class Owner(commands.Cog): def __init__(self, client: Didier): self.client = client - async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override] + async def cog_check(self, ctx: commands.Context) -> bool: """Global check for every command in this cog This means that we don't have to add is_owner() to every single command separately @@ -102,7 +102,7 @@ class Owner(commands.Cog): async def add_msg(self, ctx: commands.Context): """Command group for [add X] message commands""" - @add_msg.command(name="Alias") # type: ignore[arg-type] + @add_msg.command(name="Alias") async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" async with self.client.postgres_session as session: @@ -116,7 +116,7 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) - @add_msg.command(name="Custom") # type: ignore[arg-type] + @add_msg.command(name="Custom") async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): """Add a new custom command""" async with self.client.postgres_session as session: @@ -127,7 +127,7 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) - @add_msg.command(name="Link") # type: ignore[arg-type] + @add_msg.command(name="Link") async def add_link_msg(self, ctx: commands.Context, name: str, url: str): """Add a new link""" async with self.client.postgres_session as session: @@ -136,7 +136,7 @@ class Owner(commands.Cog): await self.client.confirm_message(ctx.message) - @add_slash.command(name="custom", description="Add a custom command") # type: ignore[arg-type] + @add_slash.command(name="custom", description="Add a custom command") async def add_custom_slash(self, interaction: discord.Interaction): """Slash command to add a custom command""" if not await self.client.is_owner(interaction.user): @@ -145,7 +145,7 @@ class Owner(commands.Cog): modal = CreateCustomCommand(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="dadjoke", description="Add a dad joke") # type: ignore[arg-type] + @add_slash.command(name="dadjoke", description="Add a dad joke") async def add_dad_joke_slash(self, interaction: discord.Interaction): """Slash command to add a dad joke""" if not await self.client.is_owner(interaction.user): @@ -154,7 +154,7 @@ class Owner(commands.Cog): modal = AddDadJoke(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="deadline", description="Add a deadline") # type: ignore[arg-type] + @add_slash.command(name="deadline", description="Add a deadline") @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)") async def add_deadline_slash(self, interaction: discord.Interaction, course: str): """Slash command to add a deadline""" @@ -174,7 +174,7 @@ class Owner(commands.Cog): """Autocompletion for the 'course'-parameter""" return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) - @add_slash.command(name="event", description="Add a new event") # type: ignore[arg-type] + @add_slash.command(name="event", description="Add a new event") async def add_event_slash(self, interaction: discord.Interaction): """Slash command to add new events""" if not await self.client.is_owner(interaction.user): @@ -183,7 +183,7 @@ class Owner(commands.Cog): modal = AddEvent(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="link", description="Add a new link") # type: ignore[arg-type] + @add_slash.command(name="link", description="Add a new link") async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" if not await self.client.is_owner(interaction.user): @@ -192,7 +192,7 @@ class Owner(commands.Cog): modal = AddLink(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="meme", description="Add a new meme") # type: ignore[arg-type] + @add_slash.command(name="meme", description="Add a new meme") async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): """Slash command to add new memes""" await interaction.response.defer(ephemeral=True) @@ -205,11 +205,11 @@ class Owner(commands.Cog): await interaction.followup.send(f"Added meme `{meme.meme_id}`.") await self.client.database_caches.memes.invalidate(session) - @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) # type: ignore[arg-type] + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" - @edit_msg.command(name="Custom") # type: ignore[arg-type] + @edit_msg.command(name="Custom") async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): """Edit an existing custom command""" async with self.client.postgres_session as session: @@ -220,7 +220,7 @@ class Owner(commands.Cog): await ctx.reply(f"No command found matching `{command}`.") return await self.client.reject_message(ctx.message) - @edit_slash.command(name="custom", description="Edit a custom command") # type: ignore[arg-type] + @edit_slash.command(name="custom", description="Edit a custom command") @app_commands.describe(command="The name of the command to edit") async def edit_custom_slash(self, interaction: discord.Interaction, command: str): """Slash command to edit a custom command""" diff --git a/didier/cogs/school.py b/didier/cogs/school.py index cd9366a..7af8a81 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -27,7 +27,7 @@ class School(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="deadlines") # type: ignore[arg-type] + @commands.hybrid_command(name="deadlines") async def deadlines(self, ctx: commands.Context): """Show upcoming deadlines.""" async with ctx.typing(): @@ -40,7 +40,7 @@ class School(commands.Cog): embed = Deadlines(deadlines).to_embed() await ctx.reply(embed=embed, mention_author=False, ephemeral=False) - @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) # type: ignore[arg-type] + @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) @app_commands.rename(day_dt="date") async def les( self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None @@ -72,7 +72,10 @@ class School(commands.Cog): except NotInMainGuildException: return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False) - @commands.hybrid_command(name="menu", aliases=["eten", "food"]) # type: ignore[arg-type] + @commands.hybrid_command( + name="menu", + aliases=["eten", "food"], + ) @app_commands.rename(day_dt="date") async def menu( self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None @@ -93,7 +96,7 @@ class School(commands.Cog): embed = no_menu_found(day_dt) await ctx.reply(embed=embed, mention_author=False) - @commands.hybrid_command( # type: ignore[arg-type] + @commands.hybrid_command( name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"] ) @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)") @@ -121,7 +124,7 @@ class School(commands.Cog): mention_author=False, ) - @commands.hybrid_command(name="ufora") # type: ignore[arg-type] + @commands.hybrid_command(name="ufora") async def ufora(self, ctx: commands.Context, course: str): """Link the Ufora page for a course.""" async with self.client.postgres_session as session: diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index f59d697..07d6508 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,6 +1,4 @@ -import asyncio import datetime -import logging import random import discord @@ -22,12 +20,9 @@ from didier.data.embeds.schedules import ( from didier.data.rss_feeds.free_games import fetch_free_games from didier.data.rss_feeds.ufora import fetch_ufora_announcements from didier.decorators.tasks import timed_task -from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.checks import is_owner from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now -logger = logging.getLogger(__name__) - # datetime.time()-instances for when every task should run DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) @@ -61,7 +56,7 @@ class Tasks(commands.Cog): } @overrides - async def cog_load(self) -> None: + def cog_load(self) -> None: # Only check birthdays if there's a channel to send it to if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None: self.check_birthdays.start() @@ -77,10 +72,9 @@ class Tasks(commands.Cog): # Start other tasks self.reminders.start() - asyncio.create_task(self.get_error_channel()) @overrides - async def cog_unload(self) -> None: + def cog_unload(self) -> None: # Cancel all pending tasks for task in self._tasks.values(): if task.is_running(): @@ -102,7 +96,7 @@ class Tasks(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") # type: ignore[arg-type] + @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") async def force_task(self, ctx: commands.Context, name: str): """Command to force-run a task without waiting for the specified run time""" name = name.lower() @@ -113,53 +107,23 @@ class Tasks(commands.Cog): await task(forced=True) await self.client.confirm_message(ctx.message) - async def get_error_channel(self): - """Get the configured channel from the cache""" - await self.client.wait_until_ready() - - # Configure channel to send errors to - if settings.ERRORS_CHANNEL is not None: - channel = self.client.get_channel(settings.ERRORS_CHANNEL) - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - logger.error(f"Configured error channel (id `{settings.ERRORS_CHANNEL}`) is not messageable.") - else: - self.client.error_channel = channel - elif self.client.owner_id is not None: - self.client.error_channel = self.client.get_user(self.client.owner_id) - @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) @timed_task(enums.TaskType.BIRTHDAYS) async def check_birthdays(self, **kwargs): """Check if it's currently anyone's birthday""" _ = kwargs - # Can't happen (task isn't started if this is None), but Mypy doesn't know - if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is None: - return - now = tz_aware_now().date() async with self.client.postgres_session as session: birthdays = await get_birthdays_on_day(session, now) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) if channel is None: - return await self.client.log_error("Unable to fetch channel for birthday announcements.") - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Birthday announcement channel (id `{settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL}`) is not messageable." - ) + return await self.client.log_error("Unable to find channel for birthday announcements") for birthday in birthdays: user = self.client.get_user(birthday.user_id) - if user is None: - await self.client.log_error( - f"Unable to fetch user with id `{birthday.user_id}` for birthday announcement" - ) - continue - await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention)) @check_birthdays.before_loop @@ -179,14 +143,6 @@ class Tasks(commands.Cog): games = await fetch_free_games(self.client.http_session, session) channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL) - if channel is None: - return await self.client.log_error("Unable to fetch channel for free games announcements.") - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Free games channel (id `{settings.FREE_GAMES_CHANNEL}`) is not messageable." - ) - for game in games: await channel.send(embed=game.to_embed()) @@ -248,17 +204,6 @@ class Tasks(commands.Cog): async with self.client.postgres_session as db_session: announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) - - if announcements_channel is None: - return await self.client.log_error( - f"Unable to fetch channel for ufora announcements (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`)." - ) - - if isinstance(announcements_channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Ufora announcements channel (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`) is not messageable." - ) - announcements = await fetch_ufora_announcements(self.client.http_session, db_session) for announcement in announcements: diff --git a/didier/data/embeds/error_embed.py b/didier/data/embeds/error_embed.py index 696dd80..ea03bfe 100644 --- a/didier/data/embeds/error_embed.py +++ b/didier/data/embeds/error_embed.py @@ -38,10 +38,10 @@ def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> embed = discord.Embed(title="Error", colour=discord.Colour.red()) if ctx is not None: - if ctx.guild is None or isinstance(ctx.channel, discord.DMChannel): + if ctx.guild is None: origin = "DM" else: - origin = f"<#{ctx.channel.id}> ({ctx.guild.name})" + origin = f"{ctx.channel.mention} ({ctx.guild.name})" invocation = f"{ctx.author.display_name} in {origin}" diff --git a/didier/data/embeds/logging_embed.py b/didier/data/embeds/logging_embed.py index 3a803a1..40556f2 100644 --- a/didier/data/embeds/logging_embed.py +++ b/didier/data/embeds/logging_embed.py @@ -11,7 +11,7 @@ __all__ = ["create_logging_embed"] def create_logging_embed(level: int, message: str) -> discord.Embed: """Create an embed to send to the logging channel""" colours = { - logging.DEBUG: discord.Colour.light_grey(), + logging.DEBUG: discord.Colour.light_gray, logging.ERROR: discord.Colour.red(), logging.INFO: discord.Colour.blue(), logging.WARNING: discord.Colour.yellow(), diff --git a/didier/data/scrapers/google.py b/didier/data/scrapers/google.py index 9c10716..389e9ae 100644 --- a/didier/data/scrapers/google.py +++ b/didier/data/scrapers/google.py @@ -72,12 +72,12 @@ def get_search_results(bs: BeautifulSoup) -> list[str]: return list(dict.fromkeys(results)) -async def google_search(http_session: ClientSession, query: str): +async def google_search(http_client: ClientSession, query: str): """Get the first 10 Google search results""" query = urlencode({"q": query}) # Request 20 results in case of duplicates, bad matches, ... - async with http_session.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: + async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: # Something went wrong if response.status != http.HTTPStatus.OK: return SearchData(query, response.status) diff --git a/didier/didier.py b/didier/didier.py index 33e6e4b..cf9ed1d 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -17,7 +17,7 @@ from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.logging_embed import create_logging_embed from didier.data.embeds.schedules import Schedule, parse_schedule -from didier.exceptions import GetNoneException, HTTPException, NoMatch +from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix from didier.utils.discord.snipe import should_snipe from didier.utils.easter_eggs import detect_easter_egg @@ -33,7 +33,7 @@ class Didier(commands.Bot): """DIDIER <3""" database_caches: CacheManager - error_channel: Optional[discord.abc.Messageable] = None + error_channel: discord.abc.Messageable initial_extensions: tuple[str, ...] = () http_session: ClientSession schedules: dict[settings.ScheduleType, Schedule] = {} @@ -56,17 +56,12 @@ class Didier(commands.Bot): command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status ) - # I'm not creating a custom tree, this is the way to do it - self.tree.on_error = self.on_app_command_error # type: ignore[method-assign] + self.tree.on_error = self.on_app_command_error @cached_property def main_guild(self) -> discord.Guild: """Obtain a reference to the main guild""" - guild = self.get_guild(settings.DISCORD_MAIN_GUILD) - if guild is None: - raise GetNoneException("Main guild could not be found in the bot's cache") - - return guild + return self.get_guild(settings.DISCORD_MAIN_GUILD) @property def postgres_session(self) -> AsyncSession: @@ -98,6 +93,12 @@ class Didier(commands.Bot): await self._load_initial_extensions() await self._load_directory_extensions("didier/cogs") + # Configure channel to send errors to + if settings.ERRORS_CHANNEL is not None: + self.error_channel = self.get_channel(settings.ERRORS_CHANNEL) + else: + self.error_channel = self.get_user(self.owner_id) + def _create_ignored_directories(self): """Create directories that store ignored data""" ignored = ["files/schedules"] @@ -151,27 +152,18 @@ class Didier(commands.Bot): original message instead """ if ctx.message.reference is not None: - return await self.resolve_message(ctx.message.reference) or ctx.message + return await self.resolve_message(ctx.message.reference) return ctx.message - async def resolve_message(self, reference: discord.MessageReference) -> Optional[discord.Message]: + async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: """Fetch a message from a reference""" # Message is in the cache, return it if reference.cached_message is not None: return reference.cached_message - if reference.message_id is None: - return None - # For older messages: fetch them from the API channel = self.get_channel(reference.channel_id) - if channel is None or isinstance( - channel, - (discord.CategoryChannel, discord.ForumChannel, discord.abc.PrivateChannel), - ): # Logically this can't happen, but we have to please Mypy - return None - return await channel.fetch_message(reference.message_id) async def confirm_message(self, message: discord.Message): @@ -192,7 +184,7 @@ class Didier(commands.Bot): } methods.get(level, logger.error)(message) - if log_to_discord and self.error_channel is not None: + if log_to_discord: embed = create_logging_embed(level, message) await self.error_channel.send(embed=embed) @@ -261,9 +253,10 @@ class Didier(commands.Bot): await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True) - if self.error_channel is not None: + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(await commands.Context.from_interaction(interaction), exception) - await self.error_channel.send(embed=embed) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) async def on_command_completion(self, ctx: commands.Context): """Event triggered when a message command completes successfully""" @@ -288,7 +281,7 @@ class Didier(commands.Bot): # Hybrid command errors are wrapped in an additional error, so wrap it back out if isinstance(exception, commands.HybridCommandError): - exception = exception.original # type: ignore[assignment] + exception = exception.original # Ignore exceptions that aren't important if isinstance( @@ -339,9 +332,10 @@ class Didier(commands.Bot): # Print everything that we care about to the logs/stderr await super().on_command_error(ctx, exception) - if self.error_channel is not None: + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(ctx, exception) - await self.error_channel.send(embed=embed) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) async def on_message(self, message: discord.Message, /) -> None: """Event triggered when a message is sent""" @@ -350,7 +344,7 @@ class Didier(commands.Bot): return # Boos react to people that say Dider - if "dider" in message.content.lower() and self.user is not None and message.author.id != self.user.id: + if "dider" in message.content.lower() and message.author.id != self.user.id: await message.add_reaction(settings.DISCORD_BOOS_REACT) # Potential custom command @@ -380,7 +374,7 @@ class Didier(commands.Bot): # If the edited message is currently present in the snipe cache, # don't update the , but instead change the - existing = self.sniped.get(before.channel.id) + existing = self.sniped.get(before.channel.id, None) if existing is not None and existing[0].id == before.id: before = existing[0] @@ -395,9 +389,10 @@ class Didier(commands.Bot): async def on_task_error(self, exception: Exception): """Event triggered when a task raises an exception""" - if self.error_channel: + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(None, exception) - await self.error_channel.send(embed=embed) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) async def on_thread_create(self, thread: discord.Thread): """Event triggered when a new thread is created""" diff --git a/didier/exceptions/__init__.py b/didier/exceptions/__init__.py index fa5ad13..1335dd4 100644 --- a/didier/exceptions/__init__.py +++ b/didier/exceptions/__init__.py @@ -1,14 +1,6 @@ -from .get_none_exception import GetNoneException from .http_exception import HTTPException from .missing_env import MissingEnvironmentVariable from .no_match import NoMatch, expect from .not_in_main_guild_exception import NotInMainGuildException -__all__ = [ - "GetNoneException", - "HTTPException", - "MissingEnvironmentVariable", - "NoMatch", - "expect", - "NotInMainGuildException", -] +__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect", "NotInMainGuildException"] diff --git a/didier/exceptions/get_none_exception.py b/didier/exceptions/get_none_exception.py deleted file mode 100644 index cbd2f77..0000000 --- a/didier/exceptions/get_none_exception.py +++ /dev/null @@ -1,5 +0,0 @@ -__all__ = ["GetNoneException"] - - -class GetNoneException(RuntimeError): - """Exception raised when a Bot.get()-method returned None""" diff --git a/didier/exceptions/not_in_main_guild_exception.py b/didier/exceptions/not_in_main_guild_exception.py index 5279686..5572c44 100644 --- a/didier/exceptions/not_in_main_guild_exception.py +++ b/didier/exceptions/not_in_main_guild_exception.py @@ -12,6 +12,6 @@ class NotInMainGuildException(ValueError): def __init__(self, user: Union[discord.User, discord.Member]): super().__init__( - f"User {user.display_name} (id `{user.id}`) " - f"is not a member of the configured main guild (id `{settings.DISCORD_MAIN_GUILD}`)." + f"User {user.display_name} (id {user.id}) " + f"is not a member of the configured main guild (id {settings.DISCORD_MAIN_GUILD})." ) diff --git a/didier/utils/discord/channels.py b/didier/utils/discord/channels.py deleted file mode 100644 index 26739f8..0000000 --- a/didier/utils/discord/channels.py +++ /dev/null @@ -1,5 +0,0 @@ -import discord - -__all__ = ["NON_MESSAGEABLE_CHANNEL_TYPES"] - -NON_MESSAGEABLE_CHANNEL_TYPES = (discord.ForumChannel, discord.CategoryChannel, discord.abc.PrivateChannel) diff --git a/didier/utils/discord/prefix.py b/didier/utils/discord/prefix.py index 694a4b6..f3fa7c4 100644 --- a/didier/utils/discord/prefix.py +++ b/didier/utils/discord/prefix.py @@ -15,14 +15,11 @@ def match_prefix(client: commands.Bot, message: Message) -> Optional[str]: This is done dynamically through regexes to allow case-insensitivity and variable amounts of whitespace among other things. """ - mention = f"<@!?{client.user.id}>" if client.user else None + mention = f"<@!?{client.user.id}>" regex = r"^({})\s*" # Check which prefix was used for prefix in [*constants.PREFIXES, mention]: - if prefix is None: - continue - match = re.match(regex.format(prefix), message.content, flags=re.I) if match is not None: diff --git a/didier/views/modals/bookmarks.py b/didier/views/modals/bookmarks.py index acd4c8f..f77b608 100644 --- a/didier/views/modals/bookmarks.py +++ b/didier/views/modals/bookmarks.py @@ -25,24 +25,24 @@ class CreateBookmark(discord.ui.Modal, title="Create Bookmark"): @overrides async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) - label = self.name.value.strip() try: async with self.client.postgres_session as session: bm = await create_bookmark(session, interaction.user.id, label, self.jump_url) - return await interaction.followup.send( - f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`)." + return await interaction.response.send_message( + f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True ) except DuplicateInsertException: # Label is already in use - return await interaction.followup.send(f"You already have a bookmark named `{label}`.") + return await interaction.response.send_message( + f"You already have a bookmark named `{label}`.", ephemeral=True + ) except ForbiddenNameException: # Label isn't allowed - return await interaction.followup.send(f"Bookmarks cannot be named `{label}`.") + return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.followup.send("Something went wrong.", ephemeral=True) + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/views/modals/dad_jokes.py b/didier/views/modals/dad_jokes.py index 5ebfab7..c3b2f67 100644 --- a/didier/views/modals/dad_jokes.py +++ b/didier/views/modals/dad_jokes.py @@ -26,14 +26,12 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"): @overrides async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) - async with self.client.postgres_session as session: joke = await add_dad_joke(session, str(self.joke.value)) - await interaction.followup.send(f"Successfully added joke #{joke.dad_joke_id}") + await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.followup.send("Something went wrong.", ephemeral=True) + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/views/modals/events.py b/didier/views/modals/events.py index 71acea6..e7b92b4 100644 --- a/didier/views/modals/events.py +++ b/didier/views/modals/events.py @@ -10,8 +10,6 @@ from didier import Didier __all__ = ["AddEvent"] -from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES - class AddEvent(discord.ui.Modal, title="Add Event"): """Modal to add a new event""" @@ -35,20 +33,15 @@ class AddEvent(discord.ui.Modal, title="Add Event"): @overrides async def on_submit(self, interaction: discord.Interaction) -> None: - await interaction.response.defer(ephemeral=True) - try: parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) except ParserError: - return await interaction.followup.send("Unable to parse date argument.") + return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True) - channel = self.client.get_channel(int(self.channel.value)) - - if channel is None: - return await interaction.followup.send(f"Unable to find channel with id `{self.channel.value}`") - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await interaction.followup.send(f"Channel with id `{self.channel.value}` is not messageable.") + if self.client.get_channel(int(self.channel.value)) is None: + return await interaction.response.send_message( + f"Unable to find channel `{self.channel.value}`", ephemeral=True + ) async with self.client.postgres_session as session: event = await add_event( @@ -59,10 +52,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"): channel_id=int(self.channel.value), ) - await interaction.followup.send(f"Successfully added event `{event.event_id}`.") + await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) self.client.dispatch("event_create", event) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.followup.send("Something went wrong.", ephemeral=True) + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/pyproject.toml b/pyproject.toml index d28abe3..acd06c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ omit = [ profile = "black" [tool.mypy] -check_untyped_defs = true files = [ "database/**/*.py", "didier/**/*.py", @@ -36,6 +35,7 @@ files = [ ] plugins = [ "pydantic.mypy", + "sqlalchemy.ext.mypy.plugin" ] [[tool.mypy.overrides]] module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"] diff --git a/requirements.txt b/requirements.txt index 4b0ffa3..a7b6db2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ markdownify==0.11.6 overrides==7.3.1 pydantic==2.0.2 python-dateutil==2.8.2 -sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18 +sqlalchemy[asyncio]==2.0.18 diff --git a/settings.py b/settings.py index a862fde..32bd5e0 100644 --- a/settings.py +++ b/settings.py @@ -111,7 +111,7 @@ class ScheduleInfo: role_id: Optional[int] schedule_url: Optional[str] - name: ScheduleType + name: Optional[str] = None SCHEDULE_DATA = [