Compare commits

...

8 Commits

Author SHA1 Message Date
stijndcl 773491e2ff Clap 2022-09-23 20:30:00 +02:00
stijndcl bf32a5ef47 Create command to list custom commands, add shortcuts to memegen commands 2022-09-23 18:06:33 +02:00
Stijn De Clercq 8922489a41
Merge pull request #134 from stijndcl/reminders
Reminders
2022-09-23 15:11:53 +02:00
stijndcl 3e495d8291 Remove debug date 2022-09-23 15:07:13 +02:00
stijndcl 0a9f73af8c Make tasks log exceptions 2022-09-23 14:59:47 +02:00
stijndcl 185aaadce1 Merge sequential slots into one 2022-09-23 14:47:42 +02:00
stijndcl ddd632ffd5 Fix typo 2022-09-23 14:27:56 +02:00
stijndcl d03ece6f58 Create reminders, fix bugs in schedule parsing 2022-09-23 14:25:13 +02:00
19 changed files with 448 additions and 94 deletions

View File

@ -0,0 +1,38 @@
"""Add reminders
Revision ID: a64876b41af2
Revises: c1f9ee875616
Create Date: 2022-09-23 13:37:10.331840
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "a64876b41af2"
down_revision = "c1f9ee875616"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"reminders",
sa.Column("reminder_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("category", sa.Enum("LES", name="remindercategory"), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("reminder_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("reminders")
# ### end Alembic commands ###

View File

@ -12,6 +12,7 @@ __all__ = [
"create_alias", "create_alias",
"create_command", "create_command",
"edit_command", "edit_command",
"get_all_commands",
"get_command", "get_command",
"get_command_by_alias", "get_command_by_alias",
"get_command_by_name", "get_command_by_name",
@ -55,6 +56,12 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
return alias_instance return alias_instance
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
"""Get a list of all commands"""
statement = select(CustomCommand)
return (await session.execute(statement)).scalars().all()
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message""" """Try to get a command out of a message"""
# Search lowercase & without spaces # Search lowercase & without spaces

View File

@ -0,0 +1,42 @@
from typing import Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.enums import ReminderCategory
from database.schemas import Reminder
__all__ = ["get_all_reminders_for_category", "toggle_reminder"]
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
"""Get a list of all Reminders for a given category"""
statement = select(Reminder).where(Reminder.category == category)
return (await session.execute(statement)).scalars().all()
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:
"""Switch a category on/off
Returns the new value for the category
"""
await get_or_add_user(session, user_id)
select_statement = select(Reminder).where(Reminder.user_id == user_id).where(Reminder.category == category)
reminder: Optional[Reminder] = (await session.execute(select_statement)).scalar_one_or_none()
# No reminder set yet
if reminder is None:
reminder = Reminder(user_id=user_id, category=category)
session.add(reminder)
await session.commit()
return True
# Reminder found -> delete it
delete_statement = delete(Reminder).where(Reminder.reminder_id == reminder.reminder_id)
await session.execute(delete_statement)
await session.commit()
return False

View File

@ -1,11 +1,14 @@
import enum import enum
__all__ = ["TaskType"] __all__ = ["ReminderCategory", "TaskType"]
class ReminderCategory(enum.IntEnum):
"""Enum for reminder categories"""
LES = enum.auto()
# There is a bug in typeshed that causes an incorrect PyCharm warning
# https://github.com/python/typeshed/issues/8286
# noinspection PyArgumentList
class TaskType(enum.IntEnum): class TaskType(enum.IntEnum):
"""Enum for the different types of tasks""" """Enum for the different types of tasks"""

View File

