From a614e9a9f1dcc0e4095b0631f1e571d4cc77a984 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Wed, 10 Aug 2022 01:04:19 +0200 Subject: [PATCH 1/5] Rework links --- .../versions/3962636f3a3d_add_custom_links.py | 35 +++++++++++++ database/crud/links.py | 45 ++++++++++++++++ database/exceptions/__init__.py | 5 ++ database/schemas/relational.py | 11 ++++ database/utils/caches.py | 33 ++++++++---- didier/cogs/other.py | 37 ++++++++++++++ didier/cogs/owner.py | 51 ++++++++++++++----- didier/cogs/school.py | 4 +- didier/didier.py | 11 ++++ didier/views/modals/__init__.py | 3 +- didier/views/modals/links.py | 37 ++++++++++++++ tests/test_database/test_utils/test_caches.py | 4 +- 12 files changed, 246 insertions(+), 30 deletions(-) create mode 100644 alembic/versions/3962636f3a3d_add_custom_links.py create mode 100644 database/crud/links.py create mode 100644 didier/views/modals/links.py diff --git a/alembic/versions/3962636f3a3d_add_custom_links.py b/alembic/versions/3962636f3a3d_add_custom_links.py new file mode 100644 index 0000000..ef4f13e --- /dev/null +++ b/alembic/versions/3962636f3a3d_add_custom_links.py @@ -0,0 +1,35 @@ +"""Add custom links + +Revision ID: 3962636f3a3d +Revises: 346b408c362a +Create Date: 2022-08-10 00:54:05.668255 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3962636f3a3d" +down_revision = "346b408c362a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "links", + sa.Column("link_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("url", sa.Text(), nullable=False), + sa.PrimaryKeyConstraint("link_id"), + sa.UniqueConstraint("name"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("links") + # ### end Alembic commands ### diff --git a/database/crud/links.py b/database/crud/links.py new file mode 100644 index 0000000..e97c328 --- /dev/null +++ b/database/crud/links.py @@ -0,0 +1,45 @@ +from typing import Optional + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.exceptions import NoResultFoundException +from database.schemas.relational import Link + +__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] + + +async def get_all_links(session: AsyncSession) -> list[Link]: + """Get a list of all links""" + statement = select(Link) + return (await session.execute(statement)).scalars().all() + + +async def add_link(session: AsyncSession, name: str, url: str) -> Link: + """Add a new link into the database""" + if name.islower(): + name = name.capitalize() + + instance = Link(name=name, url=url) + session.add(instance) + await session.commit() + + return instance + + +async def get_link_by_name(session: AsyncSession, name: str) -> Optional[Link]: + """Get a link by its name""" + statement = select(Link).where(func.lower(Link.name) == name.lower()) + return (await session.execute(statement)).scalar_one_or_none() + + +async def edit_link(session: AsyncSession, name: str, new_url: str): + """Edit an existing link""" + link: Optional[Link] = await get_link_by_name(session, name) + + if link is None: + raise NoResultFoundException + + link.url = new_url + session.add(link) + await session.commit() diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py index e69de29..1751bc5 100644 --- a/database/exceptions/__init__.py +++ b/database/exceptions/__init__.py @@ -0,0 +1,5 @@ +from .constraints import DuplicateInsertException +from .currency import DoubleNightly, NotEnoughDinks +from .not_found import NoResultFoundException + +__all__ = ["DuplicateInsertException", "DoubleNightly", "NotEnoughDinks", "NoResultFoundException"] diff --git a/database/schemas/relational.py b/database/schemas/relational.py index f0fb6e4..b0bc5e7 100644 --- a/database/schemas/relational.py +++ b/database/schemas/relational.py @@ -28,6 +28,7 @@ __all__ = [ "CustomCommand", "CustomCommandAlias", "DadJoke", + "Link", "NightlyData", "Task", "UforaAnnouncement", @@ -109,6 +110,16 @@ class DadJoke(Base): joke: str = Column(Text, nullable=False) +class Link(Base): + """Useful links that go useful places""" + + __tablename__ = "links" + + link_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + url: str = Column(Text, nullable=False) + + class NightlyData(Base): """Data for a user's Nightly stats""" diff --git a/database/utils/caches.py b/database/utils/caches.py index 4e35147..b08417d 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -4,10 +4,10 @@ from typing import Generic, TypeVar from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import ufora_courses, wordle +from database.crud import links, ufora_courses, wordle from database.mongo_types import MongoDatabase -__all__ = ["CacheManager", "UforaCourseCache"] +__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] T = TypeVar("T") @@ -35,12 +35,8 @@ class DatabaseCache(ABC, Generic[T]): self.data.clear() @abstractmethod - async def refresh(self, database_session: T): - """Refresh the data stored in this cache""" - async def invalidate(self, database_session: T): """Invalidate the data stored in this cache""" - await self.refresh(database_session) def get_autocomplete_suggestions(self, query: str): """Filter the cache to find everything that matches the search query""" @@ -49,6 +45,19 @@ class DatabaseCache(ABC, Generic[T]): return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] +class LinkCache(DatabaseCache[AsyncSession]): + """Cache to store the names of links""" + + @overrides + async def invalidate(self, database_session: AsyncSession): + self.clear() + + all_links = await links.get_all_links(database_session) + self.data = list(map(lambda l: l.name, all_links)) + self.data.sort() + self.data_transformed = list(map(str.lower, self.data)) + + class UforaCourseCache(DatabaseCache[AsyncSession]): """Cache to store the names of Ufora courses""" @@ -61,11 +70,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): super().clear() @overrides - async def refresh(self, database_session: AsyncSession): + async def invalidate(self, database_session: AsyncSession): self.clear() courses = await ufora_courses.get_all_courses(database_session) - self.data = list(map(lambda c: c.name, courses)) # Load the aliases @@ -97,7 +105,7 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): class WordleCache(DatabaseCache[MongoDatabase]): """Cache to store the current daily Wordle word""" - async def refresh(self, database_session: MongoDatabase): + async def invalidate(self, database_session: MongoDatabase): word = await wordle.get_daily_word(database_session) if word is not None: self.data = [word] @@ -106,14 +114,17 @@ class WordleCache(DatabaseCache[MongoDatabase]): class CacheManager: """Class that keeps track of all caches""" + links: LinkCache ufora_courses: UforaCourseCache wordle_word: WordleCache def __init__(self): + self.links = LinkCache() self.ufora_courses = UforaCourseCache() self.wordle_word = WordleCache() async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): """Initialize the contents of all caches""" - await self.ufora_courses.refresh(postgres_session) - await self.wordle_word.refresh(mongo_db) + await self.links.invalidate(postgres_session) + await self.ufora_courses.invalidate(postgres_session) + await self.wordle_word.invalidate(mongo_db) diff --git a/didier/cogs/other.py b/didier/cogs/other.py index de642a7..e00ab0b 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -1,6 +1,11 @@ +from typing import Optional + +import discord from discord import app_commands from discord.ext import commands +from database.crud.links import get_link_by_name +from database.schemas.relational import Link from didier import Didier from didier.data.apis import urban_dictionary from didier.data.embeds.google import GoogleSearch @@ -34,6 +39,38 @@ class Other(commands.Cog): embed = GoogleSearch(results).to_embed() await ctx.reply(embed=embed, mention_author=False) + async def _get_link(self, name: str) -> Optional[Link]: + async with self.client.postgres_session as session: + return await get_link_by_name(session, name.lower()) + + @commands.command(name="Link", aliases=["Links"], usage="[Name]") + async def link_msg(self, ctx: commands.Context, name: str): + """Message command to get the link to something""" + link = await self._get_link(name) + if link is None: + return await ctx.reply(f"Found no links matching `{name}`.", mention_author=False) + + target_message = await self.client.get_reply_target(ctx) + await target_message.reply(link.url, mention_author=False) + + @app_commands.command(name="link", description="Get the link to something") + @app_commands.describe(name="The name of the link") + async def link_slash(self, interaction: discord.Interaction, name: str): + """Slash command to get the link to something""" + link = await self._get_link(name) + if link is None: + return await interaction.response.send_message(f"Found no links matching `{name}`.", ephemeral=True) + + return await interaction.response.send_message(link.url) + + @link_slash.autocomplete("name") + async def _link_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'name'-parameter""" + return [ + app_commands.Choice(name=name, value=name.lower()) + for name in self.client.database_caches.links.get_autocomplete_suggestions(current) + ] + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 7d623b0..badd0c2 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -5,12 +5,17 @@ from discord import app_commands from discord.ext import commands import settings -from database.crud import custom_commands +from database.crud import custom_commands, links from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags -from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand +from didier.views.modals import ( + AddDadJoke, + AddLink, + CreateCustomCommand, + EditCustomCommand, +) class Owner(commands.Cog): @@ -80,17 +85,6 @@ class Owner(commands.Cog): async def add_msg(self, ctx: commands.Context): """Command group for [add X] message commands""" - @add_msg.command(name="Custom") - async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): - """Add a new custom command""" - async with self.client.postgres_session as session: - try: - await custom_commands.create_command(session, name, response) - await self.client.confirm_message(ctx.message) - except DuplicateInsertException: - await ctx.reply("There is already a command with this name.") - await self.client.reject_message(ctx.message) - @add_msg.command(name="Alias") async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" @@ -105,6 +99,26 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) + @add_msg.command(name="Custom") + async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): + """Add a new custom command""" + async with self.client.postgres_session as session: + try: + await custom_commands.create_command(session, name, response) + await self.client.confirm_message(ctx.message) + except DuplicateInsertException: + await ctx.reply("There is already a command with this name.") + await self.client.reject_message(ctx.message) + + @add_msg.command(name="Link") + async def add_link_msg(self, ctx: commands.Context, name: str, url: str): + """Add a new link""" + async with self.client.postgres_session as session: + await links.add_link(session, name, url) + await self.client.database_caches.links.invalidate(session) + + await self.client.confirm_message(ctx.message) + @add_slash.command(name="custom", description="Add a custom command") async def add_custom_slash(self, interaction: discord.Interaction): """Slash command to add a custom command""" @@ -123,6 +137,15 @@ class Owner(commands.Cog): modal = AddDadJoke(self.client) await interaction.response.send_modal(modal) + @add_slash.command(name="link", description="Add a new link") + async def add_link_slash(self, interaction: discord.Interaction): + """Slash command to add new links""" + if not await self.client.is_owner(interaction.user): + return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True) + + modal = AddLink(self.client) + await interaction.response.send_modal(modal) + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" @@ -135,7 +158,7 @@ class Owner(commands.Cog): await custom_commands.edit_command(session, command, flags.name, flags.response) return await self.client.confirm_message(ctx.message) except NoResultFoundException: - await ctx.reply(f"No command found matching ``{command}``.") + await ctx.reply(f"No command found matching `{command}`.") return await self.client.reject_message(ctx.message) @edit_slash.command(name="custom", description="Edit a custom command") diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 32eb8b9..881d0d5 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -72,7 +72,7 @@ class School(commands.Cog): ufora_course = await ufora_courses.get_course_by_name(session, course) if ufora_course is None: - return await ctx.reply(f"Found no course matching ``{course}``", ephemeral=True) + return await ctx.reply(f"Found no course matching `{course}`", ephemeral=True) return await ctx.reply( f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{flags.year}", @@ -80,7 +80,7 @@ class School(commands.Cog): ) @study_guide.autocomplete("course") - async def study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def _study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocompletion for the 'course'-parameter""" return [ app_commands.Choice(name=course, value=course) diff --git a/didier/didier.py b/didier/didier.py index f305cad..3217b61 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -111,6 +111,17 @@ class Didier(commands.Bot): for line in fp: self.wordle_words.add(line.strip()) + async def get_reply_target(self, ctx: commands.Context) -> discord.Message: + """Get the target message that should be replied to + + In case the invoking message is a reply to something, reply to the + original message instead + """ + if ctx.message.reference is not None: + return await self.resolve_message(ctx.message.reference) + + return ctx.message + async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: """Fetch a message from a reference""" # Message is in the cache, return it diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py index b28a4de..8a64d4d 100644 --- a/didier/views/modals/__init__.py +++ b/didier/views/modals/__init__.py @@ -1,4 +1,5 @@ from .custom_commands import CreateCustomCommand, EditCustomCommand from .dad_jokes import AddDadJoke +from .links import AddLink -__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand"] +__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand", "AddLink"] diff --git a/didier/views/modals/links.py b/didier/views/modals/links.py new file mode 100644 index 0000000..3d4b1be --- /dev/null +++ b/didier/views/modals/links.py @@ -0,0 +1,37 @@ +import traceback + +import discord.ui +from overrides import overrides + +from database.crud.links import add_link +from didier import Didier + +__all__ = ["AddLink"] + + +class AddLink(discord.ui.Modal, title="Add Link"): + """Modal to add a new link""" + + name = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, placeholder="Source") + url = discord.ui.TextInput( + label="URL", style=discord.TextStyle.short, placeholder="https://github.com/stijndcl/didier" + ) + + client: Didier + + def __init__(self, client: Didier, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + + @overrides + async def on_submit(self, interaction: discord.Interaction): + async with self.client.postgres_session as session: + await add_link(session, self.name.value, self.url.value) + await self.client.database_caches.links.invalidate(session) + + await interaction.response.send_message(f"Successfully added `{self.name.value.capitalize()}`.", ephemeral=True) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__) diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index b613737..3dc6adb 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -7,7 +7,7 @@ from database.utils.caches import UforaCourseCache async def test_ufora_course_cache_refresh_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's empty""" cache = UforaCourseCache() - await cache.refresh(postgres) + await cache.invalidate(postgres) assert len(cache.data) == 1 assert cache.data == ["test"] @@ -20,7 +20,7 @@ async def test_ufora_course_cache_refresh_not_empty(postgres: AsyncSession, ufor cache.data = ["Something"] cache.data_transformed = ["something"] - await cache.refresh(postgres) + await cache.invalidate(postgres) assert len(cache.data) == 1 assert cache.data == ["test"] From 28cf094ea3811f3e46e65c0b942b0f383f73c532 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Wed, 10 Aug 2022 01:12:28 +0200 Subject: [PATCH 2/5] Typing --- didier/views/modals/links.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/didier/views/modals/links.py b/didier/views/modals/links.py index 3d4b1be..5d70a93 100644 --- a/didier/views/modals/links.py +++ b/didier/views/modals/links.py @@ -12,8 +12,8 @@ __all__ = ["AddLink"] class AddLink(discord.ui.Modal, title="Add Link"): """Modal to add a new link""" - name = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, placeholder="Source") - url = discord.ui.TextInput( + name: discord.ui.TextInput = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, placeholder="Source") + url: discord.ui.TextInput = discord.ui.TextInput( label="URL", style=discord.TextStyle.short, placeholder="https://github.com/stijndcl/didier" ) @@ -25,6 +25,12 @@ class AddLink(discord.ui.Modal, title="Add Link"): @overrides async def on_submit(self, interaction: discord.Interaction): + if self.name.value is None: + return await interaction.response.send_message("Required field `Name` cannot be empty.", ephemeral=True) + + if self.url.value is None: + return await interaction.response.send_message("Required field `URL` cannot be empty.", ephemeral=True) + async with self.client.postgres_session as session: await add_link(session, self.name.value, self.url.value) await self.client.database_caches.links.invalidate(session) From 107e4fb580eeec936cec92d27cf9344eabab21f9 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 13 Aug 2022 00:07:48 +0200 Subject: [PATCH 3/5] Displaying deadlines --- alembic/versions/08d21b2d1a0a_deadlines.py | 39 ++++++++++++++ database/crud/deadlines.py | 31 +++++++++++ database/schemas/relational.py | 17 ++++++ didier/cogs/school.py | 11 ++++ didier/data/embeds/deadlines.py | 63 ++++++++++++++++++++++ didier/utils/types/string.py | 9 +++- 6 files changed, 169 insertions(+), 1 deletion(-) create mode 100644 alembic/versions/08d21b2d1a0a_deadlines.py create mode 100644 database/crud/deadlines.py create mode 100644 didier/data/embeds/deadlines.py diff --git a/alembic/versions/08d21b2d1a0a_deadlines.py b/alembic/versions/08d21b2d1a0a_deadlines.py new file mode 100644 index 0000000..25147cf --- /dev/null +++ b/alembic/versions/08d21b2d1a0a_deadlines.py @@ -0,0 +1,39 @@ +"""Deadlines + +Revision ID: 08d21b2d1a0a +Revises: 3962636f3a3d +Create Date: 2022-08-12 23:44:13.947011 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "08d21b2d1a0a" +down_revision = "3962636f3a3d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "deadlines", + sa.Column("deadline_id", sa.Integer(), nullable=False), + sa.Column("course_id", sa.Integer(), nullable=True), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("deadline", sa.DateTime(timezone=True), nullable=False), + sa.ForeignKeyConstraint( + ["course_id"], + ["ufora_courses.course_id"], + ), + sa.PrimaryKeyConstraint("deadline_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("deadlines") + # ### end Alembic commands ### diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py new file mode 100644 index 0000000..338a4c3 --- /dev/null +++ b/database/crud/deadlines.py @@ -0,0 +1,31 @@ +from zoneinfo import ZoneInfo + +from dateutil.parser import parse +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from database.schemas.relational import Deadline + +__all__ = ["add_deadline", "get_deadlines"] + + +async def add_deadline(session: AsyncSession, course_id: int, name: str, date_str: str): + """Add a new deadline""" + date_dt = parse(date_str).replace(tzinfo=ZoneInfo("Europe/Brussels")) + + if date_dt.hour == date_dt.minute == date_dt.second == 0: + date_dt.replace(hour=23, minute=59, second=59) + + deadline = Deadline(course_id=course_id, name=name, deadline=date_dt) + session.add(deadline) + await session.commit() + + +async def get_deadlines(session: AsyncSession) -> list[Deadline]: + """Get a list of all deadlines that are currently known + + This includes deadlines that have passed already + """ + statement = select(Deadline).options(selectinload(Deadline.course)) + return (await session.execute(statement)).scalars().all() diff --git a/database/schemas/relational.py b/database/schemas/relational.py index b0bc5e7..6fd27d0 100644 --- a/database/schemas/relational.py +++ b/database/schemas/relational.py @@ -28,6 +28,7 @@ __all__ = [ "CustomCommand", "CustomCommandAlias", "DadJoke", + "Deadline", "Link", "NightlyData", "Task", @@ -110,6 +111,19 @@ class DadJoke(Base): joke: str = Column(Text, nullable=False) +class Deadline(Base): + """A deadline for a university project""" + + __tablename__ = "deadlines" + + deadline_id: int = Column(Integer, primary_key=True) + course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) + name: str = Column(Text, nullable=False) + deadline: datetime = Column(DateTime(timezone=True), nullable=False) + + course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin") + + class Link(Base): """Useful links that go useful places""" @@ -160,6 +174,9 @@ class UforaCourse(Base): aliases: list[UforaCourseAlias] = relationship( "UforaCourseAlias", back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) + deadlines: list[Deadline] = relationship( + "Deadline", back_populates="course", cascade="all, delete-orphan", lazy="selectin" + ) class UforaCourseAlias(Base): diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 881d0d5..3ccf412 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -5,7 +5,9 @@ from discord import app_commands from discord.ext import commands from database.crud import ufora_courses +from database.crud.deadlines import get_deadlines from didier import Didier +from didier.data.embeds.deadlines import Deadlines from didier.utils.discord.flags.school import StudyGuideFlags @@ -27,6 +29,15 @@ class School(commands.Cog): """Remove the commands when the cog is unloaded""" self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) + @commands.hybrid_command(name="deadlines", description="Show upcoming deadlines") + async def deadlines(self, ctx: commands.Context): + """Show upcoming deadlines""" + async with self.client.postgres_session as session: + deadlines = await get_deadlines(session) + + embed = await Deadlines(deadlines).to_embed() + await ctx.reply(embed=embed, mention_author=False, ephemeral=False) + @commands.command(name="Pin", usage="[Message]") async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None): """Pin a message in the current channel""" diff --git a/didier/data/embeds/deadlines.py b/didier/data/embeds/deadlines.py new file mode 100644 index 0000000..0c07c40 --- /dev/null +++ b/didier/data/embeds/deadlines.py @@ -0,0 +1,63 @@ +import itertools +from datetime import datetime + +import discord +from overrides import overrides + +from database.schemas.relational import Deadline +from didier.data.embeds.base import EmbedBaseModel +from didier.utils.types.datetime import tz_aware_now +from didier.utils.types.string import get_edu_year_name + +__all__ = ["Deadlines"] + + +class Deadlines(EmbedBaseModel): + """Embed that shows all the deadlines of a semester""" + + deadlines: list[Deadline] + + def __init__(self, deadlines: list[Deadline]): + self.deadlines = deadlines + self.deadlines.sort(key=lambda deadline: deadline.deadline) + + @overrides + async def to_embed(self, **kwargs: dict) -> discord.Embed: + embed = discord.Embed(colour=discord.Colour.dark_gold()) + embed.set_author(name="Upcoming Deadlines") + now = tz_aware_now() + + has_active_deadlines = False + deadlines_grouped: dict[int, list[str]] = {} + + deadline: Deadline + for year, deadline in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year): + if year not in deadlines_grouped: + deadlines_grouped[year] = [] + + passed = deadline.deadline <= now + if passed: + has_active_deadlines = True + + deadline_str = ( + f"{deadline.course.name} - {deadline.name}: " + ) + + # Strike through deadlines that aren't active anymore + deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~") + + if not has_active_deadlines: + embed.description = "There are currently no upcoming deadlines." + embed.set_image(url="https://c.tenor.com/RUzJ3lDGQUsAAAAC/iron-man-you-can-rest-now.gif") + return embed + + for i in range(5): + if i not in deadlines_grouped: + continue + + name = get_edu_year_name(i) + description = "\n".join(deadlines_grouped[i]) + + embed.add_field(name=name, value=description, inline=False) + + return embed diff --git a/didier/utils/types/string.py b/didier/utils/types/string.py index 015996a..3a26658 100644 --- a/didier/utils/types/string.py +++ b/didier/utils/types/string.py @@ -1,7 +1,7 @@ import math from typing import Optional -__all__ = ["abbreviate", "leading", "pluralize"] +__all__ = ["abbreviate", "leading", "pluralize", "get_edu_year_name"] def abbreviate(text: str, max_length: int) -> str: @@ -43,3 +43,10 @@ def pluralize(word: str, amount: int, plural_form: Optional[str] = None) -> str: return word return plural_form or (word + "s") + + +def get_edu_year_name(year: int) -> str: + """Get the string representation of a university year""" + years = ["1st Bachelor", "2nd Bachelor", "3rd Bachelor", "1st Master", "2nd Master"] + + return years[year] From e2959c27ad6a84170100462f640fbfcba7e0b694 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 13 Aug 2022 00:41:47 +0200 Subject: [PATCH 4/5] Adding new deadlines --- database/crud/deadlines.py | 13 ++++++-- didier/cogs/owner.py | 21 ++++++++++--- didier/data/embeds/deadlines.py | 1 + didier/views/modals/__init__.py | 3 +- didier/views/modals/custom_commands.py | 2 +- didier/views/modals/deadlines.py | 43 ++++++++++++++++++++++++++ 6 files changed, 73 insertions(+), 10 deletions(-) create mode 100644 didier/views/modals/deadlines.py diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index 338a4c3..8a7ad66 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -1,3 +1,4 @@ +from typing import Optional from zoneinfo import ZoneInfo from dateutil.parser import parse @@ -5,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from database.schemas.relational import Deadline +from database.schemas.relational import Deadline, UforaCourse __all__ = ["add_deadline", "get_deadlines"] @@ -14,6 +15,7 @@ async def add_deadline(session: AsyncSession, course_id: int, name: str, date_st """Add a new deadline""" date_dt = parse(date_str).replace(tzinfo=ZoneInfo("Europe/Brussels")) + # If we only have a day, assume it's the end of the day if date_dt.hour == date_dt.minute == date_dt.second == 0: date_dt.replace(hour=23, minute=59, second=59) @@ -22,10 +24,15 @@ async def add_deadline(session: AsyncSession, course_id: int, name: str, date_st await session.commit() -async def get_deadlines(session: AsyncSession) -> list[Deadline]: +async def get_deadlines(session: AsyncSession, *, course: Optional[UforaCourse] = None) -> list[Deadline]: """Get a list of all deadlines that are currently known This includes deadlines that have passed already """ - statement = select(Deadline).options(selectinload(Deadline.course)) + statement = select(Deadline) + + if course is not None: + statement = statement.where(Deadline.course_id == course.course_id) + + statement = statement.options(selectinload(Deadline.course)) return (await session.execute(statement)).scalars().all() diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index badd0c2..69560ca 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -5,13 +5,14 @@ from discord import app_commands from discord.ext import commands import settings -from database.crud import custom_commands, links +from database.crud import custom_commands, links, ufora_courses from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags from didier.views.modals import ( AddDadJoke, + AddDeadline, AddLink, CreateCustomCommand, EditCustomCommand, @@ -137,6 +138,18 @@ class Owner(commands.Cog): modal = AddDadJoke(self.client) await interaction.response.send_modal(modal) + @add_slash.command(name="deadline", description="Add a deadline") + async def add_deadline_slash(self, interaction: discord.Interaction, course: str): + """Slash command to add a deadline""" + async with self.client.postgres_session as session: + course_instance = await ufora_courses.get_course_by_name(session, course) + + if course_instance is None: + return await interaction.response.send_message(f"No course found matching `{course}`.", ephemeral=True) + + modal = AddDeadline(self.client, course_instance) + await interaction.response.send_modal(modal) + @add_slash.command(name="link", description="Add a new link") async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" @@ -166,15 +179,13 @@ class Owner(commands.Cog): async def edit_custom_slash(self, interaction: discord.Interaction, command: str): """Slash command to edit a custom command""" if not await self.client.is_owner(interaction.user): - return interaction.response.send_message( - "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True - ) + return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True) async with self.client.postgres_session as session: _command = await custom_commands.get_command(session, command) if _command is None: return await interaction.response.send_message( - f"Geen commando gevonden voor ``{command}``.", ephemeral=True + f"No command found matching `{command}`.", ephemeral=True ) modal = EditCustomCommand(self.client, _command.name, _command.response) diff --git a/didier/data/embeds/deadlines.py b/didier/data/embeds/deadlines.py index 0c07c40..244d811 100644 --- a/didier/data/embeds/deadlines.py +++ b/didier/data/embeds/deadlines.py @@ -46,6 +46,7 @@ class Deadlines(EmbedBaseModel): # Strike through deadlines that aren't active anymore deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~") + # Send an easter egg when there are no deadlines if not has_active_deadlines: embed.description = "There are currently no upcoming deadlines." embed.set_image(url="https://c.tenor.com/RUzJ3lDGQUsAAAAC/iron-man-you-can-rest-now.gif") diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py index 8a64d4d..42c2ce8 100644 --- a/didier/views/modals/__init__.py +++ b/didier/views/modals/__init__.py @@ -1,5 +1,6 @@ from .custom_commands import CreateCustomCommand, EditCustomCommand from .dad_jokes import AddDadJoke +from .deadlines import AddDeadline from .links import AddLink -__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand", "AddLink"] +__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink"] diff --git a/didier/views/modals/custom_commands.py b/didier/views/modals/custom_commands.py index 2116bd9..738e74a 100644 --- a/didier/views/modals/custom_commands.py +++ b/didier/views/modals/custom_commands.py @@ -71,7 +71,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): async with self.client.postgres_session as session: await edit_command(session, self.original_name, name_field.value, response_field.value) - await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) + await interaction.response.send_message(f"Successfully edited `{self.original_name}`.", ephemeral=True) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore diff --git a/didier/views/modals/deadlines.py b/didier/views/modals/deadlines.py new file mode 100644 index 0000000..bd8882b --- /dev/null +++ b/didier/views/modals/deadlines.py @@ -0,0 +1,43 @@ +import traceback + +import discord.ui +from discord import Interaction +from overrides import overrides + +from database.crud.deadlines import add_deadline +from database.schemas.relational import UforaCourse + +__all__ = ["AddDeadline"] + +from didier import Didier + + +class AddDeadline(discord.ui.Modal, title="Add Deadline"): + """Modal to add a new deadline""" + + client: Didier + ufora_course: UforaCourse + + name: discord.ui.TextInput = discord.ui.TextInput( + label="Name", placeholder="Project 9001", required=True, style=discord.TextStyle.short + ) + deadline: discord.ui.TextInput = discord.ui.TextInput( + label="Deadline", placeholder="DD/MM/YYYY HH:MM:SS*", required=True, style=discord.TextStyle.short + ) + + def __init__(self, client: Didier, ufora_course: UforaCourse, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + self.ufora_course = ufora_course + + @overrides + async def on_submit(self, interaction: Interaction): + async with self.client.postgres_session as session: + await add_deadline(session, self.ufora_course.course_id, self.name.value, self.deadline.value) + + await interaction.response.send_message("Successfully added new deadline.", ephemeral=True) + + @overrides + async def on_error(self, interaction: Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__) From b26421b8752c261ff16cc4423651cd1d241e051f Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 13 Aug 2022 01:10:50 +0200 Subject: [PATCH 5/5] Fix bug, add autocomplete, make cache autocompletion slightly cleaner --- database/crud/deadlines.py | 2 +- database/utils/caches.py | 12 ++++++++---- didier/cogs/other.py | 7 ++----- didier/cogs/owner.py | 8 ++++++++ didier/cogs/school.py | 13 ++++++------- didier/data/embeds/deadlines.py | 28 ++++++++++++++-------------- didier/views/modals/deadlines.py | 3 +++ requirements-dev.txt | 1 + 8 files changed, 43 insertions(+), 31 deletions(-) diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index 8a7ad66..c1b2885 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -13,7 +13,7 @@ __all__ = ["add_deadline", "get_deadlines"] async def add_deadline(session: AsyncSession, course_id: int, name: str, date_str: str): """Add a new deadline""" - date_dt = parse(date_str).replace(tzinfo=ZoneInfo("Europe/Brussels")) + date_dt = parse(date_str, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) # If we only have a day, assume it's the end of the day if date_dt.hour == date_dt.minute == date_dt.second == 0: diff --git a/database/utils/caches.py b/database/utils/caches.py index b08417d..4f34419 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -1,6 +1,7 @@ from abc import ABC, abstractmethod from typing import Generic, TypeVar +from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession @@ -38,11 +39,13 @@ class DatabaseCache(ABC, Generic[T]): async def invalidate(self, database_session: T): """Invalidate the data stored in this cache""" - def get_autocomplete_suggestions(self, query: str): + def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: """Filter the cache to find everything that matches the search query""" query = query.lower() # Return the original (non-transformed) version of the data for pretty display in Discord - return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] + suggestions = [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] + + return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] class LinkCache(DatabaseCache[AsyncSession]): @@ -86,7 +89,7 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) @overrides - def get_autocomplete_suggestions(self, query: str): + def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: query = query.lower() results = set() @@ -99,7 +102,8 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): if query in alias: results.add(course) - return sorted(list(results)) + suggestions = sorted(list(results)) + return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] class WordleCache(DatabaseCache[MongoDatabase]): diff --git a/didier/cogs/other.py b/didier/cogs/other.py index e00ab0b..870c63b 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -64,12 +64,9 @@ class Other(commands.Cog): return await interaction.response.send_message(link.url) @link_slash.autocomplete("name") - async def _link_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def _link_name_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: """Autocompletion for the 'name'-parameter""" - return [ - app_commands.Choice(name=name, value=name.lower()) - for name in self.client.database_caches.links.get_autocomplete_suggestions(current) - ] + return self.client.database_caches.links.get_autocomplete_suggestions(current) async def setup(client: Didier): diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 69560ca..d030344 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -139,6 +139,7 @@ class Owner(commands.Cog): await interaction.response.send_modal(modal) @add_slash.command(name="deadline", description="Add a deadline") + @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)") async def add_deadline_slash(self, interaction: discord.Interaction, course: str): """Slash command to add a deadline""" async with self.client.postgres_session as session: @@ -150,6 +151,13 @@ class Owner(commands.Cog): modal = AddDeadline(self.client, course_instance) await interaction.response.send_modal(modal) + @add_deadline_slash.autocomplete("course") + async def _add_deadline_course_autocomplete( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'course'-parameter""" + return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) + @add_slash.command(name="link", description="Add a new link") async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 3ccf412..dc3807f 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -35,7 +35,7 @@ class School(commands.Cog): async with self.client.postgres_session as session: deadlines = await get_deadlines(session) - embed = await Deadlines(deadlines).to_embed() + embed = Deadlines(deadlines).to_embed() await ctx.reply(embed=embed, mention_author=False, ephemeral=False) @commands.command(name="Pin", usage="[Message]") @@ -76,7 +76,7 @@ class School(commands.Cog): @commands.hybrid_command( name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"] ) - @app_commands.describe(course="vak") + @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)") async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): """Create links to study guides""" async with self.client.postgres_session as session: @@ -91,12 +91,11 @@ class School(commands.Cog): ) @study_guide.autocomplete("course") - async def _study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + async def _study_guide_course_autocomplete( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: """Autocompletion for the 'course'-parameter""" - return [ - app_commands.Choice(name=course, value=course) - for course in self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) - ] + return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) async def setup(client: Didier): diff --git a/didier/data/embeds/deadlines.py b/didier/data/embeds/deadlines.py index 244d811..6f46bbd 100644 --- a/didier/data/embeds/deadlines.py +++ b/didier/data/embeds/deadlines.py @@ -22,7 +22,7 @@ class Deadlines(EmbedBaseModel): self.deadlines.sort(key=lambda deadline: deadline.deadline) @overrides - async def to_embed(self, **kwargs: dict) -> discord.Embed: + def to_embed(self, **kwargs: dict) -> discord.Embed: embed = discord.Embed(colour=discord.Colour.dark_gold()) embed.set_author(name="Upcoming Deadlines") now = tz_aware_now() @@ -30,33 +30,33 @@ class Deadlines(EmbedBaseModel): has_active_deadlines = False deadlines_grouped: dict[int, list[str]] = {} - deadline: Deadline - for year, deadline in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year): + for year, deadlines in itertools.groupby(self.deadlines, key=lambda _deadline: _deadline.course.year): if year not in deadlines_grouped: deadlines_grouped[year] = [] - passed = deadline.deadline <= now - if passed: - has_active_deadlines = True + for deadline in deadlines: + passed = deadline.deadline <= now + if not passed: + has_active_deadlines = True - deadline_str = ( - f"{deadline.course.name} - {deadline.name}: " - ) + deadline_str = ( + f"{deadline.course.name} - {deadline.name}: " + ) - # Strike through deadlines that aren't active anymore - deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~") + # Strike through deadlines that aren't active anymore + deadlines_grouped[year].append(deadline_str if not passed else f"~~{deadline_str}~~") - # Send an easter egg when there are no deadlines + # Send an Easter egg when there are no deadlines if not has_active_deadlines: embed.description = "There are currently no upcoming deadlines." embed.set_image(url="https://c.tenor.com/RUzJ3lDGQUsAAAAC/iron-man-you-can-rest-now.gif") return embed - for i in range(5): + for i in range(1, 6): if i not in deadlines_grouped: continue - name = get_edu_year_name(i) + name = get_edu_year_name(i - 1) description = "\n".join(deadlines_grouped[i]) embed.add_field(name=name, value=description, inline=False) diff --git a/didier/views/modals/deadlines.py b/didier/views/modals/deadlines.py index bd8882b..cd2a26c 100644 --- a/didier/views/modals/deadlines.py +++ b/didier/views/modals/deadlines.py @@ -32,6 +32,9 @@ class AddDeadline(discord.ui.Modal, title="Add Deadline"): @overrides async def on_submit(self, interaction: Interaction): + if not self.name.value or not self.deadline.value: + return await interaction.response.send_message("Required fields cannot be empty.", ephemeral=True) + async with self.client.postgres_session as session: await add_deadline(session, self.ufora_course.course_id, self.name.value, self.deadline.value) diff --git a/requirements-dev.txt b/requirements-dev.txt index 64b8467..9307b33 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -8,6 +8,7 @@ pytest-asyncio==0.18.3 pytest-env==0.6.2 sqlalchemy2-stubs==0.0.2a23 types-beautifulsoup4==4.11.3 +types-python-dateutil==2.8.19 types-pytz==2021.3.8 # Flake8 + plugins