Merge pull request #167 from stijndcl/events

Add event reminders
pull/172/head
Stijn De Clercq 2023-02-16 23:01:45 +00:00 committed by GitHub
commit c570cd2db2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 318 additions and 2 deletions

View File

@ -0,0 +1,36 @@
"""Add events table
Revision ID: 954ad804f057
Revises: 9fb84b4d9f0b
Create Date: 2023-02-02 22:20:23.107931
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "954ad804f057"
down_revision = "9fb84b4d9f0b"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"events",
sa.Column("event_id", sa.Integer(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("notification_channel", sa.BigInteger(), nullable=False),
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False),
sa.PrimaryKeyConstraint("event_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("events")
# ### end Alembic commands ###

View File

@ -0,0 +1,50 @@
import datetime
from typing import Optional
from zoneinfo import ZoneInfo
from dateutil.parser import parse
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas import Event
__all__ = ["add_event", "delete_event_by_id", "get_event_by_id", "get_events", "get_next_event"]
async def add_event(
session: AsyncSession, *, name: str, description: Optional[str], date_str: str, channel_id: int
) -> Event:
"""Create a new event"""
date_dt = parse(date_str, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels"))
event = Event(name=name, description=description, timestamp=date_dt, notification_channel=channel_id)
session.add(event)
await session.commit()
await session.refresh(event)
return event
async def delete_event_by_id(session: AsyncSession, event_id: int):
"""Delete an event by its id"""
statement = delete(Event).where(Event.event_id == event_id)
await session.execute(statement)
await session.commit()
async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Event]:
"""Get an event by its id"""
statement = select(Event).where(Event.event_id == event_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
"""Get a list of all upcoming events"""
statement = select(Event).where(Event.timestamp > now)
return (await session.execute(statement)).scalars().all()
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:
"""Get the first upcoming event"""
statement = select(Event).where(Event.timestamp > now).order_by(Event.timestamp)
return (await session.execute(statement)).scalar_one_or_none()

View File

@ -33,6 +33,7 @@ __all__ = [
"DadJoke", "DadJoke",
"Deadline", "Deadline",
"EasterEgg", "EasterEgg",
"Event",
"FreeGame", "FreeGame",
"GitHubLink", "GitHubLink",
"Link", "Link",
@ -175,6 +176,18 @@ class EasterEgg(Base):
startswith: bool = Column(Boolean, nullable=False, server_default="1") startswith: bool = Column(Boolean, nullable=False, server_default="1")
class Event(Base):
"""A scheduled event"""
__tablename__ = "events"
event_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False)
description: Optional[str] = Column(Text, nullable=True)
notification_channel: int = Column(BigInteger, nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
class FreeGame(Base): class FreeGame(Base):
"""A temporarily free game""" """A temporarily free game"""

View File

@ -4,20 +4,22 @@ import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.crud import birthdays, bookmarks, github from database.crud import birthdays, bookmarks, events, github
from database.exceptions import ( from database.exceptions import (
DuplicateInsertException, DuplicateInsertException,
Forbidden, Forbidden,
ForbiddenNameException, ForbiddenNameException,
NoResultFoundException, NoResultFoundException,
) )
from database.schemas import Event
from didier import Didier from didier import Didier
from didier.exceptions import expect from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource from didier.menus.bookmarks import BookmarkSource
from didier.utils.discord import colours from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.constants import Limits from didier.utils.discord.constants import Limits
from didier.utils.types.datetime import str_to_date from didier.utils.timer import Timer
from didier.utils.types.datetime import str_to_date, tz_aware_now
from didier.utils.types.string import abbreviate, leading from didier.utils.types.string import abbreviate, leading
from didier.views.modals import CreateBookmark from didier.views.modals import CreateBookmark
@ -26,6 +28,7 @@ class Discord(commands.Cog):
"""Commands related to Discord itself, which work with resources like servers and members.""" """Commands related to Discord itself, which work with resources like servers and members."""
client: Didier client: Didier
timer: Timer
# Context-menu references # Context-menu references
_bookmark_ctx_menu: app_commands.ContextMenu _bookmark_ctx_menu: app_commands.ContextMenu
@ -38,12 +41,41 @@ class Discord(commands.Cog):
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx) self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
self.client.tree.add_command(self._bookmark_ctx_menu) self.client.tree.add_command(self._bookmark_ctx_menu)
self.client.tree.add_command(self._pin_ctx_menu) self.client.tree.add_command(self._pin_ctx_menu)
self.timer = Timer(self.client)
async def cog_unload(self) -> None: async def cog_unload(self) -> None:
"""Remove the commands when the cog is unloaded""" """Remove the commands when the cog is unloaded"""
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
@commands.Cog.listener()
async def on_event_create(self, event: Event):
"""Custom listener called when an event is created"""
self.timer.maybe_replace_task(event)
@commands.Cog.listener()
async def on_timer_end(self, event_id: int):
"""Custom listener called when an event timer ends"""
async with self.client.postgres_session as session:
event = await events.get_event_by_id(session, event_id)
if event is None:
return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True)
channel = self.client.get_channel(event.notification_channel)
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
embed.add_field(name="Event", value=event.name, inline=False)
embed.description = event.description
await channel.send(embed=embed)
# Remove the database entry
await events.delete_event_by_id(session, event.event_id)
# Set the next timer
self.client.loop.create_task(self.timer.update())
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: discord.User = None): async def birthday(self, ctx: commands.Context, user: discord.User = None):
"""Command to check the birthday of `user`. """Command to check the birthday of `user`.
@ -200,6 +232,54 @@ class Discord(commands.Cog):
modal = CreateBookmark(self.client, message.jump_url) modal = CreateBookmark(self.client, message.jump_url)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@commands.hybrid_command(name="events")
@app_commands.rename(event_id="id")
@app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.")
async def events(self, ctx: commands.Context, event_id: Optional[int] = None):
"""Show information about the event with id `event_id`.
If no value for `event_id` is supplied, this shows all upcoming events instead.
"""
async with ctx.typing():
async with self.client.postgres_session as session:
if event_id is None:
upcoming = await events.get_events(session, now=tz_aware_now())
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
if not upcoming:
embed.colour = discord.Colour.red()
embed.description = "There are currently no upcoming events scheduled."
return await ctx.reply(embed=embed, mention_author=False)
upcoming.sort(key=lambda e: e.timestamp.timestamp())
description_items = []
for event in upcoming:
description_items.append(
f"`{event.event_id}`: {event.name} ({discord.utils.format_dt(event.timestamp, style='R')})"
)
embed.description = "\n".join(description_items)
return await ctx.reply(embed=embed, mention_author=False)
else:
result_event = await events.get_event_by_id(session, event_id)
if result_event is None:
return await ctx.reply(f"Found no event with id `{event_id}`.", mention_author=False)
embed = discord.Embed(title="Upcoming Events", colour=discord.Colour.blue())
embed.add_field(name="Name", value=result_event.name, inline=True)
embed.add_field(name="Id", value=result_event.event_id, inline=True)
embed.add_field(
name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True
)
embed.add_field(
name="Channel",
value=self.client.get_channel(result_event.notification_channel).mention,
inline=False,
)
embed.description = result_event.description
return await ctx.reply(embed=embed, mention_author=False)
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None):
"""Show a user's GitHub links. """Show a user's GitHub links.

