mirror of https://github.com/stijndcl/didier
				
				
				
			Create command to list custom commands, add shortcuts to memegen commands
							parent
							
								
									8922489a41
								
							
						
					
					
						commit
						bf32a5ef47
					
				| 
						 | 
				
			
			@ -12,6 +12,7 @@ __all__ = [
 | 
			
		|||
    "create_alias",
 | 
			
		||||
    "create_command",
 | 
			
		||||
    "edit_command",
 | 
			
		||||
    "get_all_commands",
 | 
			
		||||
    "get_command",
 | 
			
		||||
    "get_command_by_alias",
 | 
			
		||||
    "get_command_by_name",
 | 
			
		||||
| 
						 | 
				
			
			@ -55,6 +56,12 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
 | 
			
		|||
    return alias_instance
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
 | 
			
		||||
    """Get a list of all commands"""
 | 
			
		||||
    statement = select(CustomCommand)
 | 
			
		||||
    return (await session.execute(statement)).scalars().all()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
 | 
			
		||||
    """Try to get a command out of a message"""
 | 
			
		||||
    # Search lowercase & without spaces
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -14,7 +14,6 @@ from database.exceptions import (
 | 
			
		|||
from didier import Didier
 | 
			
		||||
from didier.exceptions import expect
 | 
			
		||||
from didier.menus.bookmarks import BookmarkSource
 | 
			
		||||
from didier.menus.common import Menu
 | 
			
		||||
from didier.utils.discord import colours
 | 
			
		||||
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
 | 
			
		||||
from didier.utils.discord.constants import Limits
 | 
			
		||||
| 
						 | 
				
			
			@ -186,9 +185,7 @@ class Discord(commands.Cog):
 | 
			
		|||
            embed.description = "You haven't created any bookmarks yet."
 | 
			
		||||
            return await ctx.reply(embed=embed, mention_author=False)
 | 
			
		||||
 | 
			
		||||
        source = BookmarkSource(ctx, results)
 | 
			
		||||
        menu = Menu(source)
 | 
			
		||||
        await menu.start(ctx)
 | 
			
		||||
        await BookmarkSource(ctx, results).start()
 | 
			
		||||
 | 
			
		||||
    async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message):
 | 
			
		||||
        """Create a bookmark out of this message"""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
import shlex
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import discord
 | 
			
		||||
from discord import app_commands
 | 
			
		||||
| 
						 | 
				
			
			@ -9,7 +10,6 @@ from database.crud.memes import get_all_memes, get_meme_by_name
 | 
			
		|||
from didier import Didier
 | 
			
		||||
from didier.data.apis.imgflip import generate_meme
 | 
			
		||||
from didier.exceptions.no_match import expect
 | 
			
		||||
from didier.menus.common import Menu
 | 
			
		||||
from didier.menus.memes import MemeSource
 | 
			
		||||
from didier.views.modals import GenerateMeme
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -42,7 +42,7 @@ class Fun(commands.Cog):
 | 
			
		|||
            return await ctx.reply(joke.joke, mention_author=False)
 | 
			
		||||
 | 
			
		||||
    @commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True)
 | 
			
		||||
    async def memegen_msg(self, ctx: commands.Context, template: str, *, fields: str):
 | 
			
		||||
    async def memegen_msg(self, ctx: commands.Context, template: Optional[str] = None, *, fields: Optional[str] = None):
 | 
			
		||||
        """Generate a meme with template `template` and fields `fields`.
 | 
			
		||||
 | 
			
		||||
        The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes".
 | 
			
		||||
| 
						 | 
				
			
			@ -55,7 +55,17 @@ class Fun(commands.Cog):
 | 
			
		|||
 | 
			
		||||
        Example: if template `a` only has 1 field,
 | 
			
		||||
        `memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]`
 | 
			
		||||
 | 
			
		||||
        When no arguments are provided, this is a shortcut to `memegen list`.
 | 
			
		||||
 | 
			
		||||
        When only a template is provided, this is a shortcut to `memegen preview`.
 | 
			
		||||
        """
 | 
			
		||||
        if template is None:
 | 
			
		||||
            return await self.memegen_ls_msg(ctx)
 | 
			
		||||
 | 
			
		||||
        if fields is None:
 | 
			
		||||
            return await self.memegen_preview_msg(ctx, template)
 | 
			
		||||
 | 
			
		||||
        async with ctx.typing():
 | 
			
		||||
            meme = await self._do_generate_meme(template, shlex.split(fields))
 | 
			
		||||
            return await ctx.reply(meme, mention_author=False)
 | 
			
		||||
| 
						 | 
				
			
			@ -69,9 +79,7 @@ class Fun(commands.Cog):
 | 
			
		|||
        async with self.client.postgres_session as session:
 | 
			
		||||
            results = await get_all_memes(session)
 | 
			
		||||
 | 
			
		||||
        source = MemeSource(ctx, results)
 | 
			
		||||
        menu = Menu(source)
 | 
			
		||||
        await menu.start(ctx)
 | 
			
		||||
        await MemeSource(ctx, results).start()
 | 
			
		||||
 | 
			
		||||
    @memegen_msg.command(name="preview", aliases=["p"])
 | 
			
		||||
    async def memegen_preview_msg(self, ctx: commands.Context, template: str):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,9 +4,11 @@ from typing import Optional
 | 
			
		|||
 | 
			
		||||
