diff --git a/alembic/versions/36300b558ef1_meme_templates.py b/alembic/versions/36300b558ef1_meme_templates.py new file mode 100644 index 0000000..275133a --- /dev/null +++ b/alembic/versions/36300b558ef1_meme_templates.py @@ -0,0 +1,37 @@ +"""Meme templates + +Revision ID: 36300b558ef1 +Revises: 08d21b2d1a0a +Create Date: 2022-08-25 01:34:22.845955 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "36300b558ef1" +down_revision = "08d21b2d1a0a" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "meme", + sa.Column("meme_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("template_id", sa.Integer(), nullable=False), + sa.Column("field_count", sa.Integer(), nullable=False), + sa.PrimaryKeyConstraint("meme_id"), + sa.UniqueConstraint("name"), + sa.UniqueConstraint("template_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("meme") + # ### end Alembic commands ### diff --git a/database/crud/memes.py b/database/crud/memes.py new file mode 100644 index 0000000..f92f6ef --- /dev/null +++ b/database/crud/memes.py @@ -0,0 +1,35 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas.relational import MemeTemplate + +__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"] + + +async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]: + """Add a new meme into the database""" + try: + meme = MemeTemplate(name=name, template_id=template_id, field_count=field_count) + session.add(meme) + await session.commit() + return meme + except IntegrityError: + return None + + +async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: + """Get a list of all memes""" + statement = select(MemeTemplate) + return (await session.execute(statement)).scalars().all() + + +async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: + """Try to find a meme by its name + + Returns the first match found by PSQL + """ + statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%")) + return (await session.execute(statement)).scalar() diff --git a/database/schemas/relational.py b/database/schemas/relational.py index 6fd27d0..904459e 100644 --- a/database/schemas/relational.py +++ b/database/schemas/relational.py @@ -30,6 +30,7 @@ __all__ = [ "DadJoke", "Deadline", "Link", + "MemeTemplate", "NightlyData", "Task", "UforaAnnouncement", @@ -134,6 +135,17 @@ class Link(Base): url: str = Column(Text, nullable=False) +class MemeTemplate(Base): + """A meme template for the Imgflip API""" + + __tablename__ = "meme" + + meme_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + template_id: int = Column(Integer, nullable=False, unique=True) + field_count: int = Column(Integer, nullable=False) + + class NightlyData(Base): """Data for a user's Nightly stats""" diff --git a/database/utils/caches.py b/database/utils/caches.py index 4f34419..ba26ab2 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -5,7 +5,7 @@ from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import links, ufora_courses, wordle +from database.crud import links, memes, ufora_courses, wordle from database.mongo_types import MongoDatabase __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] @@ -61,6 +61,19 @@ class LinkCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) +class MemeCache(DatabaseCache[AsyncSession]): + """Cache to store the names of meme templates""" + + @overrides + async def invalidate(self, database_session: AsyncSession): + self.clear() + + all_memes = await memes.get_all_memes(database_session) + self.data = list(map(lambda m: m.name, all_memes)) + self.data.sort() + self.data_transformed = list(map(str.lower, self.data)) + + class UforaCourseCache(DatabaseCache[AsyncSession]): """Cache to store the names of Ufora courses""" @@ -119,16 +132,19 @@ class CacheManager: """Class that keeps track of all caches""" links: LinkCache + memes: MemeCache ufora_courses: UforaCourseCache wordle_word: WordleCache def __init__(self): self.links = LinkCache() + self.memes = MemeCache() self.ufora_courses = UforaCourseCache() self.wordle_word = WordleCache() async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): """Initialize the contents of all caches""" await self.links.invalidate(postgres_session) + await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) await self.wordle_word.invalidate(mongo_db) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 1b3dea5..05446f0 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -46,12 +46,12 @@ class Currency(commands.Cog): bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(colour=discord.Colour.blue()) - embed.set_author(name=f"Bank van {ctx.author.display_name}") + embed.set_author(name=f"{ctx.author.display_name}'s Bank") embed.set_thumbnail(url=ctx.author.avatar.url) embed.add_field(name="Interest level", value=bank.interest_level) - embed.add_field(name="Capaciteit level", value=bank.capacity_level) - embed.add_field(name="Momenteel geïnvesteerd", value=bank.invested, inline=False) + embed.add_field(name="Capacity level", value=bank.capacity_level) + embed.add_field(name="Currently invested", value=bank.invested, inline=False) await ctx.reply(embed=embed, mention_author=False) @@ -68,11 +68,11 @@ class Currency(commands.Cog): name=f"Interest ({bank.interest_level})", value=str(interest_upgrade_price(bank.interest_level)) ) embed.add_field( - name=f"Capaciteit ({bank.capacity_level})", value=str(capacity_upgrade_price(bank.capacity_level)) + name=f"Capacity ({bank.capacity_level})", value=str(capacity_upgrade_price(bank.capacity_level)) ) embed.add_field(name=f"Rob ({bank.rob_level})", value=str(rob_upgrade_price(bank.rob_level))) - embed.set_footer(text="Didier Bank Upgrade [Categorie]") + embed.set_footer(text="Didier Bank Upgrade [Category]") await ctx.reply(embed=embed, mention_author=False) @@ -84,7 +84,7 @@ class Currency(commands.Cog): await crud.upgrade_capacity(session, ctx.author.id) await ctx.message.add_reaction("⏫") except NotEnoughDinks: - await ctx.reply("Je hebt niet genoeg Didier Dinks om dit te doen.", mention_author=False) + await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) @bank_upgrades.command(name="Interest", aliases=["I"]) @@ -95,7 +95,7 @@ class Currency(commands.Cog): await crud.upgrade_interest(session, ctx.author.id) await ctx.message.add_reaction("⏫") except NotEnoughDinks: - await ctx.reply("Je hebt niet genoeg Didier Dinks om dit te doen.", mention_author=False) + await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) @bank_upgrades.command(name="Rob", aliases=["R"]) @@ -106,7 +106,7 @@ class Currency(commands.Cog): await crud.upgrade_rob(session, ctx.author.id) await ctx.message.add_reaction("⏫") except NotEnoughDinks: - await ctx.reply("Je hebt niet genoeg Didier Dinks om dit te doen.", mention_author=False) + await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) @commands.hybrid_command(name="dinks") @@ -115,7 +115,7 @@ class Currency(commands.Cog): async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) plural = pluralize("Didier Dink", bank.dinks) - await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) + await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) @commands.command(name="Invest", aliases=["Deposit", "Dep"]) async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore @@ -127,10 +127,10 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", invested) if invested == 0: - await ctx.reply("Je hebt geen Didier Dinks om te investeren.", mention_author=False) + await ctx.reply("You don't have any Didier Dinks to invest.", mention_author=False) else: await ctx.reply( - f"**{ctx.author.display_name}** heeft **{invested}** {plural} geïnvesteerd.", mention_author=False + f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False ) @commands.hybrid_command(name="nightly") @@ -139,9 +139,11 @@ class Currency(commands.Cog): async with self.client.postgres_session as session: try: await crud.claim_nightly(session, ctx.author.id) - await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") + await ctx.reply(f"You've claimed your daily **{crud.NIGHTLY_AMOUNT}** Didier Dinks.") except DoubleNightly: - await ctx.reply("Je hebt je nightly al geclaimd vandaag.", mention_author=False, ephemeral=True) + await ctx.reply( + "You've already claimed your Didier Nightly today.", mention_author=False, ephemeral=True + ) async def setup(client: Didier): diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 0aade01..f5eaa6f 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -1,7 +1,15 @@ +import shlex + +import discord +from discord import app_commands from discord.ext import commands from database.crud.dad_jokes import get_random_dad_joke +from database.crud.memes import get_meme_by_name from didier import Didier +from didier.data.apis.imgflip import generate_meme +from didier.exceptions.no_match import expect +from didier.views.modals import GenerateMeme class Fun(commands.Cog): @@ -9,9 +17,18 @@ class Fun(commands.Cog): client: Didier + # Slash groups + memes_slash = app_commands.Group(name="meme", description="Commands to generate memes", guild_only=False) + def __init__(self, client: Didier): self.client = client + async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str: + async with self.client.postgres_session as session: + result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name) + meme = await generate_meme(self.client.http_session, result, fields) + return meme + @commands.hybrid_command( name="dadjoke", aliases=["Dad", "Dj"], @@ -23,6 +40,50 @@ class Fun(commands.Cog): joke = await get_random_dad_joke(session) return await ctx.reply(joke.joke, mention_author=False) + @commands.group(name="Memegen", aliases=["Meme", "Memes"], invoke_without_command=True, case_insensitive=True) + async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str): + """Command group for meme-related commands""" + async with ctx.typing(): + meme = await self._do_generate_meme(meme_name, shlex.split(fields)) + return await ctx.reply(meme, mention_author=False) + + @memegen_msg.command(name="Preview", aliases=["P"]) + async def memegen_preview_msg(self, ctx: commands.Context, meme_name: str): + """Generate a preview for a meme, to see how the fields are structured""" + async with ctx.typing(): + fields = [f"Field #{i + 1}" for i in range(20)] + meme = await self._do_generate_meme(meme_name, fields) + return await ctx.reply(meme, mention_author=False) + + @memes_slash.command(name="generate", description="Generate a meme") + async def memegen_slash(self, interaction: discord.Interaction, meme: str): + """Slash command to generate a meme""" + async with self.client.postgres_session as session: + result = expect(await get_meme_by_name(session, meme), entity_type="meme", argument=meme) + + modal = GenerateMeme(self.client, result) + await interaction.response.send_modal(modal) + + @memes_slash.command( + name="preview", description="Generate a preview for a meme, to see how the fields are structured" + ) + async def memegen_preview_slash(self, interaction: discord.Interaction, meme: str): + """Slash command to generate a meme preview""" + await interaction.response.defer() + + fields = [f"Field #{i + 1}" for i in range(20)] + meme_url = await self._do_generate_meme(meme, fields) + + await interaction.followup.send(meme_url) + + @memegen_slash.autocomplete("meme") + @memegen_preview_slash.autocomplete("meme") + async def _memegen_slash_autocomplete_meme( + self, _: discord.Interaction, current: str + ) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'meme'-parameter""" + return self.client.database_caches.memes.get_autocomplete_suggestions(current) + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index d030344..2fa4b4e 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -5,7 +5,7 @@ from discord import app_commands from discord.ext import commands import settings -from database.crud import custom_commands, links, ufora_courses +from database.crud import custom_commands, links, memes, ufora_courses from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier @@ -167,6 +167,19 @@ class Owner(commands.Cog): modal = AddLink(self.client) await interaction.response.send_modal(modal) + @add_slash.command(name="meme", description="Add a new meme") + async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): + """Slash command to add new memes""" + await interaction.response.defer(ephemeral=True) + + async with self.client.postgres_session as session: + meme = await memes.add_meme(session, name, imgflip_id, field_count) + if meme is None: + return await interaction.followup.send("A meme with this name (or id) already exists.") + + await interaction.followup.send(f"Added meme `{meme.meme_id}`.") + await self.client.database_caches.memes.invalidate(session) + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py new file mode 100644 index 0000000..b373897 --- /dev/null +++ b/didier/data/apis/imgflip.py @@ -0,0 +1,39 @@ +from aiohttp import ClientSession + +import settings +from database.schemas.relational import MemeTemplate +from didier.exceptions.missing_env import MissingEnvironmentVariable +from didier.utils.http.requests import ensure_post + +__all__ = ["generate_meme"] + + +def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]: + """Generate the template boxes for Imgflip""" + # If a meme only has 1 field, join all the arguments together into one string + if meme.field_count == 1: + fields = [" ".join(fields)] + + fields = fields[: min(20, meme.field_count)] + # TODO manipulate the text if necessary + return fields + + +async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> str: + """Make a request to Imgflip to generate a meme""" + name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD + + # Ensure credentials exist + if name is None: + raise MissingEnvironmentVariable("IMGFLIP_NAME") + + if password is None: + raise MissingEnvironmentVariable("IMGFLIP_PASSWORD") + + boxes = generate_boxes(meme, fields) + payload = {"template_id": meme.template_id, "username": name, "password": password} + for i, box in enumerate(boxes): + payload[f"boxes[{i}][text]"] = box + + async with ensure_post(http_session, "https://api.imgflip.com/caption_image", payload=payload) as response: + return response["data"]["url"] diff --git a/didier/exceptions/http_exception.py b/didier/exceptions/http_exception.py new file mode 100644 index 0000000..c4e9f2c --- /dev/null +++ b/didier/exceptions/http_exception.py @@ -0,0 +1,8 @@ +__all__ = ["HTTPException"] + + +class HTTPException(RuntimeError): + """Error raised when an API call fails""" + + def __init__(self, status_code: int): + super().__init__(f"Something went wrong (status {status_code}).") diff --git a/didier/exceptions/missing_env.py b/didier/exceptions/missing_env.py new file mode 100644 index 0000000..863a092 --- /dev/null +++ b/didier/exceptions/missing_env.py @@ -0,0 +1,8 @@ +__all__ = ["MissingEnvironmentVariable"] + + +class MissingEnvironmentVariable(RuntimeError): + """Exception raised when an environment variable is missing""" + + def __init__(self, variable_name): + super().__init__(f"Missing environment variable: {variable_name}") diff --git a/didier/exceptions/no_match.py b/didier/exceptions/no_match.py new file mode 100644 index 0000000..fed1a6d --- /dev/null +++ b/didier/exceptions/no_match.py @@ -0,0 +1,24 @@ +from typing import Optional, TypeVar + +__all__ = ["NoMatch", "expect"] + + +class NoMatch(ValueError): + """Error raised when a database lookup failed""" + + def __init__(self, entity_type: str, argument: str): + super().__init__(f"Found no {entity_type} matching `{argument}`.") + + +T = TypeVar("T") + + +def expect(instance: Optional[T], *, entity_type: str, argument: str) -> T: + """Mark a database instance as expected, otherwise raise a custom exception + + This is not just done in the database layer because it's not always preferable + """ + if instance is None: + raise NoMatch(entity_type, argument) + + return instance diff --git a/didier/utils/http/__init__.py b/didier/utils/http/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/http/requests.py b/didier/utils/http/requests.py new file mode 100644 index 0000000..f649701 --- /dev/null +++ b/didier/utils/http/requests.py @@ -0,0 +1,56 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from aiohttp import ClientResponse, ClientSession + +from didier.exceptions.http_exception import HTTPException + +logger = logging.getLogger(__name__) + + +__all__ = ["ensure_get", "ensure_post"] + + +def request_successful(response: ClientResponse) -> bool: + """Check if a request was successful or not""" + return 200 <= response.status < 300 + + +@asynccontextmanager +async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]: + """Context manager that automatically raises an exception if a GET-request fails""" + async with http_session.get(endpoint) as response: + if not request_successful(response): + logger.error( + "Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json() + ) + + raise HTTPException(response.status) + + yield await response.json() + + +@asynccontextmanager +async def ensure_post( + http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True +) -> AsyncGenerator[dict, None]: + """Context manager that automatically raises an exception if a POST-request fails""" + async with http_session.post(endpoint, data=payload) as response: + if not request_successful(response): + logger.error( + "Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s", + endpoint, + response.status, + payload, + await response.json(), + ) + + raise HTTPException(response.status) + + if expect_return: + yield await response.json() + else: + # Always return A dict so you can always "use" the result without having to check + # if it is None or not + yield {} diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py index 42c2ce8..62d700f 100644 --- a/didier/views/modals/__init__.py +++ b/didier/views/modals/__init__.py @@ -2,5 +2,6 @@ from .custom_commands import CreateCustomCommand, EditCustomCommand from .dad_jokes import AddDadJoke from .deadlines import AddDeadline from .links import AddLink +from .memes import GenerateMeme -__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink"] +__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink", "GenerateMeme"] diff --git a/didier/views/modals/memes.py b/didier/views/modals/memes.py new file mode 100644 index 0000000..c98e17f --- /dev/null +++ b/didier/views/modals/memes.py @@ -0,0 +1,46 @@ +import traceback + +import discord.ui +from overrides import overrides + +from database.schemas.relational import MemeTemplate +from didier import Didier +from didier.data.apis.imgflip import generate_meme + +__all__ = ["GenerateMeme"] + + +class GenerateMeme(discord.ui.Modal, title="Generate Meme"): + """Modal to generate a meme""" + + client: Didier + meme: MemeTemplate + + def __init__(self, client: Didier, meme: MemeTemplate, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + self.meme = meme + + for i in range(meme.field_count): + self.add_item( + discord.ui.TextInput( + label=f"Field #{i + 1}", + placeholder="Here be funny text", + style=discord.TextStyle.long, + required=True, + ) + ) + + @overrides + async def on_submit(self, interaction: discord.Interaction): + await interaction.response.defer() + + fields = [item.value for item in self.children if isinstance(item, discord.ui.TextInput)] + + meme_url = await generate_meme(self.client.http_session, self.meme, fields) + await interaction.followup.send(meme_url) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + traceback.print_tb(error.__traceback__) + await interaction.followup.send("Something went wrong.", ephemeral=True) diff --git a/main.py b/main.py index fdbc027..f791621 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ def setup_logging(): max_log_size = 32 * 1024 * 1024 # Configure Didier handler - didier_log = logging.getLogger("didier") + didier_log = logging.getLogger(__name__) didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s")) diff --git a/settings.py b/settings.py index ec6039c..a041397 100644 --- a/settings.py +++ b/settings.py @@ -24,6 +24,8 @@ __all__ = [ "UFORA_ANNOUNCEMENTS_CHANNEL", "UFORA_RSS_TOKEN", "URBAN_DICTIONARY_TOKEN", + "IMGFLIP_NAME", + "IMGFLIP_PASSWORD", ] @@ -64,3 +66,5 @@ UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNE """API Keys""" UFORA_RSS_TOKEN: Optional[str] = env.str("UFORA_RSS_TOKEN", None) URBAN_DICTIONARY_TOKEN: Optional[str] = env.str("URBAN_DICTIONARY_TOKEN", None) +IMGFLIP_NAME: Optional[str] = env.str("IMGFLIP_NAME", None) +IMGFLIP_PASSWORD: Optional[str] = env.str("IMGFLIP_PASSWORD", None)