From 5deb31247451cec2e0a1c096a9cc16ff794aa209 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 2 Feb 2023 23:16:51 +0100 Subject: [PATCH 1/5] Start working on events --- .../versions/954ad804f057_add_events_table.py | 36 +++++++++++++ database/crud/events.py | 41 ++++++++++++++ database/schemas.py | 13 +++++ didier/cogs/discord.py | 48 ++++++++++++++++- didier/cogs/owner.py | 10 ++++ didier/views/modals/__init__.py | 2 + didier/views/modals/events.py | 54 +++++++++++++++++++ 7 files changed, 203 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/954ad804f057_add_events_table.py create mode 100644 database/crud/events.py create mode 100644 didier/views/modals/events.py diff --git a/alembic/versions/954ad804f057_add_events_table.py b/alembic/versions/954ad804f057_add_events_table.py new file mode 100644 index 0000000..c066446 --- /dev/null +++ b/alembic/versions/954ad804f057_add_events_table.py @@ -0,0 +1,36 @@ +"""Add events table + +Revision ID: 954ad804f057 +Revises: 9fb84b4d9f0b +Create Date: 2023-02-02 22:20:23.107931 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "954ad804f057" +down_revision = "9fb84b4d9f0b" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "events", + sa.Column("event_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("notification_channel", sa.BigInteger(), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.PrimaryKeyConstraint("event_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("events") + # ### end Alembic commands ### diff --git a/database/crud/events.py b/database/crud/events.py new file mode 100644 index 0000000..5d25550 --- /dev/null +++ b/database/crud/events.py @@ -0,0 +1,41 @@ +from typing import Optional +from zoneinfo import ZoneInfo + +from dateutil.parser import parse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas import Event + +__all__ = ["add_event", "get_event_by_id", "get_events", "get_next_event"] + + +async def add_event( + session: AsyncSession, *, name: str, description: Optional[str], date_str: str, channel_id: int +) -> Event: + """Create a new event""" + date_dt = parse(date_str, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) + + event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id) + session.add(event) + await session.commit() + + return event + + +async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]: + """Get an event by its id""" + statement = select(Event).where(Event.event_id == event_id) + return (await session.execute(statement)).scalar_one_or_none() + + +async def get_events(session: AsyncSession) -> list[Event]: + """Get a list of all upcoming events""" + statement = select(Event) + return (await session.execute(statement)).scalars().all() + + +async def get_next_event(session: AsyncSession) -> Optional[Event]: + """Get the first upcoming event""" + statement = select(Event).order_by(Event.timestamp) + return (await session.execute(statement)).scalar_one_or_none() diff --git a/database/schemas.py b/database/schemas.py index ffda6ba..3300e6a 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -33,6 +33,7 @@ __all__ = [ "DadJoke", "Deadline", "EasterEgg", + "Event", "FreeGame", "GitHubLink", "Link", @@ -175,6 +176,18 @@ class EasterEgg(Base): startswith: bool = Column(Boolean, nullable=False, server_default="1") +class Event(Base): + """A scheduled event""" + + __tablename__ = "events" + + event_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False) + description: Optional[str] = Column(Text, nullable=True) + notification_channel: int = Column(BigInteger, nullable=False) + timestamp: datetime = Column(DateTime(timezone=True), nullable=False) + + class FreeGame(Base): """A temporarily free game""" diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index df91673..e2d9426 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -4,7 +4,7 @@ import discord from discord import app_commands from discord.ext import commands -from database.crud import birthdays, bookmarks, github +from database.crud import birthdays, bookmarks, events, github from database.exceptions import ( DuplicateInsertException, Forbidden, @@ -200,6 +200,52 @@ class Discord(commands.Cog): modal = CreateBookmark(self.client, message.jump_url) await interaction.response.send_modal(modal) + @commands.hybrid_command(name="events") + @app_commands.rename(event_id="id") + @app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.") + async def events(self, ctx: commands.Context, event_id: Optional[int] = None): + """Show information about the event with id `event_id`. + + If no value for `event_id` is supplied, this shows all upcoming events instead. + """ + async with ctx.typing(): + async with self.client.postgres_session as session: + if event_id is None: + upcoming = await events.get_events(session) + + embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) + if not upcoming: + embed.colour = discord.Colour.red() + embed.description = "There are currently no upcoming events scheduled." + return await ctx.reply(embed=embed, mention_author=False) + + upcoming.sort(key=lambda e: e.timestamp.timestamp()) + description_items = [] + + for event in upcoming: + description_items.append( + f"`{event.event_id}`: {event.name} ({discord.utils.format_dt(event.timestamp, style='R')})" + ) + + embed.description = "\n".join(description_items) + return await ctx.reply(embed=embed, mention_author=False) + else: + event = await events.get_event_by_id(session, event_id) + if event is None: + return await ctx.reply(f"Found no event with id `{event_id}`.", mention_author=False) + + embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) + embed.add_field(name="Name", value=event.name, inline=True) + embed.add_field(name="Id", value=event.event_id, inline=True) + embed.add_field( + name="Timer", value=discord.utils.format_dt(event.timestamp, style="R"), inline=True + ) + embed.add_field( + name="Channel", value=self.client.get_channel(event.notification_channel).mention, inline=False + ) + embed.description = event.description + return await ctx.reply(embed=embed, mention_author=False) + @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): """Show a user's GitHub links. diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 1c2e76b..139f02c 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -13,6 +13,7 @@ from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags from didier.views.modals import ( AddDadJoke, AddDeadline, + AddEvent, AddLink, CreateCustomCommand, EditCustomCommand, @@ -173,6 +174,15 @@ class Owner(commands.Cog): """Autocompletion for the 'course'-parameter""" return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) + @add_slash.command(name="event", description="Add a new event") + async def add_event_slash(self, interaction: discord.Interaction): + """Slash command to add new events""" + if not await self.client.is_owner(interaction.user): + return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True) + + modal = AddEvent(self.client) + await interaction.response.send_modal(modal) + @add_slash.command(name="link", description="Add a new link") async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py index e9f92f0..fe3c352 100644 --- a/didier/views/modals/__init__.py +++ b/didier/views/modals/__init__.py @@ -2,6 +2,7 @@ from .bookmarks import CreateBookmark from .custom_commands import CreateCustomCommand, EditCustomCommand from .dad_jokes import AddDadJoke from .deadlines import AddDeadline +from .events import AddEvent from .links import AddLink from .memes import GenerateMeme @@ -9,6 +10,7 @@ __all__ = [ "CreateBookmark", "AddDadJoke", "AddDeadline", + "AddEvent", "CreateCustomCommand", "EditCustomCommand", "AddLink", diff --git a/didier/views/modals/events.py b/didier/views/modals/events.py new file mode 100644 index 0000000..a02b963 --- /dev/null +++ b/didier/views/modals/events.py @@ -0,0 +1,54 @@ +from zoneinfo import ZoneInfo + +import discord +from dateutil.parser import ParserError, parse +from overrides import overrides + +from database.crud.events import add_event +from didier import Didier + +__all__ = ["AddEvent"] + + +class AddEvent(discord.ui.Modal, title="Add Event"): + """Modal to add a new event""" + + name: discord.ui.TextInput = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, required=True) + description: discord.ui.TextInput = discord.ui.TextInput( + label="Description", style=discord.TextStyle.paragraph, required=False, default=None + ) + channel: discord.ui.TextInput = discord.ui.TextInput( + label="Channel id", style=discord.TextStyle.short, required=True, placeholder="676713433567199232" + ) + timestamp: discord.ui.TextInput = discord.ui.TextInput( + label="Date", style=discord.TextStyle.short, required=True, placeholder="21/02/2020 21:21:00" + ) + + client: Didier + + def __init__(self, client: Didier, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + + @overrides + async def on_submit(self, interaction: discord.Interaction) -> None: + try: + parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) + except ParserError: + return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True) + + if self.client.get_channel(int(self.channel.value)) is None: + return await interaction.response.send_message( + f"Unable to find channel `{self.channel.value}`", ephemeral=True + ) + + async with self.client.postgres_session as session: + event = await add_event( + session, + name=self.name.value, + description=self.description.value, + date_str=self.timestamp.value, + channel_id=int(self.channel.value), + ) + + return await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) From 1831446f6513b5ffa9f75af42d6ac8b2221d93bd Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 3 Feb 2023 00:16:49 +0100 Subject: [PATCH 2/5] Create first implementation of events --- database/crud/events.py | 21 +++++++++---- didier/cogs/discord.py | 33 ++++++++++++++++++-- didier/utils/timer.py | 57 +++++++++++++++++++++++++++++++++++ didier/views/modals/events.py | 9 +++++- 4 files changed, 111 insertions(+), 9 deletions(-) create mode 100644 didier/utils/timer.py diff --git a/database/crud/events.py b/database/crud/events.py index 5d25550..8636d1b 100644 --- a/database/crud/events.py +++ b/database/crud/events.py @@ -1,13 +1,14 @@ +import datetime from typing import Optional from zoneinfo import ZoneInfo from dateutil.parser import parse -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from database.schemas import Event -__all__ = ["add_event", "get_event_by_id", "get_events", "get_next_event"] +__all__ = ["add_event", "delete_event_by_id", "get_event_by_id", "get_events", "get_next_event"] async def add_event( @@ -19,23 +20,31 @@ async def add_event( event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id) session.add(event) await session.commit() + await session.refresh(event) return event +async def delete_event_by_id(session: AsyncSession, event_id: int): + """Delete an event by its id""" + statement = delete(Event).where(Event.event_id == event_id) + await session.execute(statement) + await session.commit() + + async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]: """Get an event by its id""" statement = select(Event).where(Event.event_id == event_id) return (await session.execute(statement)).scalar_one_or_none() -async def get_events(session: AsyncSession) -> list[Event]: +async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: """Get a list of all upcoming events""" - statement = select(Event) + statement = select(Event).where(Event.timestamp > now) return (await session.execute(statement)).scalars().all() -async def get_next_event(session: AsyncSession) -> Optional[Event]: +async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: """Get the first upcoming event""" - statement = select(Event).order_by(Event.timestamp) + statement = select(Event).where(Event.timestamp > now).order_by(Event.timestamp) return (await session.execute(statement)).scalar_one_or_none() diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index e2d9426..4ff9c3e 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -11,13 +11,15 @@ from database.exceptions import ( ForbiddenNameException, NoResultFoundException, ) +from database.schemas import Event from didier import Didier from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.constants import Limits -from didier.utils.types.datetime import str_to_date +from didier.utils.timer import Timer +from didier.utils.types.datetime import str_to_date, tz_aware_now from didier.utils.types.string import abbreviate, leading from didier.views.modals import CreateBookmark @@ -26,6 +28,7 @@ class Discord(commands.Cog): """Commands related to Discord itself, which work with resources like servers and members.""" client: Didier + timer: Timer # Context-menu references _bookmark_ctx_menu: app_commands.ContextMenu @@ -38,12 +41,38 @@ class Discord(commands.Cog): self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx) self.client.tree.add_command(self._bookmark_ctx_menu) self.client.tree.add_command(self._pin_ctx_menu) + self.timer = Timer(self.client) async def cog_unload(self) -> None: """Remove the commands when the cog is unloaded""" self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) + @commands.Cog.listener("event_create") + async def on_event_create(self, event: Event): + """Custom listener called when an event is created""" + self.timer.maybe_replace_task(event) + + @commands.Cog.listener("timer_end") + async def on_timer_end(self, event_id: int): + """Custom listener called when an event timer ends""" + async with self.client.postgres_session as session: + event = await events.get_event_by_id(session, event_id) + + channel = self.client.get_channel(event.notification_channel) + + embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) + embed.add_field(name="Event", value=event.name, inline=False) + embed.description = event.description + + await channel.send(embed=embed) + + # Remove the database entry + await events.delete_event_by_id(session, event.event_id) + + # Set the next timer + self.client.loop.create_task(self.timer.update()) + @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) async def birthday(self, ctx: commands.Context, user: discord.User = None): """Command to check the birthday of `user`. @@ -211,7 +240,7 @@ class Discord(commands.Cog): async with ctx.typing(): async with self.client.postgres_session as session: if event_id is None: - upcoming = await events.get_events(session) + upcoming = await events.get_events(session, now=tz_aware_now()) embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) if not upcoming: diff --git a/didier/utils/timer.py b/didier/utils/timer.py new file mode 100644 index 0000000..ae72df4 --- /dev/null +++ b/didier/utils/timer.py @@ -0,0 +1,57 @@ +import asyncio +from datetime import datetime +from typing import Optional + +import discord.utils + +from database.crud.events import get_next_event +from database.schemas import Event +from didier import Didier +from didier.utils.types.datetime import tz_aware_now + +__all__ = ["Timer"] + + +class Timer: + """Class for scheduled timers""" + + client: Didier + upcoming_timer: Optional[datetime] + upcoming_event_id: Optional[int] + _task: Optional[asyncio.Task] + + def __init__(self, client: Didier): + self.client = client + + self.upcoming_timer = None + self.upcoming_event_id = None + self._task = None + + self.client.loop.create_task(self.update()) + + async def update(self): + """Get & schedule the closest reminder""" + async with self.client.postgres_session as session: + event = await get_next_event(session, now=tz_aware_now()) + + # No upcoming events + if event is None: + return + + self.maybe_replace_task(event) + + def maybe_replace_task(self, event: Event): + """Replace the current task if necessary""" + # If there is a current (pending) task, and the new timer is sooner than the + # pending one, cancel it + if self._task is not None and not self._task.done() and self.upcoming_timer > event.timestamp: + self._task.cancel() + + self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) + self.upcoming_timer = event.timestamp + self.upcoming_event_id = event.event_id + + async def end_timer(self, *, endtime: datetime, event_id: int): + """Wait until a timer runs out, and then trigger an event to send the message""" + await discord.utils.sleep_until(endtime) + self.client.dispatch("timer_end", event_id) diff --git a/didier/views/modals/events.py b/didier/views/modals/events.py index a02b963..e7b92b4 100644 --- a/didier/views/modals/events.py +++ b/didier/views/modals/events.py @@ -1,3 +1,4 @@ +import traceback from zoneinfo import ZoneInfo import discord @@ -51,4 +52,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"): channel_id=int(self.channel.value), ) - return await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) + await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) + self.client.dispatch("event_create", event) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__) From 4548f0e0036b47ca68946a710bbb7cddec17688d Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 3 Feb 2023 00:25:29 +0100 Subject: [PATCH 3/5] Fix listeners --- didier/cogs/discord.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index 4ff9c3e..38787bc 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -48,12 +48,12 @@ class Discord(commands.Cog): self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) - @commands.Cog.listener("event_create") + @commands.Cog.listener() async def on_event_create(self, event: Event): """Custom listener called when an event is created""" self.timer.maybe_replace_task(event) - @commands.Cog.listener("timer_end") + @commands.Cog.listener() async def on_timer_end(self, event_id: int): """Custom listener called when an event timer ends""" async with self.client.postgres_session as session: From 4e205a02c7fd970b83ba93665066345138a8a3fb Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 3 Feb 2023 00:57:12 +0100 Subject: [PATCH 4/5] Last fixes --- didier/utils/timer.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/didier/utils/timer.py b/didier/utils/timer.py index ae72df4..306c073 100644 --- a/didier/utils/timer.py +++ b/didier/utils/timer.py @@ -44,8 +44,12 @@ class Timer: """Replace the current task if necessary""" # If there is a current (pending) task, and the new timer is sooner than the # pending one, cancel it - if self._task is not None and not self._task.done() and self.upcoming_timer > event.timestamp: - self._task.cancel() + if self._task is not None and not self._task.done(): + if self.upcoming_timer > event.timestamp: + self._task.cancel() + else: + # The new task happens after the existing task, it has to wait for its turn + return self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) self.upcoming_timer = event.timestamp @@ -54,4 +58,6 @@ class Timer: async def end_timer(self, *, endtime: datetime, event_id: int): """Wait until a timer runs out, and then trigger an event to send the message""" await discord.utils.sleep_until(endtime) + self.upcoming_timer = None + self.upcoming_event_id = None self.client.dispatch("timer_end", event_id) From deba3ababa6ddaf8c2adc9e6144a61a70183750f Mon Sep 17 00:00:00 2001 From: Stijn De Clercq Date: Thu, 16 Feb 2023 23:57:56 +0100 Subject: [PATCH 5/5] Typing --- didier/cogs/discord.py | 19 ++++++++++++------- didier/utils/timer.py | 5 +++-- 2 files changed, 15 insertions(+), 9 deletions(-) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index 38787bc..ab00717 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -59,6 +59,9 @@ class Discord(commands.Cog): async with self.client.postgres_session as session: event = await events.get_event_by_id(session, event_id) + if event is None: + return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True) + channel = self.client.get_channel(event.notification_channel) embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) @@ -259,20 +262,22 @@ class Discord(commands.Cog): embed.description = "\n".join(description_items) return await ctx.reply(embed=embed, mention_author=False) else: - event = await events.get_event_by_id(session, event_id) - if event is None: + result_event = await events.get_event_by_id(session, event_id) + if result_event is None: return await ctx.reply(f"Found no event with id `{event_id}`.", mention_author=False) embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue()) - embed.add_field(name="Name", value=event.name, inline=True) - embed.add_field(name="Id", value=event.event_id, inline=True) + embed.add_field(name="Name", value=result_event.name, inline=True) + embed.add_field(name="Id", value=result_event.event_id, inline=True) embed.add_field( - name="Timer", value=discord.utils.format_dt(event.timestamp, style="R"), inline=True + name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True ) embed.add_field( - name="Channel", value=self.client.get_channel(event.notification_channel).mention, inline=False + name="Channel", + value=self.client.get_channel(result_event.notification_channel).mention, + inline=False, ) - embed.description = event.description + embed.description = result_event.description return await ctx.reply(embed=embed, mention_author=False) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) diff --git a/didier/utils/timer.py b/didier/utils/timer.py index 306c073..e5b216b 100644 --- a/didier/utils/timer.py +++ b/didier/utils/timer.py @@ -45,15 +45,16 @@ class Timer: # If there is a current (pending) task, and the new timer is sooner than the # pending one, cancel it if self._task is not None and not self._task.done(): - if self.upcoming_timer > event.timestamp: + # The upcoming timer will never be None at this point, but Mypy is mad + if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp: self._task.cancel() else: # The new task happens after the existing task, it has to wait for its turn return - self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) self.upcoming_timer = event.timestamp self.upcoming_event_id = event.event_id + self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) async def end_timer(self, *, endtime: datetime, event_id: int): """Wait until a timer runs out, and then trigger an event to send the message"""