from discord.ext import commands
 | 
			
		||||
 | 
			
		||||
from database.crud.custom_commands import get_all_commands
 | 
			
		||||
from database.crud.reminders import toggle_reminder
 | 
			
		||||
from database.enums import ReminderCategory
 | 
			
		||||
from didier import Didier
 | 
			
		||||
from didier.menus.custom_commands import CustomCommandSource
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Meta(commands.Cog):
 | 
			
		||||
| 
						 | 
				
			
			@ -17,6 +19,15 @@ class Meta(commands.Cog):
 | 
			
		|||
    def __init__(self, client: Didier):
 | 
			
		||||
        self.client = client
 | 
			
		||||
 | 
			
		||||
    @commands.command(name="custom")
 | 
			
		||||
    async def custom(self, ctx: commands.Context):
 | 
			
		||||
        """Get a list of all custom commands that are registered."""
 | 
			
		||||
        async with self.client.postgres_session as session:
 | 
			
		||||
            custom_commands = await get_all_commands(session)
 | 
			
		||||
 | 
			
		||||
        custom_commands.sort(key=lambda c: c.name.lower())
 | 
			
		||||
        await CustomCommandSource(ctx, custom_commands).start()
 | 
			
		||||
 | 
			
		||||
    @commands.command(name="marco")
 | 
			
		||||
    async def marco(self, ctx: commands.Context):
 | 
			
		||||
        """Get Didier's latency."""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,5 +1,4 @@
 | 
			
		|||
import discord
 | 
			
		||||
from discord.ext import commands
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
 | 
			
		||||
from database.schemas import Bookmark
 | 
			
		||||
| 
						 | 
				
			
			@ -14,16 +13,16 @@ class BookmarkSource(PageSource[Bookmark]):
 | 
			
		|||
    """PageSource for the Bookmark commands"""
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    def create_embeds(self, ctx: commands.Context):
 | 
			
		||||
    def create_embeds(self):
 | 
			
		||||
        for page in range(self.page_count):
 | 
			
		||||
            embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue())
 | 
			
		||||
            avatar_url = get_author_avatar(ctx).url
 | 
			
		||||
            embed.set_author(name=ctx.author.display_name, icon_url=avatar_url)
 | 
			
		||||
            avatar_url = get_author_avatar(self.ctx).url
 | 
			
		||||
            embed.set_author(name=self.ctx.author.display_name, icon_url=avatar_url)
 | 
			
		||||
 | 
			
		||||
            description = ""
 | 
			
		||||
            description_data = []
 | 
			
		||||
 | 
			
		||||
            for bookmark in self.get_page_data(page):
 | 
			
		||||
                description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n"
 | 
			
		||||
                description_data.append(f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})")
 | 
			
		||||
 | 
			
		||||
            embed.description = description.strip()
 | 
			
		||||
            embed.description = "\n".join(description_data)
 | 
			
		||||
            self.embeds.append(embed)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,3 +1,5 @@
 | 
			
		|||
