mirror of https://github.com/stijndcl/didier
				
				
				
			Fix bug, add autocomplete, make cache autocompletion slightly cleaner
							parent
							
								
									e2959c27ad
								
							
						
					
					
						commit
						b26421b875
					
				| 
						 | 
					@ -13,7 +13,7 @@ __all__ = ["add_deadline", "get_deadlines"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def add_deadline(session: AsyncSession, course_id: int, name: str, date_str: str):
 | 
					async def add_deadline(session: AsyncSession, course_id: int, name: str, date_str: str):
 | 
				
			||||||
    """Add a new deadline"""
 | 
					    """Add a new deadline"""
 | 
				
			||||||
    date_dt = parse(date_str).replace(tzinfo=ZoneInfo("Europe/Brussels"))
 | 
					    date_dt = parse(date_str, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # If we only have a day, assume it's the end of the day
 | 
					    # If we only have a day, assume it's the end of the day
 | 
				
			||||||
    if date_dt.hour == date_dt.minute == date_dt.second == 0:
 | 
					    if date_dt.hour == date_dt.minute == date_dt.second == 0:
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,6 +1,7 @@
 | 
				
			||||||
from abc import ABC, abstractmethod
 | 
					from abc import ABC, abstractmethod
 | 
				
			||||||
from typing import Generic, TypeVar
 | 
					from typing import Generic, TypeVar
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from discord import app_commands
 | 
				
			||||||
from overrides import overrides
 | 
					from overrides import overrides
 | 
				
			||||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
					from sqlalchemy.ext.asyncio import AsyncSession
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -38,11 +39,13 @@ class DatabaseCache(ABC, Generic[T]):
 | 
				
			||||||
    async def invalidate(self, database_session: T):
 | 
					    async def invalidate(self, database_session: T):
 | 
				
			||||||
        """Invalidate the data stored in this cache"""
 | 
					        """Invalidate the data stored in this cache"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_autocomplete_suggestions(self, query: str):
 | 
					    def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
 | 
				
			||||||
        """Filter the cache to find everything that matches the search query"""
 | 
					        """Filter the cache to find everything that matches the search query"""
 | 
				
			||||||
        query = query.lower()
 | 
					        query = query.lower()
 | 
				
			||||||
        # Return the original (non-transformed) version of the data for pretty display in Discord
 | 
					        # Return the original (non-transformed) version of the data for pretty display in Discord
 | 
				
			||||||
        return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
 | 
					        suggestions = [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class LinkCache(DatabaseCache[AsyncSession]):
 | 
					class LinkCache(DatabaseCache[AsyncSession]):
 | 
				
			||||||
| 
						 | 
					@ -86,7 +89,7 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
 | 
				
			||||||
        self.data_transformed = list(map(str.lower, self.data))
 | 
					        self.data_transformed = list(map(str.lower, self.data))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @overrides
 | 
					    @overrides
 | 
				
			||||||
    def get_autocomplete_suggestions(self, query: str):
 | 
					    def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
 | 
				
			||||||
        query = query.lower()
 | 
					        query = query.lower()
 | 
				
			||||||
        results = set()
 | 
					        results = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -99,7 +102,8 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
 | 
				
			||||||
            if query in alias:
 | 
					            if query in alias:
 | 
				
			||||||
                results.add(course)
 | 
					                results.add(course)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return sorted(list(results))
 | 
					        suggestions = sorted(list(results))
 | 
				
			||||||
 | 
					        return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class WordleCache(DatabaseCache[MongoDatabase]):
 | 
					class WordleCache(DatabaseCache[MongoDatabase]):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -64,12 +64,9 @@ class Other(commands.Cog):
 | 
				
			||||||
        return await interaction.response.send_message(link.url)
 | 
					        return await interaction.response.send_message(link.url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @link_slash.autocomplete("name")
 | 
					    @link_slash.autocomplete("name")
 | 
				
			||||||
    async def _link_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
 | 
					    async def _link_name_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
 | 
				
			||||||
        """Autocompletion for the 'name'-parameter"""
 | 
					        """Autocompletion for the 'name'-parameter"""
 | 
				
			||||||
        return [
 | 
					        return self.client.database_caches.links.get_autocomplete_suggestions(current)
 | 
				
			||||||
            app_commands.Choice(name=name, value=name.lower())
 | 
					 | 
				
			||||||
            for name in self.client.database_caches.links.get_autocomplete_suggestions(current)
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def setup(client: Didier):
 | 
					async def setup(client: Didier):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -139,6 +139,7 @@ class Owner(commands.Cog):
 | 
				
			||||||
        await interaction.response.send_modal(modal)
 | 
					        await interaction.response.send_modal(modal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @add_slash.command(name="deadline", description="Add a deadline")
 | 
					    @add_slash.command(name="deadline", description="Add a deadline")
 | 
				
			||||||
 | 
					    @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)")
 | 
				
			||||||
    async def add_deadline_slash(self, interaction: discord.Interaction, course: str):
 | 
					    async def add_deadline_slash(self, interaction: discord.Interaction, course: str):
 | 
				
			||||||
        """Slash command to add a deadline"""
 | 
					        """Slash command to add a deadline"""
 | 
				
			||||||
        async with self.client.postgres_session as session:
 | 
					        async with self.client.postgres_session as session:
 | 
				
			||||||
| 
						 | 
					@ -150,6 +151,13 @@ class Owner(commands.Cog):
 | 
				
			||||||
        modal = AddDeadline(self.client, course_instance)
 | 
					        modal = AddDeadline(self.client, course_instance)
 | 
				
			||||||
        await interaction.response.send_modal(modal)
 | 
					        await interaction.response.send_modal(modal)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @add_deadline_slash.autocomplete("course")
 | 
				
			||||||
 | 
					    async def _add_deadline_course_autocomplete(
 | 
				
			||||||
 | 
					        self, _: discord.Interaction, current: str
 | 
				
			||||||
 | 
					    ) -> list[app_commands.Choice[str]]:
 | 
				
			||||||
 | 
					        """Autocompletion for the 'course'-parameter"""
 | 
				
			||||||
 | 
					        return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @add_slash.command(name="link", description="Add a new link")
 | 
					    @add_slash.command(name="link", description="Add a new link")
 | 
				
			||||||
    async def add_link_slash(self, interaction: discord.Interaction):
 | 
					    async def add_link_slash(self, interaction: discord.Interaction):
 | 
				
			||||||
        """Slash command to add new links"""
 | 
					        """Slash command to add new links"""
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,7 +35,7 @@ class School(commands.Cog):
 | 
				
			||||||
        async with self.client.postgres_session as session:
 | 
					        async with self.client.postgres_session as session:
 | 
				
			||||||
            deadlines = await get_deadlines(session)
 | 
					            deadlines = await get_deadlines(session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        embed = await Deadlines(deadlines).to_embed()
 | 
					        embed = Deadlines(deadlines).to_embed()
 | 
				
			||||||
        await ctx.reply(embed=embed, mention_author=False, ephemeral=False)
 | 
					        await ctx.reply(embed=embed, mention_author=False, ephemeral=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @commands.command(name="Pin", usage="[Message]")
 | 
					    @commands.command(name="Pin", usage="[Message]")
 | 
				
			||||||
| 
						 | 
					@ -76,7 +76,7 @@ class School(commands.Cog):
 | 
				
			||||||
    @commands.hybrid_command(
 | 
					    @commands.hybrid_command(
 | 
				
			||||||
        name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"]
 | 
					        name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"]
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    @app_commands.describe(course="vak")
 | 
					    @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)")
 | 
				
			||||||
    async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
 | 
					    async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
 | 
				
			||||||
        """Create links to study guides"""
 | 
					        """Create links to study guides"""
 | 
				
			||||||
        async with self.client.postgres_session as session:
 | 
					        async with self.client.postgres_session as session:
 | 
				
			||||||
| 
						 | 
					@ -91,12 +91,11 @@ class School(commands.Cog):
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @study_guide.autocomplete("course")
 | 
					    @study_guide.autocomplete("course")
 | 
				
			||||||
    async def _study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
 | 
					    async def _study_guide_course_autocomplete(
 | 
				
			||||||
 | 
					        self, _: discord.Interaction, current: str
 | 
				
			||||||
 | 
					    ) -> list[app_commands.Choice[str]]:
 | 
				
			||||||
        """Autocompletion for the 'course'-parameter"""
 | 
					        """Autocompletion for the 'course'-parameter"""
 | 
				
			||||||
        return [
 | 
					        return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
 | 
				
			||||||
            app_commands.Choice(name=course, value=course)
 | 
					 | 
				
			||||||
            for course in self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
async def setup(client: Didier):
 | 
					async def setup(client: Didier):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -22,7 +22,7 @@ class Deadlines(EmbedBaseModel):
 | 
				
			||||||
        self.deadlines.sort(key=lambda deadline: deadline.deadline)
 | 
					        self.deadlines.sort(key=lambda deadline: deadline.deadline)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @overrides
 | 
					    @overrides
 | 
				
			||||||
    async def to_embed(self, **kwargs: dict) -> discord.Embed:
 | 
					    def to_embed(self, **kwargs: dict) -> discord.Embed:
 | 
				
			||||||
        embed = discord.Embed(colour=discord.Colour.dark_gold())
 | 
					        embed = discord.Embed(colour=discord.Colour.dark_gold())
 | 
				
			||||||
        embed.set_author(name="Upcoming Deadlines")
 | 
					        embed.set_author(name="Upcoming Deadlines")
 | 
				
			||||||
        now = tz_aware_now()
 | 
					        now = tz_aware_now()
 | 
				
			||||||
| 
						 | 
					@ -30,33 +30,33 @@ class Deadlines(EmbedBaseModel):
 | 
				
			||||||
        has_active_deadlines = False
 | 
					        has_active_deadlines = False
 | 
				
			||||||
        deadlines_grouped: dict[int, list[str]] = {}
 | 
					        deadlines_grouped: dict[int, list[str]] = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        deadline: Deadline
 | 
					        for year, deadlines in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year):
 | 
				
			||||||
        for year, deadline in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year):
 | 
					 | 
				
			||||||
            if year not in deadlines_grouped:
 | 
					            if year not in deadlines_grouped:
 | 
				
			||||||
                deadlines_grouped[year] = []
 | 
					                deadlines_grouped[year] = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            passed = deadline.deadline <= now
 | 
					            for deadline in deadlines:
 | 
				
			||||||
            if passed:
 | 
					                passed = deadline.deadline <= now
 | 
				
			||||||
                has_active_deadlines = True
 | 
					                if not passed:
 | 
				
			||||||
 | 
					                    has_active_deadlines = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            deadline_str = (
 | 
					                deadline_str = (
 | 
				
			||||||
                f"{deadline.course.name} - {deadline.name}: <t:{round(datetime.timestamp(deadline.deadline))}:R>"
 | 
					                    f"{deadline.course.name} - {deadline.name}: <t:{round(datetime.timestamp(deadline.deadline))}:R>"
 | 
				
			||||||
            )
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            # Strike through deadlines that aren't active anymore
 | 
					                # Strike through deadlines that aren't active anymore
 | 
				
			||||||
            deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~")
 | 
					                deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Send an easter egg when there are no deadlines
 | 
					        # Send an Easter egg when there are no deadlines
 | 
				
			||||||
        if not has_active_deadlines:
 | 
					        if not has_active_deadlines:
 | 
				
			||||||
            embed.description = "There are currently no upcoming deadlines."
 | 
					            embed.description = "There are currently no upcoming deadlines."
 | 
				
			||||||
            embed.set_image(url="https://c.tenor.com/RUzJ3lDGQUsAAAAC/iron-man-you-can-rest-now.gif")
 | 
					            embed.set_image(url="https://c.tenor.com/RUzJ3lDGQUsAAAAC/iron-man-you-can-rest-now.gif")
 | 
				
			||||||
            return embed
 | 
					            return embed
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for i in range(5):
 | 
					        for i in range(1, 6):
 | 
				
			||||||
            if i not in deadlines_grouped:
 | 
					            if i not in deadlines_grouped:
 | 
				
			||||||
                continue
 | 
					                continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            name = get_edu_year_name(i)
 | 
					            name = get_edu_year_name(i - 1)
 | 
				
			||||||
            description = "\n".join(deadlines_grouped[i])
 | 
					            description = "\n".join(deadlines_grouped[i])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            embed.add_field(name=name, value=description, inline=False)
 | 
					            embed.add_field(name=name, value=description, inline=False)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -32,6 +32,9 @@ class AddDeadline(discord.ui.Modal, title="Add Deadline"):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @overrides
 | 
					    @overrides
 | 
				
			||||||
    async def on_submit(self, interaction: Interaction):
 | 
					    async def on_submit(self, interaction: Interaction):
 | 
				
			||||||
 | 
					        if not self.name.value or not self.deadline.value:
 | 
				
			||||||
 | 
					            return await interaction.response.send_message("Required fields cannot be empty.", ephemeral=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        async with self.client.postgres_session as session:
 | 
					        async with self.client.postgres_session as session:
 | 
				
			||||||
            await add_deadline(session, self.ufora_course.course_id, self.name.value, self.deadline.value)
 | 
					            await add_deadline(session, self.ufora_course.course_id, self.name.value, self.deadline.value)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -8,6 +8,7 @@ pytest-asyncio==0.18.3
 | 
				
			||||||
pytest-env==0.6.2
 | 
					pytest-env==0.6.2
 | 
				
			||||||
sqlalchemy2-stubs==0.0.2a23
 | 
					sqlalchemy2-stubs==0.0.2a23
 | 
				
			||||||
types-beautifulsoup4==4.11.3
 | 
					types-beautifulsoup4==4.11.3
 | 
				
			||||||
 | 
					types-python-dateutil==2.8.19
 | 
				
			||||||
types-pytz==2021.3.8
 | 
					types-pytz==2021.3.8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# Flake8 + plugins
 | 
					# Flake8 + plugins
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue