fix typing issues

pull/128/head
stijndcl 2022-08-28 22:15:03 +02:00
parent 654fbcd46b
commit 0186a0793a
12 changed files with 88 additions and 33 deletions

View File

@ -27,10 +27,13 @@ class Currency(commands.Cog):
@commands.command(name="Award") @commands.command(name="Award")
@commands.check(is_owner) @commands.check(is_owner)
async def award(self, ctx: commands.Context, user: discord.User, amount: abbreviated_number): # type: ignore async def award(
self,
ctx: commands.Context,
user: discord.User,
amount: typing.Annotated[int, abbreviated_number],
):
"""Award a user a given amount of Didier Dinks""" """Award a user a given amount of Didier Dinks"""
amount = typing.cast(int, amount)
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
await crud.add_dinks(session, user.id, amount) await crud.add_dinks(session, user.id, amount)
plural = pluralize("Didier Dink", amount) plural = pluralize("Didier Dink", amount)
@ -116,10 +119,8 @@ class Currency(commands.Cog):
await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False)
@commands.command(name="Invest", aliases=["Deposit", "Dep"]) @commands.command(name="Invest", aliases=["Deposit", "Dep"])
async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]):
"""Invest a given amount of Didier Dinks""" """Invest a given amount of Didier Dinks"""
amount = typing.cast(typing.Union[str, int], amount)
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
invested = await crud.invest(session, ctx.author.id, amount) invested = await crud.invest(session, ctx.author.id, amount)
plural = pluralize("Didier Dink", invested) plural = pluralize("Didier Dink", invested)

View File

@ -1,3 +1,6 @@
from datetime import datetime
from typing import Optional
import discord import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
@ -5,7 +8,10 @@ from discord.ext import commands
from database.crud import ufora_courses from database.crud import ufora_courses
from database.crud.deadlines import get_deadlines from database.crud.deadlines import get_deadlines
from didier import Didier from didier import Didier
from didier.data.apis.hydra import fetch_menu
from didier.data.embeds.deadlines import Deadlines from didier.data.embeds.deadlines import Deadlines
from didier.data.embeds.hydra import no_menu_found
from didier.exceptions import HTTPException
from didier.utils.discord.flags.school import StudyGuideFlags from didier.utils.discord.flags.school import StudyGuideFlags
@ -26,6 +32,23 @@ class School(commands.Cog):
embed = Deadlines(deadlines).to_embed() embed = Deadlines(deadlines).to_embed()
await ctx.reply(embed=embed, mention_author=False, ephemeral=False) await ctx.reply(embed=embed, mention_author=False, ephemeral=False)
@commands.hybrid_command(
name="menu", description="Show the menu in the Ghent University restaurants", aliases=["Eten", "Food"]
)
async def menu(self, ctx: commands.Context, day: Optional[str] = None):
"""Get the menu for a given day in the restaurants"""
# TODO time converter (transformer) for [DAY]
# TODO autocompletion for [DAY]
async with ctx.typing():
day_dt = datetime.now()
try:
menu = await fetch_menu(self.client.http_session, day_dt)
embed = menu.to_embed(day_dt=day_dt)
except HTTPException:
embed = no_menu_found(day_dt)
await ctx.reply(embed=embed, mention_author=False)
@commands.hybrid_command( @commands.hybrid_command(
name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"] name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"]
) )

View File

@ -11,5 +11,5 @@ __all__ = ["fetch_menu"]
async def fetch_menu(http_session: ClientSession, day_dt: datetime) -> Menu: async def fetch_menu(http_session: ClientSession, day_dt: datetime) -> Menu:
"""Fetch the menu for a given day""" """Fetch the menu for a given day"""
endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json" endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json"
async with ensure_get(http_session, endpoint) as response: async with ensure_get(http_session, endpoint, log_exceptions=False) as response:
return Menu.parse_obj(response) return Menu.parse_obj(response)

View File

@ -13,7 +13,7 @@ class EmbedBaseModel(ABC):
"""Abstract base class for a model that can be turned into a Discord embed""" """Abstract base class for a model that can be turned into a Discord embed"""
@abstractmethod @abstractmethod
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
"""Turn this model into a Discord embed""" """Turn this model into a Discord embed"""
raise NotImplementedError raise NotImplementedError

View File

@ -22,7 +22,7 @@ class Deadlines(EmbedBaseModel):
self.deadlines.sort(key=lambda deadline: deadline.deadline) self.deadlines.sort(key=lambda deadline: deadline.deadline)
@overrides @overrides
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(title="Upcoming Deadlines", colour=discord.Colour.dark_gold()) embed = discord.Embed(title="Upcoming Deadlines", colour=discord.Colour.dark_gold())
now = tz_aware_now() now = tz_aware_now()

View File

@ -32,7 +32,7 @@ class GoogleSearch(EmbedBaseModel):
return embed return embed
@overrides @overrides
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
if not self.data.results or self.data.status_code != HTTPStatus.OK: if not self.data.results or self.data.status_code != HTTPStatus.OK:
return self._error_embed() return self._error_embed()

View File

@ -1,3 +1,3 @@
from .menu import Menu from .menu import Menu, no_menu_found
__all__ = ["Menu"] __all__ = ["Menu", "no_menu_found"]

View File