@ -37,6 +37,7 @@ __all__ = [
"Link", "Link",
"MemeTemplate", "MemeTemplate",
"NightlyData", "NightlyData",
"Reminder",
"Task", "Task",
"UforaAnnouncement", "UforaAnnouncement",
"UforaCourse", "UforaCourse",
@ -219,6 +220,18 @@ class NightlyData(Base):
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class Reminder(Base):
"""Something that a user should be reminded of"""
__tablename__ = "reminders"
reminder_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False)
user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin")
class Task(Base): class Task(Base):
"""A Didier task""" """A Didier task"""
@ -303,6 +316,9 @@ class User(Base):
nightly_data: NightlyData = relationship( nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
reminders: list[Reminder] = relationship(
"Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_guesses: list[WordleGuess] = relationship( wordle_guesses: list[WordleGuess] = relationship(
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" "WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )

View File

@ -14,7 +14,6 @@ from database.exceptions import (
from didier import Didier from didier import Didier
from didier.exceptions import expect from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource from didier.menus.bookmarks import BookmarkSource
from didier.menus.common import Menu
from didier.utils.discord import colours from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.constants import Limits from didier.utils.discord.constants import Limits
@ -186,9 +185,7 @@ class Discord(commands.Cog):
embed.description = "You haven't created any bookmarks yet." embed.description = "You haven't created any bookmarks yet."
return await ctx.reply(embed=embed, mention_author=False) return await ctx.reply(embed=embed, mention_author=False)
source = BookmarkSource(ctx, results) await BookmarkSource(ctx, results).start()
menu = Menu(source)
await menu.start(ctx)
async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message): async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message):
"""Create a bookmark out of this message""" """Create a bookmark out of this message"""

View File

