mirror of https://github.com/stijndcl/didier
Memegen works
parent
dbb570420b
commit
7d7ab98254
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import MemeTemplate
|
from database.schemas.relational import MemeTemplate
|
||||||
|
|
||||||
__all__ = ["add_meme", "get_all_memes"]
|
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]
|
||||||
|
|
||||||
|
|
||||||
async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]:
|
async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]:
|
||||||
|
@ -20,6 +20,12 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
|
||||||
|
"""Get a list of all memes"""
|
||||||
|
statement = select(MemeTemplate)
|
||||||
|
return (await session.execute(statement)).scalars().all()
|
||||||
|
|
||||||
|
|
||||||
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:
|
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:
|
||||||
"""Try to find a meme by its name
|
"""Try to find a meme by its name
|
||||||
|
|
||||||
|
@ -27,9 +33,3 @@ async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTe
|
||||||
"""
|
"""
|
||||||
statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%"))
|
statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%"))
|
||||||
return (await session.execute(statement)).scalar()
|
return (await session.execute(statement)).scalar()
|
||||||
|
|
||||||
|
|
||||||
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
|
|
||||||
"""Get a list of all memes"""
|
|
||||||
statement = select(MemeTemplate)
|
|
||||||
return (await session.execute(statement)).scalars().all()
|
|
||||||
|
|
|
@ -3,7 +3,9 @@ from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from database.crud.dad_jokes import get_random_dad_joke
|
from database.crud.dad_jokes import get_random_dad_joke
|
||||||
|
from database.crud.memes import get_meme_by_name
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
|
from didier.views.modals import GenerateMeme
|
||||||
|
|
||||||
|
|
||||||
class Fun(commands.Cog):
|
class Fun(commands.Cog):
|
||||||
|
@ -11,7 +13,8 @@ class Fun(commands.Cog):
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
|
|
||||||
memegen_slash = app_commands.Group(name="meme", description="Commands to generate memes")
|
# Slash groups
|
||||||
|
memes_slash = app_commands.Group(name="meme", description="Commands to generate memes", guild_only=False)
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
@ -31,9 +34,16 @@ class Fun(commands.Cog):
|
||||||
async def memegen_ctx(self, ctx: commands.Context):
|
async def memegen_ctx(self, ctx: commands.Context):
|
||||||
"""Command group for meme-related commands"""
|
"""Command group for meme-related commands"""
|
||||||
|
|
||||||
@memegen_slash.command(name="generate", description="Generate a meme")
|
@memes_slash.command(name="generate", description="Generate a meme")
|
||||||
async def memegen_slash(self, ctx: commands.Context, meme: str):
|
async def memegen_slash(self, interaction: discord.Interaction, meme: str):
|
||||||
"""Slash command to generate a meme"""
|
"""Slash command to generate a meme"""
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
result = await get_meme_by_name(session, meme)
|
||||||
|
if result is None:
|
||||||
|
return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True)
|
||||||
|
|
||||||
|
modal = GenerateMeme(self.client, result)
|
||||||
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
@memegen_slash.autocomplete("meme")
|
@memegen_slash.autocomplete("meme")
|
||||||
async def _memegen_slash_autocomplete_meme(
|
async def _memegen_slash_autocomplete_meme(
|
||||||
|
|
|
@ -9,14 +9,15 @@ from didier.exceptions.missing_env import MissingEnvironmentVariable
|
||||||
__all__ = ["generate_meme"]
|
__all__ = ["generate_meme"]
|
||||||
|
|
||||||
|
|
||||||
def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[dict[str:str]]:
|
def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]:
|
||||||
"""Generate the template boxes for Imgflip"""
|
"""Generate the template boxes for Imgflip"""
|
||||||
# If a meme only has 1 field, join all the arguments together into one string
|
# If a meme only has 1 field, join all the arguments together into one string
|
||||||
if meme.field_count == 1:
|
if meme.field_count == 1:
|
||||||
fields = [" ".join(fields)]
|
fields = [" ".join(fields)]
|
||||||
|
|
||||||
fields = fields[: min(20, meme.field_count)]
|
fields = fields[: min(20, meme.field_count)]
|
||||||
return [{"text": text} for text in fields]
|
# TODO manipulate the text if necessary
|
||||||
|
return fields
|
||||||
|
|
||||||
|
|
||||||
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]:
|
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]:
|
||||||
|
@ -31,12 +32,13 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields:
|
||||||
raise MissingEnvironmentVariable("IMGFLIP_PASSWORD")
|
raise MissingEnvironmentVariable("IMGFLIP_PASSWORD")
|
||||||
|
|
||||||
boxes = generate_boxes(meme, fields)
|
boxes = generate_boxes(meme, fields)
|
||||||
payload = {"template_id": meme.template_id, "username": name, "password": password, "boxes": boxes}
|
payload = {"template_id": meme.template_id, "username": name, "password": password}
|
||||||
|
for i, box in enumerate(boxes):
|
||||||
|
payload[f"boxes[{i}][text]"] = box
|
||||||
|
|
||||||
async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response:
|
async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response:
|
||||||
if response.status != 200:
|
if response.status != 200:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
data = await response.json()
|
data = await response.json()
|
||||||
|
|
||||||
return data["data"]["url"]
|
return data["data"]["url"]
|
||||||
|
|
|
@ -2,5 +2,6 @@ from .custom_commands import CreateCustomCommand, EditCustomCommand
|
||||||
from .dad_jokes import AddDadJoke
|
from .dad_jokes import AddDadJoke
|
||||||
from .deadlines import AddDeadline
|
from .deadlines import AddDeadline
|
||||||
from .links import AddLink
|
from .links import AddLink
|
||||||
|
from .memes import GenerateMeme
|
||||||
|
|
||||||
__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink"]
|
__all__ = ["AddDadJoke", "AddDeadline", "CreateCustomCommand", "EditCustomCommand", "AddLink", "GenerateMeme"]
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import discord.ui
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
|
from database.schemas.relational import MemeTemplate
|
||||||
|
from didier import Didier
|
||||||
|
from didier.data.apis.imgflip import generate_meme
|
||||||
|
|
||||||
|
__all__ = ["GenerateMeme"]
|
||||||
|
|
||||||
|
|
||||||
|
class GenerateMeme(discord.ui.Modal, title="Generate Meme"):
|
||||||
|
"""Modal to generate a meme"""
|
||||||
|
|
||||||
|
client: Didier
|
||||||
|
meme: MemeTemplate
|
||||||
|
|
||||||
|
def __init__(self, client: Didier, meme: MemeTemplate, *args, **kwargs):
|
||||||
|
super().__init__(*args, **kwargs)
|
||||||
|
self.client = client
|
||||||
|
self.meme = meme
|
||||||
|
|
||||||
|
for i in range(meme.field_count):
|
||||||
|
self.add_item(
|
||||||
|
discord.ui.TextInput(
|
||||||
|
label=f"Field #{i + 1}",
|
||||||
|
placeholder="Here be funny text",
|
||||||
|
style=discord.TextStyle.long,
|
||||||
|
required=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def on_submit(self, interaction: discord.Interaction):
|
||||||
|
await interaction.response.defer()
|
||||||
|
|
||||||
|
fields = [item.value for item in self.children if isinstance(item, discord.ui.TextInput)]
|
||||||
|
|
||||||
|
meme_url = await generate_meme(self.client.http_session, self.meme, fields)
|
||||||
|
|
||||||
|
if meme_url is None:
|
||||||
|
return await interaction.followup.send("Something went wrong.")
|
||||||
|
|
||||||
|
await interaction.followup.send(meme_url)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||||
|
traceback.print_tb(error.__traceback__)
|
||||||
|
await interaction.followup.send("Something went wrong.", ephemeral=True)
|
Loading…
Reference in New Issue