@ -1,4 +1,5 @@
from typing import Literal, Optional from datetime import datetime
from typing import Literal, Optional, cast
import discord import discord
from overrides import overrides from overrides import overrides
@ -6,8 +7,10 @@ from pydantic import BaseModel
from didier.data.embeds.base import EmbedPydantic from didier.data.embeds.base import EmbedPydantic
from didier.utils.discord.colours import ghent_university_blue from didier.utils.discord.colours import ghent_university_blue
from didier.utils.types.datetime import int_to_weekday
from didier.utils.types.string import leading
__all__ = ["Menu"] __all__ = ["Menu", "no_menu_found"]
class _Meal(BaseModel): class _Meal(BaseModel):
@ -16,7 +19,7 @@ class _Meal(BaseModel):
kind: Literal["meat", "fish", "soup", "vegetarian", "vegan"] kind: Literal["meat", "fish", "soup", "vegetarian", "vegan"]
name: str name: str
price: str price: str
type: Literal["main", "side"] type: Literal["cold", "main", "side"]
class Menu(EmbedPydantic): class Menu(EmbedPydantic):
@ -28,7 +31,18 @@ class Menu(EmbedPydantic):
message: Optional[str] = None message: Optional[str] = None
@overrides @overrides
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(title="Menu", colour=ghent_university_blue()) day_dt: datetime = cast(datetime, kwargs.get("day_dt"))
weekday = int_to_weekday(day_dt.weekday())
formatted_date = f"{leading('0', str(day_dt.day))}/{leading('0', str(day_dt.month))}/{day_dt.year}"
embed = discord.Embed(title=f"Menu - {weekday} {formatted_date}", colour=ghent_university_blue())
return embed return embed
def no_menu_found(day_dt: datetime) -> discord.Embed:
"""Return a different embed if no menu could be found"""
embed = discord.Embed(title="Menu", colour=discord.Colour.red())
embed.description = f"Unable to retrieve menu for {day_dt.strftime('%d/%m/%Y')}."
return embed

View File

@ -48,7 +48,7 @@ class UforaNotification(EmbedBaseModel):
self.published_dt = self._published_datetime() self.published_dt = self._published_datetime()
self._published = self._get_published() self._published = self._get_published()
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
"""Turn the notification into an embed""" """Turn the notification into an embed"""
embed = discord.Embed(title=self._title, colour=ghent_university_blue()) embed = discord.Embed(title=self._title, colour=ghent_university_blue())

View File

@ -46,7 +46,7 @@ class Definition(EmbedPydantic):
return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH) return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH)
@overrides @overrides
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(title="Urban Dictionary", colour=colours.urban_dictionary_green()) embed = discord.Embed(title="Urban Dictionary", colour=colours.urban_dictionary_green())
embed.add_field(name="Term", value=self.word, inline=True) embed.add_field(name="Term", value=self.word, inline=True)

View File

@ -126,7 +126,7 @@ class WordleErrorEmbed(EmbedBaseModel):
message: str message: str
@overrides @overrides
def to_embed(self, **kwargs: dict) -> discord.Embed: def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(colour=discord.Colour.red(), title="Wordle") embed = discord.Embed(colour=discord.Colour.red(), title="Wordle")
embed.description = self.message embed.description = self.message
embed.set_footer(text=footer()) embed.set_footer(text=footer())

View File

@ -2,7 +2,7 @@ import logging
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from typing import AsyncGenerator from typing import AsyncGenerator
from aiohttp import ClientResponse, ClientSession from aiohttp import ClientResponse, ClientSession, ContentTypeError
from didier.exceptions.http_exception import HTTPException from didier.exceptions.http_exception import HTTPException
@ -18,13 +18,19 @@ def request_successful(response: ClientResponse) -> bool:
@asynccontextmanager @asynccontextmanager
async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerator[dict, None]: async def ensure_get(
http_session: ClientSession, endpoint: str, *, log_exceptions: bool = True
) -> AsyncGenerator[dict, None]:
"""Context manager that automatically raises an exception if a GET-request fails""" """Context manager that automatically raises an exception if a GET-request fails"""
async with http_session.get(endpoint) as response: async with http_session.get(endpoint) as response:
try:
content = await response.json()
except ContentTypeError:
content = await response.text()
if not request_successful(response): if not request_successful(response):
logger.error( if log_exceptions:
"Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, await response.json() logger.error("Failed HTTP request to %s (status %s)\nResponse: %s", endpoint, response.status, content)
)
raise HTTPException(response.status) raise HTTPException(response.status)
@ -33,17 +39,28 @@ async def ensure_get(http_session: ClientSession, endpoint: str) -> AsyncGenerat
@asynccontextmanager @asynccontextmanager
async def ensure_post( async def ensure_post(
http_session: ClientSession, endpoint: str, payload: dict, *, expect_return: bool = True http_session: ClientSession,
endpoint: str,
payload: dict,
*,
log_exceptions: bool = True,
expect_return: bool = True
) -> AsyncGenerator[dict, None]: ) -> AsyncGenerator[dict, None]:
"""Context manager that automatically raises an exception if a POST-request fails""" """Context manager that automatically raises an exception if a POST-request fails"""
async with http_session.post(endpoint, data=payload) as response: async with http_session.post(endpoint, data=payload) as response:
if not request_successful(response): if not request_successful(response):
try:
content = await response.json()
except ContentTypeError:
content = await response.text()
if log_exceptions:
logger.error( logger.error(
"Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s", "Failed HTTP request to %s (status %s)\nPayload: %s\nResponse: %s",
endpoint, endpoint,
response.status, response.status,
payload, payload,
await response.json(), content,
) )
raise HTTPException(response.status) raise HTTPException(response.status)