Create first implementation of events

pull/167/head
stijndcl 2023-02-03 00:16:49 +01:00 committed by Stijn De Clercq
parent 5deb312474
commit 1831446f65
4 changed files with 111 additions and 9 deletions

View File

@ -1,13 +1,14 @@
import datetime
from typing import Optional
from zoneinfo import ZoneInfo
from dateutil.parser import parse
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas import Event
__all__ = ["add_event", "get_event_by_id", "get_events", "get_next_event"]
__all__ = ["add_event", "delete_event_by_id", "get_event_by_id", "get_events", "get_next_event"]
async def add_event(
@ -19,23 +20,31 @@ async def add_event(
event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id)
session.add(event)
await session.commit()
await session.refresh(event)
return event
async def delete_event_by_id(session: AsyncSession, event_id: int):
"""Delete an event by its id"""
statement = delete(Event).where(Event.event_id == event_id)
await session.execute(statement)
await session.commit()
async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]:
"""Get an event by its id"""
statement = select(Event).where(Event.event_id == event_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_events(session: AsyncSession) -> list[Event]:
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
"""Get a list of all upcoming events"""
statement = select(Event)
statement = select(Event).where(Event.timestamp > now)
return (await session.execute(statement)).scalars().all()
async def get_next_event(session: AsyncSession) -> Optional[Event]:
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:
"""Get the first upcoming event"""
statement = select(Event).order_by(Event.timestamp)
statement = select(Event).where(Event.timestamp > now).order_by(Event.timestamp)
return (await session.execute(statement)).scalar_one_or_none()

View File

@ -11,13 +11,15 @@ from database.exceptions import (
ForbiddenNameException,
NoResultFoundException,
)
from database.schemas import Event
from didier import Didier
from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource
from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.constants import Limits
from didier.utils.types.datetime import str_to_date
from didier.utils.timer import Timer
from didier.utils.types.datetime import str_to_date, tz_aware_now
from didier.utils.types.string import abbreviate, leading
from didier.views.modals import CreateBookmark
@ -26,6 +28,7 @@ class Discord(commands.Cog):
"""Commands related to Discord itself, which work with resources like servers and members."""
client: Didier
timer: Timer
# Context-menu references
_bookmark_ctx_menu: app_commands.ContextMenu
@ -38,12 +41,38 @@ class Discord(commands.Cog):
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
self.client.tree.add_command(self._bookmark_ctx_menu)
self.client.tree.add_command(self._pin_ctx_menu)
self.timer = Timer(self.client)
async def cog_unload(self) -> None:
"""Remove the commands when the cog is unloaded"""
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
@commands.Cog.listener("event_create")
async def on_event_create(self, event: Event):
"""Custom listener called when an event is created"""
self.timer.maybe_replace_task(event)
@commands.Cog.listener("timer_end")
async def on_timer_end(self, event_id: int):
"""Custom listener called when an event timer ends"""
async with self.client.postgres_session as session:
event = await events.get_event_by_id(session, event_id)
channel = self.client.get_channel(event.notification_channel)
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
embed.add_field(name="Event", value=event.name, inline=False)
embed.description = event.description
await channel.send(embed=embed)
# Remove the database entry
await events.delete_event_by_id(session, event.event_id)
# Set the next timer
self.client.loop.create_task(self.timer.update())
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: discord.User = None):
"""Command to check the birthday of `user`.
@ -211,7 +240,7 @@ class Discord(commands.Cog):
async with ctx.typing():
async with self.client.postgres_session as session:
if event_id is None:
upcoming = await events.get_events(session)
upcoming = await events.get_events(session, now=tz_aware_now())
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
if not upcoming:

View File

@ -0,0 +1,57 @@
import asyncio
from datetime import datetime
from typing import Optional
import discord.utils
from database.crud.events import get_next_event
from database.schemas import Event
from didier import Didier
from didier.utils.types.datetime import tz_aware_now
__all__ = ["Timer"]
class Timer:
"""Class for scheduled timers"""
client: Didier
upcoming_timer: Optional[datetime]
upcoming_event_id: Optional[int]
_task: Optional[asyncio.Task]
def __init__(self, client: Didier):
self.client = client
self.upcoming_timer = None
self.upcoming_event_id = None
self._task = None
self.client.loop.create_task(self.update())
async def update(self):
"""Get & schedule the closest reminder"""
async with self.client.postgres_session as session:
event = await get_next_event(session, now=tz_aware_now())
# No upcoming events
if event is None:
return
self.maybe_replace_task(event)
def maybe_replace_task(self, event: Event):
"""Replace the current task if necessary"""
# If there is a current (pending) task, and the new timer is sooner than the
# pending one, cancel it
if self._task is not None and not self._task.done() and self.upcoming_timer > event.timestamp:
self._task.cancel()
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id))
self.upcoming_timer = event.timestamp
self.upcoming_event_id = event.event_id
async def end_timer(self, *, endtime: datetime, event_id: int):
"""Wait until a timer runs out, and then trigger an event to send the message"""
await discord.utils.sleep_until(endtime)
self.client.dispatch("timer_end", event_id)

View File

@ -1,3 +1,4 @@
import traceback
from zoneinfo import ZoneInfo
import discord
@ -51,4 +52,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
channel_id=int(self.channel.value),
)
return await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
self.client.dispatch("event_create", event)
@overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__)