mirror of https://github.com/stijndcl/didier
				
				
				
			Deleting bookmarks, fix bug in menus
							parent
							
								
									f70736b4d5
								
							
						
					
					
						commit
						152f84ed1c
					
				|  | @ -1,11 +1,16 @@ | |||
| from typing import Optional | ||||
| 
 | ||||
| import sqlalchemy.exc | ||||
| from sqlalchemy import func, select | ||||
| from sqlalchemy import delete, func, select | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.crud.users import get_or_add_user | ||||
| from database.exceptions import DuplicateInsertException, ForbiddenNameException | ||||
| from database.exceptions import ( | ||||
|     DuplicateInsertException, | ||||
|     Forbidden, | ||||
|     ForbiddenNameException, | ||||
|     NoResultFoundException, | ||||
| ) | ||||
| from database.schemas import Bookmark | ||||
| 
 | ||||
| __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] | ||||
|  | @ -14,7 +19,7 @@ __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"] | |||
| async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_url: str) -> Bookmark: | ||||
|     """Create a new bookmark to a message""" | ||||
|     # Don't allow bookmarks with names of subcommands | ||||
|     if label.lower() in ["create", "ls", "list", "search"]: | ||||
|     if label.lower() in ["create", "delete", "ls", "list", "Rm", "search"]: | ||||
|         raise ForbiddenNameException | ||||
| 
 | ||||
|     await get_or_add_user(session, user_id) | ||||
|  | @ -30,6 +35,28 @@ async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_ | |||
|     return bookmark | ||||
| 
 | ||||
| 
 | ||||
| async def delete_bookmark_by_id(session: AsyncSession, user_id: int, bookmark_id: int): | ||||
|     """Find a bookmark by its id & delete it | ||||
| 
 | ||||
|     This fails if you don't own this bookmark | ||||
|     """ | ||||
|     statement = select(Bookmark).where(Bookmark.bookmark_id == bookmark_id) | ||||
|     bookmark = (await session.execute(statement)).scalar_one_or_none() | ||||
| 
 | ||||
|     # No bookmark with this id | ||||
|     if bookmark is None: | ||||
|         raise NoResultFoundException | ||||
| 
 | ||||
|     # You don't own this bookmark | ||||
|     if bookmark.user_id != user_id: | ||||
|         raise Forbidden | ||||
| 
 | ||||
|     # Delete it | ||||
|     statement = delete(Bookmark).where(Bookmark.bookmark_id == bookmark_id) | ||||
|     await session.execute(statement) | ||||
|     await session.commit() | ||||
| 
 | ||||
| 
 | ||||
| async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[str] = None) -> list[Bookmark]: | ||||
|     """Get all a user's bookmarks""" | ||||
|     statement = select(Bookmark).where(Bookmark.user_id == user_id) | ||||
|  |  | |||
|  | @ -1,10 +1,12 @@ | |||
| from .constraints import DuplicateInsertException, ForbiddenNameException | ||||
| from .currency import DoubleNightly, NotEnoughDinks | ||||
| from .forbidden import Forbidden | ||||
| from .not_found import NoResultFoundException | ||||
| 
 | ||||
| __all__ = [ | ||||
|     "DuplicateInsertException", | ||||
|     "ForbiddenNameException", | ||||
|     "Forbidden", | ||||
|     "DoubleNightly", | ||||
|     "NotEnoughDinks", | ||||
|     "NoResultFoundException", | ||||
|  |  | |||
|  | @ -0,0 +1,5 @@ | |||
| __all__ = ["Forbidden"] | ||||
| 
 | ||||
| 
 | ||||
| class Forbidden(Exception): | ||||
|     """Exception raised when trying to access a resource that isn't yours""" | ||||
|  | @ -5,7 +5,12 @@ from discord import app_commands | |||
| from discord.ext import commands | ||||
| 
 | ||||
| from database.crud import birthdays, bookmarks | ||||
| from database.exceptions import DuplicateInsertException, ForbiddenNameException | ||||
| from database.exceptions import ( | ||||
|     DuplicateInsertException, | ||||
|     Forbidden, | ||||
|     ForbiddenNameException, | ||||
|     NoResultFoundException, | ||||
| ) | ||||
| from didier import Didier | ||||
| from didier.exceptions import expect | ||||
| from didier.menus.bookmarks import BookmarkSource | ||||
|  | @ -73,8 +78,12 @@ class Discord(commands.Cog): | |||
|             await self.client.confirm_message(ctx.message) | ||||
| 
 | ||||
|     @commands.group(name="Bookmark", aliases=["Bm", "Bookmarks"], case_insensitive=True, invoke_without_command=True) | ||||
|     async def bookmark(self, ctx: commands.Context, label: str): | ||||
|     async def bookmark(self, ctx: commands.Context, *, label: Optional[str] = None): | ||||
|         """Post a bookmarked message""" | ||||
|         # No label: shortcut to display bookmarks | ||||
|         if label is None: | ||||
|             return await self.bookmark_search(ctx, query=None) | ||||
| 
 | ||||
|         async with self.client.postgres_session as session: | ||||
|             result = expect( | ||||
|                 await bookmarks.get_bookmark_by_name(session, ctx.author.id, label), | ||||
|  | @ -107,6 +116,28 @@ class Discord(commands.Cog): | |||
|             # Label isn't allowed | ||||
|             return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) | ||||
| 
 | ||||
|     @bookmark.command(name="Delete", aliases=["Rm"]) | ||||
|     async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): | ||||
|         """Delete a bookmark by its id""" | ||||
|         # The bookmarks are displayed with a hashtag in front of the id | ||||
|         # so strip it out in case people want to try and use this | ||||
|         bookmark_id = bookmark_id.removeprefix("#") | ||||
| 
 | ||||
|         try: | ||||
|             bookmark_id_int = int(bookmark_id) | ||||
|         except ValueError: | ||||
|             return await ctx.reply(f"`{bookmark_id}` is not a valid bookmark id.", mention_author=False) | ||||
| 
 | ||||
|         async with self.client.postgres_session as session: | ||||
|             try: | ||||
|                 await bookmarks.delete_bookmark_by_id(session, ctx.author.id, bookmark_id_int) | ||||
|             except NoResultFoundException: | ||||
|                 return await ctx.reply(f"Found no bookmark with id `#{bookmark_id_int}`.", mention_author=False) | ||||
|             except Forbidden: | ||||
|                 return await ctx.reply(f"You don't own bookmark `#{bookmark_id_int}`.", mention_author=False) | ||||
| 
 | ||||
|         return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) | ||||
| 
 | ||||
|     @bookmark.command(name="Search", aliases=["List", "Ls"]) | ||||
|     async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): | ||||
|         """Search through the list of bookmarks""" | ||||
|  |  | |||
|  | @ -17,11 +17,12 @@ class PageSource(ABC, Generic[T]): | |||
|     """Base class that handles the embeds displayed in a menu""" | ||||
| 
 | ||||
|     dataset: list[T] | ||||
|     embeds: list[discord.Embed] = [] | ||||
|     embeds: list[discord.Embed] | ||||
|     page_count: int | ||||
|     per_page: int | ||||
| 
 | ||||
|     def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10): | ||||
|         self.embeds = [] | ||||
|         self.dataset = dataset | ||||
|         self.per_page = per_page | ||||
|         self.page_count = self._get_page_count() | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue