From b26421b8752c261ff16cc4423651cd1d241e051f Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 13 Aug 2022 01:10:50 +0200 Subject: [PATCH] Fix bug, add autocomplete, make cache autocompletion slightly cleaner --- database/crud/deadlines.py | 2 +- database/utils/caches.py | 12 ++++++++---- didier/cogs/other.py | 7 ++----- didier/cogs/owner.py | 8 ++++++++ didier/cogs/school.py | 13 ++++++------- didier/data/embeds/deadlines.py | 28 ++++++++++++++-------------- didier/views/modals/deadlines.py | 3 +++ requirements-dev.txt | 1 + 8 files changed, 43 insertions(+), 31 deletions(-) diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index 8a7ad66..c1b2885 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -13,7 +13,7 @@ __all__ = ["add_deadline", "get_deadlines"] async def add_deadline(session: AsyncSession, course_id: int, name: str, date_str: str): """Add a new deadline""" - date_dt = parse(date_str).replace(tzinfo=ZoneInfo("Europe/Brussels")) + date_dt = parse(date_str, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) # If we only have a day, assume it's the end of the day if date_dt.hour == date_dt.minute == date_dt.second == 0: diff --git a/database/utils/caches.py b/database/utils/caches.py index b08417d..4f34419 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar +from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession @@ -38,11 +39,13 @@ class DatabaseCache(ABC, Generic[T]): async def invalidate(self, database_session: T): """Invalidate the data stored in this cache""" - def get_autocomplete_suggestions(self, query: str): + def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: """Filter the cache to find everything that matches the search query""" query = query.lower() # Return the original (non-transformed) version of the data for pretty display in Discord - return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] + suggestions = [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] + + return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] class LinkCache(DatabaseCache[AsyncSession]): @@ -86,7 +89,7 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) @overrides - def get_autocomplete_suggestions(self, query: str): + def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: query = query.lower() results = set() @@ -99,7 +102,8 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): if query in alias: results.add(course) - return sorted(list(results)) + suggestions = sorted(list(results)) + return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] class WordleCache(DatabaseCache[MongoDatabase]): diff --git a/didier/cogs/other.py b/didier/cogs/other.py index e00ab0b..870c63b 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -64,12 +64,9 @@ class Other(commands.Cog): return await interaction.response.send_message(link.url) @link_slash.autocomplete("name") - async def _link_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def _link_name_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocompletion for the 'name'-parameter""" - return [ - app_commands.Choice(name=name, value=name.lower()) - for name in self.client.database_caches.links.get_autocomplete_suggestions(current) - ] + return self.client.database_caches.links.get_autocomplete_suggestions(current) async def setup(client: Didier): diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 69560ca..d030344 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -139,6 +139,7 @@ class Owner(commands.Cog): await interaction.response.send_modal(modal) @add_slash.command(name="deadline", description="Add a deadline") + @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)") async def add_deadline_slash(self, interaction: discord.Interaction, course: str): """Slash command to add a deadline""" async with self.client.postgres_session as session: @@ -150,6 +151,13 @@ class Owner(commands.Cog): modal = AddDeadline(self.client, course_instance) await interaction.response.send_modal(modal) + @add_deadline_slash.autocomplete("course") + async def _add_deadline_course_autocomplete( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'course'-parameter""" + return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) + @add_slash.command(name="link", description="Add a new link") async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 3ccf412..dc3807f 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -35,7 +35,7 @@ class School(commands.Cog): async with self.client.postgres_session as session: deadlines = await get_deadlines(session) - embed = await Deadlines(deadlines).to_embed() + embed = Deadlines(deadlines).to_embed() await ctx.reply(embed=embed, mention_author=False, ephemeral=False) @commands.command(name="Pin", usage="[Message]") @@ -76,7 +76,7 @@ class School(commands.Cog): @commands.hybrid_command( name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"] ) - @app_commands.describe(course="vak") + @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)") async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): """Create links to study guides""" async with self.client.postgres_session as session: @@ -91,12 +91,11 @@ class School(commands.Cog): ) @study_guide.autocomplete("course") - async def _study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def _study_guide_course_autocomplete( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: """Autocompletion for the 'course'-parameter""" - return [ - app_commands.Choice(name=course, value=course) - for course in self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) - ] + return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) async def setup(client: Didier): diff --git a/didier/data/embeds/deadlines.py b/didier/data/embeds/deadlines.py index 244d811..6f46bbd 100644 --- a/didier/data/embeds/deadlines.py +++ b/didier/data/embeds/deadlines.py @@ -22,7 +22,7 @@ class Deadlines(EmbedBaseModel): self.deadlines.sort(key=lambda deadline: deadline.deadline) @overrides - async def to_embed(self, **kwargs: dict) -> discord.Embed: + def to_embed(self, **kwargs: dict) -> discord.Embed: embed = discord.Embed(colour=discord.Colour.dark_gold()) embed.set_author(name="Upcoming Deadlines") now = tz_aware_now() @@ -30,33 +30,33 @@ class Deadlines(EmbedBaseModel): has_active_deadlines = False deadlines_grouped: dict[int, list[str]] = {} - deadline: Deadline - for year, deadline in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year): + for year, deadlines in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year): if year not in deadlines_grouped: deadlines_grouped[year] = [] - passed = deadline.deadline <= now - if passed: - has_active_deadlines = True + for deadline in deadlines: + passed = deadline.deadline <= now + if not passed: + has_active_deadlines = True - deadline_str = ( - f"{deadline.course.name} - {deadline.name}: " - ) + deadline_str = ( + f"{deadline.course.name} - {deadline.name}: " + ) - # Strike through deadlines that aren't active anymore - deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~") + # Strike through deadlines that aren't active anymore + deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~") - # Send an easter egg when there are no deadlines + # Send an Easter egg when there are no deadlines if not has_active_deadlines: embed.description = "There are currently no upcoming deadlines." embed.set_image(url="https://c.tenor.com/RUzJ3lDGQUsAAAAC/iron-man-you-can-rest-now.gif") return embed - for i in range(5): + for i in range(1, 6): if i not in deadlines_grouped: continue - name = get_edu_year_name(i) + name = get_edu_year_name(i - 1) description = "\n".join(deadlines_grouped[i]) embed.add_field(name=name, value=description, inline=False) diff --git a/didier/views/modals/deadlines.py b/didier/views/modals/deadlines.py index bd8882b..cd2a26c 100644 --- a/didier/views/modals/deadlines.py +++ b/didier/views/modals/deadlines.py @@ -32,6 +32,9 @@ class AddDeadline(discord.ui.Modal, title="Add Deadline"): @overrides async def on_submit(self, interaction: Interaction): + if not self.name.value or not self.deadline.value: + return await interaction.response.send_message("Required fields cannot be empty.", ephemeral=True) + async with self.client.postgres_session as session: await add_deadline(session, self.ufora_course.course_id, self.name.value, self.deadline.value) diff --git a/requirements-dev.txt b/requirements-dev.txt index 64b8467..9307b33 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ pytest-asyncio==0.18.3 pytest-env==0.6.2 sqlalchemy2-stubs==0.0.2a23 types-beautifulsoup4==4.11.3 +types-python-dateutil==2.8.19 types-pytz==2021.3.8 # Flake8 + plugins