@ -1,4 +1,5 @@
import shlex import shlex
from typing import Optional
import discord import discord
from discord import app_commands from discord import app_commands
@ -9,8 +10,8 @@ from database.crud.memes import get_all_memes, get_meme_by_name
from didier import Didier from didier import Didier
from didier.data.apis.imgflip import generate_meme from didier.data.apis.imgflip import generate_meme
from didier.exceptions.no_match import expect from didier.exceptions.no_match import expect
from didier.menus.common import Menu
from didier.menus.memes import MemeSource from didier.menus.memes import MemeSource
from didier.utils.discord import constants
from didier.views.modals import GenerateMeme from didier.views.modals import GenerateMeme
@ -25,6 +26,22 @@ class Fun(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.hybrid_command(name="clap")
async def clap(self, ctx: commands.Context, *, text: str):
"""Clap a message with emojis for extra dramatic effect"""
chars = list(filter(lambda c: c.isalnum(), text))
if not chars:
return await ctx.reply("👏", mention_author=False)
text = "👏".join(list(map(lambda c: constants.EMOJI_MAP.get(c), chars)))
text = f"👏{text}👏"
if len(text) > constants.Limits.MESSAGE_LENGTH:
return await ctx.reply("Message is too long.", mention_author=False)
return await ctx.reply(text, mention_author=False)
async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str: async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str:
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name) result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name)
@ -42,7 +59,7 @@ class Fun(commands.Cog):
return await ctx.reply(joke.joke, mention_author=False) return await ctx.reply(joke.joke, mention_author=False)
@commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True) @commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True)
async def memegen_msg(self, ctx: commands.Context, template: str, *, fields: str): async def memegen_msg(self, ctx: commands.Context, template: Optional[str] = None, *, fields: Optional[str] = None):
"""Generate a meme with template `template` and fields `fields`. """Generate a meme with template `template` and fields `fields`.
The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes". The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes".
@ -55,7 +72,17 @@ class Fun(commands.Cog):
Example: if template `a` only has 1 field, Example: if template `a` only has 1 field,
`memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]` `memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]`
When no arguments are provided, this is a shortcut to `memegen list`.
When only a template is provided, this is a shortcut to `memegen preview`.
""" """
if template is None:
return await self.memegen_ls_msg(ctx)
if fields is None:
return await self.memegen_preview_msg(ctx, template)
async with ctx.typing(): async with ctx.typing():
meme = await self._do_generate_meme(template, shlex.split(fields)) meme = await self._do_generate_meme(template, shlex.split(fields))
return await ctx.reply(meme, mention_author=False) return await ctx.reply(meme, mention_author=False)
@ -69,9 +96,7 @@ class Fun(commands.Cog):
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
results = await get_all_memes(session) results = await get_all_memes(session)
source = MemeSource(ctx, results) await MemeSource(ctx, results).start()
menu = Menu(source)
await menu.start(ctx)
@memegen_msg.command(name="preview", aliases=["p"]) @memegen_msg.command(name="preview", aliases=["p"])
async def memegen_preview_msg(self, ctx: commands.Context, template: str): async def memegen_preview_msg(self, ctx: commands.Context, template: str):
@ -83,7 +108,7 @@ class Fun(commands.Cog):
@memes_slash.command(name="generate") @memes_slash.command(name="generate")
async def memegen_slash(self, interaction: discord.Interaction, template: str): async def memegen_slash(self, interaction: discord.Interaction, template: str):
"""Generate a meme with template `template`.""" """Generate a meme."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
result = expect(await get_meme_by_name(session, template), entity_type="meme", argument=template) result = expect(await get_meme_by_name(session, template), entity_type="meme", argument=template)

View File

@ -4,7 +4,11 @@ from typing import Optional
from discord.ext import commands from discord.ext import commands
from database.crud.custom_commands import get_all_commands
from database.crud.reminders import toggle_reminder
from database.enums import ReminderCategory
from didier import Didier from didier import Didier
from didier.menus.custom_commands import CustomCommandSource
class Meta(commands.Cog): class Meta(commands.Cog):
@ -15,11 +19,45 @@ class Meta(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.command(name="custom")
async def custom(self, ctx: commands.Context):
"""Get a list of all custom commands that are registered."""
async with self.client.postgres_session as session:
custom_commands = await get_all_commands(session)
custom_commands.sort(key=lambda c: c.name.lower())
await CustomCommandSource(ctx, custom_commands).start()
@commands.command(name="marco") @commands.command(name="marco")
async def marco(self, ctx: commands.Context): async def marco(self, ctx: commands.Context):
"""Get Didier's latency.""" """Get Didier's latency."""
return await ctx.reply(f"Polo! {round(self.client.latency * 1000)}ms", mention_author=False) return await ctx.reply(f"Polo! {round(self.client.latency * 1000)}ms", mention_author=False)
@commands.command(name="remind", aliases=["remindme"])
async def remind(self, ctx: commands.Context, category: str):
"""Make Didier send you reminders every day."""
category = category.lower()
available_categories = [
(
"les",
ReminderCategory.LES,
)
]
for name, category_mapping in available_categories:
if name == category:
async with self.client.postgres_session as session:
new_value = await toggle_reminder(session, ctx.author.id, category_mapping)
toggle = "on" if new_value else "off"
return await ctx.reply(
f"Reminders for category `{name}` have been toggled {toggle}.", mention_author=False
)
# No match found
return await ctx.reply(f"`{category}` is not a supported category.", mention_author=False)
@commands.command(name="source", aliases=["src"]) @commands.command(name="source", aliases=["src"])
async def source(self, ctx: commands.Context, *, command_name: Optional[str] = None): async def source(self, ctx: commands.Context, *, command_name: Optional[str] = None):
"""Get a link to the source code of Didier. """Get a link to the source code of Didier.

View File

@ -11,7 +11,7 @@ from didier import Didier
from didier.data.apis.hydra import fetch_menu from didier.data.apis.hydra import fetch_menu
from didier.data.embeds.deadlines import Deadlines from didier.data.embeds.deadlines import Deadlines
from didier.data.embeds.hydra import no_menu_found from didier.data.embeds.hydra import no_menu_found
from didier.data.embeds.schedules import Schedule, get_schedule_for_user from didier.data.embeds.schedules import Schedule, get_schedule_for_day
from didier.exceptions import HTTPException, NotInMainGuildException from didier.exceptions import HTTPException, NotInMainGuildException
from didier.utils.discord.converters.time import DateTransformer from didier.utils.discord.converters.time import DateTransformer
from didier.utils.discord.flags.school import StudyGuideFlags from didier.utils.discord.flags.school import StudyGuideFlags
@ -55,10 +55,11 @@ class School(commands.Cog):
try: try:
member_instance = to_main_guild_member(self.client, ctx.author) member_instance = to_main_guild_member(self.client, ctx.author)
roles = {role.id for role in member_instance.roles}
# Always make sure there is at least one schedule in case it returns None # Always make sure there is at least one schedule in case it returns None
# this allows proper error messages # this allows proper error messages
schedule = get_schedule_for_user(self.client, member_instance, day_dt) or Schedule() schedule = (get_schedule_for_day(self.client, day_dt) or Schedule()).personalize(roles)
return await ctx.reply(embed=schedule.to_embed(day=day_dt), mention_author=False) return await ctx.reply(embed=schedule.to_embed(day=day_dt), mention_author=False)

View File

@ -1,6 +1,5 @@
import datetime import datetime
import random import random
import traceback
import discord import discord
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
@ -9,10 +8,16 @@ from overrides import overrides
import settings import settings
from database import enums from database import enums
from database.crud.birthdays import get_birthdays_on_day from database.crud.birthdays import get_birthdays_on_day
from database.crud.reminders import get_all_reminders_for_category
from database.crud.ufora_announcements import remove_old_announcements from database.crud.ufora_announcements import remove_old_announcements
from database.crud.wordle import set_daily_word from database.crud.wordle import set_daily_word
from database.schemas import Reminder
from didier import Didier from didier import Didier
from didier.data.embeds.schedules import Schedule, parse_schedule_from_content from didier.data.embeds.schedules import (
Schedule,
get_schedule_for_day,
parse_schedule_from_content,
)
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
from didier.decorators.tasks import timed_task from didier.decorators.tasks import timed_task
from didier.utils.discord.checks import is_owner from didier.utils.discord.checks import is_owner
@ -44,6 +49,7 @@ class Tasks(commands.Cog):
self._tasks = { self._tasks = {
"birthdays": self.check_birthdays, "birthdays": self.check_birthdays,
"schedules": self.pull_schedules, "schedules": self.pull_schedules,
"reminders": self.reminders,
"ufora": self.pull_ufora_announcements, "ufora": self.pull_ufora_announcements,
"remove_ufora": self.remove_old_ufora_announcements, "remove_ufora": self.remove_old_ufora_announcements,
"wordle": self.reset_wordle_word, "wordle": self.reset_wordle_word,
@ -61,6 +67,7 @@ class Tasks(commands.Cog):
self.remove_old_ufora_announcements.start() self.remove_old_ufora_announcements.start()
# Start other tasks # Start other tasks
self.reminders.start()
self.reset_wordle_word.start() self.reset_wordle_word.start()
self.pull_schedules.start() self.pull_schedules.start()
@ -135,7 +142,7 @@ class Tasks(commands.Cog):
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
for data in settings.SCHEDULE_DATA: for data in settings.SCHEDULE_DATA:
if data.schedule_url is None: if data.schedule_url is None:
return continue
async with self.client.http_session.get(data.schedule_url) as response: async with self.client.http_session.get(data.schedule_url) as response:
# If a schedule couldn't be fetched, log it and move on # If a schedule couldn't be fetched, log it and move on
@ -180,6 +187,50 @@ class Tasks(commands.Cog):
async def _before_ufora_announcements(self): async def _before_ufora_announcements(self):
await self.client.wait_until_ready() await self.client.wait_until_ready()
async def _send_les_reminders(self, entries: list[Reminder]):
today = tz_aware_now().date()
# Create the main schedule for the day once here, to avoid doing it repeatedly
daily_schedule = get_schedule_for_day(self.client, today)
# No class today
if not daily_schedule:
return
for entry in entries:
member = self.client.main_guild.get_member(entry.user_id)
if not member:
continue
roles = {role.id for role in member.roles}
personal_schedule = daily_schedule.personalize(roles)
# No class today
if not personal_schedule:
continue
await member.send(embed=personal_schedule.to_embed(day=today))
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
async def reminders(self, **kwargs):
"""Send daily reminders to people"""
_ = kwargs
async with self.client.postgres_session as session:
for category in enums.ReminderCategory:
entries = await get_all_reminders_for_category(session, category)
if not entries:
continue
# This is slightly ugly, but it's the best way to go about it
# There won't be a lot of categories anyway
if category == enums.ReminderCategory.LES:
await self._send_les_reminders(entries)
@reminders.before_loop
async def _before_reminders(self):
await self.client.wait_until_ready()
@tasks.loop(hours=24) @tasks.loop(hours=24)
async def remove_old_ufora_announcements(self): async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day""" """Remove all announcements that are over 1 week old, once per day"""
@ -200,12 +251,12 @@ class Tasks(commands.Cog):
@check_birthdays.error @check_birthdays.error
@pull_schedules.error @pull_schedules.error
@pull_ufora_announcements.error @pull_ufora_announcements.error
@reminders.error
@remove_old_ufora_announcements.error @remove_old_ufora_announcements.error
@reset_wordle_word.error @reset_wordle_word.error
async def _on_tasks_error(self, error: BaseException): async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks""" """Error handler for all tasks"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__))) self.client.dispatch("task_error", error)
self.client.dispatch("task_error")
async def setup(client: Didier): async def setup(client: Didier):

