Merge pull request #134 from stijndcl/reminders

Reminders
pull/122/head
Stijn De Clercq 2022-09-23 15:11:53 +02:00 committed by GitHub
commit 8922489a41
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 270 additions and 31 deletions

View File

@ -0,0 +1,38 @@
"""Add reminders
Revision ID: a64876b41af2
Revises: c1f9ee875616
Create Date: 2022-09-23 13:37:10.331840
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a64876b41af2"
down_revision = "c1f9ee875616"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"reminders",
sa.Column("reminder_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("category", sa.Enum("LES", name="remindercategory"), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("reminder_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("reminders")
# ### end Alembic commands ###

View File

@ -0,0 +1,42 @@
from typing import Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.enums import ReminderCategory
from database.schemas import Reminder
__all__ = ["get_all_reminders_for_category", "toggle_reminder"]
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
"""Get a list of all Reminders for a given category"""
statement = select(Reminder).where(Reminder.category == category)
return (await session.execute(statement)).scalars().all()
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:
"""Switch a category on/off
Returns the new value for the category
"""
await get_or_add_user(session, user_id)
select_statement = select(Reminder).where(Reminder.user_id == user_id).where(Reminder.category == category)
reminder: Optional[Reminder] = (await session.execute(select_statement)).scalar_one_or_none()
# No reminder set yet
if reminder is None:
reminder = Reminder(user_id=user_id, category=category)
session.add(reminder)
await session.commit()
return True
# Reminder found -> delete it
delete_statement = delete(Reminder).where(Reminder.reminder_id == reminder.reminder_id)
await session.execute(delete_statement)
await session.commit()
return False

View File

@ -1,11 +1,14 @@
import enum
__all__ = ["TaskType"]
__all__ = ["ReminderCategory", "TaskType"]
class ReminderCategory(enum.IntEnum):
"""Enum for reminder categories"""
LES = enum.auto()
# There is a bug in typeshed that causes an incorrect PyCharm warning
# https://github.com/python/typeshed/issues/8286
# noinspection PyArgumentList
class TaskType(enum.IntEnum):
"""Enum for the different types of tasks"""

View File

@ -37,6 +37,7 @@ __all__ = [
"Link",
"MemeTemplate",
"NightlyData",
"Reminder",
"Task",
"UforaAnnouncement",
"UforaCourse",
@ -219,6 +220,18 @@ class NightlyData(Base):
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class Reminder(Base):
"""Something that a user should be reminded of"""
__tablename__ = "reminders"
reminder_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False)
user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin")
class Task(Base):
"""A Didier task"""
@ -303,6 +316,9 @@ class User(Base):
nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
reminders: list[Reminder] = relationship(
"Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_guesses: list[WordleGuess] = relationship(
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)

View File

@ -83,7 +83,7 @@ class Fun(commands.Cog):
@memes_slash.command(name="generate")
async def memegen_slash(self, interaction: discord.Interaction, template: str):
"""Generate a meme with template `template`."""
"""Generate a meme."""
async with self.client.postgres_session as session:
result = expect(await get_meme_by_name(session, template), entity_type="meme", argument=template)

View File

@ -4,6 +4,8 @@ from typing import Optional
from discord.ext import commands
from database.crud.reminders import toggle_reminder
from database.enums import ReminderCategory
from didier import Didier
@ -20,6 +22,31 @@ class Meta(commands.Cog):
"""Get Didier's latency."""
return await ctx.reply(f"Polo! {round(self.client.latency * 1000)}ms", mention_author=False)
@commands.command(name="remind", aliases=["remindme"])
async def remind(self, ctx: commands.Context, category: str):
"""Make Didier send you reminders every day."""
category = category.lower()
available_categories = [
(
"les",
ReminderCategory.LES,
)
]
for name, category_mapping in available_categories:
if name == category:
async with self.client.postgres_session as session:
new_value = await toggle_reminder(session, ctx.author.id, category_mapping)
toggle = "on" if new_value else "off"
return await ctx.reply(
f"Reminders for category `{name}` have been toggled {toggle}.", mention_author=False
)
# No match found
return await ctx.reply(f"`{category}` is not a supported category.", mention_author=False)
@commands.command(name="source", aliases=["src"])
async def source(self, ctx: commands.Context, *, command_name: Optional[str] = None):
"""Get a link to the source code of Didier.

View File

@ -11,7 +11,7 @@ from didier import Didier
from didier.data.apis.hydra import fetch_menu
from didier.data.embeds.deadlines import Deadlines
from didier.data.embeds.hydra import no_menu_found
from didier.data.embeds.schedules import Schedule, get_schedule_for_user
from didier.data.embeds.schedules import Schedule, get_schedule_for_day
from didier.exceptions import HTTPException, NotInMainGuildException
from didier.utils.discord.converters.time import DateTransformer
from didier.utils.discord.flags.school import StudyGuideFlags
@ -55,10 +55,11 @@ class School(commands.Cog):
try:
member_instance = to_main_guild_member(self.client, ctx.author)
roles = {role.id for role in member_instance.roles}
# Always make sure there is at least one schedule in case it returns None
# this allows proper error messages
schedule = get_schedule_for_user(self.client, member_instance, day_dt) or Schedule()
schedule = (get_schedule_for_day(self.client, day_dt) or Schedule()).personalize(roles)
return await ctx.reply(embed=schedule.to_embed(day=day_dt), mention_author=False)

View File

@ -1,6 +1,5 @@
import datetime
import random
import traceback
import discord
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
@ -9,10 +8,16 @@ from overrides import overrides
import settings
from database import enums
from database.crud.birthdays import get_birthdays_on_day
from database.crud.reminders import get_all_reminders_for_category
from database.crud.ufora_announcements import remove_old_announcements
from database.crud.wordle import set_daily_word
from database.schemas import Reminder
from didier import Didier
from didier.data.embeds.schedules import Schedule, parse_schedule_from_content
from didier.data.embeds.schedules import (
Schedule,
get_schedule_for_day,
parse_schedule_from_content,
)
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
from didier.decorators.tasks import timed_task
from didier.utils.discord.checks import is_owner
@ -44,6 +49,7 @@ class Tasks(commands.Cog):
self._tasks = {
"birthdays": self.check_birthdays,
"schedules": self.pull_schedules,
"reminders": self.reminders,
"ufora": self.pull_ufora_announcements,
"remove_ufora": self.remove_old_ufora_announcements,
"wordle": self.reset_wordle_word,
@ -61,6 +67,7 @@ class Tasks(commands.Cog):
self.remove_old_ufora_announcements.start()
# Start other tasks
self.reminders.start()
self.reset_wordle_word.start()
self.pull_schedules.start()
@ -135,7 +142,7 @@ class Tasks(commands.Cog):
async with self.client.postgres_session as session:
for data in settings.SCHEDULE_DATA:
if data.schedule_url is None:
return
continue
async with self.client.http_session.get(data.schedule_url) as response:
# If a schedule couldn't be fetched, log it and move on
@ -180,6 +187,50 @@ class Tasks(commands.Cog):
async def _before_ufora_announcements(self):
await self.client.wait_until_ready()
async def _send_les_reminders(self, entries: list[Reminder]):
today = tz_aware_now().date()
# Create the main schedule for the day once here, to avoid doing it repeatedly
daily_schedule = get_schedule_for_day(self.client, today)
# No class today
if not daily_schedule:
return
for entry in entries:
member = self.client.main_guild.get_member(entry.user_id)
if not member:
continue
roles = {role.id for role in member.roles}
personal_schedule = daily_schedule.personalize(roles)
# No class today
if not personal_schedule:
continue
await member.send(embed=personal_schedule.to_embed(day=today))
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
async def reminders(self, **kwargs):
"""Send daily reminders to people"""
_ = kwargs
async with self.client.postgres_session as session:
for category in enums.ReminderCategory:
entries = await get_all_reminders_for_category(session, category)
if not entries:
continue
# This is slightly ugly, but it's the best way to go about it
# There won't be a lot of categories anyway
if category == enums.ReminderCategory.LES:
await self._send_les_reminders(entries)
@reminders.before_loop
async def _before_reminders(self):
await self.client.wait_until_ready()
@tasks.loop(hours=24)
async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day"""
@ -200,12 +251,12 @@ class Tasks(commands.Cog):
@check_birthdays.error
@pull_schedules.error
@pull_ufora_announcements.error
@reminders.error
@remove_old_ufora_announcements.error
@reset_wordle_word.error
async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
self.client.dispatch("task_error")
self.client.dispatch("task_error", error)
async def setup(client: Didier):

