From 9e9be39358b419d6f3066e1a0f4ab11eeb284a9f Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 5 Nov 2022 22:26:48 +0100 Subject: [PATCH] Only show upcoming deadlines for courses the user is actually subscribed to --- database/crud/deadlines.py | 13 ++++++++----- didier/cogs/school.py | 14 +++++++++----- didier/utils/discord/users.py | 18 +++++++++++++++++- 3 files changed, 34 insertions(+), 11 deletions(-) diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index ce7fba5..78d623f 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -1,3 +1,4 @@ +import datetime from typing import Optional from zoneinfo import ZoneInfo @@ -24,13 +25,15 @@ async def add_deadline(session: AsyncSession, course_id: int, name: str, date_st await session.commit() -async def get_deadlines(session: AsyncSession, *, course: Optional[UforaCourse] = None) -> list[Deadline]: - """Get a list of all deadlines that are currently known - - This includes deadlines that have passed already - """ +async def get_deadlines( + session: AsyncSession, *, after: Optional[datetime.date] = None, course: Optional[UforaCourse] = None +) -> list[Deadline]: + """Get a list of all upcoming deadlines""" statement = select(Deadline) + if after is not None: + statement = statement.where(Deadline.deadline > after) + if course is not None: statement = statement.where(Deadline.course_id == course.course_id) diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 3ff1886..7af8a81 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -15,7 +15,7 @@ from didier.data.embeds.schedules import Schedule, get_schedule_for_day from didier.exceptions import HTTPException, NotInMainGuildException from didier.utils.discord.converters.time import DateTransformer from didier.utils.discord.flags.school import StudyGuideFlags -from didier.utils.discord.users import to_main_guild_member +from didier.utils.discord.users import has_course, to_main_guild_member from didier.utils.types.datetime import skip_weekends, tz_aware_today @@ -30,11 +30,15 @@ class School(commands.Cog): @commands.hybrid_command(name="deadlines") async def deadlines(self, ctx: commands.Context): """Show upcoming deadlines.""" - async with self.client.postgres_session as session: - deadlines = await get_deadlines(session) + async with ctx.typing(): + async with self.client.postgres_session as session: + deadlines = await get_deadlines(session, after=tz_aware_today()) - embed = Deadlines(deadlines).to_embed() - await ctx.reply(embed=embed, mention_author=False, ephemeral=False) + member = to_main_guild_member(self.client, ctx.author) + deadlines = list(filter(lambda d: has_course(member, d.course), deadlines)) + + embed = Deadlines(deadlines).to_embed() + await ctx.reply(embed=embed, mention_author=False, ephemeral=False) @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) @app_commands.rename(day_dt="date") diff --git a/didier/utils/discord/users.py b/didier/utils/discord/users.py index 27feaa3..89789cd 100644 --- a/didier/utils/discord/users.py +++ b/didier/utils/discord/users.py @@ -2,10 +2,26 @@ from typing import Union import discord +from database.schemas import UforaCourse from didier import Didier from didier.exceptions import NotInMainGuildException -__all__ = ["to_main_guild_member"] +__all__ = ["has_course", "to_main_guild_member"] + + +def has_course(member: discord.Member, course: UforaCourse) -> bool: + """Check if a member is taking a Ufora course""" + for role in member.roles: + if role.id == course.role_id: + return True + + if course.overarching_role_id is not None and course.overarching_role_id == role.id: + return True + + if course.alternative_overarching_role_id is not None and course.alternative_overarching_role_id == role.id: + return True + + return False def to_main_guild_member(client: Didier, user: Union[discord.User, discord.Member]) -> discord.Member: