From 12d2017cbe4822d0569b1a25b4144d214a29d025 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 30 Aug 2022 01:32:46 +0200 Subject: [PATCH 1/4] Creating bookmarks + message command --- alembic/versions/f5da771a155d_bookmarks.py | 40 +++++++++++++ database/crud/bookmarks.py | 46 +++++++++++++++ database/exceptions/__init__.py | 10 +++- database/exceptions/constraints.py | 6 +- database/schemas.py | 19 +++++++ didier/cogs/discord.py | 65 +++++++++++++++++++--- didier/views/modals/__init__.py | 11 +++- didier/views/modals/bookmarks.py | 48 ++++++++++++++++ 8 files changed, 234 insertions(+), 11 deletions(-) create mode 100644 alembic/versions/f5da771a155d_bookmarks.py create mode 100644 database/crud/bookmarks.py create mode 100644 didier/views/modals/bookmarks.py diff --git a/alembic/versions/f5da771a155d_bookmarks.py b/alembic/versions/f5da771a155d_bookmarks.py new file mode 100644 index 0000000..154b907 --- /dev/null +++ b/alembic/versions/f5da771a155d_bookmarks.py @@ -0,0 +1,40 @@ +"""Bookmarks + +Revision ID: f5da771a155d +Revises: 38b7c29f10ee +Create Date: 2022-08-30 01:08:54.323883 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "f5da771a155d" +down_revision = "38b7c29f10ee" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "bookmarks", + sa.Column("bookmark_id", sa.Integer(), nullable=False), + sa.Column("label", sa.Text(), nullable=False), + sa.Column("jump_url", sa.Text(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("bookmark_id"), + sa.UniqueConstraint("user_id", "label"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("bookmarks") + # ### end Alembic commands ### diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py new file mode 100644 index 0000000..8876d1d --- /dev/null +++ b/database/crud/bookmarks.py @@ -0,0 +1,46 @@ +from typing import Optional + +import sqlalchemy.exc +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud.users import get_or_add_user +from database.exceptions import DuplicateInsertException, ForbiddenNameException +from database.schemas import Bookmark + +__all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] + + +async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_url: str) -> Bookmark: + """Create a new bookmark to a message""" + # Don't allow bookmarks with names of subcommands + if label.lower() in ["create", "ls", "list", "search"]: + raise ForbiddenNameException + + await get_or_add_user(session, user_id) + + try: + bookmark = Bookmark(label=label, jump_url=jump_url, user_id=user_id) + session.add(bookmark) + await session.commit() + await session.refresh(bookmark) + except sqlalchemy.exc.IntegrityError as e: + raise DuplicateInsertException from e + + return bookmark + + +async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[str] = None) -> list[Bookmark]: + """Get all a user's bookmarks""" + statement = select(Bookmark).where(Bookmark.user_id == user_id) + + if query is not None: + statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%")) + + return (await session.execute(statement)).scalars().all() + + +async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]: + """Try to find a bookmark by its name""" + statement = select(Bookmark).where(Bookmark.user_id == user_id).where(func.lower(Bookmark.label) == query.lower()) + return (await session.execute(statement)).scalar_one_or_none() diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py index 1751bc5..f52ea80 100644 --- a/database/exceptions/__init__.py +++ b/database/exceptions/__init__.py @@ -1,5 +1,11 @@ -from .constraints import DuplicateInsertException +from .constraints import DuplicateInsertException, ForbiddenNameException from .currency import DoubleNightly, NotEnoughDinks from .not_found import NoResultFoundException -__all__ = ["DuplicateInsertException", "DoubleNightly", "NotEnoughDinks", "NoResultFoundException"] +__all__ = [ + "DuplicateInsertException", + "ForbiddenNameException", + "DoubleNightly", + "NotEnoughDinks", + "NoResultFoundException", +] diff --git a/database/exceptions/constraints.py b/database/exceptions/constraints.py index 1087d6e..2970d6c 100644 --- a/database/exceptions/constraints.py +++ b/database/exceptions/constraints.py @@ -1,5 +1,9 @@ -__all__ = ["DuplicateInsertException"] +__all__ = ["DuplicateInsertException", "ForbiddenNameException"] class DuplicateInsertException(Exception): """Exception raised when a value already exists""" + + +class ForbiddenNameException(Exception): + """Exception raised when trying to insert something with a name that isn't allowed""" diff --git a/database/schemas.py b/database/schemas.py index 182653f..f8fa018 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -13,6 +13,7 @@ from sqlalchemy import ( ForeignKey, Integer, Text, + UniqueConstraint, ) from sqlalchemy.orm import declarative_base, relationship @@ -25,6 +26,7 @@ __all__ = [ "Base", "Bank", "Birthday", + "Bookmark", "CustomCommand", "CustomCommandAlias", "DadJoke", @@ -78,6 +80,20 @@ class Birthday(Base): user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") +class Bookmark(Base): + """A bookmark to a given message""" + + __tablename__ = "bookmarks" + __table_args__ = (UniqueConstraint("user_id", "label"),) + + bookmark_id: int = Column(Integer, primary_key=True) + label: str = Column(Text, nullable=False) + jump_url: str = Column(Text, nullable=False) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + + user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin") + + class CustomCommand(Base): """Custom commands to fill the hole Dyno couldn't""" @@ -231,6 +247,9 @@ class User(Base): birthday: Optional[Birthday] = relationship( "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) + bookmarks: list[Bookmark] = relationship( + "Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + ) nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index 142e7d0..23c7c35 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -4,10 +4,13 @@ import discord from discord import app_commands from discord.ext import commands -from database.crud import birthdays +from database.crud import birthdays, bookmarks +from database.exceptions import DuplicateInsertException, ForbiddenNameException from didier import Didier +from didier.exceptions import expect from didier.utils.types.datetime import str_to_date from didier.utils.types.string import leading +from didier.views.modals import CreateBookmark class Discord(commands.Cog): @@ -16,16 +19,20 @@ class Discord(commands.Cog): client: Didier # Context-menu references + _bookmark_ctx_menu: app_commands.ContextMenu _pin_ctx_menu: app_commands.ContextMenu def __init__(self, client: Didier): self.client = client - self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx) + self._bookmark_ctx_menu = app_commands.ContextMenu(name="Bookmark", callback=self._bookmark_ctx) + self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx) + self.client.tree.add_command(self._bookmark_ctx_menu) self.client.tree.add_command(self._pin_ctx_menu) async def cog_unload(self) -> None: """Remove the commands when the cog is unloaded""" + self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) @commands.group(name="Birthday", aliases=["Bd", "Birthdays"], case_insensitive=True, invoke_without_command=True) @@ -35,14 +42,14 @@ class Discord(commands.Cog): async with self.client.postgres_session as session: birthday = await birthdays.get_birthday_for_user(session, user_id) - name = "Jouw" if user is None else f"{user.display_name}'s" + name = "Your" if user is None else f"{user.display_name}'s" if birthday is None: - return await ctx.reply(f"{name} verjaardag zit niet in de database.", mention_author=False) + return await ctx.reply(f"I don't know {name} birthday.", mention_author=False) day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) - return await ctx.reply(f"{name} verjaardag staat ingesteld op **{day}/{month}**.", mention_author=False) + return await ctx.reply(f"{name} birthday is set to **{day}/{month}**.", mention_author=False) @birthday.command(name="Set", aliases=["Config"]) async def birthday_set(self, ctx: commands.Context, date_str: str): @@ -56,12 +63,56 @@ class Discord(commands.Cog): date.replace(year=default_year) except ValueError: - return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) + return await ctx.reply(f"`{date_str}` is not a valid date.", mention_author=False) async with self.client.postgres_session as session: await birthdays.add_birthday(session, ctx.author.id, date) await self.client.confirm_message(ctx.message) + @commands.group(name="Bookmark", aliases=["Bm", "Bookmarks"], case_insensitive=True, invoke_without_command=True) + async def bookmark(self, ctx: commands.Context, label: str): + """Post a bookmarked message""" + async with self.client.postgres_session as session: + result = expect( + await bookmarks.get_bookmark_by_name(session, ctx.author.id, label), + entity_type="bookmark", + argument="label", + ) + await ctx.reply(result.jump_url, mention_author=False) + + @bookmark.command(name="Create", aliases=["New"]) + async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]): + """Create a new bookmark""" + # If no message was passed, allow replying to the message that should be bookmarked + if message is None and ctx.message.reference is not None: + message = await self.client.resolve_message(ctx.message.reference) + + # Didn't fix it, so no message was found + if message is None: + return await ctx.reply("Found no message to bookmark.", delete_after=10) + + # Create new bookmark + + try: + async with self.client.postgres_session as session: + bm = await bookmarks.create_bookmark(session, ctx.author.id, label, message.jump_url) + await ctx.reply(f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", mention_author=False) + except DuplicateInsertException: + # Label is already in use + return await ctx.reply(f"You already have a bookmark named `{label}`.", mention_author=False) + except ForbiddenNameException: + # Label isn't allowed + return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) + + @bookmark.command(name="Search", aliases=["List", "Ls"]) + async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): + """Search through the list of bookmarks""" + + async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message): + """Create a bookmark out of this message""" + modal = CreateBookmark(self.client, message.jump_url) + await interaction.response.send_modal(modal) + @commands.command(name="Join", usage="[Thread]") async def join(self, ctx: commands.Context, thread: discord.Thread): """Make Didier join a thread""" @@ -88,7 +139,7 @@ class Discord(commands.Cog): await message.pin(reason=f"Didier Pin by {ctx.author.display_name}") await message.add_reaction("📌") - async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message): + async def _pin_ctx(self, interaction: discord.Interaction, message: discord.Message): """Pin a message in the current channel""" # Is already pinned if message.pinned: diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py index 62d700f..e9f92f0 100644 --- a/didier/views/modals/__init__.py +++ b/didier/views/modals/__init__.py @@ -1,7 +1,16 @@ +from .bookmarks import CreateBookmark from .custom_commands import CreateCustomCommand, EditCustomCommand from .dad_jokes import AddDadJoke from .deadlines import AddDeadline from .links import AddLink from .memes import GenerateMeme -__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink", "GenerateMeme"] +__all__ = [ + "CreateBookmark", + "AddDadJoke", + "AddDeadline", + "CreateCustomCommand", + "EditCustomCommand", + "AddLink", + "GenerateMeme", +] diff --git a/didier/views/modals/bookmarks.py b/didier/views/modals/bookmarks.py new file mode 100644 index 0000000..f77b608 --- /dev/null +++ b/didier/views/modals/bookmarks.py @@ -0,0 +1,48 @@ +import traceback + +import discord.ui +from overrides import overrides + +from database.crud.bookmarks import create_bookmark +from database.exceptions import DuplicateInsertException, ForbiddenNameException +from didier import Didier + +__all__ = ["CreateBookmark"] + + +class CreateBookmark(discord.ui.Modal, title="Create Bookmark"): + """Modal to create a bookmark""" + + client: Didier + jump_url: str + + name: discord.ui.TextInput = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, required=True) + + def __init__(self, client: Didier, jump_url: str, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + self.jump_url = jump_url + + @overrides + async def on_submit(self, interaction: discord.Interaction): + label = self.name.value.strip() + + try: + async with self.client.postgres_session as session: + bm = await create_bookmark(session, interaction.user.id, label, self.jump_url) + return await interaction.response.send_message( + f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True + ) + except DuplicateInsertException: + # Label is already in use + return await interaction.response.send_message( + f"You already have a bookmark named `{label}`.", ephemeral=True + ) + except ForbiddenNameException: + # Label isn't allowed + return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__) From f70736b4d55a6853dadb2b89bcecdd69f1f4944c Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 30 Aug 2022 01:55:40 +0200 Subject: [PATCH 2/4] Searching for bookmarks --- didier/cogs/discord.py | 16 +++++++++++ didier/{utils/discord => }/menus/__init__.py | 0 didier/menus/bookmarks.py | 29 ++++++++++++++++++++ didier/{utils/discord => }/menus/common.py | 16 +++++++---- didier/utils/discord/assets.py | 12 ++++++++ 5 files changed, 68 insertions(+), 5 deletions(-) rename didier/{utils/discord => }/menus/__init__.py (100%) create mode 100644 didier/menus/bookmarks.py rename didier/{utils/discord => }/menus/common.py (94%) create mode 100644 didier/utils/discord/assets.py diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index 23c7c35..be6fa90 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -8,6 +8,9 @@ from database.crud import birthdays, bookmarks from database.exceptions import DuplicateInsertException, ForbiddenNameException from didier import Didier from didier.exceptions import expect +from didier.menus.bookmarks import BookmarkSource +from didier.menus.common import Menu +from didier.utils.discord.assets import get_author_avatar from didier.utils.types.datetime import str_to_date from didier.utils.types.string import leading from didier.views.modals import CreateBookmark @@ -107,6 +110,19 @@ class Discord(commands.Cog): @bookmark.command(name="Search", aliases=["List", "Ls"]) async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): """Search through the list of bookmarks""" + async with self.client.postgres_session as session: + results = await bookmarks.get_bookmarks(session, ctx.author.id, query=query) + + if not results: + embed = discord.Embed(title="Bookmarks", colour=discord.Colour.red()) + avatar_url = get_author_avatar(ctx).url + embed.set_author(name=ctx.author.display_name, icon_url=avatar_url) + embed.description = "You haven't created any bookmarks yet." + return await ctx.reply(embed=embed, mention_author=False) + + source = BookmarkSource(ctx, results) + menu = Menu(source) + await menu.start(ctx) async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message): """Create a bookmark out of this message""" diff --git a/didier/utils/discord/menus/__init__.py b/didier/menus/__init__.py similarity index 100% rename from didier/utils/discord/menus/__init__.py rename to didier/menus/__init__.py diff --git a/didier/menus/bookmarks.py b/didier/menus/bookmarks.py new file mode 100644 index 0000000..101cd6e --- /dev/null +++ b/didier/menus/bookmarks.py @@ -0,0 +1,29 @@ +import discord +from discord.ext import commands +from overrides import overrides + +from database.schemas import Bookmark +from didier.menus.common import PageSource + +__all__ = ["BookmarkSource"] + +from didier.utils.discord.assets import get_author_avatar + + +class BookmarkSource(PageSource[Bookmark]): + """PageSource for the Bookmark commands""" + + @overrides + def create_embeds(self, ctx: commands.Context): + for page in range(self.page_count): + embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue()) + avatar_url = get_author_avatar(ctx).url + embed.set_author(name=ctx.author.display_name, icon_url=avatar_url) + + description = "" + + for bookmark in self.dataset[page : page + self.per_page]: + description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n" + + embed.description = description.strip() + self.embeds.append(embed) diff --git a/didier/utils/discord/menus/common.py b/didier/menus/common.py similarity index 94% rename from didier/utils/discord/menus/common.py rename to didier/menus/common.py index 983723d..763d976 100644 --- a/didier/utils/discord/menus/common.py +++ b/didier/menus/common.py @@ -21,11 +21,11 @@ class PageSource(ABC, Generic[T]): page_count: int per_page: int - def __init__(self, dataset: list[T], *, per_page: int = 10): + def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10): self.dataset = dataset self.per_page = per_page self.page_count = self._get_page_count() - self.create_embeds() + self.create_embeds(ctx) self._add_embed_page_footers() def _get_page_count(self) -> int: @@ -47,7 +47,7 @@ class PageSource(ABC, Generic[T]): embed.set_footer(text=f"{i + 1}/{self.page_count}") @abstractmethod - def create_embeds(self): + def create_embeds(self, ctx: commands.Context): """Method that builds the list of embeds from the input data""" raise NotImplementedError @@ -68,6 +68,10 @@ class Menu(discord.ui.View): def do_button_disabling(self): """Disable buttons depending on the current page""" + # No items to disable + if not self.children: + return + first_page = cast(discord.ui.Button, self.children[0]) first_page.disabled = self.current_page == 0 @@ -87,8 +91,6 @@ class Menu(discord.ui.View): """ self.do_button_disabling() - print(self.current_page, self.source[self.current_page].footer.text) - # Send the initial message if there is none yet, else edit the existing one if self.message is None: self.message = await self.ctx.reply( @@ -100,6 +102,10 @@ class Menu(discord.ui.View): async def start(self, ctx: commands.Context): """Send the initial message with this menu""" self.ctx = ctx + + if len(self.source) == 1: + self.clear_items() + await self.display_current_state() async def stop_view(self, interaction: Optional[discord.Interaction] = None): diff --git a/didier/utils/discord/assets.py b/didier/utils/discord/assets.py new file mode 100644 index 0000000..90473a6 --- /dev/null +++ b/didier/utils/discord/assets.py @@ -0,0 +1,12 @@ +from typing import Union + +import discord +from discord.ext import commands + +__all__ = ["get_author_avatar"] + + +def get_author_avatar(ctx: Union[commands.Context, discord.Interaction]) -> discord.Asset: + """Get a user's avatar asset""" + author = ctx.author if isinstance(ctx, commands.Context) else ctx.user + return author.avatar or author.default_avatar From 152f84ed1c78773a5d5f313c4f7400dfcf4610fb Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 1 Sep 2022 01:02:18 +0200 Subject: [PATCH 3/4] Deleting bookmarks, fix bug in menus --- database/crud/bookmarks.py | 33 +++++++++++++++++++++++++++--- database/exceptions/__init__.py | 2 ++ database/exceptions/forbidden.py | 5 +++++ didier/cogs/discord.py | 35 ++++++++++++++++++++++++++++++-- didier/menus/common.py | 3 ++- 5 files changed, 72 insertions(+), 6 deletions(-) create mode 100644 database/exceptions/forbidden.py diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index 8876d1d..cc61bd8 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -1,11 +1,16 @@ from typing import Optional import sqlalchemy.exc -from sqlalchemy import func, select +from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from database.crud.users import get_or_add_user -from database.exceptions import DuplicateInsertException, ForbiddenNameException +from database.exceptions import ( + DuplicateInsertException, + Forbidden, + ForbiddenNameException, + NoResultFoundException, +) from database.schemas import Bookmark __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] @@ -14,7 +19,7 @@ __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_url: str) -> Bookmark: """Create a new bookmark to a message""" # Don't allow bookmarks with names of subcommands - if label.lower() in ["create", "ls", "list", "search"]: + if label.lower() in ["create", "delete", "ls", "list", "Rm", "search"]: raise ForbiddenNameException await get_or_add_user(session, user_id) @@ -30,6 +35,28 @@ async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_ return bookmark +async def delete_bookmark_by_id(session: AsyncSession, user_id: int, bookmark_id: int): + """Find a bookmark by its id & delete it + + This fails if you don't own this bookmark + """ + statement = select(Bookmark).where(Bookmark.bookmark_id == bookmark_id) + bookmark = (await session.execute(statement)).scalar_one_or_none() + + # No bookmark with this id + if bookmark is None: + raise NoResultFoundException + + # You don't own this bookmark + if bookmark.user_id != user_id: + raise Forbidden + + # Delete it + statement = delete(Bookmark).where(Bookmark.bookmark_id == bookmark_id) + await session.execute(statement) + await session.commit() + + async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[str] = None) -> list[Bookmark]: """Get all a user's bookmarks""" statement = select(Bookmark).where(Bookmark.user_id == user_id) diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py index f52ea80..ece1d0e 100644 --- a/database/exceptions/__init__.py +++ b/database/exceptions/__init__.py @@ -1,10 +1,12 @@ from .constraints import DuplicateInsertException, ForbiddenNameException from .currency import DoubleNightly, NotEnoughDinks +from .forbidden import Forbidden from .not_found import NoResultFoundException __all__ = [ "DuplicateInsertException", "ForbiddenNameException", + "Forbidden", "DoubleNightly", "NotEnoughDinks", "NoResultFoundException", diff --git a/database/exceptions/forbidden.py b/database/exceptions/forbidden.py new file mode 100644 index 0000000..dadeec1 --- /dev/null +++ b/database/exceptions/forbidden.py @@ -0,0 +1,5 @@ +__all__ = ["Forbidden"] + + +class Forbidden(Exception): + """Exception raised when trying to access a resource that isn't yours""" diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index be6fa90..d90752a 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -5,7 +5,12 @@ from discord import app_commands from discord.ext import commands from database.crud import birthdays, bookmarks -from database.exceptions import DuplicateInsertException, ForbiddenNameException +from database.exceptions import ( + DuplicateInsertException, + Forbidden, + ForbiddenNameException, + NoResultFoundException, +) from didier import Didier from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource @@ -73,8 +78,12 @@ class Discord(commands.Cog): await self.client.confirm_message(ctx.message) @commands.group(name="Bookmark", aliases=["Bm", "Bookmarks"], case_insensitive=True, invoke_without_command=True) - async def bookmark(self, ctx: commands.Context, label: str): + async def bookmark(self, ctx: commands.Context, *, label: Optional[str] = None): """Post a bookmarked message""" + # No label: shortcut to display bookmarks + if label is None: + return await self.bookmark_search(ctx, query=None) + async with self.client.postgres_session as session: result = expect( await bookmarks.get_bookmark_by_name(session, ctx.author.id, label), @@ -107,6 +116,28 @@ class Discord(commands.Cog): # Label isn't allowed return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) + @bookmark.command(name="Delete", aliases=["Rm"]) + async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): + """Delete a bookmark by its id""" + # The bookmarks are displayed with a hashtag in front of the id + # so strip it out in case people want to try and use this + bookmark_id = bookmark_id.removeprefix("#") + + try: + bookmark_id_int = int(bookmark_id) + except ValueError: + return await ctx.reply(f"`{bookmark_id}` is not a valid bookmark id.", mention_author=False) + + async with self.client.postgres_session as session: + try: + await bookmarks.delete_bookmark_by_id(session, ctx.author.id, bookmark_id_int) + except NoResultFoundException: + return await ctx.reply(f"Found no bookmark with id `#{bookmark_id_int}`.", mention_author=False) + except Forbidden: + return await ctx.reply(f"You don't own bookmark `#{bookmark_id_int}`.", mention_author=False) + + return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) + @bookmark.command(name="Search", aliases=["List", "Ls"]) async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): """Search through the list of bookmarks""" diff --git a/didier/menus/common.py b/didier/menus/common.py index 763d976..6f871d0 100644 --- a/didier/menus/common.py +++ b/didier/menus/common.py @@ -17,11 +17,12 @@ class PageSource(ABC, Generic[T]): """Base class that handles the embeds displayed in a menu""" dataset: list[T] - embeds: list[discord.Embed] = [] + embeds: list[discord.Embed] page_count: int per_page: int def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10): + self.embeds = [] self.dataset = dataset self.per_page = per_page self.page_count = self._get_page_count() From 149d132e6de3f45ed4e82fa0e0d9d14155fe50ad Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 1 Sep 2022 01:26:46 +0200 Subject: [PATCH 4/4] Fix typing --- database/crud/bookmarks.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index cc61bd8..d696e50 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -40,8 +40,8 @@ async def delete_bookmark_by_id(session: AsyncSession, user_id: int, bookmark_id This fails if you don't own this bookmark """ - statement = select(Bookmark).where(Bookmark.bookmark_id == bookmark_id) - bookmark = (await session.execute(statement)).scalar_one_or_none() + select_statement = select(Bookmark).where(Bookmark.bookmark_id == bookmark_id) + bookmark = (await session.execute(select_statement)).scalar_one_or_none() # No bookmark with this id if bookmark is None: @@ -52,8 +52,8 @@ async def delete_bookmark_by_id(session: AsyncSession, user_id: int, bookmark_id raise Forbidden # Delete it - statement = delete(Bookmark).where(Bookmark.bookmark_id == bookmark_id) - await session.execute(statement) + delete_statement = delete(Bookmark).where(Bookmark.bookmark_id == bookmark_id) + await session.execute(delete_statement) await session.commit()