View File

@ -1,4 +1,5 @@
import traceback
from typing import Optional
import discord
from discord.ext import commands
@ -18,7 +19,7 @@ def _get_traceback(exception: Exception) -> str:
if line.strip().startswith("The above exception was the direct cause of"):
break
# Escape Discord markdown formatting
# Escape Discord Markdown formatting
error_string += line.replace(r"*", r"\*").replace(r"_", r"\_")
if line.strip():
error_string += "\n"
@ -26,12 +27,17 @@ def _get_traceback(exception: Exception) -> str:
return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8)
def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed:
def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> discord.Embed:
"""Create an embed for the traceback of an exception"""
message = str(exception)
# Wrap the traceback in a codeblock for readability
description = _get_traceback(exception).strip()
description = f"```\n{description}\n```"
embed = discord.Embed(title="Error", colour=discord.Colour.red())
if ctx is not None:
if ctx.guild is None:
origin = "DM"
else:
@ -39,10 +45,12 @@ def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.E
invocation = f"{ctx.author.display_name} in {origin}"
embed = discord.Embed(title="Error", colour=discord.Colour.red())
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
embed.add_field(name="Context", value=invocation, inline=True)
embed.add_field(name="Exception", value=str(exception), inline=False)
if message:
embed.add_field(name="Exception", value=message, inline=False)
embed.add_field(name="Traceback", value=description, inline=False)
return embed

