mirror of https://github.com/stijndcl/didier
Command to add memes
parent
86dd6cb27b
commit
dbb570420b
|
@ -0,0 +1,37 @@
|
||||||
|
"""Meme templates
|
||||||
|
|
||||||
|
Revision ID: 36300b558ef1
|
||||||
|
Revises: 08d21b2d1a0a
|
||||||
|
Create Date: 2022-08-25 01:34:22.845955
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "36300b558ef1"
|
||||||
|
down_revision = "08d21b2d1a0a"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"meme",
|
||||||
|
sa.Column("meme_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("name", sa.Text(), nullable=False),
|
||||||
|
sa.Column("template_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("field_count", sa.Integer(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("meme_id"),
|
||||||
|
sa.UniqueConstraint("name"),
|
||||||
|
sa.UniqueConstraint("template_id"),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("meme")
|
||||||
|
# ### end Alembic commands ###
|
|
@ -0,0 +1,35 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.schemas.relational import MemeTemplate
|
||||||
|
|
||||||
|
__all__ = ["add_meme", "get_all_memes"]
|
||||||
|
|
||||||
|
|
||||||
|
async def add_meme(session: AsyncSession, name: str, template_id: int, field_count: int) -> Optional[MemeTemplate]:
|
||||||
|
"""Add a new meme into the database"""
|
||||||
|
try:
|
||||||
|
meme = MemeTemplate(name=name, template_id=template_id, field_count=field_count)
|
||||||
|
session.add(meme)
|
||||||
|
await session.commit()
|
||||||
|
return meme
|
||||||
|
except IntegrityError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:
|
||||||
|
"""Try to find a meme by its name
|
||||||
|
|
||||||
|
Returns the first match found by PSQL
|
||||||
|
"""
|
||||||
|
statement = select(MemeTemplate).where(MemeTemplate.name.ilike(f"%{query.lower()}%"))
|
||||||
|
return (await session.execute(statement)).scalar()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
|
||||||
|
"""Get a list of all memes"""
|
||||||
|
statement = select(MemeTemplate)
|
||||||
|
return (await session.execute(statement)).scalars().all()
|
|
@ -30,6 +30,7 @@ __all__ = [
|
||||||
"DadJoke",
|
"DadJoke",
|
||||||
"Deadline",
|
"Deadline",
|
||||||
"Link",
|
"Link",
|
||||||
|
"MemeTemplate",
|
||||||
"NightlyData",
|
"NightlyData",
|
||||||
"Task",
|
"Task",
|
||||||
"UforaAnnouncement",
|
"UforaAnnouncement",
|
||||||
|
@ -134,6 +135,17 @@ class Link(Base):
|
||||||
url: str = Column(Text, nullable=False)
|
url: str = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
|
class MemeTemplate(Base):
|
||||||
|
"""A meme template for the Imgflip API"""
|
||||||
|
|
||||||
|
__tablename__ = "meme"
|
||||||
|
|
||||||
|
meme_id: int = Column(Integer, primary_key=True)
|
||||||
|
name: str = Column(Text, nullable=False, unique=True)
|
||||||
|
template_id: int = Column(Integer, nullable=False, unique=True)
|
||||||
|
field_count: int = Column(Integer, nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class NightlyData(Base):
|
class NightlyData(Base):
|
||||||
"""Data for a user's Nightly stats"""
|
"""Data for a user's Nightly stats"""
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from discord import app_commands
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import links, ufora_courses, wordle
|
from database.crud import links, memes, ufora_courses, wordle
|
||||||
from database.mongo_types import MongoDatabase
|
from database.mongo_types import MongoDatabase
|
||||||
|
|
||||||
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
|
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
|
||||||
|
@ -61,6 +61,19 @@ class LinkCache(DatabaseCache[AsyncSession]):
|
||||||
self.data_transformed = list(map(str.lower, self.data))
|
self.data_transformed = list(map(str.lower, self.data))
|
||||||
|
|
||||||
|
|
||||||
|
class MemeCache(DatabaseCache[AsyncSession]):
|
||||||
|
"""Cache to store the names of meme templates"""
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def invalidate(self, database_session: AsyncSession):
|
||||||
|
self.clear()
|
||||||
|
|
||||||
|
all_memes = await memes.get_all_memes(database_session)
|
||||||
|
self.data = list(map(lambda m: m.name, all_memes))
|
||||||
|
self.data.sort()
|
||||||
|
self.data_transformed = list(map(str.lower, self.data))
|
||||||
|
|
||||||
|
|
||||||
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
||||||
"""Cache to store the names of Ufora courses"""
|
"""Cache to store the names of Ufora courses"""
|
||||||
|
|
||||||
|
@ -119,16 +132,19 @@ class CacheManager:
|
||||||
"""Class that keeps track of all caches"""
|
"""Class that keeps track of all caches"""
|
||||||
|
|
||||||
links: LinkCache
|
links: LinkCache
|
||||||
|
memes: MemeCache
|
||||||
ufora_courses: UforaCourseCache
|
ufora_courses: UforaCourseCache
|
||||||
wordle_word: WordleCache
|
wordle_word: WordleCache
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.links = LinkCache()
|
self.links = LinkCache()
|
||||||
|
self.memes = MemeCache()
|
||||||
self.ufora_courses = UforaCourseCache()
|
self.ufora_courses = UforaCourseCache()
|
||||||
self.wordle_word = WordleCache()
|
self.wordle_word = WordleCache()
|
||||||
|
|
||||||
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
||||||
"""Initialize the contents of all caches"""
|
"""Initialize the contents of all caches"""
|
||||||
await self.links.invalidate(postgres_session)
|
await self.links.invalidate(postgres_session)
|
||||||
|
await self.memes.invalidate(postgres_session)
|
||||||
await self.ufora_courses.invalidate(postgres_session)
|
await self.ufora_courses.invalidate(postgres_session)
|
||||||
await self.wordle_word.invalidate(mongo_db)
|
await self.wordle_word.invalidate(mongo_db)
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
import discord
|
||||||
|
from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from database.crud.dad_jokes import get_random_dad_joke
|
from database.crud.dad_jokes import get_random_dad_joke
|
||||||
|
@ -9,6 +11,8 @@ class Fun(commands.Cog):
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
|
|
||||||
|
memegen_slash = app_commands.Group(name="meme", description="Commands to generate memes")
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
@ -23,6 +27,21 @@ class Fun(commands.Cog):
|
||||||
joke = await get_random_dad_joke(session)
|
joke = await get_random_dad_joke(session)
|
||||||
return await ctx.reply(joke.joke, mention_author=False)
|
return await ctx.reply(joke.joke, mention_author=False)
|
||||||
|
|
||||||
|
@commands.group(name="Memegen", aliases=["Meme", "Memes"], invoke_without_command=True, case_insensitive=True)
|
||||||
|
async def memegen_ctx(self, ctx: commands.Context):
|
||||||
|
"""Command group for meme-related commands"""
|
||||||
|
|
||||||
|
@memegen_slash.command(name="generate", description="Generate a meme")
|
||||||
|
async def memegen_slash(self, ctx: commands.Context, meme: str):
|
||||||
|
"""Slash command to generate a meme"""
|
||||||
|
|
||||||
|
@memegen_slash.autocomplete("meme")
|
||||||
|
async def _memegen_slash_autocomplete_meme(
|
||||||
|
self, _: discord.Interaction, current: str
|
||||||
|
) -> list[app_commands.Choice[str]]:
|
||||||
|
"""Autocompletion for the 'meme'-parameter"""
|
||||||
|
return self.client.database_caches.memes.get_autocomplete_suggestions(current)
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
"""Load the cog"""
|
"""Load the cog"""
|
||||||
|
|
|
@ -5,7 +5,7 @@ from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.crud import custom_commands, links, ufora_courses
|
from database.crud import custom_commands, links, memes, ufora_courses
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
|
@ -167,6 +167,19 @@ class Owner(commands.Cog):
|
||||||
modal = AddLink(self.client)
|
modal = AddLink(self.client)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
|
@add_slash.command(name="meme", description="Add a new meme")
|
||||||
|
async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int):
|
||||||
|
"""Slash command to add new memes"""
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
meme = await memes.add_meme(session, name, imgflip_id, field_count)
|
||||||
|
if meme is None:
|
||||||
|
return await interaction.followup.send("A meme with this name (or id) already exists.")
|
||||||
|
|
||||||
|
await interaction.followup.send(f"Added meme `{meme.meme_id}`.")
|
||||||
|
await self.client.database_caches.memes.invalidate()
|
||||||
|
|
||||||
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
|
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
|
||||||
async def edit_msg(self, ctx: commands.Context):
|
async def edit_msg(self, ctx: commands.Context):
|
||||||
"""Command group for [edit X] commands"""
|
"""Command group for [edit X] commands"""
|
||||||
|
|
|
@ -0,0 +1,42 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
import settings
|
||||||
|
from database.schemas.relational import MemeTemplate
|
||||||
|
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
||||||
|
|
||||||
|
__all__ = ["generate_meme"]
|
||||||
|
|
||||||
|
|
||||||
|
def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[dict[str:str]]:
|
||||||
|
"""Generate the template boxes for Imgflip"""
|
||||||
|
# If a meme only has 1 field, join all the arguments together into one string
|
||||||
|
if meme.field_count == 1:
|
||||||
|
fields = [" ".join(fields)]
|
||||||
|
|
||||||
|
fields = fields[: min(20, meme.field_count)]
|
||||||
|
return [{"text": text} for text in fields]
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]:
|
||||||
|
"""Make a request to Imgflip to generate a meme"""
|
||||||
|
name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD
|
||||||
|
|
||||||
|
# Ensure credentials exist
|
||||||
|
if name is None:
|
||||||
|
raise MissingEnvironmentVariable("IMGFLIP_NAME")
|
||||||
|
|
||||||
|
if password is None:
|
||||||
|
raise MissingEnvironmentVariable("IMGFLIP_PASSWORD")
|
||||||
|
|
||||||
|
boxes = generate_boxes(meme, fields)
|
||||||
|
payload = {"template_id": meme.template_id, "username": name, "password": password, "boxes": boxes}
|
||||||
|
|
||||||
|
async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response:
|
||||||
|
if response.status != 200:
|
||||||
|
return None
|
||||||
|
|
||||||
|
data = await response.json()
|
||||||
|
|
||||||
|
return data["data"]["url"]
|
|
@ -0,0 +1,8 @@
|
||||||
|
__all__ = ["MissingEnvironmentVariable"]
|
||||||
|
|
||||||
|
|
||||||
|
class MissingEnvironmentVariable(RuntimeError):
|
||||||
|
"""Exception raised when an environment variable is missing"""
|
||||||
|
|
||||||
|
def __init__(self, variable_name):
|
||||||
|
super().__init__(f"Missing environment variable: {variable_name}")
|
|
@ -24,6 +24,8 @@ __all__ = [
|
||||||
"UFORA_ANNOUNCEMENTS_CHANNEL",
|
"UFORA_ANNOUNCEMENTS_CHANNEL",
|
||||||
"UFORA_RSS_TOKEN",
|
"UFORA_RSS_TOKEN",
|
||||||
"URBAN_DICTIONARY_TOKEN",
|
"URBAN_DICTIONARY_TOKEN",
|
||||||
|
"IMGFLIP_NAME",
|
||||||
|
"IMGFLIP_PASSWORD",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -64,3 +66,5 @@ UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNE
|
||||||
"""API Keys"""
|
"""API Keys"""
|
||||||
UFORA_RSS_TOKEN: Optional[str] = env.str("UFORA_RSS_TOKEN", None)
|
UFORA_RSS_TOKEN: Optional[str] = env.str("UFORA_RSS_TOKEN", None)
|
||||||
URBAN_DICTIONARY_TOKEN: Optional[str] = env.str("URBAN_DICTIONARY_TOKEN", None)
|
URBAN_DICTIONARY_TOKEN: Optional[str] = env.str("URBAN_DICTIONARY_TOKEN", None)
|
||||||
|
IMGFLIP_NAME: Optional[str] = env.str("IMGFLIP_NAME", None)
|
||||||
|
IMGFLIP_PASSWORD: Optional[str] = env.str("IMGFLIP_PASSWORD", None)
|
||||||
|
|
Loading…
Reference in New Issue