from __future__ import annotations
 | 
			
		||||
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Generic, Optional, TypeVar, cast
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -13,50 +15,6 @@ __all__ = ["Menu", "PageSource"]
 | 
			
		|||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PageSource(ABC, Generic[T]):
 | 
			
		||||
    """Base class that handles the embeds displayed in a menu"""
 | 
			
		||||
 | 
			
		||||
    dataset: list[T]
 | 
			
		||||
    embeds: list[discord.Embed]
 | 
			
		||||
    page_count: int
 | 
			
		||||
    per_page: int
 | 
			
		||||
 | 
			
		||||
    def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
 | 
			
		||||
        self.embeds = []
 | 
			
		||||
        self.dataset = dataset
 | 
			
		||||
        self.per_page = per_page
 | 
			
		||||
        self.page_count = self._get_page_count()
 | 
			
		||||
        self.create_embeds(ctx)
 | 
			
		||||
        self._add_embed_page_footers()
 | 
			
		||||
 | 
			
		||||
    def _get_page_count(self) -> int:
 | 
			
		||||
        """Calculate the amount of pages required"""
 | 
			
		||||
        if len(self.dataset) % self.per_page == 0:
 | 
			
		||||
            return len(self.dataset) // self.per_page
 | 
			
		||||
 | 
			
		||||
        return (len(self.dataset) // self.per_page) + 1
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index: int) -> discord.Embed:
 | 
			
		||||
        return self.embeds[index]
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.page_count
 | 
			
		||||
 | 
			
		||||
    def _add_embed_page_footers(self):
 | 
			
		||||
        """Add the current page in the footer of every embed"""
 | 
			
		||||
        for i, embed in enumerate(self.embeds):
 | 
			
		||||
            embed.set_footer(text=f"{i + 1}/{self.page_count}")
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def create_embeds(self, ctx: commands.Context):
 | 
			
		||||
        """Method that builds the list of embeds from the input data"""
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def get_page_data(self, page: int) -> list[T]:
 | 
			
		||||
        """Get the chunk of the dataset for page [page]"""
 | 
			
		||||
        return self.dataset[page : page + self.per_page]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Menu(discord.ui.View):
 | 
			
		||||
    """Base class for a menu"""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -166,3 +124,58 @@ class Menu(discord.ui.View):
 | 
			
		|||
        """Button to show the last page"""
 | 
			
		||||
        self.current_page = len(self.source) - 1
 | 
			
		||||
        await self.display_current_state(interaction)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PageSource(ABC, Generic[T]):
 | 
			
		||||
    """Base class that handles the embeds displayed in a menu"""
 | 
			
		||||
 | 
			
		||||
    ctx: commands.Context
 | 
			
		||||
    dataset: list[T]
 | 
			
		||||
    embeds: list[discord.Embed]
 | 
			
		||||
    page_count: int
 | 
			
		||||
    per_page: int
 | 
			
		||||
 | 
			
		||||
    def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
 | 
			
		||||
        self.ctx = ctx
 | 
			
		||||
        self.embeds = []
 | 
			
		||||
        self.dataset = dataset
 | 
			
		||||
        self.per_page = per_page
 | 
			
		||||
        self.page_count = self._get_page_count()
 | 
			
		||||
        self.create_embeds()
 | 
			
		||||
        self._add_embed_page_footers()
 | 
			
		||||
 | 
			
		||||
    def _get_page_count(self) -> int:
 | 
			
		||||
        """Calculate the amount of pages required"""
 | 
			
		||||
        if len(self.dataset) % self.per_page == 0:
 | 
			
		||||
            return len(self.dataset) // self.per_page
 | 
			
		||||
 | 
			
		||||
        return (len(self.dataset) // self.per_page) + 1
 | 
			
		||||
 | 
			
		||||
    def __getitem__(self, index: int) -> discord.Embed:
 | 
			
		||||
        return self.embeds[index]
 | 
			
		||||
 | 
			
		||||
    def __len__(self):
 | 
			
		||||
        return self.page_count
 | 
			
		||||
 | 
			
		||||
    def _add_embed_page_footers(self):
 | 
			
		||||
        """Add the current page in the footer of every embed"""
 | 
			
		||||
        for i, embed in enumerate(self.embeds):
 | 
			
		||||
            embed.set_footer(text=f"{i + 1}/{self.page_count}")
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def create_embeds(self):
 | 
			
		||||
        """Method that builds the list of embeds from the input data"""
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
    def get_page_data(self, page: int) -> list[T]:
 | 
			
		||||
        """Get the chunk of the dataset for page [page]"""
 | 
			
		||||
        return self.dataset[page : page + self.per_page]
 | 
			
		||||
 | 
			
		||||
    async def start(self, *, ephemeral: bool = False, timeout: Optional[int] = None) -> Menu:
 | 
			
		||||
        """Shortcut to creating (and starting) a Menu with this source
 | 
			
		||||
 | 
			
		||||
        This returns the created menu
 | 
			
		||||
        """
 | 
			
		||||
        menu = Menu(self, ephemeral=ephemeral, timeout=timeout)
 | 
			
		||||
        await menu.start(self.ctx)
 | 
			
		||||
        return menu
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,24 @@
 | 
			
		|||
import discord
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
 | 
			
		||||
from database.schemas import CustomCommand
 | 
			
		||||
from didier.menus.common import PageSource
 | 
			
		||||
 | 
			
		||||
__all__ = ["CustomCommandSource"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class CustomCommandSource(PageSource[CustomCommand]):
 | 
			
		||||
    """PageSource for custom commands"""
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    def create_embeds(self):
 | 
			
		||||
        for page in range(self.page_count):
 | 
			
		||||
            embed = discord.Embed(colour=discord.Colour.blue(), title="Custom Commands")
 | 
			
		||||
 | 
			
		||||
            description_data = []
 | 
			
		||||
 | 
			
		||||
            for command in self.get_page_data(page):
 | 
			
		||||
                description_data.append(command.name.title())
 | 
			
		||||
 | 
			
		||||
            embed.description = "\n".join(description_data)
 | 
			
		||||
            self.embeds.append(embed)
 | 
			
		||||
| 
						 | 
				
			
			@ -1,5 +1,4 @@
 | 
			
		|||
import discord
 | 
			
		||||
from discord.ext import commands
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
 | 
			
		||||
from database.schemas import MemeTemplate
 | 
			
		||||
| 
						 | 
				
			
			@ -12,7 +11,7 @@ class MemeSource(PageSource[MemeTemplate]):
 | 
			
		|||
    """PageSource for meme templates"""
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    def create_embeds(self, ctx: commands.Context):
 | 
			
		||||
    def create_embeds(self):
 | 
			
		||||
        for page in range(self.page_count):
 | 
			
		||||
            # The colour of the embed is (69,4,20) with the values +100 because they were too dark
 | 
			
		||||
            embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue