Create command to list custom commands, add shortcuts to memegen commands

pull/122/head
stijndcl 2022-09-23 18:06:33 +02:00
parent 8922489a41
commit bf32a5ef47
8 changed files with 120 additions and 62 deletions

View File

@ -12,6 +12,7 @@ __all__ = [
"create_alias", "create_alias",
"create_command", "create_command",
"edit_command", "edit_command",
"get_all_commands",
"get_command", "get_command",
"get_command_by_alias", "get_command_by_alias",
"get_command_by_name", "get_command_by_name",
@ -55,6 +56,12 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
return alias_instance return alias_instance
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
"""Get a list of all commands"""
statement = select(CustomCommand)
return (await session.execute(statement)).scalars().all()
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message""" """Try to get a command out of a message"""
# Search lowercase & without spaces # Search lowercase & without spaces

View File

@ -14,7 +14,6 @@ from database.exceptions import (
from didier import Didier from didier import Didier
from didier.exceptions import expect from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource from didier.menus.bookmarks import BookmarkSource
from didier.menus.common import Menu
from didier.utils.discord import colours from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.constants import Limits from didier.utils.discord.constants import Limits
@ -186,9 +185,7 @@ class Discord(commands.Cog):
embed.description = "You haven't created any bookmarks yet." embed.description = "You haven't created any bookmarks yet."
return await ctx.reply(embed=embed, mention_author=False) return await ctx.reply(embed=embed, mention_author=False)
source = BookmarkSource(ctx, results) await BookmarkSource(ctx, results).start()
menu = Menu(source)
await menu.start(ctx)
async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message): async def _bookmark_ctx(self, interaction: discord.Interaction, message: discord.Message):
"""Create a bookmark out of this message""" """Create a bookmark out of this message"""

View File

@ -1,4 +1,5 @@
import shlex import shlex
from typing import Optional
import discord import discord
from discord import app_commands from discord import app_commands
@ -9,7 +10,6 @@ from database.crud.memes import get_all_memes, get_meme_by_name
from didier import Didier from didier import Didier
from didier.data.apis.imgflip import generate_meme from didier.data.apis.imgflip import generate_meme
from didier.exceptions.no_match import expect from didier.exceptions.no_match import expect
from didier.menus.common import Menu
from didier.menus.memes import MemeSource from didier.menus.memes import MemeSource
from didier.views.modals import GenerateMeme from didier.views.modals import GenerateMeme
@ -42,7 +42,7 @@ class Fun(commands.Cog):
return await ctx.reply(joke.joke, mention_author=False) return await ctx.reply(joke.joke, mention_author=False)
@commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True) @commands.group(name="memegen", aliases=["meme", "memes"], invoke_without_command=True, case_insensitive=True)
async def memegen_msg(self, ctx: commands.Context, template: str, *, fields: str): async def memegen_msg(self, ctx: commands.Context, template: Optional[str] = None, *, fields: Optional[str] = None):
"""Generate a meme with template `template` and fields `fields`. """Generate a meme with template `template` and fields `fields`.
The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes". The arguments are parsed based on spaces. Arguments that contain spaces should be wrapped in "quotes".
@ -55,7 +55,17 @@ class Fun(commands.Cog):
Example: if template `a` only has 1 field, Example: if template `a` only has 1 field,
`memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]` `memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]`
When no arguments are provided, this is a shortcut to `memegen list`.
When only a template is provided, this is a shortcut to `memegen preview`.
""" """
if template is None:
return await self.memegen_ls_msg(ctx)
if fields is None:
return await self.memegen_preview_msg(ctx, template)
async with ctx.typing(): async with ctx.typing():
meme = await self._do_generate_meme(template, shlex.split(fields)) meme = await self._do_generate_meme(template, shlex.split(fields))
return await ctx.reply(meme, mention_author=False) return await ctx.reply(meme, mention_author=False)
@ -69,9 +79,7 @@ class Fun(commands.Cog):
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
results = await get_all_memes(session) results = await get_all_memes(session)
source = MemeSource(ctx, results) await MemeSource(ctx, results).start()
menu = Menu(source)
await menu.start(ctx)
@memegen_msg.command(name="preview", aliases=["p"]) @memegen_msg.command(name="preview", aliases=["p"])
async def memegen_preview_msg(self, ctx: commands.Context, template: str): async def memegen_preview_msg(self, ctx: commands.Context, template: str):

View File

@ -4,9 +4,11 @@ from typing import Optional
from discord.ext import commands from discord.ext import commands
from database.crud.custom_commands import get_all_commands
from database.crud.reminders import toggle_reminder from database.crud.reminders import toggle_reminder
from database.enums import ReminderCategory from database.enums import ReminderCategory
from didier import Didier from didier import Didier
from didier.menus.custom_commands import CustomCommandSource
class Meta(commands.Cog): class Meta(commands.Cog):
@ -17,6 +19,15 @@ class Meta(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.command(name="custom")
async def custom(self, ctx: commands.Context):
"""Get a list of all custom commands that are registered."""
async with self.client.postgres_session as session:
custom_commands = await get_all_commands(session)
custom_commands.sort(key=lambda c: c.name.lower())
await CustomCommandSource(ctx, custom_commands).start()
@commands.command(name="marco") @commands.command(name="marco")
async def marco(self, ctx: commands.Context): async def marco(self, ctx: commands.Context):
"""Get Didier's latency.""" """Get Didier's latency."""

View File

@ -1,5 +1,4 @@
import discord import discord
from discord.ext import commands
from overrides import overrides from overrides import overrides
from database.schemas import Bookmark from database.schemas import Bookmark
@ -14,16 +13,16 @@ class BookmarkSource(PageSource[Bookmark]):
"""PageSource for the Bookmark commands""" """PageSource for the Bookmark commands"""
@overrides @overrides
def create_embeds(self, ctx: commands.Context): def create_embeds(self):
for page in range(self.page_count): for page in range(self.page_count):
embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue()) embed = discord.Embed(title="Bookmarks", colour=discord.Colour.blue())
avatar_url = get_author_avatar(ctx).url avatar_url = get_author_avatar(self.ctx).url
embed.set_author(name=ctx.author.display_name, icon_url=avatar_url) embed.set_author(name=self.ctx.author.display_name, icon_url=avatar_url)
description = "" description_data = []
for bookmark in self.get_page_data(page): for bookmark in self.get_page_data(page):
description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n" description_data.append(f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})")
embed.description = description.strip() embed.description = "\n".join(description_data)
self.embeds.append(embed) self.embeds.append(embed)

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, Optional, TypeVar, cast from typing import Generic, Optional, TypeVar, cast
@ -13,50 +15,6 @@ __all__ = ["Menu", "PageSource"]
T = TypeVar("T") T = TypeVar("T")
class PageSource(ABC, Generic[T]):
"""Base class that handles the embeds displayed in a menu"""
dataset: list[T]
embeds: list[discord.Embed]
page_count: int
per_page: int
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
self.embeds = []
self.dataset = dataset
self.per_page = per_page
self.page_count = self._get_page_count()
self.create_embeds(ctx)
self._add_embed_page_footers()
def _get_page_count(self) -> int:
"""Calculate the amount of pages required"""
if len(self.dataset) % self.per_page == 0:
return len(self.dataset) // self.per_page
return (len(self.dataset) // self.per_page) + 1
def __getitem__(self, index: int) -> discord.Embed:
return self.embeds[index]
def __len__(self):
return self.page_count
def _add_embed_page_footers(self):
"""Add the current page in the footer of every embed"""
for i, embed in enumerate(self.embeds):
embed.set_footer(text=f"{i + 1}/{self.page_count}")
@abstractmethod
def create_embeds(self, ctx: commands.Context):
"""Method that builds the list of embeds from the input data"""
raise NotImplementedError
def get_page_data(self, page: int) -> list[T]:
"""Get the chunk of the dataset for page [page]"""
return self.dataset[page : page + self.per_page]
class Menu(discord.ui.View): class Menu(discord.ui.View):
"""Base class for a menu""" """Base class for a menu"""
@ -166,3 +124,58 @@ class Menu(discord.ui.View):
"""Button to show the last page""" """Button to show the last page"""
self.current_page = len(self.source) - 1 self.current_page = len(self.source) - 1
await self.display_current_state(interaction) await self.display_current_state(interaction)
class PageSource(ABC, Generic[T]):
"""Base class that handles the embeds displayed in a menu"""
ctx: commands.Context
dataset: list[T]
embeds: list[discord.Embed]
page_count: int
per_page: int
def __init__(self, ctx: commands.Context, dataset: list[T], *, per_page: int = 10):
self.ctx = ctx
self.embeds = []
self.dataset = dataset
self.per_page = per_page
self.page_count = self._get_page_count()
self.create_embeds()
self._add_embed_page_footers()
def _get_page_count(self) -> int:
"""Calculate the amount of pages required"""
if len(self.dataset) % self.per_page == 0:
return len(self.dataset) // self.per_page
return (len(self.dataset) // self.per_page) + 1
def __getitem__(self, index: int) -> discord.Embed:
return self.embeds[index]
def __len__(self):
return self.page_count
def _add_embed_page_footers(self):
"""Add the current page in the footer of every embed"""
for i, embed in enumerate(self.embeds):
embed.set_footer(text=f"{i + 1}/{self.page_count}")
@abstractmethod
def create_embeds(self):
"""Method that builds the list of embeds from the input data"""
raise NotImplementedError
def get_page_data(self, page: int) -> list[T]:
"""Get the chunk of the dataset for page [page]"""
return self.dataset[page : page + self.per_page]
async def start(self, *, ephemeral: bool = False, timeout: Optional[int] = None) -> Menu:
"""Shortcut to creating (and starting) a Menu with this source
This returns the created menu
"""
menu = Menu(self, ephemeral=ephemeral, timeout=timeout)
await menu.start(self.ctx)
return menu

View File

@ -0,0 +1,24 @@
import discord
from overrides import overrides
from database.schemas import CustomCommand
from didier.menus.common import PageSource
__all__ = ["CustomCommandSource"]
class CustomCommandSource(PageSource[CustomCommand]):
"""PageSource for custom commands"""
@overrides
def create_embeds(self):
for page in range(self.page_count):
embed = discord.Embed(colour=discord.Colour.blue(), title="Custom Commands")
description_data = []
for command in self.get_page_data(page):
description_data.append(command.name.title())
embed.description = "\n".join(description_data)
self.embeds.append(embed)

View File

@ -1,5 +1,4 @@
import discord import discord
from discord.ext import commands
from overrides import overrides from overrides import overrides
from database.schemas import MemeTemplate from database.schemas import MemeTemplate
@ -12,7 +11,7 @@ class MemeSource(PageSource[MemeTemplate]):
"""PageSource for meme templates""" """PageSource for meme templates"""
@overrides @overrides
def create_embeds(self, ctx: commands.Context): def create_embeds(self):
for page in range(self.page_count): for page in range(self.page_count):
# The colour of the embed is (69,4,20) with the values +100 because they were too dark # The colour of the embed is (69,4,20) with the values +100 because they were too dark
embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120)) embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120))