mirror of https://github.com/stijndcl/didier
Create first implementation of events
parent
5deb312474
commit
1831446f65
|
@ -1,13 +1,14 @@
|
||||||
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
from dateutil.parser import parse
|
from dateutil.parser import parse
|
||||||
from sqlalchemy import select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas import Event
|
from database.schemas import Event
|
||||||
|
|
||||||
__all__ = ["add_event", "get_event_by_id", "get_events", "get_next_event"]
|
__all__ = ["add_event", "delete_event_by_id", "get_event_by_id", "get_events", "get_next_event"]
|
||||||
|
|
||||||
|
|
||||||
async def add_event(
|
async def add_event(
|
||||||
|
@ -19,23 +20,31 @@ async def add_event(
|
||||||
event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id)
|
event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id)
|
||||||
session.add(event)
|
session.add(event)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.refresh(event)
|
||||||
|
|
||||||
return event
|
return event
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_event_by_id(session: AsyncSession, event_id: int):
|
||||||
|
"""Delete an event by its id"""
|
||||||
|
statement = delete(Event).where(Event.event_id == event_id)
|
||||||
|
await session.execute(statement)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]:
|
async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]:
|
||||||
"""Get an event by its id"""
|
"""Get an event by its id"""
|
||||||
statement = select(Event).where(Event.event_id == event_id)
|
statement = select(Event).where(Event.event_id == event_id)
|
||||||
return (await session.execute(statement)).scalar_one_or_none()
|
return (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def get_events(session: AsyncSession) -> list[Event]:
|
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
|
||||||
"""Get a list of all upcoming events"""
|
"""Get a list of all upcoming events"""
|
||||||
statement = select(Event)
|
statement = select(Event).where(Event.timestamp > now)
|
||||||
return (await session.execute(statement)).scalars().all()
|
return (await session.execute(statement)).scalars().all()
|
||||||
|
|
||||||
|
|
||||||
async def get_next_event(session: AsyncSession) -> Optional[Event]:
|
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:
|
||||||
"""Get the first upcoming event"""
|
"""Get the first upcoming event"""
|
||||||
statement = select(Event).order_by(Event.timestamp)
|
statement = select(Event).where(Event.timestamp > now).order_by(Event.timestamp)
|
||||||
return (await session.execute(statement)).scalar_one_or_none()
|
return (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
|
@ -11,13 +11,15 @@ from database.exceptions import (
|
||||||
ForbiddenNameException,
|
ForbiddenNameException,
|
||||||
NoResultFoundException,
|
NoResultFoundException,
|
||||||
)
|
)
|
||||||
|
from database.schemas import Event
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.exceptions import expect
|
from didier.exceptions import expect
|
||||||
from didier.menus.bookmarks import BookmarkSource
|
from didier.menus.bookmarks import BookmarkSource
|
||||||
from didier.utils.discord import colours
|
from didier.utils.discord import colours
|
||||||
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
||||||
from didier.utils.discord.constants import Limits
|
from didier.utils.discord.constants import Limits
|
||||||
from didier.utils.types.datetime import str_to_date
|
from didier.utils.timer import Timer
|
||||||
|
from didier.utils.types.datetime import str_to_date, tz_aware_now
|
||||||
from didier.utils.types.string import abbreviate, leading
|
from didier.utils.types.string import abbreviate, leading
|
||||||
from didier.views.modals import CreateBookmark
|
from didier.views.modals import CreateBookmark
|
||||||
|
|
||||||
|
@ -26,6 +28,7 @@ class Discord(commands.Cog):
|
||||||
"""Commands related to Discord itself, which work with resources like servers and members."""
|
"""Commands related to Discord itself, which work with resources like servers and members."""
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
|
timer: Timer
|
||||||
|
|
||||||
# Context-menu references
|
# Context-menu references
|
||||||
_bookmark_ctx_menu: app_commands.ContextMenu
|
_bookmark_ctx_menu: app_commands.ContextMenu
|
||||||
|
@ -38,12 +41,38 @@ class Discord(commands.Cog):
|
||||||
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
|
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
|
||||||
self.client.tree.add_command(self._bookmark_ctx_menu)
|
self.client.tree.add_command(self._bookmark_ctx_menu)
|
||||||
self.client.tree.add_command(self._pin_ctx_menu)
|
self.client.tree.add_command(self._pin_ctx_menu)
|
||||||
|
self.timer = Timer(self.client)
|
||||||
|
|
||||||
async def cog_unload(self) -> None:
|
async def cog_unload(self) -> None:
|
||||||
"""Remove the commands when the cog is unloaded"""
|
"""Remove the commands when the cog is unloaded"""
|
||||||
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
|
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
|
||||||
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
||||||
|
|
||||||
|
@commands.Cog.listener("event_create")
|
||||||
|
async def on_event_create(self, event: Event):
|
||||||
|
"""Custom listener called when an event is created"""
|
||||||
|
self.timer.maybe_replace_task(event)
|
||||||
|
|
||||||
|
@commands.Cog.listener("timer_end")
|
||||||
|
async def on_timer_end(self, event_id: int):
|
||||||
|
"""Custom listener called when an event timer ends"""
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
event = await events.get_event_by_id(session, event_id)
|
||||||
|
|
||||||
|
channel = self.client.get_channel(event.notification_channel)
|
||||||
|
|
||||||
|
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
|
||||||
|
embed.add_field(name="Event", value=event.name, inline=False)
|
||||||
|
embed.description = event.description
|
||||||
|
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
||||||
|
# Remove the database entry
|
||||||
|
await events.delete_event_by_id(session, event.event_id)
|
||||||
|
|
||||||
|
# Set the next timer
|
||||||
|
self.client.loop.create_task(self.timer.update())
|
||||||
|
|
||||||
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
|
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
|
||||||
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
||||||
"""Command to check the birthday of `user`.
|
"""Command to check the birthday of `user`.
|
||||||
|
@ -211,7 +240,7 @@ class Discord(commands.Cog):
|
||||||
async with ctx.typing():
|
async with ctx.typing():
|
||||||
async with self.client.postgres_session as session:
|
async with self.client.postgres_session as session:
|
||||||
if event_id is None:
|
if event_id is None:
|
||||||
upcoming = await events.get_events(session)
|
upcoming = await events.get_events(session, now=tz_aware_now())
|
||||||
|
|
||||||
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
|
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
|
||||||
if not upcoming:
|
if not upcoming:
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord.utils
|
||||||
|
|
||||||
|
from database.crud.events import get_next_event
|
||||||
|
from database.schemas import Event
|
||||||
|
from didier import Didier
|
||||||
|
from didier.utils.types.datetime import tz_aware_now
|
||||||
|
|
||||||
|
__all__ = ["Timer"]
|
||||||
|
|
||||||
|
|
||||||
|
class Timer:
|
||||||
|
"""Class for scheduled timers"""
|
||||||
|
|
||||||
|
client: Didier
|
||||||
|
upcoming_timer: Optional[datetime]
|
||||||
|
upcoming_event_id: Optional[int]
|
||||||
|
_task: Optional[asyncio.Task]
|
||||||
|
|
||||||
|
def __init__(self, client: Didier):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
self.upcoming_timer = None
|
||||||
|
self.upcoming_event_id = None
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
self.client.loop.create_task(self.update())
|
||||||
|
|
||||||
|
async def update(self):
|
||||||
|
"""Get & schedule the closest reminder"""
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
event = await get_next_event(session, now=tz_aware_now())
|
||||||
|
|
||||||
|
# No upcoming events
|
||||||
|
if event is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
self.maybe_replace_task(event)
|
||||||
|
|
||||||
|
def maybe_replace_task(self, event: Event):
|
||||||
|
"""Replace the current task if necessary"""
|
||||||
|
# If there is a current (pending) task, and the new timer is sooner than the
|
||||||
|
# pending one, cancel it
|
||||||
|
if self._task is not None and not self._task.done() and self.upcoming_timer > event.timestamp:
|
||||||
|
self._task.cancel()
|
||||||
|
|
||||||
|
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id))
|
||||||
|
self.upcoming_timer = event.timestamp
|
||||||
|
self.upcoming_event_id = event.event_id
|
||||||
|
|
||||||
|
async def end_timer(self, *, endtime: datetime, event_id: int):
|
||||||
|
"""Wait until a timer runs out, and then trigger an event to send the message"""
|
||||||
|
await discord.utils.sleep_until(endtime)
|
||||||
|
self.client.dispatch("timer_end", event_id)
|
|
@ -1,3 +1,4 @@
|
||||||
|
import traceback
|
||||||
from zoneinfo import ZoneInfo
|
from zoneinfo import ZoneInfo
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
@ -51,4 +52,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
|
||||||
channel_id=int(self.channel.value),
|
channel_id=int(self.channel.value),
|
||||||
)
|
)
|
||||||
|
|
||||||
return await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
|
await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
|
||||||
|
self.client.dispatch("event_create", event)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||||
|
await interaction.response.send_message("Something went wrong.", ephemeral=True)
|
||||||
|
traceback.print_tb(error.__traceback__)
|
||||||
|
|
Loading…
Reference in New Issue