Make fancy functions for database & http stuff, meme preview

pull/127/head
stijndcl 2022-08-26 20:02:54 +02:00
parent d1d10ee853
commit a0c1b986cd
7 changed files with 110 additions and 24 deletions

View File

@ -8,6 +8,7 @@ from database.crud.dad_jokes import get_random_dad_joke
from database.crud.memes import get_meme_by_name
from didier import Didier
from didier.data.apis.imgflip import generate_meme
from didier.exceptions.no_match import expect
from didier.views.modals import GenerateMeme
@ -22,6 +23,12 @@ class Fun(commands.Cog):
def __init__(self, client: Didier):
self.client = client
async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str:
async with self.client.postgres_session as session:
result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name)
meme = await generate_meme(self.client.http_session, result, fields)
return meme
@commands.hybrid_command(
name="dadjoke",
aliases=["Dad", "Dj"],
@ -37,26 +44,22 @@ class Fun(commands.Cog):
async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str):
"""Command group for meme-related commands"""
async with ctx.typing():
async with self.client.postgres_session as session:
result = await get_meme_by_name(session, meme_name)
meme = await self._do_generate_meme(meme_name, shlex.split(fields))
return await ctx.reply(meme, mention_author=False)
if result is None:
return await ctx.reply(f"Found no meme matching `{meme_name}`.", mention_author=False)
meme = await generate_meme(self.client.http_session, result, shlex.split(fields))
if meme is None:
return await ctx.reply("Something went wrong.", mention_author=False)
return await ctx.reply(meme)
@memegen_msg.command(name="Preview", aliases=["P"])
async def memegen_preview_msg(self, ctx: commands.Context, meme_name: str):
"""Generate a preview for a meme, to see how the fields are structured"""
async with ctx.typing():
fields = [f"Field #{i + 1}" for i in range(20)]
meme = await self._do_generate_meme(meme_name, fields)
return await ctx.reply(meme, mention_author=False)
@memes_slash.command(name="generate", description="Generate a meme")
async def memegen_slash(self, interaction: discord.Interaction, meme: str):
"""Slash command to generate a meme"""
async with self.client.postgres_session as session:
result = await get_meme_by_name(session, meme)
if result is None:
return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True)
result = await expect(get_meme_by_name(session, meme), entity_type="meme", argument=meme)
modal = GenerateMeme(self.client, result)
await interaction.response.send_modal(modal)

View File

@ -1,10 +1,9 @@
from typing import Optional
from aiohttp import ClientSession
import settings
from database.schemas.relational import MemeTemplate
from didier.exceptions.missing_env import MissingEnvironmentVariable
from didier.utils.http.requests import ensure_post
__all__ = ["generate_meme"]
@ -20,7 +19,7 @@ def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]:
return fields
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]:
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> str:
"""Make a request to Imgflip to generate a meme"""
name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD
@ -36,9 +35,5 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields:
for i, box in enumerate(boxes):
payload[f"boxes[{i}][text]"] = box
async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response:
if response.status != 200:
return None
data = await response.json()
return data["data"]["url"]
async with ensure_post(http_session, "https://api.imgflip.com/caption_image", payload=payload) as response:
return response["data"]["url"]

View File

@ -0,0 +1,8 @@
__all__ = ["HTTPException"]
class HTTPException(RuntimeError):
"""Error raised when an API call fails"""
def __init__(self, status_code: int):
super().__init__(f"Something went wrong (status {status_code}).")

View File

@ -0,0 +1,24 @@
from typing import Optional, TypeVar
__all__ = ["NoMatch", "expect"]
class NoMatch(ValueError):
"""Error raised when a database lookup failed"""
def __init__(self, entity_type: str, argument: str):
super().__init__(f"Found no {entity_type} matching `{argument}`.")
T = TypeVar("T")
def expect(instance: Optional[T], *, entity_type: str, argument: str) -> T:
"""Mark a database instance as expected, otherwise raise a custom exception
This is not just done in the database layer because it's not always preferable
"""
if instance is None:
raise NoMatch(entity_type, argument)
return instance

View File

View File

@ -0,0 +1,56 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from aiohttp import ClientResponse, ClientSession
from didier.exceptions.http_exception import HTTPException
logger = logging.getLogger(__name__)
__all__ = ["ensure_get", "ensure_post"]
def request_successful(response: ClientResponse) -> bool:
"""Check if a request was successful or not"""
return 200 <= response.status < 300
@asynccontextmanager
async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]:
"""Context manager that automatically raises an exception if a GET-request fails"""
async with http_session.get(endpoint) as response:
if not request_successful(response):
logger.error(
"Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json()
)
raise HTTPException(response.status)
yield await response.json()
@asynccontextmanager
async def ensure_post(
http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True
) -> AsyncGenerator[dict, None]:
"""Context manager that automatically raises an exception if a POST-request fails"""
async with http_session.post(endpoint, data=payload) as response:
if not request_successful(response):
logger.error(
"Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s",
endpoint,
response.status,
payload,
await response.json(),
)
raise HTTPException(response.status)
if expect_return:
yield await response.json()
else:
# Always return A dict so you can always "use" the result without having to check
# if it is None or not
yield {}

View File

@ -19,7 +19,7 @@ def setup_logging():
max_log_size = 32 * 1024 * 1024
# Configure Didier handler
didier_log = logging.getLogger("didier")
didier_log = logging.getLogger(__name__)
didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5)
didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))