View File

@ -1,4 +1,5 @@
import traceback import traceback
from typing import Optional
import discord import discord
from discord.ext import commands from discord.ext import commands
@ -18,7 +19,7 @@ def _get_traceback(exception: Exception) -> str:
if line.strip().startswith("The above exception was the direct cause of"): if line.strip().startswith("The above exception was the direct cause of"):
break break
# Escape Discord markdown formatting # Escape Discord Markdown formatting
error_string += line.replace(r"*", r"\*").replace(r"_", r"\_") error_string += line.replace(r"*", r"\*").replace(r"_", r"\_")
if line.strip(): if line.strip():
error_string += "\n" error_string += "\n"
@ -26,23 +27,30 @@ def _get_traceback(exception: Exception) -> str:
return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8) return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8)
def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed: def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> discord.Embed:
"""Create an embed for the traceback of an exception""" """Create an embed for the traceback of an exception"""
message = str(exception)
# Wrap the traceback in a codeblock for readability # Wrap the traceback in a codeblock for readability
description = _get_traceback(exception).strip() description = _get_traceback(exception).strip()
description = f"```\n{description}\n```" description = f"```\n{description}\n```"
if ctx.guild is None:
origin = "DM"
else:
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
invocation = f"{ctx.author.display_name} in {origin}"
embed = discord.Embed(title="Error", colour=discord.Colour.red()) embed = discord.Embed(title="Error", colour=discord.Colour.red())
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
embed.add_field(name="Context", value=invocation, inline=True) if ctx is not None:
embed.add_field(name="Exception", value=str(exception), inline=False) if ctx.guild is None:
origin = "DM"
else:
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
invocation = f"{ctx.author.display_name} in {origin}"
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
embed.add_field(name="Context", value=invocation, inline=True)
if message:
embed.add_field(name="Exception", value=message, inline=False)
embed.add_field(name="Traceback", value=description, inline=False) embed.add_field(name="Traceback", value=description, inline=False)
return embed return embed