View File

@ -2,6 +2,9 @@ import logging
import discord
from didier.utils.discord.constants import Limits
from didier.utils.types.string import abbreviate
__all__ = ["create_logging_embed"]
@ -16,6 +19,6 @@ def create_logging_embed(level: int, message: str) -> discord.Embed:
colour = colours.get(level, discord.Colour.red())
embed = discord.Embed(colour=colour, title="Logging")
embed.description = message
embed.description = abbreviate(message, Limits.EMBED_DESCRIPTION_LENGTH)
return embed

View File

@ -22,7 +22,7 @@ from didier.utils.types.datetime import LOCAL_TIMEZONE, int_to_weekday, time_str
from didier.utils.types.string import leading
from settings import ScheduleType
__all__ = ["Schedule", "get_schedule_for_user", "parse_schedule_from_content", "parse_schedule"]
__all__ = ["Schedule", "get_schedule_for_day", "parse_schedule_from_content", "parse_schedule"]
@dataclass
@ -48,6 +48,10 @@ class Schedule(EmbedBaseModel):
def personalize(self, roles: set[int]) -> Schedule:
"""Personalize a schedule for a user, only adding courses they follow"""
# If the schedule is already empty, just return instantly
if not self.slots:
return Schedule()
personal_slots = set()
for slot in self.slots:
role_found = slot.role_id is not None and slot.role_id in roles
@ -57,6 +61,21 @@ class Schedule(EmbedBaseModel):
return Schedule(personal_slots)
def simplify(self):
"""Merge sequential slots in the same location into one
Note: this is done in-place instead of returning a new schedule!
(The operation is O(n^2))
Example:
13:00 - 14:30: AD3 in S9
14:30 - 1600: AD3 in S9
"""
for first in self.slots:
for second in self.slots:
if first == second:
continue
@overrides
def to_embed(self, **kwargs) -> discord.Embed:
day: date = kwargs.get("day", date.today())
@ -104,10 +123,12 @@ class ScheduleSlot:
def __post_init__(self):
"""Fix some properties to display more nicely"""
# Re-format the location data
room, building, campus = re.search(r"(.*)\. Gebouw (.*)\. Campus (.*)\. ", self.location).groups()
room, building, campus = re.search(r"(.*)\. (?:Gebouw )?(.*)\. (?:Campus )?(.*)\. ", self.location).groups()
room = room.replace("PC / laptoplokaal ", "PC-lokaal")
self.location = f"{campus} {building} {room}"
# The same course can only start once at the same moment,
# so this is guaranteed to be unique
self._hash = hash(f"{self.course.course_id} {str(self.start_time)}")
@property
@ -131,15 +152,33 @@ class ScheduleSlot:
return self._hash == other._hash
def could_merge_with(self, other: ScheduleSlot) -> bool:
"""Check if two slots are actually one with a 15-min break in-between
def get_schedule_for_user(client: Didier, member: discord.Member, day_dt: date) -> Optional[Schedule]:
"""Get a user's schedule"""
roles: set[int] = {role.id for role in member.roles}
If they are, merge the two into one (this edits the first slot in-place!)
"""
if self.course.course_id != other.course.course_id:
return False
if self.location != other.location:
return False
if self.start_time == other.end_time:
other.end_time = self.end_time
return True
elif self.end_time == other.start_time:
self.end_time = other.end_time
return True
return False
def get_schedule_for_day(client: Didier, day_dt: date) -> Optional[Schedule]:
"""Get a schedule for an entire day"""
main_schedule: Optional[Schedule] = None
for schedule in client.schedules.values():
personalized_schedule = schedule.on_day(day_dt).personalize(roles)
personalized_schedule = schedule.on_day(day_dt)
if not personalized_schedule:
continue
@ -204,6 +243,10 @@ async def parse_schedule_from_content(content: str, *, database_session: AsyncSe
location=event.location,
)
# Slot extends another one, don't add it
if any(s.could_merge_with(slot) for s in slots):
continue
slots.add(slot)
return Schedule(slots=slots)

View File

@ -363,6 +363,13 @@ class Didier(commands.Bot):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def on_task_error(self, exception: Exception):
"""Event triggered when a task raises an exception"""
if settings.ERRORS_CHANNEL is not None:
embed = create_error_embed(None, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed)
async def on_thread_create(self, thread: discord.Thread):
"""Event triggered when a new thread is created"""
# Join threads automatically