diff --git a/database/crud/events.py b/database/crud/events.py index 5d25550..8636d1b 100644 --- a/database/crud/events.py +++ b/database/crud/events.py @@ -1,13 +1,14 @@ +import datetime from typing import Optional from zoneinfo import ZoneInfo from dateutil.parser import parse -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from database.schemas import Event -__all__ = ["add_event", "get_event_by_id", "get_events", "get_next_event"] +__all__ = ["add_event", "delete_event_by_id", "get_event_by_id", "get_events", "get_next_event"] async def add_event( @@ -19,23 +20,31 @@ async def add_event( event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id) session.add(event) await session.commit() + await session.refresh(event) return event +async def delete_event_by_id(session: AsyncSession, event_id: int): + """Delete an event by its id""" + statement = delete(Event).where(Event.event_id == event_id) + await session.execute(statement) + await session.commit() + + async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]: """Get an event by its id""" statement = select(Event).where(Event.event_id == event_id) return (await session.execute(statement)).scalar_one_or_none() -async def get_events(session: AsyncSession) -> list[Event]: +async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: """Get a list of all upcoming events""" - statement = select(Event) + statement = select(Event).where(Event.timestamp > now) return (await session.execute(statement)).scalars().all() -async def get_next_event(session: AsyncSession) -> Optional[Event]: +async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: """Get the first upcoming event""" - statement = select(Event).order_by(Event.timestamp) + statement = select(Event).where(Event.timestamp > now).order_by(Event.timestamp) return (await session.execute(statement)).scalar_one_or_none() diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index e2d9426..4ff9c3e 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -11,13 +11,15 @@ from database.exceptions import ( ForbiddenNameException, NoResultFoundException, ) +from database.schemas import Event from didier import Didier from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.constants import Limits -from didier.utils.types.datetime import str_to_date +from didier.utils.timer import Timer +from didier.utils.types.datetime import str_to_date, tz_aware_now from didier.utils.types.string import abbreviate, leading from didier.views.modals import CreateBookmark @@ -26,6 +28,7 @@ class Discord(commands.Cog): """Commands related to Discord itself, which work with resources like servers and members.""" client: Didier + timer: Timer # Context-menu references _bookmark_ctx_menu: app_commands.ContextMenu @@ -38,12 +41,38 @@ class Discord(commands.Cog): self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx) self.client.tree.add_command(self._bookmark_ctx_menu) self.client.tree.add_command(self._pin_ctx_menu) + self.timer = Timer(self.client) async def cog_unload(self) -> None: """Remove the commands when the cog is unloaded""" self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) + @commands.Cog.listener("event_create") + async def on_event_create(self, event: Event): + """Custom listener called when an event is created""" + self.timer.maybe_replace_task(event) + + @commands.Cog.listener("timer_end") + async def on_timer_end(self, event_id: int): + """Custom listener called when an event timer ends""" + async with self.client.postgres_session as session: + event = await events.get_event_by_id(session, event_id) + + channel = self.client.get_channel(event.notification_channel) + + embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) + embed.add_field(name="Event", value=event.name, inline=False) + embed.description = event.description + + await channel.send(embed=embed) + + # Remove the database entry + await events.delete_event_by_id(session, event.event_id) + + # Set the next timer + self.client.loop.create_task(self.timer.update()) + @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) async def birthday(self, ctx: commands.Context, user: discord.User = None): """Command to check the birthday of `user`. @@ -211,7 +240,7 @@ class Discord(commands.Cog): async with ctx.typing(): async with self.client.postgres_session as session: if event_id is None: - upcoming = await events.get_events(session) + upcoming = await events.get_events(session, now=tz_aware_now()) embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) if not upcoming: diff --git a/didier/utils/timer.py b/didier/utils/timer.py new file mode 100644 index 0000000..ae72df4 --- /dev/null +++ b/didier/utils/timer.py @@ -0,0 +1,57 @@ +import asyncio +from datetime import datetime +from typing import Optional + +import discord.utils + +from database.crud.events import get_next_event +from database.schemas import Event +from didier import Didier +from didier.utils.types.datetime import tz_aware_now + +__all__ = ["Timer"] + + +class Timer: + """Class for scheduled timers""" + + client: Didier + upcoming_timer: Optional[datetime] + upcoming_event_id: Optional[int] + _task: Optional[asyncio.Task] + + def __init__(self, client: Didier): + self.client = client + + self.upcoming_timer = None + self.upcoming_event_id = None + self._task = None + + self.client.loop.create_task(self.update()) + + async def update(self): + """Get & schedule the closest reminder""" + async with self.client.postgres_session as session: + event = await get_next_event(session, now=tz_aware_now()) + + # No upcoming events + if event is None: + return + + self.maybe_replace_task(event) + + def maybe_replace_task(self, event: Event): + """Replace the current task if necessary""" + # If there is a current (pending) task, and the new timer is sooner than the + # pending one, cancel it + if self._task is not None and not self._task.done() and self.upcoming_timer > event.timestamp: + self._task.cancel() + + self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) + self.upcoming_timer = event.timestamp + self.upcoming_event_id = event.event_id + + async def end_timer(self, *, endtime: datetime, event_id: int): + """Wait until a timer runs out, and then trigger an event to send the message""" + await discord.utils.sleep_until(endtime) + self.client.dispatch("timer_end", event_id) diff --git a/didier/views/modals/events.py b/didier/views/modals/events.py index a02b963..e7b92b4 100644 --- a/didier/views/modals/events.py +++ b/didier/views/modals/events.py @@ -1,3 +1,4 @@ +import traceback from zoneinfo import ZoneInfo import discord @@ -51,4 +52,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"): channel_id=int(self.channel.value), ) - return await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) + await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) + self.client.dispatch("event_create", event) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__)