mirror of https://github.com/stijndcl/didier
Make fancy functions for database & http stuff, meme preview
parent
d1d10ee853
commit
a0c1b986cd
|
@ -8,6 +8,7 @@ from database.crud.dad_jokes import get_random_dad_joke
|
|||
from database.crud.memes import get_meme_by_name
|
||||
from didier import Didier
|
||||
from didier.data.apis.imgflip import generate_meme
|
||||
from didier.exceptions.no_match import expect
|
||||
from didier.views.modals import GenerateMeme
|
||||
|
||||
|
||||
|
@ -22,6 +23,12 @@ class Fun(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str:
|
||||
async with self.client.postgres_session as session:
|
||||
result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name)
|
||||
meme = await generate_meme(self.client.http_session, result, fields)
|
||||
return meme
|
||||
|
||||
@commands.hybrid_command(
|
||||
name="dadjoke",
|
||||
aliases=["Dad", "Dj"],
|
||||
|
@ -37,26 +44,22 @@ class Fun(commands.Cog):
|
|||
async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str):
|
||||
"""Command group for meme-related commands"""
|
||||
async with ctx.typing():
|
||||
async with self.client.postgres_session as session:
|
||||
result = await get_meme_by_name(session, meme_name)
|
||||
meme = await self._do_generate_meme(meme_name, shlex.split(fields))
|
||||
return await ctx.reply(meme, mention_author=False)
|
||||
|
||||
if result is None:
|
||||
return await ctx.reply(f"Found no meme matching `{meme_name}`.", mention_author=False)
|
||||
|
||||
meme = await generate_meme(self.client.http_session, result, shlex.split(fields))
|
||||
if meme is None:
|
||||
return await ctx.reply("Something went wrong.", mention_author=False)
|
||||
|
||||
return await ctx.reply(meme)
|
||||
@memegen_msg.command(name="Preview", aliases=["P"])
|
||||
async def memegen_preview_msg(self, ctx: commands.Context, meme_name: str):
|
||||
"""Generate a preview for a meme, to see how the fields are structured"""
|
||||
async with ctx.typing():
|
||||
fields = [f"Field #{i + 1}" for i in range(20)]
|
||||
meme = await self._do_generate_meme(meme_name, fields)
|
||||
return await ctx.reply(meme, mention_author=False)
|
||||
|
||||
@memes_slash.command(name="generate", description="Generate a meme")
|
||||
async def memegen_slash(self, interaction: discord.Interaction, meme: str):
|
||||
"""Slash command to generate a meme"""
|
||||
async with self.client.postgres_session as session:
|
||||
result = await get_meme_by_name(session, meme)
|
||||
|
||||
if result is None:
|
||||
return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True)
|
||||
result = await expect(get_meme_by_name(session, meme), entity_type="meme", argument=meme)
|
||||
|
||||
modal = GenerateMeme(self.client, result)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
from typing import Optional
|
||||
|
||||
from aiohttp import ClientSession
|
||||
|
||||
import settings
|
||||
from database.schemas.relational import MemeTemplate
|
||||
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
||||
from didier.utils.http.requests import ensure_post
|
||||
|
||||
__all__ = ["generate_meme"]
|
||||
|
||||
|
@ -20,7 +19,7 @@ def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]:
|
|||
return fields
|
||||
|
||||
|
||||
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]:
|
||||
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> str:
|
||||
"""Make a request to Imgflip to generate a meme"""
|
||||
name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD
|
||||
|
||||
|
@ -36,9 +35,5 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields:
|
|||
for i, box in enumerate(boxes):
|
||||
payload[f"boxes[{i}][text]"] = box
|
||||
|
||||
async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response:
|
||||
if response.status != 200:
|
||||
return None
|
||||
|
||||
data = await response.json()
|
||||
return data["data"]["url"]
|
||||
async with ensure_post(http_session, "https://api.imgflip.com/caption_image", payload=payload) as response:
|
||||
return response["data"]["url"]
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
__all__ = ["HTTPException"]
|
||||
|
||||
|
||||
class HTTPException(RuntimeError):
|
||||
"""Error raised when an API call fails"""
|
||||
|
||||
def __init__(self, status_code: int):
|
||||
super().__init__(f"Something went wrong (status {status_code}).")
|
|
@ -0,0 +1,24 @@
|
|||
from typing import Optional, TypeVar
|
||||
|
||||
__all__ = ["NoMatch", "expect"]
|
||||
|
||||
|
||||
class NoMatch(ValueError):
|
||||
"""Error raised when a database lookup failed"""
|
||||
|
||||
def __init__(self, entity_type: str, argument: str):
|
||||
super().__init__(f"Found no {entity_type} matching `{argument}`.")
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def expect(instance: Optional[T], *, entity_type: str, argument: str) -> T:
|
||||
"""Mark a database instance as expected, otherwise raise a custom exception
|
||||
|
||||
This is not just done in the database layer because it's not always preferable
|
||||
"""
|
||||
if instance is None:
|
||||
raise NoMatch(entity_type, argument)
|
||||
|
||||
return instance
|
|
@ -0,0 +1,56 @@
|
|||
import logging
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from aiohttp import ClientResponse, ClientSession
|
||||
|
||||
from didier.exceptions.http_exception import HTTPException
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
__all__ = ["ensure_get", "ensure_post"]
|
||||
|
||||
|
||||
def request_successful(response: ClientResponse) -> bool:
|
||||
"""Check if a request was successful or not"""
|
||||
return 200 <= response.status < 300
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]:
|
||||
"""Context manager that automatically raises an exception if a GET-request fails"""
|
||||
async with http_session.get(endpoint) as response:
|
||||
if not request_successful(response):
|
||||
logger.error(
|
||||
"Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json()
|
||||
)
|
||||
|
||||
raise HTTPException(response.status)
|
||||
|
||||
yield await response.json()
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def ensure_post(
|
||||
http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True
|
||||
) -> AsyncGenerator[dict, None]:
|
||||
"""Context manager that automatically raises an exception if a POST-request fails"""
|
||||
async with http_session.post(endpoint, data=payload) as response:
|
||||
if not request_successful(response):
|
||||
logger.error(
|
||||
"Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s",
|
||||
endpoint,
|
||||
response.status,
|
||||
payload,
|
||||
await response.json(),
|
||||
)
|
||||
|
||||
raise HTTPException(response.status)
|
||||
|
||||
if expect_return:
|
||||
yield await response.json()
|
||||
else:
|
||||
# Always return A dict so you can always "use" the result without having to check
|
||||
# if it is None or not
|
||||
yield {}
|
2
main.py
2
main.py
|
@ -19,7 +19,7 @@ def setup_logging():
|
|||
max_log_size = 32 * 1024 * 1024
|
||||
|
||||
# Configure Didier handler
|
||||
didier_log = logging.getLogger("didier")
|
||||
didier_log = logging.getLogger(__name__)
|
||||
|
||||
didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5)
|
||||
didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))
|
||||
|
|
Loading…
Reference in New Issue