From a0c1b986cdb7da265dd0add5f2a47b14e2c58b3e Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 26 Aug 2022 20:02:54 +0200 Subject: [PATCH] Make fancy functions for database & http stuff, meme preview --- didier/cogs/fun.py | 31 ++++++++-------- didier/data/apis/imgflip.py | 13 +++---- didier/exceptions/http_exception.py | 8 +++++ didier/exceptions/no_match.py | 24 +++++++++++++ didier/utils/http/__init__.py | 0 didier/utils/http/requests.py | 56 +++++++++++++++++++++++++++++ main.py | 2 +- 7 files changed, 110 insertions(+), 24 deletions(-) create mode 100644 didier/exceptions/http_exception.py create mode 100644 didier/exceptions/no_match.py create mode 100644 didier/utils/http/__init__.py create mode 100644 didier/utils/http/requests.py diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 2d53b0c..85f0d95 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -8,6 +8,7 @@ from database.crud.dad_jokes import get_random_dad_joke from database.crud.memes import get_meme_by_name from didier import Didier from didier.data.apis.imgflip import generate_meme +from didier.exceptions.no_match import expect from didier.views.modals import GenerateMeme @@ -22,6 +23,12 @@ class Fun(commands.Cog): def __init__(self, client: Didier): self.client = client + async def _do_generate_meme(self, meme_name: str, fields: list[str]) -> str: + async with self.client.postgres_session as session: + result = expect(await get_meme_by_name(session, meme_name), entity_type="meme", argument=meme_name) + meme = await generate_meme(self.client.http_session, result, fields) + return meme + @commands.hybrid_command( name="dadjoke", aliases=["Dad", "Dj"], @@ -37,26 +44,22 @@ class Fun(commands.Cog): async def memegen_msg(self, ctx: commands.Context, meme_name: str, *, fields: str): """Command group for meme-related commands""" async with ctx.typing(): - async with self.client.postgres_session as session: - result = await get_meme_by_name(session, meme_name) + meme = await self._do_generate_meme(meme_name, shlex.split(fields)) + return await ctx.reply(meme, mention_author=False) - if result is None: - return await ctx.reply(f"Found no meme matching `{meme_name}`.", mention_author=False) - - meme = await generate_meme(self.client.http_session, result, shlex.split(fields)) - if meme is None: - return await ctx.reply("Something went wrong.", mention_author=False) - - return await ctx.reply(meme) + @memegen_msg.command(name="Preview", aliases=["P"]) + async def memegen_preview_msg(self, ctx: commands.Context, meme_name: str): + """Generate a preview for a meme, to see how the fields are structured""" + async with ctx.typing(): + fields = [f"Field #{i + 1}" for i in range(20)] + meme = await self._do_generate_meme(meme_name, fields) + return await ctx.reply(meme, mention_author=False) @memes_slash.command(name="generate", description="Generate a meme") async def memegen_slash(self, interaction: discord.Interaction, meme: str): """Slash command to generate a meme""" async with self.client.postgres_session as session: - result = await get_meme_by_name(session, meme) - - if result is None: - return await interaction.response.send_message(f"Found no meme matching `{meme}`.", ephemeral=True) + result = await expect(get_meme_by_name(session, meme), entity_type="meme", argument=meme) modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py index a136a72..b373897 100644 --- a/didier/data/apis/imgflip.py +++ b/didier/data/apis/imgflip.py @@ -1,10 +1,9 @@ -from typing import Optional - from aiohttp import ClientSession import settings from database.schemas.relational import MemeTemplate from didier.exceptions.missing_env import MissingEnvironmentVariable +from didier.utils.http.requests import ensure_post __all__ = ["generate_meme"] @@ -20,7 +19,7 @@ def generate_boxes(meme: MemeTemplate, fields: list[str]) -> list[str]: return fields -async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> Optional[str]: +async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: list[str]) -> str: """Make a request to Imgflip to generate a meme""" name, password = settings.IMGFLIP_NAME, settings.IMGFLIP_PASSWORD @@ -36,9 +35,5 @@ async def generate_meme(http_session: ClientSession, meme: MemeTemplate, fields: for i, box in enumerate(boxes): payload[f"boxes[{i}][text]"] = box - async with http_session.post("https://api.imgflip.com/caption_image", data=payload) as response: - if response.status != 200: - return None - - data = await response.json() - return data["data"]["url"] + async with ensure_post(http_session, "https://api.imgflip.com/caption_image", payload=payload) as response: + return response["data"]["url"] diff --git a/didier/exceptions/http_exception.py b/didier/exceptions/http_exception.py new file mode 100644 index 0000000..c4e9f2c --- /dev/null +++ b/didier/exceptions/http_exception.py @@ -0,0 +1,8 @@ +__all__ = ["HTTPException"] + + +class HTTPException(RuntimeError): + """Error raised when an API call fails""" + + def __init__(self, status_code: int): + super().__init__(f"Something went wrong (status {status_code}).") diff --git a/didier/exceptions/no_match.py b/didier/exceptions/no_match.py new file mode 100644 index 0000000..fed1a6d --- /dev/null +++ b/didier/exceptions/no_match.py @@ -0,0 +1,24 @@ +from typing import Optional, TypeVar + +__all__ = ["NoMatch", "expect"] + + +class NoMatch(ValueError): + """Error raised when a database lookup failed""" + + def __init__(self, entity_type: str, argument: str): + super().__init__(f"Found no {entity_type} matching `{argument}`.") + + +T = TypeVar("T") + + +def expect(instance: Optional[T], *, entity_type: str, argument: str) -> T: + """Mark a database instance as expected, otherwise raise a custom exception + + This is not just done in the database layer because it's not always preferable + """ + if instance is None: + raise NoMatch(entity_type, argument) + + return instance diff --git a/didier/utils/http/__init__.py b/didier/utils/http/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/http/requests.py b/didier/utils/http/requests.py new file mode 100644 index 0000000..f649701 --- /dev/null +++ b/didier/utils/http/requests.py @@ -0,0 +1,56 @@ +import logging +from contextlib import asynccontextmanager +from typing import AsyncGenerator + +from aiohttp import ClientResponse, ClientSession + +from didier.exceptions.http_exception import HTTPException + +logger = logging.getLogger(__name__) + + +__all__ = ["ensure_get", "ensure_post"] + + +def request_successful(response: ClientResponse) -> bool: + """Check if a request was successful or not""" + return 200 <= response.status < 300 + + +@asynccontextmanager +async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]: + """Context manager that automatically raises an exception if a GET-request fails""" + async with http_session.get(endpoint) as response: + if not request_successful(response): + logger.error( + "Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json() + ) + + raise HTTPException(response.status) + + yield await response.json() + + +@asynccontextmanager +async def ensure_post( + http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True +) -> AsyncGenerator[dict, None]: + """Context manager that automatically raises an exception if a POST-request fails""" + async with http_session.post(endpoint, data=payload) as response: + if not request_successful(response): + logger.error( + "Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s", + endpoint, + response.status, + payload, + await response.json(), + ) + + raise HTTPException(response.status) + + if expect_return: + yield await response.json() + else: + # Always return A dict so you can always "use" the result without having to check + # if it is None or not + yield {} diff --git a/main.py b/main.py index fdbc027..f791621 100644 --- a/main.py +++ b/main.py @@ -19,7 +19,7 @@ def setup_logging(): max_log_size = 32 * 1024 * 1024 # Configure Didier handler - didier_log = logging.getLogger("didier") + didier_log = logging.getLogger(__name__) didier_handler = RotatingFileHandler(settings.LOGFILE, mode="a", maxBytes=max_log_size, backupCount=5) didier_handler.setFormatter(logging.Formatter("[%(asctime)s] [%(levelname)s]: %(message)s"))