View File

@ -13,6 +13,7 @@ from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags
from didier.views.modals import ( from didier.views.modals import (
AddDadJoke, AddDadJoke,
AddDeadline, AddDeadline,
AddEvent,
AddLink, AddLink,
CreateCustomCommand, CreateCustomCommand,
EditCustomCommand, EditCustomCommand,
@ -173,6 +174,15 @@ class Owner(commands.Cog):
"""Autocompletion for the 'course'-parameter""" """Autocompletion for the 'course'-parameter"""
return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
@add_slash.command(name="event", description="Add a new event")
async def add_event_slash(self, interaction: discord.Interaction):
"""Slash command to add new events"""
if not await self.client.is_owner(interaction.user):
return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True)
modal = AddEvent(self.client)
await interaction.response.send_modal(modal)
@add_slash.command(name="link", description="Add a new link") @add_slash.command(name="link", description="Add a new link")
async def add_link_slash(self, interaction: discord.Interaction): async def add_link_slash(self, interaction: discord.Interaction):
"""Slash command to add new links""" """Slash command to add new links"""

View File

@ -0,0 +1,64 @@
import asyncio
from datetime import datetime
from typing import Optional
import discord.utils
from database.crud.events import get_next_event
from database.schemas import Event
from didier import Didier
from didier.utils.types.datetime import tz_aware_now
__all__ = ["Timer"]
class Timer:
"""Class for scheduled timers"""
client: Didier
upcoming_timer: Optional[datetime]
upcoming_event_id: Optional[int]
_task: Optional[asyncio.Task]
def __init__(self, client: Didier):
self.client = client
self.upcoming_timer = None
self.upcoming_event_id = None
self._task = None
self.client.loop.create_task(self.update())
async def update(self):
"""Get & schedule the closest reminder"""
async with self.client.postgres_session as session:
event = await get_next_event(session, now=tz_aware_now())
# No upcoming events
if event is None:
return
self.maybe_replace_task(event)
def maybe_replace_task(self, event: Event):
"""Replace the current task if necessary"""
# If there is a current (pending) task, and the new timer is sooner than the
# pending one, cancel it
if self._task is not None and not self._task.done():
# The upcoming timer will never be None at this point, but Mypy is mad
if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp:
self._task.cancel()
else:
# The new task happens after the existing task, it has to wait for its turn
return
self.upcoming_timer = event.timestamp
self.upcoming_event_id = event.event_id
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id))
async def end_timer(self, *, endtime: datetime, event_id: int):
"""Wait until a timer runs out, and then trigger an event to send the message"""
await discord.utils.sleep_until(endtime)
self.upcoming_timer = None
self.upcoming_event_id = None
self.client.dispatch("timer_end", event_id)

