From dbb570420b50b3eca615436bac2659d66b56c5b5 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 25 Aug 2022 02:07:02 +0200 Subject: [PATCH 1/9] Command to add memes --- .../versions/36300b558ef1_meme_templates.py | 37 ++++++++++++++++ database/crud/memes.py | 35 ++++++++++++++++ database/schemas/relational.py | 12 ++++++ database/utils/caches.py | 18 +++++++- didier/cogs/fun.py | 19 +++++++++ didier/cogs/owner.py | 15 ++++++- didier/data/apis/imgflip.py | 42 +++++++++++++++++++ didier/exceptions/missing_env.py | 8 ++++ settings.py | 4 ++ 9 files changed, 188 insertions(+), 2 deletions(-) create mode 100644 alembic/versions/36300b558ef1_meme_templates.py create mode 100644 database/crud/memes.py create mode 100644 didier/data/apis/imgflip.py create mode 100644 didier/exceptions/missing_env.py diff --git a/alembic/versions/36300b558ef1_meme_templates.py b/alembic/versions/36300b558ef1_meme_templates.py new file mode 100644 index 0000000..275133a --- /dev/null +++ b/alembic/versions/36300b558ef1_meme_templates.py @@ -0,0 +1,37 @@ +"""Meme templates + +Revision ID: 36300b558ef1 +Revises: 08d21b2d1a0a +Create Date: 2022-08-25 01:34:22.845955 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "36300b558ef1" +down_revision = "08d21b2d1a0a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "meme", + sa.Column("meme_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("template_id", sa.Integer(), nullable=False), + sa.Column("field_count", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("meme_id"), + sa.UniqueConstraint("name"), + sa.UniqueConstraint("template_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("meme") + # ### end Alembic commands ### diff --git a/database/crud/memes.py b/database/crud/memes.py new file mode 100644 index 0000000..cccdcfb --- /dev/null +++ b/database/crud/memes.py @@ -0,0 +1,35 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas.relational import MemeTemplate + +__all__ = ["add_meme", "get_all_memes"] + + +async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]: + """Add a new meme into the database""" + try: + meme = MemeTemplate(name=name, template_id=template_id, field_count=field_count) + session.add(meme) + await session.commit() + return meme + except IntegrityError: + return None + + +async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: + """Try to find a meme by its name + + Returns the first match found by PSQL + """ + statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%")) + return (await session.execute(statement)).scalar() + + +async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: + """Get a list of all memes""" + statement = select(MemeTemplate) + return (await session.execute(statement)).scalars().all() diff --git a/database/schemas/relational.py b/database/schemas/relational.py index 6fd27d0..904459e 100644 --- a/database/schemas/relational.py +++ b/database/schemas/relational.py @@ -30,6 +30,7 @@ __all__ = [ "DadJoke", "Deadline", "Link", + "MemeTemplate", "NightlyData", "Task", "UforaAnnouncement", @@ -134,6 +135,17 @@ class Link(Base): url: str = Column(Text, nullable=False) +class MemeTemplate(Base): + """A meme template for the Imgflip API""" + + __tablename__ = "meme" + + meme_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + template_id: int = Column(Integer, nullable=False, unique=True) + field_count: int = Column(Integer, nullable=False) + + class NightlyData(Base): """Data for a user's Nightly stats""" diff --git a/database/utils/caches.py b/database/utils/caches.py index 4f34419..ba26ab2 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -5,7 +5,7 @@ from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import links, ufora_courses, wordle +from database.crud import links, memes, ufora_courses, wordle from database.mongo_types import MongoDatabase __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] @@ -61,6 +61,19 @@ class LinkCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) +class MemeCache(DatabaseCache[AsyncSession]): + """Cache to store the names of meme templates""" + + @overrides + async def invalidate(self, database_session: AsyncSession): + self.clear() + + all_memes = await memes.get_all_memes(database_session) + self.data = list(map(lambda m: m.name, all_memes)) + self.data.sort() + self.data_transformed = list(map(str.lower, self.data)) + + class UforaCourseCache(DatabaseCache[AsyncSession]): """Cache to store the names of Ufora courses""" @@ -119,16 +132,19 @@ class CacheManager: """Class that keeps track of all caches""" links: LinkCache + memes: MemeCache ufora_courses: UforaCourseCache wordle_word: WordleCache def __init__(self): self.links = LinkCache() + self.memes = MemeCache() self.ufora_courses = UforaCourseCache() self.wordle_word = WordleCache() async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): """Initialize the contents of all caches""" await self.links.invalidate(postgres_session) + await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) await self.wordle_word.invalidate(mongo_db) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 0aade01..b014ddd 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -1,3 +1,5 @@ +import discord +from discord import app_commands from discord.ext import commands from database.crud.dad_jokes import get_random_dad_joke @@ -9,6 +11,8 @@ class Fun(commands.Cog): client: Didier + memegen_slash = app_commands.Group(name="meme", description="Commands to generate memes") + def __init__(self, client: Didier): self.client = client @@ -23,6 +27,21 @@ class Fun(commands.Cog): joke = await get_random_dad_joke(session) return await ctx.reply(joke.joke, mention_author=False) + @commands.group(name="Memegen", aliases=["Meme", "Memes"], invoke_without_command=True, case_insensitive=True) + async def memegen_ctx(self, ctx: commands.Context): + """Command group for meme-related commands""" + + @memegen_slash.command(name="generate", description="Generate a meme") + async def memegen_slash(self, ctx: commands.Context, meme: str): + """Slash command to generate a meme""" + + @memegen_slash.autocomplete("meme") + async def _memegen_slash_autocomplete_meme( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'meme'-parameter""" + return self.client.database_caches.memes.get_autocomplete_suggestions(current) + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index d030344..c5ee213 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -5,7 +5,7 @@ from discord import app_commands from discord.ext import commands import settings -from database.crud import custom_commands, links, ufora_courses +from database.crud import custom_commands, links, memes, ufora_courses from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier @@ -167,6 +167,19 @@ class Owner(commands.Cog): modal = AddLink(self.client) await interaction.response.send_modal(modal) + @add_slash.command(name="meme", description="Add a new meme") + async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): + """Slash command to add new memes""" + await interaction.response.defer(ephemeral=True) + + async with self.client.postgres_session as session: + meme = await memes.add_meme(session, name, imgflip_id, field_count) + if meme is None: + return await interaction.followup.send("A meme with this name (or id) already exists.") + + await interaction.followup.send(f"Added meme `{meme.meme_id}`.") + await self.client.database_caches.memes.invalidate() + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py new file mode 100644 index 0000000..7447fe3 --- /dev/null +++ b/didier/data/apis/imgflip.py @@ -0,0 +1,42 @@ +from typing import Optional + +from aiohttp import ClientSession + +import settings +from database.schemas.relational import MemeTemplate +from didier.exceptions.missing_env import MissingEnvironmentVariable + +__all__ = ["generate_meme"] + + +def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[dict[str:str]]: + """Generate the template boxes for Imgflip""" + # If a meme only has 1 field, join all the arguments together into one string + if meme.field_count == 1: + fields = [" ".join(fields)] + + fields = fields[: min(20, meme.field_count)] + return [{"text": text} for text in fields] + + +async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]: + """Make a request to Imgflip to generate a meme""" + name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD + + # Ensure credentials exist + if name is None: + raise MissingEnvironmentVariable("IMGFLIP_NAME") + + if password is None: + raise MissingEnvironmentVariable("IMGFLIP_PASSWORD") + + boxes = generate_boxes(meme, fields) + payload = {"template_id": meme.template_id, "username": name, "password": password, "boxes": boxes} + + async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response: + if response.status != 200: + return None + + data = await response.json() + + return data["data"]["url"] diff --git a/didier/exceptions/missing_env.py b/didier/exceptions/missing_env.py new file mode 100644 index 0000000..863a092 --- /dev/null +++ b/didier/exceptions/missing_env.py @@ -0,0 +1,8 @@ +__all__ = ["MissingEnvironmentVariable"] + + +class MissingEnvironmentVariable(RuntimeError): + """Exception raised when an environment variable is missing""" + + def __init__(self, variable_name): + super().__init__(f"Missing environment variable: {variable_name}") diff --git a/settings.py b/settings.py index ec6039c..a041397 100644 --- a/settings.py +++ b/settings.py @@ -24,6 +24,8 @@ __all__ = [ "UFORA_ANNOUNCEMENTS_CHANNEL", "UFORA_RSS_TOKEN", "URBAN_DICTIONARY_TOKEN", + "IMGFLIP_NAME", + "IMGFLIP_PASSWORD", ] @@ -64,3 +66,5 @@ UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNE """API Keys""" UFORA_RSS_TOKEN: Optional[str] = env.str("UFORA_RSS_TOKEN", None) URBAN_DICTIONARY_TOKEN: Optional[str] = env.str("URBAN_DICTIONARY_TOKEN", None) +IMGFLIP_NAME: Optional[str] = env.str("IMGFLIP_NAME", None) +IMGFLIP_PASSWORD: Optional[str] = env.str("IMGFLIP_PASSWORD", None) From 7d7ab98254b4fab081342f105cb5d12c6f368796 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 25 Aug 2022 11:04:25 +0200 Subject: [PATCH 2/9] Memegen works --- database/crud/memes.py | 14 ++++----- didier/cogs/fun.py | 16 +++++++++-- didier/data/apis/imgflip.py | 10 ++++--- didier/views/modals/__init__.py | 3 +- didier/views/modals/memes.py | 50 +++++++++++++++++++++++++++++++++ 5 files changed, 78 insertions(+), 15 deletions(-) create mode 100644 didier/views/modals/memes.py diff --git a/database/crud/memes.py b/database/crud/memes.py index cccdcfb..f92f6ef 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.schemas.relational import MemeTemplate -__all__ = ["add_meme", "get_all_memes"] +__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"] async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]: @@ -20,6 +20,12 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou return None +async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: + """Get a list of all memes""" + statement = select(MemeTemplate) + return (await session.execute(statement)).scalars().all() + + async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: """Try to find a meme by its name @@ -27,9 +33,3 @@ async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTe """ statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%")) return (await session.execute(statement)).scalar() - - -async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: - """Get a list of all memes""" - statement = select(MemeTemplate) - return (await session.execute(statement)).scalars().all() diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index b014ddd..7d00bc1 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -3,7 +3,9 @@ from discord import app_commands from discord.ext import commands from database.crud.dad_jokes import get_random_dad_joke +from database.crud.memes import get_meme_by_name from didier import Didier +from didier.views.modals import GenerateMeme class Fun(commands.Cog): @@ -11,7 +13,8 @@ class Fun(commands.Cog): client: Didier - memegen_slash = app_commands.Group(name="meme", description="Commands to generate memes") + # Slash groups + memes_slash = app_commands.Group(name="meme", description="Commands to generate memes", guild_only=False) def __init__(self, client: Didier): self.client = client @@ -31,9 +34,16 @@ class Fun(commands.Cog): async def memegen_ctx(self, ctx: commands.Context): """Command group for meme-related commands""" - @memegen_slash.command(name="generate", description="Generate a meme") - async def memegen_slash(self, ctx: commands.Context, meme: str): + @memes_slash.command(name="generate", description="Generate a meme") + async def memegen_slash(self, interaction: discord.Interaction, meme: str): """Slash command to generate a meme""" + async with self.client.postgres_session as session: + result = await get_meme_by_name(session, meme) + if result is None: + return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True) + + modal = GenerateMeme(self.client, result) + await interaction.response.send_modal(modal) @memegen_slash.autocomplete("meme") async def _memegen_slash_autocomplete_meme( diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py index 7447fe3..a136a72 100644 --- a/didier/data/apis/imgflip.py +++ b/didier/data/apis/imgflip.py @@ -9,14 +9,15 @@ from didier.exceptions.missing_env import MissingEnvironmentVariable __all__ = ["generate_meme"] -def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[dict[str:str]]: +def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]: """Generate the template boxes for Imgflip""" # If a meme only has 1 field, join all the arguments together into one string if meme.field_count == 1: fields = [" ".join(fields)] fields = fields[: min(20, meme.field_count)] - return [{"text": text} for text in fields] + # TODO manipulate the text if necessary + return fields async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]: @@ -31,12 +32,13 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: raise MissingEnvironmentVariable("IMGFLIP_PASSWORD") boxes = generate_boxes(meme, fields) - payload = {"template_id": meme.template_id, "username": name, "password": password, "boxes": boxes} + payload = {"template_id": meme.template_id, "username": name, "password": password} + for i, box in enumerate(boxes): + payload[f"boxes[{i}][text]"] = box async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response: if response.status != 200: return None data = await response.json() - return data["data"]["url"] diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py index 42c2ce8..62d700f 100644 --- a/didier/views/modals/__init__.py +++ b/didier/views/modals/__init__.py @@ -2,5 +2,6 @@ from .custom_commands import CreateCustomCommand, EditCustomCommand from .dad_jokes import AddDadJoke from .deadlines import AddDeadline from .links import AddLink +from .memes import GenerateMeme -__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink"] +__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink", "GenerateMeme"] diff --git a/didier/views/modals/memes.py b/didier/views/modals/memes.py new file mode 100644 index 0000000..8448dbf --- /dev/null +++ b/didier/views/modals/memes.py @@ -0,0 +1,50 @@ +import traceback + +import discord.ui +from overrides import overrides + +from database.schemas.relational import MemeTemplate +from didier import Didier +from didier.data.apis.imgflip import generate_meme + +__all__ = ["GenerateMeme"] + + +class GenerateMeme(discord.ui.Modal, title="Generate Meme"): + """Modal to generate a meme""" + + client: Didier + meme: MemeTemplate + + def __init__(self, client: Didier, meme: MemeTemplate, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + self.meme = meme + + for i in range(meme.field_count): + self.add_item( + discord.ui.TextInput( + label=f"Field #{i + 1}", + placeholder="Here be funny text", + style=discord.TextStyle.long, + required=True, + ) + ) + + @overrides + async def on_submit(self, interaction: discord.Interaction): + await interaction.response.defer() + + fields = [item.value for item in self.children if isinstance(item, discord.ui.TextInput)] + + meme_url = await generate_meme(self.client.http_session, self.meme, fields) + + if meme_url is None: + return await interaction.followup.send("Something went wrong.") + + await interaction.followup.send(meme_url) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + traceback.print_tb(error.__traceback__) + await interaction.followup.send("Something went wrong.", ephemeral=True) From 8fb990cea82ed45e40f4b249076631372c2d74b6 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 18:32:53 +0200 Subject: [PATCH 3/9] Add missing translations, memegen message command --- didier/cogs/currency.py | 28 +++++++++++++++------------- didier/cogs/fun.py | 22 +++++++++++++++++++--- 2 files changed, 34 insertions(+), 16 deletions(-) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 1b3dea5..05446f0 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -46,12 +46,12 @@ class Currency(commands.Cog): bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(colour=discord.Colour.blue()) - embed.set_author(name=f"Bank van {ctx.author.display_name}") + embed.set_author(name=f"{ctx.author.display_name}'s Bank") embed.set_thumbnail(url=ctx.author.avatar.url) embed.add_field(name="Interest level", value=bank.interest_level) - embed.add_field(name="Capaciteit level", value=bank.capacity_level) - embed.add_field(name="Momenteel geïnvesteerd", value=bank.invested, inline=False) + embed.add_field(name="Capacity level", value=bank.capacity_level) + embed.add_field(name="Currently invested", value=bank.invested, inline=False) await ctx.reply(embed=embed, mention_author=False) @@ -68,11 +68,11 @@ class Currency(commands.Cog): name=f"Interest ({bank.interest_level})", value=str(interest_upgrade_price(bank.interest_level)) ) embed.add_field( - name=f"Capaciteit ({bank.capacity_level})", value=str(capacity_upgrade_price(bank.capacity_level)) + name=f"Capacity ({bank.capacity_level})", value=str(capacity_upgrade_price(bank.capacity_level)) ) embed.add_field(name=f"Rob ({bank.rob_level})", value=str(rob_upgrade_price(bank.rob_level))) - embed.set_footer(text="Didier Bank Upgrade [Categorie]") + embed.set_footer(text="Didier Bank Upgrade [Category]") await ctx.reply(embed=embed, mention_author=False) @@ -84,7 +84,7 @@ class Currency(commands.Cog): await crud.upgrade_capacity(session, ctx.author.id) await ctx.message.add_reaction("⏫") except NotEnoughDinks: - await ctx.reply("Je hebt niet genoeg Didier Dinks om dit te doen.", mention_author=False) + await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) @bank_upgrades.command(name="Interest", aliases=["I"]) @@ -95,7 +95,7 @@ class Currency(commands.Cog): await crud.upgrade_interest(session, ctx.author.id) await ctx.message.add_reaction("⏫") except NotEnoughDinks: - await ctx.reply("Je hebt niet genoeg Didier Dinks om dit te doen.", mention_author=False) + await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) @bank_upgrades.command(name="Rob", aliases=["R"]) @@ -106,7 +106,7 @@ class Currency(commands.Cog): await crud.upgrade_rob(session, ctx.author.id) await ctx.message.add_reaction("⏫") except NotEnoughDinks: - await ctx.reply("Je hebt niet genoeg Didier Dinks om dit te doen.", mention_author=False) + await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) @commands.hybrid_command(name="dinks") @@ -115,7 +115,7 @@ class Currency(commands.Cog): async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) plural = pluralize("Didier Dink", bank.dinks) - await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) + await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) @commands.command(name="Invest", aliases=["Deposit", "Dep"]) async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore @@ -127,10 +127,10 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", invested) if invested == 0: - await ctx.reply("Je hebt geen Didier Dinks om te investeren.", mention_author=False) + await ctx.reply("You don't have any Didier Dinks to invest.", mention_author=False) else: await ctx.reply( - f"**{ctx.author.display_name}** heeft **{invested}** {plural} geïnvesteerd.", mention_author=False + f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False ) @commands.hybrid_command(name="nightly") @@ -139,9 +139,11 @@ class Currency(commands.Cog): async with self.client.postgres_session as session: try: await crud.claim_nightly(session, ctx.author.id) - await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") + await ctx.reply(f"You've claimed your daily **{crud.NIGHTLY_AMOUNT}** Didier Dinks.") except DoubleNightly: - await ctx.reply("Je hebt je nightly al geclaimd vandaag.", mention_author=False, ephemeral=True) + await ctx.reply( + "You've already claimed your Didier Nightly today.", mention_author=False, ephemeral=True + ) async def setup(client: Didier): diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 7d00bc1..57bac27 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -1,3 +1,5 @@ +import shlex + import discord from discord import app_commands from discord.ext import commands @@ -5,6 +7,7 @@ from discord.ext import commands from database.crud.dad_jokes import get_random_dad_joke from database.crud.memes import get_meme_by_name from didier import Didier +from didier.data.apis.imgflip import generate_meme from didier.views.modals import GenerateMeme @@ -31,16 +34,29 @@ class Fun(commands.Cog): return await ctx.reply(joke.joke, mention_author=False) @commands.group(name="Memegen", aliases=["Meme", "Memes"], invoke_without_command=True, case_insensitive=True) - async def memegen_ctx(self, ctx: commands.Context): + async def memegen_ctx(self, ctx: commands.Context, meme_name: str, *, fields: str): """Command group for meme-related commands""" + async with ctx.typing(): + async with self.client.postgres_session as session: + result = await get_meme_by_name(session, meme_name) + + if result is None: + return await ctx.reply(f"Found no meme matching `{meme_name}`.", mention_author=False) + + meme = await generate_meme(self.client.http_session, result, shlex.split(fields)) + if meme is None: + return await ctx.reply("Something went wrong.", mention_author=False) + + return await ctx.reply(meme) @memes_slash.command(name="generate", description="Generate a meme") async def memegen_slash(self, interaction: discord.Interaction, meme: str): """Slash command to generate a meme""" async with self.client.postgres_session as session: result = await get_meme_by_name(session, meme) - if result is None: - return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True) + + if result is None: + return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True) modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) From d1d10ee8532fa697d679618640d22e6517970b55 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 18:51:34 +0200 Subject: [PATCH 4/9] Fix naming & missing argument --- didier/cogs/fun.py | 2 +- didier/cogs/owner.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 57bac27..2d53b0c 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -34,7 +34,7 @@ class Fun(commands.Cog): return await ctx.reply(joke.joke, mention_author=False) @commands.group(name="Memegen", aliases=["Meme", "Memes"], invoke_without_command=True, case_insensitive=True) - async def memegen_ctx(self, ctx: commands.Context, meme_name: str, *, fields: str): + async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str): """Command group for meme-related commands""" async with ctx.typing(): async with self.client.postgres_session as session: diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index c5ee213..2fa4b4e 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -177,8 +177,8 @@ class Owner(commands.Cog): if meme is None: return await interaction.followup.send("A meme with this name (or id) already exists.") - await interaction.followup.send(f"Added meme `{meme.meme_id}`.") - await self.client.database_caches.memes.invalidate() + await interaction.followup.send(f"Added meme `{meme.meme_id}`.") + await self.client.database_caches.memes.invalidate(session) @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): From a0c1b986cdb7da265dd0add5f2a47b14e2c58b3e Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 20:02:54 +0200 Subject: [PATCH 5/9] Make fancy functions for database & http stuff, meme preview --- didier/cogs/fun.py | 31 ++++++++-------- didier/data/apis/imgflip.py | 13 +++---- didier/exceptions/http_exception.py | 8 +++++ didier/exceptions/no_match.py | 24 +++++++++++++ didier/utils/http/__init__.py | 0 didier/utils/http/requests.py | 56 +++++++++++++++++++++++++++++ main.py | 2 +- 7 files changed, 110 insertions(+), 24 deletions(-) create mode 100644 didier/exceptions/http_exception.py create mode 100644 didier/exceptions/no_match.py create mode 100644 didier/utils/http/__init__.py create mode 100644 didier/utils/http/requests.py diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 2d53b0c..85f0d95 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -8,6 +8,7 @@ from database.crud.dad_jokes import get_random_dad_joke from database.crud.memes import get_meme_by_name from didier import Didier from didier.data.apis.imgflip import generate_meme +from didier.exceptions.no_match import expect from didier.views.modals import GenerateMeme @@ -22,6 +23,12 @@ class Fun(commands.Cog): def __init__(self, client: Didier): self.client = client + async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str: + async with self.client.postgres_session as session: + result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name) + meme = await generate_meme(self.client.http_session, result, fields) + return meme + @commands.hybrid_command( name="dadjoke", aliases=["Dad", "Dj"], @@ -37,26 +44,22 @@ class Fun(commands.Cog): async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str): """Command group for meme-related commands""" async with ctx.typing(): - async with self.client.postgres_session as session: - result = await get_meme_by_name(session, meme_name) + meme = await self._do_generate_meme(meme_name, shlex.split(fields)) + return await ctx.reply(meme, mention_author=False) - if result is None: - return await ctx.reply(f"Found no meme matching `{meme_name}`.", mention_author=False) - - meme = await generate_meme(self.client.http_session, result, shlex.split(fields)) - if meme is None: - return await ctx.reply("Something went wrong.", mention_author=False) - - return await ctx.reply(meme) + @memegen_msg.command(name="Preview", aliases=["P"]) + async def memegen_preview_msg(self, ctx: commands.Context, meme_name: str): + """Generate a preview for a meme, to see how the fields are structured""" + async with ctx.typing(): + fields = [f"Field #{i + 1}" for i in range(20)] + meme = await self._do_generate_meme(meme_name, fields) + return await ctx.reply(meme, mention_author=False) @memes_slash.command(name="generate", description="Generate a meme") async def memegen_slash(self, interaction: discord.Interaction, meme: str): """Slash command to generate a meme""" async with self.client.postgres_session as session: - result = await get_meme_by_name(session, meme) - - if result is None: - return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True) + result = await expect(get_meme_by_name(session, meme), entity_type="meme", argument=meme) modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py index a136a72..b373897 100644 --- a/didier/data/apis/imgflip.py +++ b/didier/data/apis/imgflip.py @@ -1,10 +1,9 @@ -from typing import Optional - from aiohttp import ClientSession import settings from database.schemas.relational import MemeTemplate from didier.exceptions.missing_env import MissingEnvironmentVariable +from didier.utils.http.requests import ensure_post __all__ = ["generate_meme"] @@ -20,7 +19,7 @@ def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]: return fields -async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]: +async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> str: """Make a request to Imgflip to generate a meme""" name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD @@ -36,9 +35,5 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: for i, box in enumerate(boxes): payload[f"boxes[{i}][text]"] = box - async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response: - if response.status != 200: - return None - - data = await response.json() - return data["data"]["url"] + async with ensure_post(http_session, "https://api.imgflip.com/caption_image", payload=payload) as response: + return response["data"]["url"] diff --git a/didier/exceptions/http_exception.py b/didier/exceptions/http_exception.py new file mode 100644 index 0000000..c4e9f2c --- /dev/null +++ b/didier/exceptions/http_exception.py @@ -0,0 +1,8 @@ +__all__ = ["HTTPException"] + + +class HTTPException(RuntimeError): + """Error raised when an API call fails""" + + def __init__(self, status_code: int): + super().__init__(f"Something went wrong (status {status_code}).") diff --git a/didier/exceptions/no_match.py b/didier/exceptions/no_match.py new file mode 100644 index 0000000..fed1a6d --- /dev/null +++ b/didier/exceptions/no_match.py @@ -0,0 +1,24 @@ +from typing import Optional, TypeVar + +__all__ = ["NoMatch", "expect"] + + +class NoMatch(ValueError): + """Error raised when a database lookup failed""" + + def __init__(self, entity_type: str, argument: str): + super().__init__(f"Found no {entity_type} matching `{argument}`.") + + +T = TypeVar("T") + + +def expect(instance: Optional[T], *, entity_type: str, argument: str) -> T: + """Mark a database instance as expected, otherwise raise a custom exception + + This is not just done in the database layer because it's not always preferable + """ + if instance is None: + raise NoMatch(entity_type, argument) + + return instance diff --git a/didier/utils/http/__init__.py b/didier/utils/http/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/http/requests.py b/didier/utils/http/requests.py new file mode 100644 index 0000000..f649701 --- /dev/null +++ b/didier/utils/http/requests.py @@ -0,0 +1,56 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from aiohttp import ClientResponse, ClientSession + +from didier.exceptions.http_exception import HTTPException + +logger = logging.getLogger(__name__) + + +__all__ = ["ensure_get", "ensure_post"] + + +def request_successful(response: ClientResponse) -> bool: + """Check if a request was successful or not""" + return 200 <= response.status < 300 + + +@asynccontextmanager +async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]: + """Context manager that automatically raises an exception if a GET-request fails""" + async with http_session.get(endpoint) as response: + if not request_successful(response): + logger.error( + "Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json() + ) + + raise HTTPException(response.status) + + yield await response.json() + + +@asynccontextmanager +async def ensure_post( + http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True +) -> AsyncGenerator[dict, None]: + """Context manager that automatically raises an exception if a POST-request fails""" + async with http_session.post(endpoint, data=payload) as response: + if not request_successful(response): + logger.error( + "Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s", + endpoint, + response.status, + payload, + await response.json(), + ) + + raise HTTPException(response.status) + + if expect_return: + yield await response.json() + else: + # Always return A dict so you can always "use" the result without having to check + # if it is None or not + yield {} diff --git a/main.py b/main.py index fdbc027..f791621 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ def setup_logging(): max_log_size = 32 * 1024 * 1024 # Configure Didier handler - didier_log = logging.getLogger("didier") + didier_log = logging.getLogger(__name__) didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s")) From f9083e84ed0240eacdc891166bcd31bf35f0d74c Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 21:35:49 +0200 Subject: [PATCH 6/9] Meme preview slash command --- didier/cogs/fun.py | 13 +++++++++++++ didier/views/modals/memes.py | 4 ---- 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 85f0d95..495ff55 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -64,7 +64,20 @@ class Fun(commands.Cog): modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) + @memes_slash.command( + name="preview", description="Generate a preview for a meme, to see how the fields are structured" + ) + async def memegen_preview_slash(self, interaction: discord.Interaction, meme: str): + """Slash command to generate a meme preview""" + await interaction.response.defer() + + fields = [f"Field #{i + 1}" for i in range(20)] + meme_url = await self._do_generate_meme(meme, fields) + + await interaction.followup.send(meme_url) + @memegen_slash.autocomplete("meme") + @memegen_preview_slash.autocomplete("meme") async def _memegen_slash_autocomplete_meme( self, _: discord.Interaction, current: str ) -> list[app_commands.Choice[str]]: diff --git a/didier/views/modals/memes.py b/didier/views/modals/memes.py index 8448dbf..c98e17f 100644 --- a/didier/views/modals/memes.py +++ b/didier/views/modals/memes.py @@ -38,10 +38,6 @@ class GenerateMeme(discord.ui.Modal, title="Generate Meme"): fields = [item.value for item in self.children if isinstance(item, discord.ui.TextInput)] meme_url = await generate_meme(self.client.http_session, self.meme, fields) - - if meme_url is None: - return await interaction.followup.send("Something went wrong.") - await interaction.followup.send(meme_url) @overrides From 966eb6316556367c175126753caa88c5b0d68b5a Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 21:43:32 +0200 Subject: [PATCH 7/9] broken type --- didier/cogs/fun.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 495ff55..f5eaa6f 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -59,7 +59,7 @@ class Fun(commands.Cog): async def memegen_slash(self, interaction: discord.Interaction, meme: str): """Slash command to generate a meme""" async with self.client.postgres_session as session: - result = await expect(get_meme_by_name(session, meme), entity_type="meme", argument=meme) + result = expect(await get_meme_by_name(session, meme), entity_type="meme", argument=meme) modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) From 8a42e24c34edc449d5ff82ba7387713b033ef4a4 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 21:50:13 +0200 Subject: [PATCH 8/9] Make meme preview ephemeral --- didier/cogs/fun.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index f5eaa6f..d604b36 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -74,7 +74,7 @@ class Fun(commands.Cog): fields = [f"Field #{i + 1}" for i in range(20)] meme_url = await self._do_generate_meme(meme, fields) - await interaction.followup.send(meme_url) + await interaction.followup.send(meme_url, ephemeral=True) @memegen_slash.autocomplete("meme") @memegen_preview_slash.autocomplete("meme") From 6f0ac487cc0e6fd2284707761395d57c3c0c23cc Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 22:55:42 +0200 Subject: [PATCH 9/9] Handle custom exceptions --- didier/didier.py | 26 ++++++++++++++++++++++++++ didier/exceptions/__init__.py | 5 +++++ 2 files changed, 31 insertions(+) diff --git a/didier/didier.py b/didier/didier.py index 3217b61..337b342 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -4,6 +4,7 @@ import os import discord import motor.motor_asyncio from aiohttp import ClientSession +from discord.app_commands import AppCommandError from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession @@ -12,6 +13,7 @@ from database.crud import custom_commands from database.engine import DBSession, mongo_client from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed +from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix __all__ = ["Didier"] @@ -46,6 +48,8 @@ class Didier(commands.Bot): command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status ) + self.tree.on_error = self.on_app_command_error + @property def postgres_session(self) -> AsyncSession: """Obtain a session for the PostgreSQL database""" @@ -197,6 +201,18 @@ class Didier(commands.Bot): """Event triggered when a new thread is created""" await thread.join() + async def on_app_command_error(self, interaction: discord.Interaction, exception: AppCommandError): + """Event triggered when an application command errors""" + # If commands have their own error handler, let it handle the error instead + if hasattr(interaction.command, "on_error"): + return + + if isinstance(exception, (NoMatch, discord.app_commands.CommandInvokeError)): + if interaction.response.is_done(): + return await interaction.response.send_message(str(exception.original), ephemeral=True) + else: + return await interaction.followup.send(str(exception.original), ephemeral=True) + async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /): """Event triggered when a regular command errors""" # If working locally, print everything to your console @@ -219,6 +235,16 @@ class Didier(commands.Bot): ): return + # Responses to things that go wrong during processing of commands + if isinstance(exception, commands.CommandInvokeError) and isinstance( + exception.original, + ( + NoMatch, + HTTPException, + ), + ): + return await ctx.reply(str(exception.original), mention_author=False) + # Print everything that we care about to the logs/stderr await super().on_command_error(ctx, exception) diff --git a/didier/exceptions/__init__.py b/didier/exceptions/__init__.py index e69de29..4321cae 100644 --- a/didier/exceptions/__init__.py +++ b/didier/exceptions/__init__.py @@ -0,0 +1,5 @@ +from .http_exception import HTTPException +from .missing_env import MissingEnvironmentVariable +from .no_match import NoMatch, expect + +__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect"]