mirror of https://github.com/stijndcl/didier
Make fancy functions for database & http stuff, meme preview
parent
d1d10ee853
commit
a0c1b986cd
|
@ -8,6 +8,7 @@ from database.crud.dad_jokes import get_random_dad_joke
|
||||||
from database.crud.memes import get_meme_by_name
|
from database.crud.memes import get_meme_by_name
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.apis.imgflip import generate_meme
|
from didier.data.apis.imgflip import generate_meme
|
||||||
|
from didier.exceptions.no_match import expect
|
||||||
from didier.views.modals import GenerateMeme
|
from didier.views.modals import GenerateMeme
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,6 +23,12 @@ class Fun(commands.Cog):
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str:
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name)
|
||||||
|
meme = await generate_meme(self.client.http_session, result, fields)
|
||||||
|
return meme
|
||||||
|
|
||||||
@commands.hybrid_command(
|
@commands.hybrid_command(
|
||||||
name="dadjoke",
|
name="dadjoke",
|
||||||
aliases=["Dad", "Dj"],
|
aliases=["Dad", "Dj"],
|
||||||
|
@ -37,26 +44,22 @@ class Fun(commands.Cog):
|
||||||
async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str):
|
async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str):
|
||||||
"""Command group for meme-related commands"""
|
"""Command group for meme-related commands"""
|
||||||
async with ctx.typing():
|
async with ctx.typing():
|
||||||
async with self.client.postgres_session as session:
|
meme = await self._do_generate_meme(meme_name, shlex.split(fields))
|
||||||
result = await get_meme_by_name(session, meme_name)
|
return await ctx.reply(meme, mention_author=False)
|
||||||
|
|
||||||
if result is None:
|
@memegen_msg.command(name="Preview", aliases=["P"])
|
||||||
return await ctx.reply(f"Found no meme matching `{meme_name}`.", mention_author=False)
|
async def memegen_preview_msg(self, ctx: commands.Context, meme_name: str):
|
||||||
|
"""Generate a preview for a meme, to see how the fields are structured"""
|
||||||
meme = await generate_meme(self.client.http_session, result, shlex.split(fields))
|
async with ctx.typing():
|
||||||
if meme is None:
|
fields = [f"Field #{i + 1}" for i in range(20)]
|
||||||
return await ctx.reply("Something went wrong.", mention_author=False)
|
meme = await self._do_generate_meme(meme_name, fields)
|
||||||
|
return await ctx.reply(meme, mention_author=False)
|
||||||
return await ctx.reply(meme)
|
|
||||||
|
|
||||||
@memes_slash.command(name="generate", description="Generate a meme")
|
@memes_slash.command(name="generate", description="Generate a meme")
|
||||||
async def memegen_slash(self, interaction: discord.Interaction, meme: str):
|
async def memegen_slash(self, interaction: discord.Interaction, meme: str):
|
||||||
"""Slash command to generate a meme"""
|
"""Slash command to generate a meme"""
|
||||||
async with self.client.postgres_session as session:
|
async with self.client.postgres_session as session:
|
||||||
result = await get_meme_by_name(session, meme)
|
result = await expect(get_meme_by_name(session, meme), entity_type="meme", argument=meme)
|
||||||
|
|
||||||
if result is None:
|
|
||||||
return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True)
|
|
||||||
|
|
||||||
modal = GenerateMeme(self.client, result)
|
modal = GenerateMeme(self.client, result)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.schemas.relational import MemeTemplate
|
from database.schemas.relational import MemeTemplate
|
||||||
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
||||||
|
from didier.utils.http.requests import ensure_post
|
||||||
|
|
||||||
__all__ = ["generate_meme"]
|
__all__ = ["generate_meme"]
|
||||||
|
|
||||||
|
@ -20,7 +19,7 @@ def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]:
|
||||||
return fields
|
return fields
|
||||||
|
|
||||||
|
|
||||||
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]:
|
async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> str:
|
||||||
"""Make a request to Imgflip to generate a meme"""
|
"""Make a request to Imgflip to generate a meme"""
|
||||||
name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD
|
name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD
|
||||||
|
|
||||||
|
@ -36,9 +35,5 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields:
|
||||||
for i, box in enumerate(boxes):
|
for i, box in enumerate(boxes):
|
||||||
payload[f"boxes[{i}][text]"] = box
|
payload[f"boxes[{i}][text]"] = box
|
||||||
|
|
||||||
async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response:
|
async with ensure_post(http_session, "https://api.imgflip.com/caption_image", payload=payload) as response:
|
||||||
if response.status != 200:
|
return response["data"]["url"]
|
||||||
return None
|
|
||||||
|
|
||||||
data = await response.json()
|
|
||||||
return data["data"]["url"]
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
__all__ = ["HTTPException"]
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPException(RuntimeError):
|
||||||
|
"""Error raised when an API call fails"""
|
||||||
|
|
||||||
|
def __init__(self, status_code: int):
|
||||||
|
super().__init__(f"Something went wrong (status {status_code}).")
|
|
@ -0,0 +1,24 @@
|
||||||
|
from typing import Optional, TypeVar
|
||||||
|
|
||||||
|
__all__ = ["NoMatch", "expect"]
|
||||||
|
|
||||||
|
|
||||||
|
class NoMatch(ValueError):
|
||||||
|
"""Error raised when a database lookup failed"""
|
||||||
|
|
||||||
|
def __init__(self, entity_type: str, argument: str):
|
||||||
|
super().__init__(f"Found no {entity_type} matching `{argument}`.")
|
||||||
|
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
def expect(instance: Optional[T], *, entity_type: str, argument: str) -> T:
|
||||||
|
"""Mark a database instance as expected, otherwise raise a custom exception
|
||||||
|
|
||||||
|
This is not just done in the database layer because it's not always preferable
|
||||||
|
"""
|
||||||
|
if instance is None:
|
||||||
|
raise NoMatch(entity_type, argument)
|
||||||
|
|
||||||
|
return instance
|
|
@ -0,0 +1,56 @@
|
||||||
|
import logging
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from aiohttp import ClientResponse, ClientSession
|
||||||
|
|
||||||
|
from didier.exceptions.http_exception import HTTPException
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["ensure_get", "ensure_post"]
|
||||||
|
|
||||||
|
|
||||||
|
def request_successful(response: ClientResponse) -> bool:
|
||||||
|
"""Check if a request was successful or not"""
|
||||||
|
return 200 <= response.status < 300
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]:
|
||||||
|
"""Context manager that automatically raises an exception if a GET-request fails"""
|
||||||
|
async with http_session.get(endpoint) as response:
|
||||||
|
if not request_successful(response):
|
||||||
|
logger.error(
|
||||||
|
"Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json()
|
||||||
|
)
|
||||||
|
|
||||||
|
raise HTTPException(response.status)
|
||||||
|
|
||||||
|
yield await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def ensure_post(
|
||||||
|
http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True
|
||||||
|
) -> AsyncGenerator[dict, None]:
|
||||||
|
"""Context manager that automatically raises an exception if a POST-request fails"""
|
||||||
|
async with http_session.post(endpoint, data=payload) as response:
|
||||||
|
if not request_successful(response):
|
||||||
|
logger.error(
|
||||||
|
"Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s",
|
||||||
|
endpoint,
|
||||||
|
response.status,
|
||||||
|
payload,
|
||||||
|
await response.json(),
|
||||||
|
)
|
||||||
|
|
||||||
|
raise HTTPException(response.status)
|
||||||
|
|
||||||
|
if expect_return:
|
||||||
|
yield await response.json()
|
||||||
|
else:
|
||||||
|
# Always return A dict so you can always "use" the result without having to check
|
||||||
|
# if it is None or not
|
||||||
|
yield {}
|
2
main.py
2
main.py
|
@ -19,7 +19,7 @@ def setup_logging():
|
||||||
max_log_size = 32 * 1024 * 1024
|
max_log_size = 32 * 1024 * 1024
|
||||||
|
|
||||||
# Configure Didier handler
|
# Configure Didier handler
|
||||||
didier_log = logging.getLogger("didier")
|
didier_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5)
|
didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5)
|
||||||
didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))
|
didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))
|
||||||
|
|
Loading…
Reference in New Issue