From d03ece6f58a5fef6d2c316b9a6e3f00b6926f9c5 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 14:25:13 +0200 Subject: [PATCH 1/7] Create reminders, fix bugs in schedule parsing --- .../versions/a64876b41af2_add_reminders.py | 38 ++++++++++++ database/crud/reminders.py | 42 ++++++++++++++ database/enums.py | 11 ++-- database/schemas.py | 16 +++++ didier/cogs/fun.py | 2 +- didier/cogs/meta.py | 27 +++++++++ didier/cogs/school.py | 5 +- didier/cogs/tasks.py | 58 ++++++++++++++++++- didier/data/embeds/schedules.py | 17 +++--- 9 files changed, 199 insertions(+), 17 deletions(-) create mode 100644 alembic/versions/a64876b41af2_add_reminders.py create mode 100644 database/crud/reminders.py diff --git a/alembic/versions/a64876b41af2_add_reminders.py b/alembic/versions/a64876b41af2_add_reminders.py new file mode 100644 index 0000000..40d48b5 --- /dev/null +++ b/alembic/versions/a64876b41af2_add_reminders.py @@ -0,0 +1,38 @@ +"""Add reminders + +Revision ID: a64876b41af2 +Revises: c1f9ee875616 +Create Date: 2022-09-23 13:37:10.331840 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "a64876b41af2" +down_revision = "c1f9ee875616" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "reminders", + sa.Column("reminder_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("category", sa.Enum("LES", name="remindercategory"), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("reminder_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("reminders") + # ### end Alembic commands ### diff --git a/database/crud/reminders.py b/database/crud/reminders.py new file mode 100644 index 0000000..007a779 --- /dev/null +++ b/database/crud/reminders.py @@ -0,0 +1,42 @@ +from typing import Optional + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud.users import get_or_add_user +from database.enums import ReminderCategory +from database.schemas import Reminder + +__all__ = ["get_all_reminders_for_category", "toggle_reminder"] + + +async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]: + """Get a list of all Reminders for a given category""" + statement = select(Reminder).where(Reminder.category == category) + return (await session.execute(statement)).scalars().all() + + +async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool: + """Switch a category on/off + + Returns the new value for the category + """ + await get_or_add_user(session, user_id) + + select_statement = select(Reminder).where(Reminder.user_id == user_id).where(Reminder.category == category) + reminder: Optional[Reminder] = (await session.execute(select_statement)).scalar_one_or_none() + + # No reminder set yet + if reminder is None: + reminder = Reminder(user_id=user_id, category=category) + session.add(reminder) + await session.commit() + + return True + + # Reminder found -> delete it + delete_statement = delete(Reminder).where(Reminder.reminder_id == reminder.reminder_id) + await session.execute(delete_statement) + await session.commit() + + return False diff --git a/database/enums.py b/database/enums.py index e2cf565..3eec739 100644 --- a/database/enums.py +++ b/database/enums.py @@ -1,11 +1,14 @@ import enum -__all__ = ["TaskType"] +__all__ = ["ReminderCategory", "TaskType"] + + +class ReminderCategory(enum.IntEnum): + """Enum for reminder categories""" + + LES = enum.auto() -# There is a bug in typeshed that causes an incorrect PyCharm warning -# https://github.com/python/typeshed/issues/8286 -# noinspection PyArgumentList class TaskType(enum.IntEnum): """Enum for the different types of tasks""" diff --git a/database/schemas.py b/database/schemas.py index f497b2d..3fd8ce4 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -37,6 +37,7 @@ __all__ = [ "Link", "MemeTemplate", "NightlyData", + "Reminder", "Task", "UforaAnnouncement", "UforaCourse", @@ -219,6 +220,18 @@ class NightlyData(Base): user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") +class Reminder(Base): + """Something that a user should be reminded of""" + + __tablename__ = "reminders" + + reminder_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False) + + user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin") + + class Task(Base): """A Didier task""" @@ -303,6 +316,9 @@ class User(Base): nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) + reminders: list[Reminder] = relationship( + "Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + ) wordle_guesses: list[WordleGuess] = relationship( "WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index fe78275..25d0baf 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -83,7 +83,7 @@ class Fun(commands.Cog): @memes_slash.command(name="generate") async def memegen_slash(self, interaction: discord.Interaction, template: str): - """Generate a meme with template `template`.""" + """Generate a meme.""" async with self.client.postgres_session as session: result = expect(await get_meme_by_name(session, template), entity_type="meme", argument=template) diff --git a/didier/cogs/meta.py b/didier/cogs/meta.py index 7b7a732..ae4ea81 100644 --- a/didier/cogs/meta.py +++ b/didier/cogs/meta.py @@ -4,6 +4,8 @@ from typing import Optional from discord.ext import commands +from database.crud.reminders import toggle_reminder +from database.enums import ReminderCategory from didier import Didier @@ -20,6 +22,31 @@ class Meta(commands.Cog): """Get Didier's latency.""" return await ctx.reply(f"Polo! {round(self.client.latency * 1000)}ms", mention_author=False) + @commands.command(name="remind", aliases=["remindme"]) + async def remind(self, ctx: commands.Context, category: str): + """Make Didier send you reminders every day.""" + category = category.lower() + + available_categories = [ + ( + "les", + ReminderCategory.LES, + ) + ] + + for name, category_mapping in available_categories: + if name == category: + async with self.client.postgres_session as session: + new_value = await toggle_reminder(session, ctx.author.id, category_mapping) + + toggle = "on" if new_value else "off" + return await ctx.reply( + f"Reminders for category `{name}` have been toggled {toggle}.", mention_author=False + ) + + # No match found + return await ctx.reply(f"`{category}` is not a supported category.", mention_author=False) + @commands.command(name="source", aliases=["src"]) async def source(self, ctx: commands.Context, *, command_name: Optional[str] = None): """Get a link to the source code of Didier. diff --git a/didier/cogs/school.py b/didier/cogs/school.py index aacdc4d..a9b6b7a 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -11,7 +11,7 @@ from didier import Didier from didier.data.apis.hydra import fetch_menu from didier.data.embeds.deadlines import Deadlines from didier.data.embeds.hydra import no_menu_found -from didier.data.embeds.schedules import Schedule, get_schedule_for_user +from didier.data.embeds.schedules import Schedule, get_schedule_for_day from didier.exceptions import HTTPException, NotInMainGuildException from didier.utils.discord.converters.time import DateTransformer from didier.utils.discord.flags.school import StudyGuideFlags @@ -55,10 +55,11 @@ class School(commands.Cog): try: member_instance = to_main_guild_member(self.client, ctx.author) + roles = {role.id for role in member_instance.roles} # Always make sure there is at least one schedule in case it returns None # this allows proper error messages - schedule = get_schedule_for_user(self.client, member_instance, day_dt) or Schedule() + schedule = (get_schedule_for_day(self.client, day_dt) or Schedule()).personalize(roles) return await ctx.reply(embed=schedule.to_embed(day=day_dt), mention_author=False) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 37b1786..1657be6 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -9,10 +9,16 @@ from overrides import overrides import settings from database import enums from database.crud.birthdays import get_birthdays_on_day +from database.crud.reminders import get_all_reminders_for_category from database.crud.ufora_announcements import remove_old_announcements from database.crud.wordle import set_daily_word +from database.schemas import Reminder from didier import Didier -from didier.data.embeds.schedules import Schedule, parse_schedule_from_content +from didier.data.embeds.schedules import ( + Schedule, + get_schedule_for_day, + parse_schedule_from_content, +) from didier.data.embeds.ufora.announcements import fetch_ufora_announcements from didier.decorators.tasks import timed_task from didier.utils.discord.checks import is_owner @@ -44,6 +50,7 @@ class Tasks(commands.Cog): self._tasks = { "birthdays": self.check_birthdays, "schedules": self.pull_schedules, + "reminders": self.reminders, "ufora": self.pull_ufora_announcements, "remove_ufora": self.remove_old_ufora_announcements, "wordle": self.reset_wordle_word, @@ -61,6 +68,7 @@ class Tasks(commands.Cog): self.remove_old_ufora_announcements.start() # Start other tasks + self.reminders.start() self.reset_wordle_word.start() self.pull_schedules.start() @@ -135,7 +143,7 @@ class Tasks(commands.Cog): async with self.client.postgres_session as session: for data in settings.SCHEDULE_DATA: if data.schedule_url is None: - return + continue async with self.client.http_session.get(data.schedule_url) as response: # If a schedule couldn't be fetched, log it and move on @@ -180,6 +188,51 @@ class Tasks(commands.Cog): async def _before_ufora_announcements(self): await self.client.wait_until_ready() + async def _send_les_reminders(self, entries: list[Reminder]): + today = datetime.date(year=2022, month=9, day=26) + + # Create the main schedule for the day once here, to avoid doing it repeatedly + daily_schedule = get_schedule_for_day(self.client, today) + + # No class today + if not daily_schedule: + return + + for entry in entries: + member = self.client.main_guild.get_member(entry.user_id) + if not member: + continue + + roles = {role.id for role in member.roles} + personal_schedule = daily_schedule.personalize(roles) + + # No class today + if not personal_schedule: + continue + + await member.send(embed=personal_schedule.to_embed(day=today)) + + # @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) + @tasks.loop(hours=3) + async def reminders(self, **kwargs): + """Send daily reminders to people""" + _ = kwargs + + async with self.client.postgres_session as session: + for category in enums.ReminderCategory: + entries = await get_all_reminders_for_category(session, category) + if not entries: + continue + + # This is slightly ugly, but it's the best way to go about it + # There won't be a lot of categories anyway + if category == enums.ReminderCategory: + await self._send_les_reminders(entries) + + @reminders.before_loop + async def _before_reminders(self): + await self.client.wait_until_ready() + @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" @@ -200,6 +253,7 @@ class Tasks(commands.Cog): @check_birthdays.error @pull_schedules.error @pull_ufora_announcements.error + @reminders.error @remove_old_ufora_announcements.error @reset_wordle_word.error async def _on_tasks_error(self, error: BaseException): diff --git a/didier/data/embeds/schedules.py b/didier/data/embeds/schedules.py index ed03a33..39e16de 100644 --- a/didier/data/embeds/schedules.py +++ b/didier/data/embeds/schedules.py @@ -22,7 +22,7 @@ from didier.utils.types.datetime import LOCAL_TIMEZONE, int_to_weekday, time_str from didier.utils.types.string import leading from settings import ScheduleType -__all__ = ["Schedule", "get_schedule_for_user", "parse_schedule_from_content", "parse_schedule"] +__all__ = ["Schedule", "get_schedule_for_day", "parse_schedule_from_content", "parse_schedule"] @dataclass @@ -48,6 +48,10 @@ class Schedule(EmbedBaseModel): def personalize(self, roles: set[int]) -> Schedule: """Personalize a schedule for a user, only adding courses they follow""" + # If the schedule is already empty, just return instantly + if not self.slots: + return Schedule() + personal_slots = set() for slot in self.slots: role_found = slot.role_id is not None and slot.role_id in roles @@ -104,10 +108,9 @@ class ScheduleSlot: def __post_init__(self): """Fix some properties to display more nicely""" # Re-format the location data - room, building, campus = re.search(r"(.*)\. Gebouw (.*)\. Campus (.*)\. ", self.location).groups() + room, building, campus = re.search(r"(.*)\. (?:Gebouw )?(.*)\. (?:Campus )?(.*)\. ", self.location).groups() room = room.replace("PC / laptoplokaal ", "PC-lokaal") self.location = f"{campus} {building} {room}" - self._hash = hash(f"{self.course.course_id} {str(self.start_time)}") @property @@ -132,14 +135,12 @@ class ScheduleSlot: return self._hash == other._hash -def get_schedule_for_user(client: Didier, member: discord.Member, day_dt: date) -> Optional[Schedule]: - """Get a user's schedule""" - roles: set[int] = {role.id for role in member.roles} - +def get_schedule_for_day(client: Didier, day_dt: date) -> Optional[Schedule]: + """Get a schedule for an entire day""" main_schedule: Optional[Schedule] = None for schedule in client.schedules.values(): - personalized_schedule = schedule.on_day(day_dt).personalize(roles) + personalized_schedule = schedule.on_day(day_dt) if not personalized_schedule: continue From ddd632ffd5aaa5ff3398e3789f273ab28cc9c740 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 14:27:56 +0200 Subject: [PATCH 2/7] Fix typo --- didier/cogs/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 1657be6..cf1293e 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -226,7 +226,7 @@ class Tasks(commands.Cog): # This is slightly ugly, but it's the best way to go about it # There won't be a lot of categories anyway - if category == enums.ReminderCategory: + if category == enums.ReminderCategory.LES: await self._send_les_reminders(entries) @reminders.before_loop From 185aaadce1fa681db1b0b4d32d8325a0ae427c63 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 14:47:42 +0200 Subject: [PATCH 3/7] Merge sequential slots into one --- didier/data/embeds/schedules.py | 42 +++++++++++++++++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/didier/data/embeds/schedules.py b/didier/data/embeds/schedules.py index 39e16de..1c32bbf 100644 --- a/didier/data/embeds/schedules.py +++ b/didier/data/embeds/schedules.py @@ -61,6 +61,21 @@ class Schedule(EmbedBaseModel): return Schedule(personal_slots) + def simplify(self): + """Merge sequential slots in the same location into one + + Note: this is done in-place instead of returning a new schedule! + (The operation is O(n^2)) + + Example: + 13:00 - 14:30: AD3 in S9 + 14:30 - 1600: AD3 in S9 + """ + for first in self.slots: + for second in self.slots: + if first == second: + continue + @overrides def to_embed(self, **kwargs) -> discord.Embed: day: date = kwargs.get("day", date.today()) @@ -111,6 +126,9 @@ class ScheduleSlot: room, building, campus = re.search(r"(.*)\. (?:Gebouw )?(.*)\. (?:Campus )?(.*)\. ", self.location).groups() room = room.replace("PC / laptoplokaal ", "PC-lokaal") self.location = f"{campus} {building} {room}" + + # The same course can only start once at the same moment, + # so this is guaranteed to be unique self._hash = hash(f"{self.course.course_id} {str(self.start_time)}") @property @@ -134,6 +152,26 @@ class ScheduleSlot: return self._hash == other._hash + def could_merge_with(self, other: ScheduleSlot) -> bool: + """Check if two slots are actually one with a 15-min break in-between + + If they are, merge the two into one (this edits the first slot in-place!) + """ + if self.course.course_id != other.course.course_id: + return False + + if self.location != other.location: + return False + + if self.start_time == other.end_time: + other.end_time = self.end_time + return True + elif self.end_time == other.start_time: + self.end_time = other.end_time + return True + + return False + def get_schedule_for_day(client: Didier, day_dt: date) -> Optional[Schedule]: """Get a schedule for an entire day""" @@ -205,6 +243,10 @@ async def parse_schedule_from_content(content: str, *, database_session: AsyncSe location=event.location, ) + # Slot extends another one, don't add it + if any(s.could_merge_with(slot) for s in slots): + continue + slots.add(slot) return Schedule(slots=slots) From 0a9f73af8cfac0aaeadd1580d5e9dfce3534deee Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 14:59:47 +0200 Subject: [PATCH 4/7] Make tasks log exceptions --- didier/cogs/tasks.py | 7 ++----- didier/data/embeds/error_embed.py | 32 ++++++++++++++++++----------- didier/data/embeds/logging_embed.py | 5 ++++- didier/didier.py | 7 +++++++ 4 files changed, 33 insertions(+), 18 deletions(-) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index cf1293e..239f02a 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,6 +1,5 @@ import datetime import random -import traceback import discord from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error @@ -212,8 +211,7 @@ class Tasks(commands.Cog): await member.send(embed=personal_schedule.to_embed(day=today)) - # @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) - @tasks.loop(hours=3) + @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) async def reminders(self, **kwargs): """Send daily reminders to people""" _ = kwargs @@ -258,8 +256,7 @@ class Tasks(commands.Cog): @reset_wordle_word.error async def _on_tasks_error(self, error: BaseException): """Error handler for all tasks""" - print("".join(traceback.format_exception(type(error), error, error.__traceback__))) - self.client.dispatch("task_error") + self.client.dispatch("task_error", error) async def setup(client: Didier): diff --git a/didier/data/embeds/error_embed.py b/didier/data/embeds/error_embed.py index 03edd25..70bc2d3 100644 --- a/didier/data/embeds/error_embed.py +++ b/didier/data/embeds/error_embed.py @@ -1,4 +1,5 @@ import traceback +from typing import Optional import discord from discord.ext import commands @@ -18,7 +19,7 @@ def _get_traceback(exception: Exception) -> str: if line.strip().startswith("The above exception was the direct cause of"): break - # Escape Discord markdown formatting + # Escape Discord Markdown formatting error_string += line.replace(r"*", r"\*").replace(r"_", r"\_") if line.strip(): error_string += "\n" @@ -26,23 +27,30 @@ def _get_traceback(exception: Exception) -> str: return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8) -def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed: +def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> discord.Embed: """Create an embed for the traceback of an exception""" + message = str(exception) + # Wrap the traceback in a codeblock for readability description = _get_traceback(exception).strip() description = f"```\n{description}\n```" - if ctx.guild is None: - origin = "DM" - else: - origin = f"{ctx.channel.mention} ({ctx.guild.name})" - - invocation = f"{ctx.author.display_name} in {origin}" - embed = discord.Embed(title="Error", colour=discord.Colour.red()) - embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True) - embed.add_field(name="Context", value=invocation, inline=True) - embed.add_field(name="Exception", value=str(exception), inline=False) + + if ctx is not None: + if ctx.guild is None: + origin = "DM" + else: + origin = f"{ctx.channel.mention} ({ctx.guild.name})" + + invocation = f"{ctx.author.display_name} in {origin}" + + embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True) + embed.add_field(name="Context", value=invocation, inline=True) + + if message: + embed.add_field(name="Exception", value=message, inline=False) + embed.add_field(name="Traceback", value=description, inline=False) return embed diff --git a/didier/data/embeds/logging_embed.py b/didier/data/embeds/logging_embed.py index 784cc3f..40556f2 100644 --- a/didier/data/embeds/logging_embed.py +++ b/didier/data/embeds/logging_embed.py @@ -2,6 +2,9 @@ import logging import discord +from didier.utils.discord.constants import Limits +from didier.utils.types.string import abbreviate + __all__ = ["create_logging_embed"] @@ -16,6 +19,6 @@ def create_logging_embed(level: int, message: str) -> discord.Embed: colour = colours.get(level, discord.Colour.red()) embed = discord.Embed(colour=colour, title="Logging") - embed.description = message + embed.description = abbreviate(message, Limits.EMBED_DESCRIPTION_LENGTH) return embed diff --git a/didier/didier.py b/didier/didier.py index 050def6..cdfa7bd 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -363,6 +363,13 @@ class Didier(commands.Bot): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) + async def on_task_error(self, exception: Exception): + """Event triggered when a task raises an exception""" + if settings.ERRORS_CHANNEL is not None: + embed = create_error_embed(None, exception) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) + async def on_thread_create(self, thread: discord.Thread): """Event triggered when a new thread is created""" # Join threads automatically From 3e495d8291bc779c243065df68cc273eae3d18d8 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 15:07:13 +0200 Subject: [PATCH 5/7] Remove debug date --- didier/cogs/tasks.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 239f02a..14e7699 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -188,7 +188,7 @@ class Tasks(commands.Cog): await self.client.wait_until_ready() async def _send_les_reminders(self, entries: list[Reminder]): - today = datetime.date(year=2022, month=9, day=26) + today = tz_aware_now().date() # Create the main schedule for the day once here, to avoid doing it repeatedly daily_schedule = get_schedule_for_day(self.client, today) From bf32a5ef47a4de6c81833a59af340feea51bb082 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 18:06:33 +0200 Subject: [PATCH 6/7] Create command to list custom commands, add shortcuts to memegen commands --- database/crud/custom_commands.py | 7 +++ didier/cogs/discord.py | 5 +- didier/cogs/fun.py | 18 ++++-- didier/cogs/meta.py | 11 ++++ didier/menus/bookmarks.py | 13 ++-- didier/menus/common.py | 101 +++++++++++++++++-------------- didier/menus/custom_commands.py | 24 ++++++++ didier/menus/memes.py | 3 +- 8 files changed, 120 insertions(+), 62 deletions(-) create mode 100644 didier/menus/custom_commands.py diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index bb6ac0c..efbc689 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -12,6 +12,7 @@ __all__ = [ "create_alias", "create_command", "edit_command", + "get_all_commands", "get_command", "get_command_by_alias", "get_command_by_name", @@ -55,6 +56,12 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo return alias_instance +async def get_all_commands(session: AsyncSession) -> list[CustomCommand]: + """Get a list of all commands""" + statement = select(CustomCommand) + return (await session.execute(statement)).scalars().all() + + async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: """Try to get a command out of a message""" # Search lowercase & without spaces diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index fa85b35..7418994 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -14,7 +14,6 @@ from database.exceptions import ( from didier import Didier from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource -from didier.menus.common import Menu from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.constants import Limits @@ -186,9 +185,7 @@ class Discord(commands.Cog): embed.description = "You haven't created any bookmarks yet." return await ctx.reply(embed=embed, mention_author=False) - source = BookmarkSource(ctx, results) - menu = Menu(source) - await menu.start(ctx) + await BookmarkSource(ctx, results).start() async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message): """Create a bookmark out of this message""" diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 25d0baf..d3e61f6 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -1,4 +1,5 @@ import shlex +from typing import Optional import discord from discord import app_commands @@ -9,7 +10,6 @@ from database.crud.memes import get_all_memes, get_meme_by_name from didier import Didier from didier.data.apis.imgflip import generate_meme from didier.exceptions.no_match import expect -from didier.menus.common import Menu from didier.menus.memes import MemeSource from didier.views.modals import GenerateMeme @@ -42,7 +42,7 @@ class Fun(commands.Cog): return await ctx.reply(joke.joke, mention_author=False) @commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True) - async def memegen_msg(self, ctx: commands.Context, template: str, *, fields: str): + async def memegen_msg(self, ctx: commands.Context, template: Optional[str] = None, *, fields: Optional[str] = None): """Generate a meme with template `template` and fields `fields`. The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes". @@ -55,7 +55,17 @@ class Fun(commands.Cog): Example: if template `a` only has 1 field, `memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]` + + When no arguments are provided, this is a shortcut to `memegen list`. + + When only a template is provided, this is a shortcut to `memegen preview`. """ + if template is None: + return await self.memegen_ls_msg(ctx) + + if fields is None: + return await self.memegen_preview_msg(ctx, template) + async with ctx.typing(): meme = await self._do_generate_meme(template, shlex.split(fields)) return await ctx.reply(meme, mention_author=False) @@ -69,9 +79,7 @@ class Fun(commands.Cog): async with self.client.postgres_session as session: results = await get_all_memes(session) - source = MemeSource(ctx, results) - menu = Menu(source) - await menu.start(ctx) + await MemeSource(ctx, results).start() @memegen_msg.command(name="preview", aliases=["p"]) async def memegen_preview_msg(self, ctx: commands.Context, template: str): diff --git a/didier/cogs/meta.py b/didier/cogs/meta.py index ae4ea81..c330dbd 100644 --- a/didier/cogs/meta.py +++ b/didier/cogs/meta.py @@ -4,9 +4,11 @@ from typing import Optional from discord.ext import commands +from database.crud.custom_commands import get_all_commands from database.crud.reminders import toggle_reminder from database.enums import ReminderCategory from didier import Didier +from didier.menus.custom_commands import CustomCommandSource class Meta(commands.Cog): @@ -17,6 +19,15 @@ class Meta(commands.Cog): def __init__(self, client: Didier): self.client = client + @commands.command(name="custom") + async def custom(self, ctx: commands.Context): + """Get a list of all custom commands that are registered.""" + async with self.client.postgres_session as session: + custom_commands = await get_all_commands(session) + + custom_commands.sort(key=lambda c: c.name.lower()) + await CustomCommandSource(ctx, custom_commands).start() + @commands.command(name="marco") async def marco(self, ctx: commands.Context): """Get Didier's latency.""" diff --git a/didier/menus/bookmarks.py b/didier/menus/bookmarks.py index d4c1e11..f727abe 100644 --- a/didier/menus/bookmarks.py +++ b/didier/menus/bookmarks.py @@ -1,5 +1,4 @@ import discord -from discord.ext import commands from overrides import overrides from database.schemas import Bookmark @@ -14,16 +13,16 @@ class BookmarkSource(PageSource[Bookmark]): """PageSource for the Bookmark commands""" @overrides - def create_embeds(self, ctx: commands.Context): + def create_embeds(self): for page in range(self.page_count): embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue()) - avatar_url = get_author_avatar(ctx).url - embed.set_author(name=ctx.author.display_name, icon_url=avatar_url) + avatar_url = get_author_avatar(self.ctx).url + embed.set_author(name=self.ctx.author.display_name, icon_url=avatar_url) - description = "" + description_data = [] for bookmark in self.get_page_data(page): - description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n" + description_data.append(f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})") - embed.description = description.strip() + embed.description = "\n".join(description_data) self.embeds.append(embed) diff --git a/didier/menus/common.py b/didier/menus/common.py index d1cb3b1..45b0d95 100644 --- a/didier/menus/common.py +++ b/didier/menus/common.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from abc import ABC, abstractmethod from typing import Generic, Optional, TypeVar, cast @@ -13,50 +15,6 @@ __all__ = ["Menu", "PageSource"] T = TypeVar("T") -class PageSource(ABC, Generic[T]): - """Base class that handles the embeds displayed in a menu""" - - dataset: list[T] - embeds: list[discord.Embed] - page_count: int - per_page: int - - def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10): - self.embeds = [] - self.dataset = dataset - self.per_page = per_page - self.page_count = self._get_page_count() - self.create_embeds(ctx) - self._add_embed_page_footers() - - def _get_page_count(self) -> int: - """Calculate the amount of pages required""" - if len(self.dataset) % self.per_page == 0: - return len(self.dataset) // self.per_page - - return (len(self.dataset) // self.per_page) + 1 - - def __getitem__(self, index: int) -> discord.Embed: - return self.embeds[index] - - def __len__(self): - return self.page_count - - def _add_embed_page_footers(self): - """Add the current page in the footer of every embed""" - for i, embed in enumerate(self.embeds): - embed.set_footer(text=f"{i + 1}/{self.page_count}") - - @abstractmethod - def create_embeds(self, ctx: commands.Context): - """Method that builds the list of embeds from the input data""" - raise NotImplementedError - - def get_page_data(self, page: int) -> list[T]: - """Get the chunk of the dataset for page [page]""" - return self.dataset[page : page + self.per_page] - - class Menu(discord.ui.View): """Base class for a menu""" @@ -166,3 +124,58 @@ class Menu(discord.ui.View): """Button to show the last page""" self.current_page = len(self.source) - 1 await self.display_current_state(interaction) + + +class PageSource(ABC, Generic[T]): + """Base class that handles the embeds displayed in a menu""" + + ctx: commands.Context + dataset: list[T] + embeds: list[discord.Embed] + page_count: int + per_page: int + + def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10): + self.ctx = ctx + self.embeds = [] + self.dataset = dataset + self.per_page = per_page + self.page_count = self._get_page_count() + self.create_embeds() + self._add_embed_page_footers() + + def _get_page_count(self) -> int: + """Calculate the amount of pages required""" + if len(self.dataset) % self.per_page == 0: + return len(self.dataset) // self.per_page + + return (len(self.dataset) // self.per_page) + 1 + + def __getitem__(self, index: int) -> discord.Embed: + return self.embeds[index] + + def __len__(self): + return self.page_count + + def _add_embed_page_footers(self): + """Add the current page in the footer of every embed""" + for i, embed in enumerate(self.embeds): + embed.set_footer(text=f"{i + 1}/{self.page_count}") + + @abstractmethod + def create_embeds(self): + """Method that builds the list of embeds from the input data""" + raise NotImplementedError + + def get_page_data(self, page: int) -> list[T]: + """Get the chunk of the dataset for page [page]""" + return self.dataset[page : page + self.per_page] + + async def start(self, *, ephemeral: bool = False, timeout: Optional[int] = None) -> Menu: + """Shortcut to creating (and starting) a Menu with this source + + This returns the created menu + """ + menu = Menu(self, ephemeral=ephemeral, timeout=timeout) + await menu.start(self.ctx) + return menu diff --git a/didier/menus/custom_commands.py b/didier/menus/custom_commands.py new file mode 100644 index 0000000..e02b647 --- /dev/null +++ b/didier/menus/custom_commands.py @@ -0,0 +1,24 @@ +import discord +from overrides import overrides + +from database.schemas import CustomCommand +from didier.menus.common import PageSource + +__all__ = ["CustomCommandSource"] + + +class CustomCommandSource(PageSource[CustomCommand]): + """PageSource for custom commands""" + + @overrides + def create_embeds(self): + for page in range(self.page_count): + embed = discord.Embed(colour=discord.Colour.blue(), title="Custom Commands") + + description_data = [] + + for command in self.get_page_data(page): + description_data.append(command.name.title()) + + embed.description = "\n".join(description_data) + self.embeds.append(embed) diff --git a/didier/menus/memes.py b/didier/menus/memes.py index d7a4be4..5d420c2 100644 --- a/didier/menus/memes.py +++ b/didier/menus/memes.py @@ -1,5 +1,4 @@ import discord -from discord.ext import commands from overrides import overrides from database.schemas import MemeTemplate @@ -12,7 +11,7 @@ class MemeSource(PageSource[MemeTemplate]): """PageSource for meme templates""" @overrides - def create_embeds(self, ctx: commands.Context): + def create_embeds(self): for page in range(self.page_count): # The colour of the embed is (69,4,20) with the values +100 because they were too dark embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120)) From 773491e2ff444a96066af2f5ca69babca4707a2f Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 23 Sep 2022 20:30:00 +0200 Subject: [PATCH 7/7] Clap --- didier/cogs/fun.py | 17 +++++++++++++ didier/utils/discord/constants.py | 42 ++++++++++++++++++++++++++++++- 2 files changed, 58 insertions(+), 1 deletion(-) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index d3e61f6..4507305 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -11,6 +11,7 @@ from didier import Didier from didier.data.apis.imgflip import generate_meme from didier.exceptions.no_match import expect from didier.menus.memes import MemeSource +from didier.utils.discord import constants from didier.views.modals import GenerateMeme @@ -25,6 +26,22 @@ class Fun(commands.Cog): def __init__(self, client: Didier): self.client = client + @commands.hybrid_command(name="clap") + async def clap(self, ctx: commands.Context, *, text: str): + """Clap a message with emojis for extra dramatic effect""" + chars = list(filter(lambda c: c.isalnum(), text)) + + if not chars: + return await ctx.reply("👏", mention_author=False) + + text = "👏".join(list(map(lambda c: constants.EMOJI_MAP.get(c), chars))) + text = f"👏{text}👏" + + if len(text) > constants.Limits.MESSAGE_LENGTH: + return await ctx.reply("Message is too long.", mention_author=False) + + return await ctx.reply(text, mention_author=False) + async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str: async with self.client.postgres_session as session: result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name) diff --git a/didier/utils/discord/constants.py b/didier/utils/discord/constants.py index 707d635..22eab4b 100644 --- a/didier/utils/discord/constants.py +++ b/didier/utils/discord/constants.py @@ -1,6 +1,46 @@ from enum import Enum -__all__ = ["Limits"] +__all__ = ["EMOJI_MAP", "Limits"] + + +EMOJI_MAP = { + "a": "🇦", + "b": "🇧", + "c": "🇨", + "d": "🇩", + "e": "🇪", + "f": "🇫", + "g": "🇬", + "h": "🇭", + "i": "🇮", + "j": "🇯", + "k": "🇰", + "l": "🇱", + "m": "🇲", + "n": "🇳", + "o": "🇴", + "p": "🇵", + "q": "🇶", + "r": "🇷", + "s": "🇸", + "t": "🇹", + "u": "🇺", + "v": "🇻", + "w": "🇼", + "x": "🇽", + "y": "🇾", + "z": "🇿", + "0": "0⃣", + "1": "1️⃣", + "2": "2️⃣", + "3": "3️⃣", + "4": "4️⃣", + "5": "5️⃣", + "6": "6️⃣", + "7": "7️⃣", + "8": "8️⃣", + "9": "9️⃣", +} class Limits(int, Enum):