View File

@ -2,6 +2,7 @@ from .bookmarks import CreateBookmark
from .custom_commands import CreateCustomCommand, EditCustomCommand from .custom_commands import CreateCustomCommand, EditCustomCommand
from .dad_jokes import AddDadJoke from .dad_jokes import AddDadJoke
from .deadlines import AddDeadline from .deadlines import AddDeadline
from .events import AddEvent
from .links import AddLink from .links import AddLink
from .memes import GenerateMeme from .memes import GenerateMeme
@ -9,6 +10,7 @@ __all__ = [
"CreateBookmark", "CreateBookmark",
"AddDadJoke", "AddDadJoke",
"AddDeadline", "AddDeadline",
"AddEvent",
"CreateCustomCommand", "CreateCustomCommand",
"EditCustomCommand", "EditCustomCommand",
"AddLink", "AddLink",

View File

@ -0,0 +1,61 @@
import traceback
from zoneinfo import ZoneInfo
import discord
from dateutil.parser import ParserError, parse
from overrides import overrides
from database.crud.events import add_event
from didier import Didier
__all__ = ["AddEvent"]
class AddEvent(discord.ui.Modal, title="Add Event"):
"""Modal to add a new event"""
name: discord.ui.TextInput = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, required=True)
description: discord.ui.TextInput = discord.ui.TextInput(
label="Description", style=discord.TextStyle.paragraph, required=False, default=None
)
channel: discord.ui.TextInput = discord.ui.TextInput(
label="Channel id", style=discord.TextStyle.short, required=True, placeholder="676713433567199232"
)
timestamp: discord.ui.TextInput = discord.ui.TextInput(
label="Date", style=discord.TextStyle.short, required=True, placeholder="21/02/2020 21:21:00"
)
client: Didier
def __init__(self, client: Didier, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = client
@overrides
async def on_submit(self, interaction: discord.Interaction) -> None:
try:
parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels"))
except ParserError:
return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True)
if self.client.get_channel(int(self.channel.value)) is None:
return await interaction.response.send_message(
f"Unable to find channel `{self.channel.value}`", ephemeral=True
)
async with self.client.postgres_session as session:
event = await add_event(
session,
name=self.name.value,
description=self.description.value,
date_str=self.timestamp.value,
channel_id=int(self.channel.value),
)
await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
self.client.dispatch("event_create", event)
@overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__)