diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index 8876d1d..cc61bd8 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -1,11 +1,16 @@ from typing import Optional import sqlalchemy.exc -from sqlalchemy import func, select +from sqlalchemy import delete, func, select from sqlalchemy.ext.asyncio import AsyncSession from database.crud.users import get_or_add_user -from database.exceptions import DuplicateInsertException, ForbiddenNameException +from database.exceptions import ( + DuplicateInsertException, + Forbidden, + ForbiddenNameException, + NoResultFoundException, +) from database.schemas import Bookmark __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] @@ -14,7 +19,7 @@ __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_url: str) -> Bookmark: """Create a new bookmark to a message""" # Don't allow bookmarks with names of subcommands - if label.lower() in ["create", "ls", "list", "search"]: + if label.lower() in ["create", "delete", "ls", "list", "Rm", "search"]: raise ForbiddenNameException await get_or_add_user(session, user_id) @@ -30,6 +35,28 @@ async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_ return bookmark +async def delete_bookmark_by_id(session: AsyncSession, user_id: int, bookmark_id: int): + """Find a bookmark by its id & delete it + + This fails if you don't own this bookmark + """ + statement = select(Bookmark).where(Bookmark.bookmark_id == bookmark_id) + bookmark = (await session.execute(statement)).scalar_one_or_none() + + # No bookmark with this id + if bookmark is None: + raise NoResultFoundException + + # You don't own this bookmark + if bookmark.user_id != user_id: + raise Forbidden + + # Delete it + statement = delete(Bookmark).where(Bookmark.bookmark_id == bookmark_id) + await session.execute(statement) + await session.commit() + + async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[str] = None) -> list[Bookmark]: """Get all a user's bookmarks""" statement = select(Bookmark).where(Bookmark.user_id == user_id) diff --git a/database/exceptions/__init__.py b/database/exceptions/__init__.py index f52ea80..ece1d0e 100644 --- a/database/exceptions/__init__.py +++ b/database/exceptions/__init__.py @@ -1,10 +1,12 @@ from .constraints import DuplicateInsertException, ForbiddenNameException from .currency import DoubleNightly, NotEnoughDinks +from .forbidden import Forbidden from .not_found import NoResultFoundException __all__ = [ "DuplicateInsertException", "ForbiddenNameException", + "Forbidden", "DoubleNightly", "NotEnoughDinks", "NoResultFoundException", diff --git a/database/exceptions/forbidden.py b/database/exceptions/forbidden.py new file mode 100644 index 0000000..dadeec1 --- /dev/null +++ b/database/exceptions/forbidden.py @@ -0,0 +1,5 @@ +__all__ = ["Forbidden"] + + +class Forbidden(Exception): + """Exception raised when trying to access a resource that isn't yours""" diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index be6fa90..d90752a 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -5,7 +5,12 @@ from discord import app_commands from discord.ext import commands from database.crud import birthdays, bookmarks -from database.exceptions import DuplicateInsertException, ForbiddenNameException +from database.exceptions import ( + DuplicateInsertException, + Forbidden, + ForbiddenNameException, + NoResultFoundException, +) from didier import Didier from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource @@ -73,8 +78,12 @@ class Discord(commands.Cog): await self.client.confirm_message(ctx.message) @commands.group(name="Bookmark", aliases=["Bm", "Bookmarks"], case_insensitive=True, invoke_without_command=True) - async def bookmark(self, ctx: commands.Context, label: str): + async def bookmark(self, ctx: commands.Context, *, label: Optional[str] = None): """Post a bookmarked message""" + # No label: shortcut to display bookmarks + if label is None: + return await self.bookmark_search(ctx, query=None) + async with self.client.postgres_session as session: result = expect( await bookmarks.get_bookmark_by_name(session, ctx.author.id, label), @@ -107,6 +116,28 @@ class Discord(commands.Cog): # Label isn't allowed return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) + @bookmark.command(name="Delete", aliases=["Rm"]) + async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): + """Delete a bookmark by its id""" + # The bookmarks are displayed with a hashtag in front of the id + # so strip it out in case people want to try and use this + bookmark_id = bookmark_id.removeprefix("#") + + try: + bookmark_id_int = int(bookmark_id) + except ValueError: + return await ctx.reply(f"`{bookmark_id}` is not a valid bookmark id.", mention_author=False) + + async with self.client.postgres_session as session: + try: + await bookmarks.delete_bookmark_by_id(session, ctx.author.id, bookmark_id_int) + except NoResultFoundException: + return await ctx.reply(f"Found no bookmark with id `#{bookmark_id_int}`.", mention_author=False) + except Forbidden: + return await ctx.reply(f"You don't own bookmark `#{bookmark_id_int}`.", mention_author=False) + + return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) + @bookmark.command(name="Search", aliases=["List", "Ls"]) async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): """Search through the list of bookmarks""" diff --git a/didier/menus/common.py b/didier/menus/common.py index 763d976..6f871d0 100644 --- a/didier/menus/common.py +++ b/didier/menus/common.py @@ -17,11 +17,12 @@ class PageSource(ABC, Generic[T]): """Base class that handles the embeds displayed in a menu""" dataset: list[T] - embeds: list[discord.Embed] = [] + embeds: list[discord.Embed] page_count: int per_page: int def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10): + self.embeds = [] self.dataset = dataset self.per_page = per_page self.page_count = self._get_page_count()