diff --git a/alembic/versions/36300b558ef1_meme_templates.py b/alembic/versions/36300b558ef1_meme_templates.py new file mode 100644 index 0000000..275133a --- /dev/null +++ b/alembic/versions/36300b558ef1_meme_templates.py @@ -0,0 +1,37 @@ +"""Meme templates + +Revision ID: 36300b558ef1 +Revises: 08d21b2d1a0a +Create Date: 2022-08-25 01:34:22.845955 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "36300b558ef1" +down_revision = "08d21b2d1a0a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "meme", + sa.Column("meme_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("template_id", sa.Integer(), nullable=False), + sa.Column("field_count", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("meme_id"), + sa.UniqueConstraint("name"), + sa.UniqueConstraint("template_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("meme") + # ### end Alembic commands ### diff --git a/database/crud/memes.py b/database/crud/memes.py new file mode 100644 index 0000000..cccdcfb --- /dev/null +++ b/database/crud/memes.py @@ -0,0 +1,35 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas.relational import MemeTemplate + +__all__ = ["add_meme", "get_all_memes"] + + +async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]: + """Add a new meme into the database""" + try: + meme = MemeTemplate(name=name, template_id=template_id, field_count=field_count) + session.add(meme) + await session.commit() + return meme + except IntegrityError: + return None + + +async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: + """Try to find a meme by its name + + Returns the first match found by PSQL + """ + statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%")) + return (await session.execute(statement)).scalar() + + +async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: + """Get a list of all memes""" + statement = select(MemeTemplate) + return (await session.execute(statement)).scalars().all() diff --git a/database/schemas/relational.py b/database/schemas/relational.py index 6fd27d0..904459e 100644 --- a/database/schemas/relational.py +++ b/database/schemas/relational.py @@ -30,6 +30,7 @@ __all__ = [ "DadJoke", "Deadline", "Link", + "MemeTemplate", "NightlyData", "Task", "UforaAnnouncement", @@ -134,6 +135,17 @@ class Link(Base): url: str = Column(Text, nullable=False) +class MemeTemplate(Base): + """A meme template for the Imgflip API""" + + __tablename__ = "meme" + + meme_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + template_id: int = Column(Integer, nullable=False, unique=True) + field_count: int = Column(Integer, nullable=False) + + class NightlyData(Base): """Data for a user's Nightly stats""" diff --git a/database/utils/caches.py b/database/utils/caches.py index 4f34419..ba26ab2 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -5,7 +5,7 @@ from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import links, ufora_courses, wordle +from database.crud import links, memes, ufora_courses, wordle from database.mongo_types import MongoDatabase __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] @@ -61,6 +61,19 @@ class LinkCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) +class MemeCache(DatabaseCache[AsyncSession]): + """Cache to store the names of meme templates""" + + @overrides + async def invalidate(self, database_session: AsyncSession): + self.clear() + + all_memes = await memes.get_all_memes(database_session) + self.data = list(map(lambda m: m.name, all_memes)) + self.data.sort() + self.data_transformed = list(map(str.lower, self.data)) + + class UforaCourseCache(DatabaseCache[AsyncSession]): """Cache to store the names of Ufora courses""" @@ -119,16 +132,19 @@ class CacheManager: """Class that keeps track of all caches""" links: LinkCache + memes: MemeCache ufora_courses: UforaCourseCache wordle_word: WordleCache def __init__(self): self.links = LinkCache() + self.memes = MemeCache() self.ufora_courses = UforaCourseCache() self.wordle_word = WordleCache() async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): """Initialize the contents of all caches""" await self.links.invalidate(postgres_session) + await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) await self.wordle_word.invalidate(mongo_db) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 0aade01..b014ddd 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -1,3 +1,5 @@ +import discord +from discord import app_commands from discord.ext import commands from database.crud.dad_jokes import get_random_dad_joke @@ -9,6 +11,8 @@ class Fun(commands.Cog): client: Didier + memegen_slash = app_commands.Group(name="meme", description="Commands to generate memes") + def __init__(self, client: Didier): self.client = client @@ -23,6 +27,21 @@ class Fun(commands.Cog): joke = await get_random_dad_joke(session) return await ctx.reply(joke.joke, mention_author=False) + @commands.group(name="Memegen", aliases=["Meme", "Memes"], invoke_without_command=True, case_insensitive=True) + async def memegen_ctx(self, ctx: commands.Context): + """Command group for meme-related commands""" + + @memegen_slash.command(name="generate", description="Generate a meme") + async def memegen_slash(self, ctx: commands.Context, meme: str): + """Slash command to generate a meme""" + + @memegen_slash.autocomplete("meme") + async def _memegen_slash_autocomplete_meme( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'meme'-parameter""" + return self.client.database_caches.memes.get_autocomplete_suggestions(current) + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index d030344..c5ee213 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -5,7 +5,7 @@ from discord import app_commands from discord.ext import commands import settings -from database.crud import custom_commands, links, ufora_courses +from database.crud import custom_commands, links, memes, ufora_courses from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier @@ -167,6 +167,19 @@ class Owner(commands.Cog): modal = AddLink(self.client) await interaction.response.send_modal(modal) + @add_slash.command(name="meme", description="Add a new meme") + async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): + """Slash command to add new memes""" + await interaction.response.defer(ephemeral=True) + + async with self.client.postgres_session as session: + meme = await memes.add_meme(session, name, imgflip_id, field_count) + if meme is None: + return await interaction.followup.send("A meme with this name (or id) already exists.") + + await interaction.followup.send(f"Added meme `{meme.meme_id}`.") + await self.client.database_caches.memes.invalidate() + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py new file mode 100644 index 0000000..7447fe3 --- /dev/null +++ b/didier/data/apis/imgflip.py @@ -0,0 +1,42 @@ +from typing import Optional + +from aiohttp import ClientSession + +import settings +from database.schemas.relational import MemeTemplate +from didier.exceptions.missing_env import MissingEnvironmentVariable + +__all__ = ["generate_meme"] + + +def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[dict[str:str]]: + """Generate the template boxes for Imgflip""" + # If a meme only has 1 field, join all the arguments together into one string + if meme.field_count == 1: + fields = [" ".join(fields)] + + fields = fields[: min(20, meme.field_count)] + return [{"text": text} for text in fields] + + +async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]: + """Make a request to Imgflip to generate a meme""" + name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD + + # Ensure credentials exist + if name is None: + raise MissingEnvironmentVariable("IMGFLIP_NAME") + + if password is None: + raise MissingEnvironmentVariable("IMGFLIP_PASSWORD") + + boxes = generate_boxes(meme, fields) + payload = {"template_id": meme.template_id, "username": name, "password": password, "boxes": boxes} + + async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response: + if response.status != 200: + return None + + data = await response.json() + + return data["data"]["url"] diff --git a/didier/exceptions/missing_env.py b/didier/exceptions/missing_env.py new file mode 100644 index 0000000..863a092 --- /dev/null +++ b/didier/exceptions/missing_env.py @@ -0,0 +1,8 @@ +__all__ = ["MissingEnvironmentVariable"] + + +class MissingEnvironmentVariable(RuntimeError): + """Exception raised when an environment variable is missing""" + + def __init__(self, variable_name): + super().__init__(f"Missing environment variable: {variable_name}") diff --git a/settings.py b/settings.py index ec6039c..a041397 100644 --- a/settings.py +++ b/settings.py @@ -24,6 +24,8 @@ __all__ = [ "UFORA_ANNOUNCEMENTS_CHANNEL", "UFORA_RSS_TOKEN", "URBAN_DICTIONARY_TOKEN", + "IMGFLIP_NAME", + "IMGFLIP_PASSWORD", ] @@ -64,3 +66,5 @@ UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNE """API Keys""" UFORA_RSS_TOKEN: Optional[str] = env.str("UFORA_RSS_TOKEN", None) URBAN_DICTIONARY_TOKEN: Optional[str] = env.str("URBAN_DICTIONARY_TOKEN", None) +IMGFLIP_NAME: Optional[str] = env.str("IMGFLIP_NAME", None) +IMGFLIP_PASSWORD: Optional[str] = env.str("IMGFLIP_PASSWORD", None)