mirror of https://github.com/stijndcl/didier
Compare commits
8 Commits
00a146cb2b
...
773491e2ff
| Author | SHA1 | Date |
|---|---|---|
|
|
773491e2ff | |
|
|
bf32a5ef47 | |
|
|
8922489a41 | |
|
|
3e495d8291 | |
|
|
0a9f73af8c | |
|
|
185aaadce1 | |
|
|
ddd632ffd5 | |
|
|
d03ece6f58 |
|
|
@ -0,0 +1,38 @@
|
|||
"""Add reminders
|
||||
|
||||
Revision ID: a64876b41af2
|
||||
Revises: c1f9ee875616
|
||||
Create Date: 2022-09-23 13:37:10.331840
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "a64876b41af2"
|
||||
down_revision = "c1f9ee875616"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"reminders",
|
||||
sa.Column("reminder_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("category", sa.Enum("LES", name="remindercategory"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("reminder_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("reminders")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -12,6 +12,7 @@ __all__ = [
|
|||
"create_alias",
|
||||
"create_command",
|
||||
"edit_command",
|
||||
"get_all_commands",
|
||||
"get_command",
|
||||
"get_command_by_alias",
|
||||
"get_command_by_name",
|
||||
|
|
@ -55,6 +56,12 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
|
|||
return alias_instance
|
||||
|
||||
|
||||
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
|
||||
"""Get a list of all commands"""
|
||||
statement = select(CustomCommand)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
|
||||
|
||||
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
"""Try to get a command out of a message"""
|
||||
# Search lowercase & without spaces
|
||||
|
|
|
|||
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.enums import ReminderCategory
|
||||
from database.schemas import Reminder
|
||||
|
||||
__all__ = ["get_all_reminders_for_category", "toggle_reminder"]
|
||||
|
||||
|
||||
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
|
||||
"""Get a list of all Reminders for a given category"""
|
||||
statement = select(Reminder).where(Reminder.category == category)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
|
||||
|
||||
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:
|
||||
"""Switch a category on/off
|
||||
|
||||
Returns the new value for the category
|
||||
"""
|
||||
await get_or_add_user(session, user_id)
|
||||
|
||||
select_statement = select(Reminder).where(Reminder.user_id == user_id).where(Reminder.category == category)
|
||||
reminder: Optional[Reminder] = (await session.execute(select_statement)).scalar_one_or_none()
|
||||
|
||||
# No reminder set yet
|
||||
if reminder is None:
|
||||
reminder = Reminder(user_id=user_id, category=category)
|
||||
session.add(reminder)
|
||||
await session.commit()
|
||||
|
||||
return True
|
||||
|
||||
# Reminder found -> delete it
|
||||
delete_statement = delete(Reminder).where(Reminder.reminder_id == reminder.reminder_id)
|
||||
await session.execute(delete_statement)
|
||||
await session.commit()
|
||||
|
||||
return False
|
||||
|
|
@ -1,11 +1,14 @@
|
|||
import enum
|
||||
|
||||
__all__ = ["TaskType"]
|
||||
__all__ = ["ReminderCategory", "TaskType"]
|
||||
|
||||
|
||||
class ReminderCategory(enum.IntEnum):
|
||||
"""Enum for reminder categories"""
|
||||
|
||||
LES = enum.auto()
|
||||
|
||||
|
||||
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
||||
# https://github.com/python/typeshed/issues/8286
|
||||
# noinspection PyArgumentList
|
||||
class TaskType(enum.IntEnum):
|
||||
"""Enum for the different types of tasks"""
|
||||
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ __all__ = [
|
|||
"Link",
|
||||
"MemeTemplate",
|
||||
"NightlyData",
|
||||
"Reminder",
|
||||
"Task",
|
||||
"UforaAnnouncement",
|
||||
"UforaCourse",
|
||||
|
|
@ -219,6 +220,18 @@ class NightlyData(Base):
|
|||
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Reminder(Base):
|
||||
"""Something that a user should be reminded of"""
|
||||
|
||||
__tablename__ = "reminders"
|
||||
|
||||
reminder_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Task(Base):
|
||||
"""A Didier task"""
|
||||
|
||||
|
|
@ -303,6 +316,9 @@ class User(Base):
|
|||
nightly_data: NightlyData = relationship(
|
||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
reminders: list[Reminder] = relationship(
|
||||
"Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_guesses: list[WordleGuess] = relationship(
|
||||
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -14,7 +14,6 @@ from database.exceptions import (
|
|||
from didier import Didier
|
||||
from didier.exceptions import expect
|
||||
from didier.menus.bookmarks import BookmarkSource
|
||||
from didier.menus.common import Menu
|
||||
from didier.utils.discord import colours
|
||||
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
||||
from didier.utils.discord.constants import Limits
|
||||
|
|
@ -186,9 +185,7 @@ class Discord(commands.Cog):
|
|||
embed.description = "You haven't created any bookmarks yet."
|
||||
return await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
source = BookmarkSource(ctx, results)
|
||||
menu = Menu(source)
|
||||
await menu.start(ctx)
|
||||
await BookmarkSource(ctx, results).start()
|
||||
|
||||
async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message):
|
||||
"""Create a bookmark out of this message"""
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import shlex
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
|
@ -9,8 +10,8 @@ from database.crud.memes import get_all_memes, get_meme_by_name
|
|||
from didier import Didier
|
||||
from didier.data.apis.imgflip import generate_meme
|
||||
from didier.exceptions.no_match import expect
|
||||
from didier.menus.common import Menu
|
||||
from didier.menus.memes import MemeSource
|
||||
from didier.utils.discord import constants
|
||||
from didier.views.modals import GenerateMeme
|
||||
|
||||
|
||||
|
|
@ -25,6 +26,22 @@ class Fun(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.hybrid_command(name="clap")
|
||||
async def clap(self, ctx: commands.Context, *, text: str):
|
||||
"""Clap a message with emojis for extra dramatic effect"""
|
||||
chars = list(filter(lambda c: c.isalnum(), text))
|
||||
|
||||
if not chars:
|
||||
return await ctx.reply("👏", mention_author=False)
|
||||
|
||||
text = "👏".join(list(map(lambda c: constants.EMOJI_MAP.get(c), chars)))
|
||||
text = f"👏{text}👏"
|
||||
|
||||
if len(text) > constants.Limits.MESSAGE_LENGTH:
|
||||
return await ctx.reply("Message is too long.", mention_author=False)
|
||||
|
||||
return await ctx.reply(text, mention_author=False)
|
||||
|
||||
async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str:
|
||||
async with self.client.postgres_session as session:
|
||||
result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name)
|
||||
|
|
@ -42,7 +59,7 @@ class Fun(commands.Cog):
|
|||
return await ctx.reply(joke.joke, mention_author=False)
|
||||
|
||||
@commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True)
|
||||
async def memegen_msg(self, ctx: commands.Context, template: str, *, fields: str):
|
||||
async def memegen_msg(self, ctx: commands.Context, template: Optional[str] = None, *, fields: Optional[str] = None):
|
||||
"""Generate a meme with template `template` and fields `fields`.
|
||||
|
||||
The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes".
|
||||
|
|
@ -55,7 +72,17 @@ class Fun(commands.Cog):
|
|||
|
||||
Example: if template `a` only has 1 field,
|
||||
`memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]`
|
||||
|
||||
When no arguments are provided, this is a shortcut to `memegen list`.
|
||||
|
||||
When only a template is provided, this is a shortcut to `memegen preview`.
|
||||
"""
|
||||
if template is None:
|
||||
return await self.memegen_ls_msg(ctx)
|
||||
|
||||
if fields is None:
|
||||
return await self.memegen_preview_msg(ctx, template)
|
||||
|
||||
async with ctx.typing():
|
||||
meme = await self._do_generate_meme(template, shlex.split(fields))
|
||||
return await ctx.reply(meme, mention_author=False)
|
||||
|
|
@ -69,9 +96,7 @@ class Fun(commands.Cog):
|
|||
async with self.client.postgres_session as session:
|
||||
results = await get_all_memes(session)
|
||||
|
||||
source = MemeSource(ctx, results)
|
||||
menu = Menu(source)
|
||||
await menu.start(ctx)
|
||||
await MemeSource(ctx, results).start()
|
||||
|
||||
@memegen_msg.command(name="preview", aliases=["p"])
|
||||
async def memegen_preview_msg(self, ctx: commands.Context, template: str):
|
||||
|
|
@ -83,7 +108,7 @@ class Fun(commands.Cog):
|
|||
|
||||
@memes_slash.command(name="generate")
|
||||
async def memegen_slash(self, interaction: discord.Interaction, template: str):
|
||||
"""Generate a meme with template `template`."""
|
||||
"""Generate a meme."""
|
||||
async with self.client.postgres_session as session:
|
||||
result = expect(await get_meme_by_name(session, template), entity_type="meme", argument=template)
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,11 @@ from typing import Optional
|
|||
|
||||
from discord.ext import commands
|
||||
|
||||
from database.crud.custom_commands import get_all_commands
|
||||
from database.crud.reminders import toggle_reminder
|
||||
from database.enums import ReminderCategory
|
||||
from didier import Didier
|
||||
from didier.menus.custom_commands import CustomCommandSource
|
||||
|
||||
|
||||
class Meta(commands.Cog):
|
||||
|
|
@ -15,11 +19,45 @@ class Meta(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.command(name="custom")
|
||||
async def custom(self, ctx: commands.Context):
|
||||
"""Get a list of all custom commands that are registered."""
|
||||
async with self.client.postgres_session as session:
|
||||
custom_commands = await get_all_commands(session)
|
||||
|
||||
custom_commands.sort(key=lambda c: c.name.lower())
|
||||
await CustomCommandSource(ctx, custom_commands).start()
|
||||
|
||||
@commands.command(name="marco")
|
||||
async def marco(self, ctx: commands.Context):
|
||||
"""Get Didier's latency."""
|
||||
return await ctx.reply(f"Polo! {round(self.client.latency * 1000)}ms", mention_author=False)
|
||||
|
||||
@commands.command(name="remind", aliases=["remindme"])
|
||||
async def remind(self, ctx: commands.Context, category: str):
|
||||
"""Make Didier send you reminders every day."""
|
||||
category = category.lower()
|
||||
|
||||
available_categories = [
|
||||
(
|
||||
"les",
|
||||
ReminderCategory.LES,
|
||||
)
|
||||
]
|
||||
|
||||
for name, category_mapping in available_categories:
|
||||
if name == category:
|
||||
async with self.client.postgres_session as session:
|
||||
new_value = await toggle_reminder(session, ctx.author.id, category_mapping)
|
||||
|
||||
toggle = "on" if new_value else "off"
|
||||
return await ctx.reply(
|
||||
f"Reminders for category `{name}` have been toggled {toggle}.", mention_author=False
|
||||
)
|
||||
|
||||
# No match found
|
||||
return await ctx.reply(f"`{category}` is not a supported category.", mention_author=False)
|
||||
|
||||
@commands.command(name="source", aliases=["src"])
|
||||
async def source(self, ctx: commands.Context, *, command_name: Optional[str] = None):
|
||||
"""Get a link to the source code of Didier.
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ from didier import Didier
|
|||
from didier.data.apis.hydra import fetch_menu
|
||||
from didier.data.embeds.deadlines import Deadlines
|
||||
from didier.data.embeds.hydra import no_menu_found
|
||||
from didier.data.embeds.schedules import Schedule, get_schedule_for_user
|
||||
from didier.data.embeds.schedules import Schedule, get_schedule_for_day
|
||||
from didier.exceptions import HTTPException, NotInMainGuildException
|
||||
from didier.utils.discord.converters.time import DateTransformer
|
||||
from didier.utils.discord.flags.school import StudyGuideFlags
|
||||
|
|
@ -55,10 +55,11 @@ class School(commands.Cog):
|
|||
|
||||
try:
|
||||
member_instance = to_main_guild_member(self.client, ctx.author)
|
||||
roles = {role.id for role in member_instance.roles}
|
||||
|
||||
# Always make sure there is at least one schedule in case it returns None
|
||||
# this allows proper error messages
|
||||
schedule = get_schedule_for_user(self.client, member_instance, day_dt) or Schedule()
|
||||
schedule = (get_schedule_for_day(self.client, day_dt) or Schedule()).personalize(roles)
|
||||
|
||||
return await ctx.reply(embed=schedule.to_embed(day=day_dt), mention_author=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import datetime
|
||||
import random
|
||||
import traceback
|
||||
|
||||
import discord
|
||||
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
|
||||
|
|
@ -9,10 +8,16 @@ from overrides import overrides
|
|||
import settings
|
||||
from database import enums
|
||||
from database.crud.birthdays import get_birthdays_on_day
|
||||
from database.crud.reminders import get_all_reminders_for_category
|
||||
from database.crud.ufora_announcements import remove_old_announcements
|
||||
from database.crud.wordle import set_daily_word
|
||||
from database.schemas import Reminder
|
||||
from didier import Didier
|
||||
from didier.data.embeds.schedules import Schedule, parse_schedule_from_content
|
||||
from didier.data.embeds.schedules import (
|
||||
Schedule,
|
||||
get_schedule_for_day,
|
||||
parse_schedule_from_content,
|
||||
)
|
||||
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
||||
from didier.decorators.tasks import timed_task
|
||||
from didier.utils.discord.checks import is_owner
|
||||
|
|
@ -44,6 +49,7 @@ class Tasks(commands.Cog):
|
|||
self._tasks = {
|
||||
"birthdays": self.check_birthdays,
|
||||
"schedules": self.pull_schedules,
|
||||
"reminders": self.reminders,
|
||||
"ufora": self.pull_ufora_announcements,
|
||||
"remove_ufora": self.remove_old_ufora_announcements,
|
||||
"wordle": self.reset_wordle_word,
|
||||
|
|
@ -61,6 +67,7 @@ class Tasks(commands.Cog):
|
|||
self.remove_old_ufora_announcements.start()
|
||||
|
||||
# Start other tasks
|
||||
self.reminders.start()
|
||||
self.reset_wordle_word.start()
|
||||
self.pull_schedules.start()
|
||||
|
||||
|
|
@ -135,7 +142,7 @@ class Tasks(commands.Cog):
|
|||
async with self.client.postgres_session as session:
|
||||
for data in settings.SCHEDULE_DATA:
|
||||
if data.schedule_url is None:
|
||||
return
|
||||
continue
|
||||
|
||||
async with self.client.http_session.get(data.schedule_url) as response:
|
||||
# If a schedule couldn't be fetched, log it and move on
|
||||
|
|
@ -180,6 +187,50 @@ class Tasks(commands.Cog):
|
|||
async def _before_ufora_announcements(self):
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
async def _send_les_reminders(self, entries: list[Reminder]):
|
||||
today = tz_aware_now().date()
|
||||
|
||||
# Create the main schedule for the day once here, to avoid doing it repeatedly
|
||||
daily_schedule = get_schedule_for_day(self.client, today)
|
||||
|
||||
# No class today
|
||||
if not daily_schedule:
|
||||
return
|
||||
|
||||
for entry in entries:
|
||||
member = self.client.main_guild.get_member(entry.user_id)
|
||||
if not member:
|
||||
continue
|
||||
|
||||
roles = {role.id for role in member.roles}
|
||||
personal_schedule = daily_schedule.personalize(roles)
|
||||
|
||||
# No class today
|
||||
if not personal_schedule:
|
||||
continue
|
||||
|
||||
await member.send(embed=personal_schedule.to_embed(day=today))
|
||||
|
||||
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
|
||||
async def reminders(self, **kwargs):
|
||||
"""Send daily reminders to people"""
|
||||
_ = kwargs
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
for category in enums.ReminderCategory:
|
||||
entries = await get_all_reminders_for_category(session, category)
|
||||
if not entries:
|
||||
continue
|
||||
|
||||
# This is slightly ugly, but it's the best way to go about it
|
||||
# There won't be a lot of categories anyway
|
||||
if category == enums.ReminderCategory.LES:
|
||||
await self._send_les_reminders(entries)
|
||||
|
||||
@reminders.before_loop
|
||||
async def _before_reminders(self):
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
@tasks.loop(hours=24)
|
||||
async def remove_old_ufora_announcements(self):
|
||||
"""Remove all announcements that are over 1 week old, once per day"""
|
||||
|
|
@ -200,12 +251,12 @@ class Tasks(commands.Cog):
|
|||
@check_birthdays.error
|
||||
@pull_schedules.error
|
||||
@pull_ufora_announcements.error
|
||||
@reminders.error
|
||||
@remove_old_ufora_announcements.error
|
||||
@reset_wordle_word.error
|
||||
async def _on_tasks_error(self, error: BaseException):
|
||||
"""Error handler for all tasks"""
|
||||
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
|
||||
self.client.dispatch("task_error")
|
||||
self.client.dispatch("task_error", error)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
import traceback
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
|
@ -18,7 +19,7 @@ def _get_traceback(exception: Exception) -> str:
|
|||
if line.strip().startswith("The above exception was the direct cause of"):
|
||||
break
|
||||
|
||||
# Escape Discord markdown formatting
|
||||
# Escape Discord Markdown formatting
|
||||
error_string += line.replace(r"*", r"\*").replace(r"_", r"\_")
|
||||
if line.strip():
|
||||
error_string += "\n"
|
||||
|
|
@ -26,23 +27,30 @@ def _get_traceback(exception: Exception) -> str:
|
|||
return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8)
|
||||
|
||||
|
||||
def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed:
|
||||
def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> discord.Embed:
|
||||
"""Create an embed for the traceback of an exception"""
|
||||
message = str(exception)
|
||||
|
||||
# Wrap the traceback in a codeblock for readability
|
||||
description = _get_traceback(exception).strip()
|
||||
description = f"```\n{description}\n```"
|
||||
|
||||
if ctx.guild is None:
|
||||
origin = "DM"
|
||||
else:
|
||||
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
|
||||
|
||||
invocation = f"{ctx.author.display_name} in {origin}"
|
||||
|
||||
embed = discord.Embed(title="Error", colour=discord.Colour.red())
|
||||
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
|
||||
embed.add_field(name="Context", value=invocation, inline=True)
|
||||
embed.add_field(name="Exception", value=str(exception), inline=False)
|
||||
|
||||
if ctx is not None:
|
||||
if ctx.guild is None:
|
||||
origin = "DM"
|
||||
else:
|
||||
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
|
||||
|
||||
invocation = f"{ctx.author.display_name} in {origin}"
|
||||
|
||||
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
|
||||
embed.add_field(name="Context", value=invocation, inline=True)
|
||||
|
||||
if message:
|
||||
embed.add_field(name="Exception", value=message, inline=False)
|
||||
|
||||
embed.add_field(name="Traceback", value=description, inline=False)
|
||||
|
||||
return embed
|
||||
|
|
|
|||
|
|
@ -2,6 +2,9 @@ import logging
|
|||
|
||||
import discord
|
||||
|
||||
from didier.utils.discord.constants import Limits
|
||||
from didier.utils.types.string import abbreviate
|
||||
|
||||
__all__ = ["create_logging_embed"]
|
||||
|
||||
|
||||
|
|
@ -16,6 +19,6 @@ def create_logging_embed(level: int, message: str) -> discord.Embed:
|
|||
|
||||
colour = colours.get(level, discord.Colour.red())
|
||||
embed = discord.Embed(colour=colour, title="Logging")
|
||||
embed.description = message
|
||||
embed.description = abbreviate(message, Limits.EMBED_DESCRIPTION_LENGTH)
|
||||
|
||||
return embed
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ from didier.utils.types.datetime import LOCAL_TIMEZONE, int_to_weekday, time_str
|
|||
from didier.utils.types.string import leading
|
||||
from settings import ScheduleType
|
||||
|
||||
__all__ = ["Schedule", "get_schedule_for_user", "parse_schedule_from_content", "parse_schedule"]
|
||||
__all__ = ["Schedule", "get_schedule_for_day", "parse_schedule_from_content", "parse_schedule"]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -48,6 +48,10 @@ class Schedule(EmbedBaseModel):
|
|||
|
||||
def personalize(self, roles: set[int]) -> Schedule:
|
||||
"""Personalize a schedule for a user, only adding courses they follow"""
|
||||
# If the schedule is already empty, just return instantly
|
||||
if not self.slots:
|
||||
return Schedule()
|
||||
|
||||
personal_slots = set()
|
||||
for slot in self.slots:
|
||||
role_found = slot.role_id is not None and slot.role_id in roles
|
||||
|
|
@ -57,6 +61,21 @@ class Schedule(EmbedBaseModel):
|
|||
|
||||
return Schedule(personal_slots)
|
||||
|
||||
def simplify(self):
|
||||
"""Merge sequential slots in the same location into one
|
||||
|
||||
Note: this is done in-place instead of returning a new schedule!
|
||||
(The operation is O(n^2))
|
||||
|
||||
Example:
|
||||
13:00 - 14:30: AD3 in S9
|
||||
14:30 - 1600: AD3 in S9
|
||||
"""
|
||||
for first in self.slots:
|
||||
for second in self.slots:
|
||||
if first == second:
|
||||
continue
|
||||
|
||||
@overrides
|
||||
def to_embed(self, **kwargs) -> discord.Embed:
|
||||
day: date = kwargs.get("day", date.today())
|
||||
|
|
@ -104,10 +123,12 @@ class ScheduleSlot:
|
|||
def __post_init__(self):
|
||||
"""Fix some properties to display more nicely"""
|
||||
# Re-format the location data
|
||||
room, building, campus = re.search(r"(.*)\. Gebouw (.*)\. Campus (.*)\. ", self.location).groups()
|
||||
room, building, campus = re.search(r"(.*)\. (?:Gebouw )?(.*)\. (?:Campus )?(.*)\. ", self.location).groups()
|
||||
room = room.replace("PC / laptoplokaal ", "PC-lokaal")
|
||||
self.location = f"{campus} {building} {room}"
|
||||
|
||||
# The same course can only start once at the same moment,
|
||||
# so this is guaranteed to be unique
|
||||
self._hash = hash(f"{self.course.course_id} {str(self.start_time)}")
|
||||
|
||||
@property
|
||||
|
|
@ -131,15 +152,33 @@ class ScheduleSlot:
|
|||
|
||||
return self._hash == other._hash
|
||||
|
||||
def could_merge_with(self, other: ScheduleSlot) -> bool:
|
||||
"""Check if two slots are actually one with a 15-min break in-between
|
||||
|
||||
def get_schedule_for_user(client: Didier, member: discord.Member, day_dt: date) -> Optional[Schedule]:
|
||||
"""Get a user's schedule"""
|
||||
roles: set[int] = {role.id for role in member.roles}
|
||||
If they are, merge the two into one (this edits the first slot in-place!)
|
||||
"""
|
||||
if self.course.course_id != other.course.course_id:
|
||||
return False
|
||||
|
||||
if self.location != other.location:
|
||||
return False
|
||||
|
||||
if self.start_time == other.end_time:
|
||||
other.end_time = self.end_time
|
||||
return True
|
||||
elif self.end_time == other.start_time:
|
||||
self.end_time = other.end_time
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_schedule_for_day(client: Didier, day_dt: date) -> Optional[Schedule]:
|
||||
"""Get a schedule for an entire day"""
|
||||
main_schedule: Optional[Schedule] = None
|
||||
|
||||
for schedule in client.schedules.values():
|
||||
personalized_schedule = schedule.on_day(day_dt).personalize(roles)
|
||||
personalized_schedule = schedule.on_day(day_dt)
|
||||
|
||||
if not personalized_schedule:
|
||||
continue
|
||||
|
|
@ -204,6 +243,10 @@ async def parse_schedule_from_content(content: str, *, database_session: AsyncSe
|
|||
location=event.location,
|
||||
)
|
||||
|
||||
# Slot extends another one, don't add it
|
||||
if any(s.could_merge_with(slot) for s in slots):
|
||||
continue
|
||||
|
||||
slots.add(slot)
|
||||
|
||||
return Schedule(slots=slots)
|
||||
|
|
|
|||
|
|
@ -363,6 +363,13 @@ class Didier(commands.Bot):
|
|||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
||||
async def on_task_error(self, exception: Exception):
|
||||
"""Event triggered when a task raises an exception"""
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
embed = create_error_embed(None, exception)
|
||||
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
await channel.send(embed=embed)
|
||||
|
||||
async def on_thread_create(self, thread: discord.Thread):
|
||||
"""Event triggered when a new thread is created"""
|
||||
# Join threads automatically
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
from overrides import overrides
|
||||
|
||||
from database.schemas import Bookmark
|
||||
|
|
@ -14,16 +13,16 @@ class BookmarkSource(PageSource[Bookmark]):
|
|||
"""PageSource for the Bookmark commands"""
|
||||
|
||||
@overrides
|
||||
def create_embeds(self, ctx: commands.Context):
|
||||
def create_embeds(self):
|
||||
for page in range(self.page_count):
|
||||
embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue())
|
||||
avatar_url = get_author_avatar(ctx).url
|
||||
embed.set_author(name=ctx.author.display_name, icon_url=avatar_url)
|
||||
avatar_url = get_author_avatar(self.ctx).url
|
||||
embed.set_author(name=self.ctx.author.display_name, icon_url=avatar_url)
|
||||
|
||||
description = ""
|
||||
description_data = []
|
||||
|
||||
for bookmark in self.get_page_data(page):
|
||||
description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n"
|
||||
description_data.append(f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})")
|
||||
|
||||
embed.description = description.strip()
|
||||
embed.description = "\n".join(description_data)
|
||||
self.embeds.append(embed)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,5 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, Optional, TypeVar, cast
|
||||
|
||||
|
|
@ -13,50 +15,6 @@ __all__ = ["Menu", "PageSource"]
|
|||
T = TypeVar("T")
|
||||
|
||||
|
||||
class PageSource(ABC, Generic[T]):
|
||||
"""Base class that handles the embeds displayed in a menu"""
|
||||
|
||||
dataset: list[T]
|
||||
embeds: list[discord.Embed]
|
||||
page_count: int
|
||||
per_page: int
|
||||
|
||||
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
|
||||
self.embeds = []
|
||||
self.dataset = dataset
|
||||
self.per_page = per_page
|
||||
self.page_count = self._get_page_count()
|
||||
self.create_embeds(ctx)
|
||||
self._add_embed_page_footers()
|
||||
|
||||
def _get_page_count(self) -> int:
|
||||
"""Calculate the amount of pages required"""
|
||||
if len(self.dataset) % self.per_page == 0:
|
||||
return len(self.dataset) // self.per_page
|
||||
|
||||
return (len(self.dataset) // self.per_page) + 1
|
||||
|
||||
def __getitem__(self, index: int) -> discord.Embed:
|
||||
return self.embeds[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.page_count
|
||||
|
||||
def _add_embed_page_footers(self):
|
||||
"""Add the current page in the footer of every embed"""
|
||||
for i, embed in enumerate(self.embeds):
|
||||
embed.set_footer(text=f"{i + 1}/{self.page_count}")
|
||||
|
||||
@abstractmethod
|
||||
def create_embeds(self, ctx: commands.Context):
|
||||
"""Method that builds the list of embeds from the input data"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_page_data(self, page: int) -> list[T]:
|
||||
"""Get the chunk of the dataset for page [page]"""
|
||||
return self.dataset[page : page + self.per_page]
|
||||
|
||||
|
||||
class Menu(discord.ui.View):
|
||||
"""Base class for a menu"""
|
||||
|
||||
|
|
@ -166,3 +124,58 @@ class Menu(discord.ui.View):
|
|||
"""Button to show the last page"""
|
||||
self.current_page = len(self.source) - 1
|
||||
await self.display_current_state(interaction)
|
||||
|
||||
|
||||
class PageSource(ABC, Generic[T]):
|
||||
"""Base class that handles the embeds displayed in a menu"""
|
||||
|
||||
ctx: commands.Context
|
||||
dataset: list[T]
|
||||
embeds: list[discord.Embed]
|
||||
page_count: int
|
||||
per_page: int
|
||||
|
||||
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
|
||||
self.ctx = ctx
|
||||
self.embeds = []
|
||||
self.dataset = dataset
|
||||
self.per_page = per_page
|
||||
self.page_count = self._get_page_count()
|
||||
self.create_embeds()
|
||||
self._add_embed_page_footers()
|
||||
|
||||
def _get_page_count(self) -> int:
|
||||
"""Calculate the amount of pages required"""
|
||||
if len(self.dataset) % self.per_page == 0:
|
||||
return len(self.dataset) // self.per_page
|
||||
|
||||
return (len(self.dataset) // self.per_page) + 1
|
||||
|
||||
def __getitem__(self, index: int) -> discord.Embed:
|
||||
return self.embeds[index]
|
||||
|
||||
def __len__(self):
|
||||
return self.page_count
|
||||
|
||||
def _add_embed_page_footers(self):
|
||||
"""Add the current page in the footer of every embed"""
|
||||
for i, embed in enumerate(self.embeds):
|
||||
embed.set_footer(text=f"{i + 1}/{self.page_count}")
|
||||
|
||||
@abstractmethod
|
||||
def create_embeds(self):
|
||||
"""Method that builds the list of embeds from the input data"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_page_data(self, page: int) -> list[T]:
|
||||
"""Get the chunk of the dataset for page [page]"""
|
||||
return self.dataset[page : page + self.per_page]
|
||||
|
||||
async def start(self, *, ephemeral: bool = False, timeout: Optional[int] = None) -> Menu:
|
||||
"""Shortcut to creating (and starting) a Menu with this source
|
||||
|
||||
This returns the created menu
|
||||
"""
|
||||
menu = Menu(self, ephemeral=ephemeral, timeout=timeout)
|
||||
await menu.start(self.ctx)
|
||||
return menu
|
||||
|
|
|
|||
|
|
@ -0,0 +1,24 @@
|
|||
import discord
|
||||
from overrides import overrides
|
||||
|
||||
from database.schemas import CustomCommand
|
||||
from didier.menus.common import PageSource
|
||||
|
||||
__all__ = ["CustomCommandSource"]
|
||||
|
||||
|
||||
class CustomCommandSource(PageSource[CustomCommand]):
|
||||
"""PageSource for custom commands"""
|
||||
|
||||
@overrides
|
||||
def create_embeds(self):
|
||||
for page in range(self.page_count):
|
||||
embed = discord.Embed(colour=discord.Colour.blue(), title="Custom Commands")
|
||||
|
||||
description_data = []
|
||||
|
||||
for command in self.get_page_data(page):
|
||||
description_data.append(command.name.title())
|
||||
|
||||
embed.description = "\n".join(description_data)
|
||||
self.embeds.append(embed)
|
||||
|
|
@ -1,5 +1,4 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
from overrides import overrides
|
||||
|
||||
from database.schemas import MemeTemplate
|
||||
|
|
@ -12,7 +11,7 @@ class MemeSource(PageSource[MemeTemplate]):
|
|||
"""PageSource for meme templates"""
|
||||
|
||||
@overrides
|
||||
def create_embeds(self, ctx: commands.Context):
|
||||
def create_embeds(self):
|
||||
for page in range(self.page_count):
|
||||
# The colour of the embed is (69,4,20) with the values +100 because they were too dark
|
||||
embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120))
|
||||
|
|
|
|||
|
|
@ -1,6 +1,46 @@
|
|||
from enum import Enum
|
||||
|
||||
__all__ = ["Limits"]
|
||||
__all__ = ["EMOJI_MAP", "Limits"]
|
||||
|
||||
|
||||
EMOJI_MAP = {
|
||||
"a": "🇦",
|
||||
"b": "🇧",
|
||||
"c": "🇨",
|
||||
"d": "🇩",
|
||||
"e": "🇪",
|
||||
"f": "🇫",
|
||||
"g": "🇬",
|
||||
"h": "🇭",
|
||||
"i": "🇮",
|
||||
"j": "🇯",
|
||||
"k": "🇰",
|
||||
"l": "🇱",
|
||||
"m": "🇲",
|
||||
"n": "🇳",
|
||||
"o": "🇴",
|
||||
"p": "🇵",
|
||||
"q": "🇶",
|
||||
"r": "🇷",
|
||||
"s": "🇸",
|
||||
"t": "🇹",
|
||||
"u": "🇺",
|
||||
"v": "🇻",
|
||||
"w": "🇼",
|
||||
"x": "🇽",
|
||||
"y": "🇾",
|
||||
"z": "🇿",
|
||||
"0": "0⃣",
|
||||
"1": "1️⃣",
|
||||
"2": "2️⃣",
|
||||
"3": "3️⃣",
|
||||
"4": "4️⃣",
|
||||
"5": "5️⃣",
|
||||
"6": "6️⃣",
|
||||
"7": "7️⃣",
|
||||
"8": "8️⃣",
|
||||
"9": "9️⃣",
|
||||
}
|
||||
|
||||
|
||||
class Limits(int, Enum):
|
||||
|
|
|
|||
Loading…
Reference in New Issue