View File

@ -2,6 +2,9 @@ import logging
import discord import discord
from didier.utils.discord.constants import Limits
from didier.utils.types.string import abbreviate
__all__ = ["create_logging_embed"] __all__ = ["create_logging_embed"]
@ -16,6 +19,6 @@ def create_logging_embed(level: int, message: str) -> discord.Embed:
colour = colours.get(level, discord.Colour.red()) colour = colours.get(level, discord.Colour.red())
embed = discord.Embed(colour=colour, title="Logging") embed = discord.Embed(colour=colour, title="Logging")
embed.description = message embed.description = abbreviate(message, Limits.EMBED_DESCRIPTION_LENGTH)
return embed return embed

View File

@ -22,7 +22,7 @@ from didier.utils.types.datetime import LOCAL_TIMEZONE, int_to_weekday, time_str
from didier.utils.types.string import leading from didier.utils.types.string import leading
from settings import ScheduleType from settings import ScheduleType
__all__ = ["Schedule", "get_schedule_for_user", "parse_schedule_from_content", "parse_schedule"] __all__ = ["Schedule", "get_schedule_for_day", "parse_schedule_from_content", "parse_schedule"]
@dataclass @dataclass
@ -48,6 +48,10 @@ class Schedule(EmbedBaseModel):
def personalize(self, roles: set[int]) -> Schedule: def personalize(self, roles: set[int]) -> Schedule:
"""Personalize a schedule for a user, only adding courses they follow""" """Personalize a schedule for a user, only adding courses they follow"""
# If the schedule is already empty, just return instantly
if not self.slots:
return Schedule()
personal_slots = set() personal_slots = set()
for slot in self.slots: for slot in self.slots:
role_found = slot.role_id is not None and slot.role_id in roles role_found = slot.role_id is not None and slot.role_id in roles
@ -57,6 +61,21 @@ class Schedule(EmbedBaseModel):
return Schedule(personal_slots) return Schedule(personal_slots)
def simplify(self):
"""Merge sequential slots in the same location into one
Note: this is done in-place instead of returning a new schedule!
(The operation is O(n^2))
Example:
13:00 - 14:30: AD3 in S9
14:30 - 1600: AD3 in S9
"""
for first in self.slots:
for second in self.slots:
if first == second:
continue
@overrides @overrides
def to_embed(self, **kwargs) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
day: date = kwargs.get("day", date.today()) day: date = kwargs.get("day", date.today())
@ -104,10 +123,12 @@ class ScheduleSlot:
def __post_init__(self): def __post_init__(self):
"""Fix some properties to display more nicely""" """Fix some properties to display more nicely"""
# Re-format the location data # Re-format the location data
room, building, campus = re.search(r"(.*)\. Gebouw (.*)\. Campus (.*)\. ", self.location).groups() room, building, campus = re.search(r"(.*)\. (?:Gebouw )?(.*)\. (?:Campus )?(.*)\. ", self.location).groups()
room = room.replace("PC / laptoplokaal ", "PC-lokaal") room = room.replace("PC / laptoplokaal ", "PC-lokaal")
self.location = f"{campus} {building} {room}" self.location = f"{campus} {building} {room}"
# The same course can only start once at the same moment,
# so this is guaranteed to be unique
self._hash = hash(f"{self.course.course_id} {str(self.start_time)}") self._hash = hash(f"{self.course.course_id} {str(self.start_time)}")
@property @property
@ -131,15 +152,33 @@ class ScheduleSlot:
return self._hash == other._hash return self._hash == other._hash
def could_merge_with(self, other: ScheduleSlot) -> bool:
"""Check if two slots are actually one with a 15-min break in-between
def get_schedule_for_user(client: Didier, member: discord.Member, day_dt: date) -> Optional[Schedule]: If they are, merge the two into one (this edits the first slot in-place!)
"""Get a user's schedule""" """
roles: set[int] = {role.id for role in member.roles} if self.course.course_id != other.course.course_id:
return False
if self.location != other.location:
return False
if self.start_time == other.end_time:
other.end_time = self.end_time
return True
elif self.end_time == other.start_time:
self.end_time = other.end_time
return True
return False
def get_schedule_for_day(client: Didier, day_dt: date) -> Optional[Schedule]:
"""Get a schedule for an entire day"""
main_schedule: Optional[Schedule] = None main_schedule: Optional[Schedule] = None
for schedule in client.schedules.values(): for schedule in client.schedules.values():
personalized_schedule = schedule.on_day(day_dt).personalize(roles) personalized_schedule = schedule.on_day(day_dt)
if not personalized_schedule: if not personalized_schedule:
continue continue
@ -204,6 +243,10 @@ async def parse_schedule_from_content(content: str, *, database_session: AsyncSe
location=event.location, location=event.location,
) )
# Slot extends another one, don't add it
if any(s.could_merge_with(slot) for s in slots):
continue
slots.add(slot) slots.add(slot)
return Schedule(slots=slots) return Schedule(slots=slots)

View File

@ -363,6 +363,13 @@ class Didier(commands.Bot):
"""Event triggered when the bot is ready""" """Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE) print(settings.DISCORD_READY_MESSAGE)
async def on_task_error(self, exception: Exception):
"""Event triggered when a task raises an exception"""
if settings.ERRORS_CHANNEL is not None:
embed = create_error_embed(None, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed)
async def on_thread_create(self, thread: discord.Thread): async def on_thread_create(self, thread: discord.Thread):
"""Event triggered when a new thread is created""" """Event triggered when a new thread is created"""
# Join threads automatically # Join threads automatically

View File

@ -1,5 +1,4 @@
import discord import discord
from discord.ext import commands
from overrides import overrides from overrides import overrides
from database.schemas import Bookmark from database.schemas import Bookmark
@ -14,16 +13,16 @@ class BookmarkSource(PageSource[Bookmark]):
"""PageSource for the Bookmark commands""" """PageSource for the Bookmark commands"""
@overrides @overrides
def create_embeds(self, ctx: commands.Context): def create_embeds(self):
for page in range(self.page_count): for page in range(self.page_count):
embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue()) embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue())
avatar_url = get_author_avatar(ctx).url avatar_url = get_author_avatar(self.ctx).url
embed.set_author(name=ctx.author.display_name, icon_url=avatar_url) embed.set_author(name=self.ctx.author.display_name, icon_url=avatar_url)
description = "" description_data = []
for bookmark in self.get_page_data(page): for bookmark in self.get_page_data(page):
description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n" description_data.append(f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})")
embed.description = description.strip() embed.description = "\n".join(description_data)
self.embeds.append(embed) self.embeds.append(embed)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar, cast from typing import Generic, Optional, TypeVar, cast
@ -13,50 +15,6 @@ __all__ = ["Menu", "PageSource"]
T = TypeVar("T") T = TypeVar("T")
class PageSource(ABC, Generic[T]):
"""Base class that handles the embeds displayed in a menu"""
dataset: list[T]
embeds: list[discord.Embed]
page_count: int
per_page: int
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
self.embeds = []
self.dataset = dataset
self.per_page = per_page
self.page_count = self._get_page_count()
self.create_embeds(ctx)
self._add_embed_page_footers()
def _get_page_count(self) -> int:
"""Calculate the amount of pages required"""
if len(self.dataset) % self.per_page == 0:
return len(self.dataset) // self.per_page
return (len(self.dataset) // self.per_page) + 1
def __getitem__(self, index: int) -> discord.Embed:
return self.embeds[index]
def __len__(self):
return self.page_count
def _add_embed_page_footers(self):
"""Add the current page in the footer of every embed"""
for i, embed in enumerate(self.embeds):
embed.set_footer(text=f"{i + 1}/{self.page_count}")
@abstractmethod
def create_embeds(self, ctx: commands.Context):
"""Method that builds the list of embeds from the input data"""
raise NotImplementedError
def get_page_data(self, page: int) -> list[T]:
"""Get the chunk of the dataset for page [page]"""
return self.dataset[page : page + self.per_page]
class Menu(discord.ui.View): class Menu(discord.ui.View):
"""Base class for a menu""" """Base class for a menu"""
@ -166,3 +124,58 @@ class Menu(discord.ui.View):
"""Button to show the last page""" """Button to show the last page"""
self.current_page = len(self.source) - 1 self.current_page = len(self.source) - 1
await self.display_current_state(interaction) await self.display_current_state(interaction)
class PageSource(ABC, Generic[T]):
"""Base class that handles the embeds displayed in a menu"""
ctx: commands.Context
dataset: list[T]
embeds: list[discord.Embed]
page_count: int
per_page: int
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
self.ctx = ctx
self.embeds = []
self.dataset = dataset
self.per_page = per_page
self.page_count = self._get_page_count()
self.create_embeds()
self._add_embed_page_footers()
def _get_page_count(self) -> int:
"""Calculate the amount of pages required"""
if len(self.dataset) % self.per_page == 0:
return len(self.dataset) // self.per_page
return (len(self.dataset) // self.per_page) + 1
def __getitem__(self, index: int) -> discord.Embed:
return self.embeds[index]
def __len__(self):
return self.page_count
def _add_embed_page_footers(self):
"""Add the current page in the footer of every embed"""
for i, embed in enumerate(self.embeds):
embed.set_footer(text=f"{i + 1}/{self.page_count}")
@abstractmethod
def create_embeds(self):
"""Method that builds the list of embeds from the input data"""
raise NotImplementedError
def get_page_data(self, page: int) -> list[T]:
"""Get the chunk of the dataset for page [page]"""
return self.dataset[page : page + self.per_page]
async def start(self, *, ephemeral: bool = False, timeout: Optional[int] = None) -> Menu:
"""Shortcut to creating (and starting) a Menu with this source
This returns the created menu
"""
menu = Menu(self, ephemeral=ephemeral, timeout=timeout)
await menu.start(self.ctx)
return menu

View File

@ -0,0 +1,24 @@
import discord
from overrides import overrides
from database.schemas import CustomCommand
from didier.menus.common import PageSource
__all__ = ["CustomCommandSource"]
class CustomCommandSource(PageSource[CustomCommand]):
"""PageSource for custom commands"""
@overrides
def create_embeds(self):
for page in range(self.page_count):
embed = discord.Embed(colour=discord.Colour.blue(), title="Custom Commands")
description_data = []
for command in self.get_page_data(page):
description_data.append(command.name.title())
embed.description = "\n".join(description_data)
self.embeds.append(embed)

View File

@ -1,5 +1,4 @@
import discord import discord
from discord.ext import commands
from overrides import overrides from overrides import overrides
from database.schemas import MemeTemplate from database.schemas import MemeTemplate
@ -12,7 +11,7 @@ class MemeSource(PageSource[MemeTemplate]):
"""PageSource for meme templates""" """PageSource for meme templates"""
@overrides @overrides
def create_embeds(self, ctx: commands.Context): def create_embeds(self):
for page in range(self.page_count): for page in range(self.page_count):
# The colour of the embed is (69,4,20) with the values +100 because they were too dark # The colour of the embed is (69,4,20) with the values +100 because they were too dark
embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120)) embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120))

View File

@ -1,6 +1,46 @@
from enum import Enum from enum import Enum
__all__ = ["Limits"] __all__ = ["EMOJI_MAP", "Limits"]
EMOJI_MAP = {
"a": "🇦",
"b": "🇧",
"c": "🇨",
"d": "🇩",
"e": "🇪",
"f": "🇫",
"g": "🇬",
"h": "🇭",
"i": "🇮",
"j": "🇯",
"k": "🇰",
"l": "🇱",
"m": "🇲",
"n": "🇳",
"o": "🇴",
"p": "🇵",
"q": "🇶",
"r": "🇷",
"s": "🇸",
"t": "🇹",
"u": "🇺",
"v": "🇻",
"w": "🇼",
"x": "🇽",
"y": "🇾",
"z": "🇿",
"0": "0⃣",
"1": "1",
"2": "2",
"3": "3",
"4": "4",
"5": "5",
"6": "6",
"7": "7",
"8": "8",
"9": "9",
}
class Limits(int, Enum): class Limits(int, Enum):