mirror of https://github.com/stijndcl/didier
Deleting bookmarks, fix bug in menus
parent
f70736b4d5
commit
152f84ed1c
|
@ -1,11 +1,16 @@
|
|||
from typing import Optional
|
||||
|
||||
import sqlalchemy.exc
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.exceptions import DuplicateInsertException, ForbiddenNameException
|
||||
from database.exceptions import (
|
||||
DuplicateInsertException,
|
||||
Forbidden,
|
||||
ForbiddenNameException,
|
||||
NoResultFoundException,
|
||||
)
|
||||
from database.schemas import Bookmark
|
||||
|
||||
__all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"]
|
||||
|
@ -14,7 +19,7 @@ __all__ = ["create_bookmark", "get_bookmarks", "get_bookmark_by_name"]
|
|||
async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_url: str) -> Bookmark:
|
||||
"""Create a new bookmark to a message"""
|
||||
# Don't allow bookmarks with names of subcommands
|
||||
if label.lower() in ["create", "ls", "list", "search"]:
|
||||
if label.lower() in ["create", "delete", "ls", "list", "Rm", "search"]:
|
||||
raise ForbiddenNameException
|
||||
|
||||
await get_or_add_user(session, user_id)
|
||||
|
@ -30,6 +35,28 @@ async def create_bookmark(session: AsyncSession, user_id: int, label: str, jump_
|
|||
return bookmark
|
||||
|
||||
|
||||
async def delete_bookmark_by_id(session: AsyncSession, user_id: int, bookmark_id: int):
|
||||
"""Find a bookmark by its id & delete it
|
||||
|
||||
This fails if you don't own this bookmark
|
||||
"""
|
||||
statement = select(Bookmark).where(Bookmark.bookmark_id == bookmark_id)
|
||||
bookmark = (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
# No bookmark with this id
|
||||
if bookmark is None:
|
||||
raise NoResultFoundException
|
||||
|
||||
# You don't own this bookmark
|
||||
if bookmark.user_id != user_id:
|
||||
raise Forbidden
|
||||
|
||||
# Delete it
|
||||
statement = delete(Bookmark).where(Bookmark.bookmark_id == bookmark_id)
|
||||
await session.execute(statement)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[str] = None) -> list[Bookmark]:
|
||||
"""Get all a user's bookmarks"""
|
||||
statement = select(Bookmark).where(Bookmark.user_id == user_id)
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from .constraints import DuplicateInsertException, ForbiddenNameException
|
||||
from .currency import DoubleNightly, NotEnoughDinks
|
||||
from .forbidden import Forbidden
|
||||
from .not_found import NoResultFoundException
|
||||
|
||||
__all__ = [
|
||||
"DuplicateInsertException",
|
||||
"ForbiddenNameException",
|
||||
"Forbidden",
|
||||
"DoubleNightly",
|
||||
"NotEnoughDinks",
|
||||
"NoResultFoundException",
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
__all__ = ["Forbidden"]
|
||||
|
||||
|
||||
class Forbidden(Exception):
|
||||
"""Exception raised when trying to access a resource that isn't yours"""
|
|
@ -5,7 +5,12 @@ from discord import app_commands
|
|||
from discord.ext import commands
|
||||
|
||||
from database.crud import birthdays, bookmarks
|
||||
from database.exceptions import DuplicateInsertException, ForbiddenNameException
|
||||
from database.exceptions import (
|
||||
DuplicateInsertException,
|
||||
Forbidden,
|
||||
ForbiddenNameException,
|
||||
NoResultFoundException,
|
||||
)
|
||||
from didier import Didier
|
||||
from didier.exceptions import expect
|
||||
from didier.menus.bookmarks import BookmarkSource
|
||||
|
@ -73,8 +78,12 @@ class Discord(commands.Cog):
|
|||
await self.client.confirm_message(ctx.message)
|
||||
|
||||
@commands.group(name="Bookmark", aliases=["Bm", "Bookmarks"], case_insensitive=True, invoke_without_command=True)
|
||||
async def bookmark(self, ctx: commands.Context, label: str):
|
||||
async def bookmark(self, ctx: commands.Context, *, label: Optional[str] = None):
|
||||
"""Post a bookmarked message"""
|
||||
# No label: shortcut to display bookmarks
|
||||
if label is None:
|
||||
return await self.bookmark_search(ctx, query=None)
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
result = expect(
|
||||
await bookmarks.get_bookmark_by_name(session, ctx.author.id, label),
|
||||
|
@ -107,6 +116,28 @@ class Discord(commands.Cog):
|
|||
# Label isn't allowed
|
||||
return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False)
|
||||
|
||||
@bookmark.command(name="Delete", aliases=["Rm"])
|
||||
async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str):
|
||||
"""Delete a bookmark by its id"""
|
||||
# The bookmarks are displayed with a hashtag in front of the id
|
||||
# so strip it out in case people want to try and use this
|
||||
bookmark_id = bookmark_id.removeprefix("#")
|
||||
|
||||
try:
|
||||
bookmark_id_int = int(bookmark_id)
|
||||
except ValueError:
|
||||
return await ctx.reply(f"`{bookmark_id}` is not a valid bookmark id.", mention_author=False)
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await bookmarks.delete_bookmark_by_id(session, ctx.author.id, bookmark_id_int)
|
||||
except NoResultFoundException:
|
||||
return await ctx.reply(f"Found no bookmark with id `#{bookmark_id_int}`.", mention_author=False)
|
||||
except Forbidden:
|
||||
return await ctx.reply(f"You don't own bookmark `#{bookmark_id_int}`.", mention_author=False)
|
||||
|
||||
return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False)
|
||||
|
||||
@bookmark.command(name="Search", aliases=["List", "Ls"])
|
||||
async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None):
|
||||
"""Search through the list of bookmarks"""
|
||||
|
|
|
@ -17,11 +17,12 @@ class PageSource(ABC, Generic[T]):
|
|||
"""Base class that handles the embeds displayed in a menu"""
|
||||
|
||||
dataset: list[T]
|
||||
embeds: list[discord.Embed] = []
|
||||
embeds: list[discord.Embed]
|
||||
page_count: int
|
||||
per_page: int
|
||||
|
||||
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
|
||||
self.embeds = []
|
||||
self.dataset = dataset
|
||||
self.per_page = per_page
|
||||
self.page_count = self._get_page_count()
|
||||
|
|
Loading…
Reference in New Issue