Compare commits

...

9 Commits

Author SHA1 Message Date
stijndcl 00a146cb2b List memes 2022-09-22 17:28:23 +02:00
stijndcl 6cfe788df5 Fix typing 2022-09-22 17:01:56 +02:00
stijndcl f049f1c80b Improve docs for memegen 2022-09-22 17:01:29 +02:00
stijndcl f19a832725 Send logs in an embed 2022-09-22 16:58:48 +02:00
stijndcl d9272f17ab Create command to monitor task execution 2022-09-22 16:49:33 +02:00
stijndcl 7f21a1cf69 Defer faster, no longer error for courses than we don't know of 2022-09-22 16:38:23 +02:00
Stijn De Clercq dc1f0f6b55
Merge pull request #133 from stijndcl/covid
Rewrite Covid
2022-09-22 16:34:39 +02:00
stijndcl 2d0babbdcb Create covid command 2022-09-22 16:30:58 +02:00
stijndcl df884f55f1 Covid api requests 2022-09-22 02:04:34 +02:00
16 changed files with 513 additions and 31 deletions

View File

@ -5,10 +5,12 @@ from discord import app_commands
from discord.ext import commands
from database.crud.dad_jokes import get_random_dad_joke
from database.crud.memes import get_meme_by_name
from database.crud.memes import get_all_memes, get_meme_by_name
from didier import Didier
from didier.data.apis.imgflip import generate_meme
from didier.exceptions.no_match import expect
from didier.menus.common import Menu
from didier.menus.memes import MemeSource
from didier.views.modals import GenerateMeme
@ -48,11 +50,29 @@ class Fun(commands.Cog):
Example: `memegen a b c d` will be parsed as `template: "a"`, `fields: ["b", "c", "d"]`
Example: `memegen "a b" "c d"` will be parsed as `template: "a b"`, `fields: ["c d"]`
In case a template only has 1 field, quotes aren't required and your arguments will be combined into one field.
Example: if template `a` only has 1 field,
`memegen a b c d` will be parsed as `template: "a"`, `fields: ["bcd"]`
"""
async with ctx.typing():
meme = await self._do_generate_meme(template, shlex.split(fields))
return await ctx.reply(meme, mention_author=False)
@memegen_msg.command(name="list", aliases=["ls"])
async def memegen_ls_msg(self, ctx: commands.Context):
"""Get a list of all available meme templates.
This command does _not_ have a /slash variant, as the memegen /slash commands provide autocompletion.
"""
async with self.client.postgres_session as session:
results = await get_all_memes(session)
source = MemeSource(ctx, results)
menu = Menu(source)
await menu.start(ctx)
@memegen_msg.command(name="preview", aliases=["p"])
async def memegen_preview_msg(self, ctx: commands.Context, template: str):
"""Generate a preview for the meme template `template`, to see how the fields are structured."""

View File

@ -7,9 +7,10 @@ from discord.ext import commands
from database.crud.links import get_link_by_name
from database.schemas import Link
from didier import Didier
from didier.data.apis import inspirobot, urban_dictionary
from didier.data.apis import disease_sh, inspirobot, urban_dictionary
from didier.data.embeds.google import GoogleSearch
from didier.data.scrapers import google
from didier.utils.discord.autocompletion.country import autocomplete_country
class Other(commands.Cog):
@ -20,16 +21,34 @@ class Other(commands.Cog):
def __init__(self, client: Didier):
self.client = client
@commands.hybrid_command(name="corona", aliases=["covid", "rona"])
async def covid(self, ctx: commands.Context, country: str = "Belgium"):
"""Show Covid-19 info for a specific country.
By default, this will fetch the numbers for Belgium.
To get worldwide stats, use `all`, `global`, `world`, or `worldwide`.
"""
async with ctx.typing():
if country.lower() in ["all", "global", "world", "worldwide"]:
data = await disease_sh.get_global_info(self.client.http_session)
else:
data = await disease_sh.get_country_info(self.client.http_session, country)
await ctx.reply(embed=data.to_embed(), mention_author=False)
@covid.autocomplete("country")
async def _covid_country_autocomplete(self, interaction: discord.Interaction, value: str):
"""Autocompletion for the 'country'-parameter"""
return autocomplete_country(value)[:25]
@commands.hybrid_command(
name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary"
)
async def define(self, ctx: commands.Context, *, query: str):
"""Look up the definition of `query` on the Urban Dictionary."""
async with ctx.typing():
status_code, definitions = await urban_dictionary.lookup(self.client.http_session, query)
if not definitions:
return await ctx.reply(f"Something went wrong (status {status_code})")
definitions = await urban_dictionary.lookup(self.client.http_session, query)
await ctx.reply(embed=definitions[0].to_embed(), mention_author=False)
@commands.hybrid_command(name="google", description="Google search")

View File

@ -47,12 +47,12 @@ class School(commands.Cog):
Schedules are personalized based on your roles in the server. If your schedule doesn't look right, make sure
that you've got the correct roles selected. In case you do, ping D STIJN.
"""
if day_dt is None:
day_dt = date.today()
day_dt = skip_weekends(day_dt)
async with ctx.typing():
if day_dt is None:
day_dt = date.today()
day_dt = skip_weekends(day_dt)
try:
member_instance = to_main_guild_member(self.client, ctx.author)

View File

@ -2,6 +2,7 @@ import datetime
import random
import traceback
import discord
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
from overrides import overrides
@ -77,7 +78,14 @@ class Tasks(commands.Cog):
Invoking the group itself shows the time until the next iteration
"""
raise NotImplementedError()
embed = discord.Embed(colour=discord.Colour.blue(), title="Tasks")
for name, task in self._tasks.items():
next_iter = task.next_iteration
timestamp = f"<t:{round(next_iter.timestamp())}:R>" if next_iter is not None else "N/A"
embed.add_field(name=name, value=timestamp)
await ctx.reply(embed=embed, mention_author=False)
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]")
async def force_task(self, ctx: commands.Context, name: str):

View File

@ -0,0 +1,38 @@
from aiohttp import ClientSession
from didier.data.embeds.disease_sh import CovidData
from didier.utils.http.requests import ensure_get
__all__ = ["get_country_info", "get_global_info"]
async def get_country_info(http_session: ClientSession, country: str) -> CovidData:
"""Fetch the info for a given country for today and yesterday"""
endpoint = f"https://disease.sh/v3/covid-19/countries/{country}"
params = {"yesterday": 0, "strict": 1, "allowNull": 0}
async with ensure_get(http_session, endpoint, params=params) as response:
today = response
params["yesterday"] = 1
async with ensure_get(http_session, endpoint, params=params) as response:
yesterday = response
data = {"today": today, "yesterday": yesterday}
return CovidData.parse_obj(data)
async def get_global_info(http_session: ClientSession) -> CovidData:
"""Fetch the global info for today and yesterday"""
endpoint = "https://disease.sh/v3/covid-19/all"
params = {"yesterday": 0, "allowNull": 0}
async with ensure_get(http_session, endpoint, params=params) as response:
today = response
params["yesterday"] = 1
async with ensure_get(http_session, endpoint, params=params) as response:
yesterday = response
data = {"today": today, "yesterday": yesterday}
return CovidData.parse_obj(data)

View File

@ -1,8 +1,7 @@
from http import HTTPStatus
from aiohttp import ClientSession
from didier.data.embeds.urban_dictionary import Definition
from didier.utils.http.requests import ensure_get
__all__ = ["lookup", "PER_PAGE"]
@ -10,13 +9,9 @@ __all__ = ["lookup", "PER_PAGE"]
PER_PAGE = 10
async def lookup(http_session: ClientSession, query: str) -> tuple[int, list[Definition]]:
async def lookup(http_session: ClientSession, query: str) -> list[Definition]:
"""Fetch the Urban Dictionary definitions for a given word"""
url = "https://api.urbandictionary.com/v0/define"
async with http_session.get(url, params={"term": query}) as response:
if response.status != HTTPStatus.OK:
return response.status, []
response_json = await response.json()
return 200, list(map(Definition.parse_obj, response_json["list"]))
async with ensure_get(http_session, url, params={"term": query}) as response:
return list(map(Definition.parse_obj, response["list"]))

View File

@ -0,0 +1,100 @@
import discord
from overrides import overrides
from pydantic import BaseModel, Field, validator
from didier.data.embeds.base import EmbedPydantic
__all__ = ["CovidData"]
class _CovidNumbers(BaseModel):
"""Covid numbers for a country
For worldwide numbers, country_info will be None
"""
updated: int
country: str = "Worldwide"
cases: int
today_cases: int = Field(alias="todayCases")
deaths: int
today_deaths: int = Field(alias="todayDeaths")
recovered: int
today_recovered: int = Field(alias="todayRecovered")
active: int
tests: int
@validator("updated")
def updated_to_seconds(cls, value: int) -> int:
"""Turn the updated field into seconds instead of milliseconds"""
return int(value) // 1000
class CovidData(EmbedPydantic):
"""Covid information from two days combined into one model"""
today: _CovidNumbers
yesterday: _CovidNumbers
@overrides
def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(colour=discord.Colour.red(), title=f"Coronatracker {self.today.country}")
embed.description = f"Last update: <t:{self.today.updated}:R>"
embed.set_thumbnail(url="https://i.imgur.com/aWnDuBt.png")
cases_indicator = self._trend_indicator(self.today.today_cases, self.yesterday.today_cases)
embed.add_field(
name="Cases (Today)",
value=f"{self.today.cases:,} **({self.today.today_cases:,})** {cases_indicator}".replace(",", "."),
inline=False,
)
active_indicator = self._trend_indicator(self.today.active, self.yesterday.active)
active_diff = self.today.active - self.yesterday.active
embed.add_field(
name="Active (Today)",
value=f"{self.today.active:,} **({self._with_sign(active_diff)})** {active_indicator}".replace(",", "."),
inline=False,
)
deaths_indicator = self._trend_indicator(self.today.today_deaths, self.yesterday.today_deaths)
embed.add_field(
name="Deaths (Today)",
value=f"{self.today.deaths:,} **({self.today.today_deaths:,})** {deaths_indicator}".replace(",", "."),
inline=False,
)
recovered_indicator = self._trend_indicator(self.today.today_recovered, self.yesterday.today_recovered)
embed.add_field(
name="Recovered (Today)",
value=f"{self.today.recovered} **({self.today.today_recovered:,})** {recovered_indicator}".replace(
",", "."
),
inline=False,
)
tests_diff = self.today.tests - self.yesterday.tests
embed.add_field(
name="Tests Administered (Today)",
value=f"{self.today.tests:,} **({tests_diff:,})**".replace(",", "."),
inline=False,
)
return embed
def _with_sign(self, value: int) -> str:
"""Prepend a + symbol if a number is positive"""
if value > 0:
return f"+{value:,}"
return f"{value:,}"
def _trend_indicator(self, today: int, yesterday: int) -> str:
"""Function that returns a rise/decline indicator for the target key."""
if today > yesterday:
return ":small_red_triangle:"
if yesterday > today:
return ":small_red_triangle_down:"
return ""

View File

@ -0,0 +1,21 @@
import logging
import discord
__all__ = ["create_logging_embed"]
def create_logging_embed(level: int, message: str) -> discord.Embed:
"""Create an embed to send to the logging channel"""
colours = {
logging.DEBUG: discord.Colour.light_gray,
logging.ERROR: discord.Colour.red(),
logging.INFO: discord.Colour.blue(),
logging.WARNING: discord.Colour.yellow(),
}
colour = colours.get(level, discord.Colour.red())
embed = discord.Embed(colour=colour, title="Logging")
embed.description = message
return embed

View File

@ -190,8 +190,7 @@ async def parse_schedule_from_content(content: str, *, database_session: AsyncSe
if code not in course_codes:
course = await get_course_by_code(database_session, code)
if course is None:
# raise ValueError(f"Unable to find course with code {code} (event {event.name})") # noqa: E800
continue # TODO uncomment the line above after all courses have been added
continue
course_codes[code] = course

View File

@ -16,6 +16,7 @@ from database.crud import command_stats, custom_commands
from database.engine import DBSession
from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed
from didier.data.embeds.logging_embed import create_logging_embed
from didier.data.embeds.schedules import Schedule, parse_schedule
from didier.exceptions import HTTPException, NoMatch
from didier.utils.discord.prefix import get_prefix
@ -181,15 +182,16 @@ class Didier(commands.Bot):
async def _log(self, level: int, message: str, log_to_discord: bool = True):
"""Log a message to the logging file, and optionally to the configured channel"""
methods = {
logging.DEBUG: logger.debug,
logging.ERROR: logger.error,
logging.INFO: logger.info,
logging.WARNING: logger.warning,
}
methods.get(level, logger.error)(message)
if log_to_discord:
# TODO pretty embed
# different colours per level?
await self.error_channel.send(message)
embed = create_logging_embed(level, message)
await self.error_channel.send(embed=embed)
async def log_error(self, message: str, log_to_discord: bool = True):
"""Log an error message"""

View File

@ -22,7 +22,7 @@ class BookmarkSource(PageSource[Bookmark]):
description = ""
for bookmark in self.dataset[page : page + self.per_page]:
for bookmark in self.get_page_data(page):
description += f"`#{bookmark.bookmark_id}`: [{bookmark.label}]({bookmark.jump_url})\n"
embed.description = description.strip()

View File

@ -52,6 +52,10 @@ class PageSource(ABC, Generic[T]):
"""Method that builds the list of embeds from the input data"""
raise NotImplementedError
def get_page_data(self, page: int) -> list[T]:
"""Get the chunk of the dataset for page [page]"""
return self.dataset[page : page + self.per_page]
class Menu(discord.ui.View):
"""Base class for a menu"""

View File

@ -0,0 +1,27 @@
import discord
from discord.ext import commands
from overrides import overrides
from database.schemas import MemeTemplate
from didier.menus.common import PageSource
__all__ = ["MemeSource"]
class MemeSource(PageSource[MemeTemplate]):
"""PageSource for meme templates"""
@overrides
def create_embeds(self, ctx: commands.Context):
for page in range(self.page_count):
# The colour of the embed is (69,4,20) with the values +100 because they were too dark
embed = discord.Embed(title="Meme Templates", colour=discord.Colour.from_rgb(169, 14, 120))
description_data = []
for template in self.get_page_data(page):
description_data.append(f"{template.name} ({template.field_count})")
embed.description = "\n".join(description_data)
embed.set_footer(text="Format: Template Name (Field Count)")
self.embeds.append(embed)

View File

@ -0,0 +1,250 @@
from discord import app_commands
__all__ = ["autocomplete_country"]
# This list was parsed out of a request to https://disease.sh/v3/covid-19/countries
country_list = [
"Afghanistan",
"Albania",
"Algeria",
"Andorra",
"Angola",
"Anguilla",
"Antigua and Barbuda",
"Argentina",
"Armenia",
"Aruba",
"Australia",
"Austria",
"Azerbaijan",
"Bahamas",
"Bahrain",
"Bangladesh",
"Barbados",
"Belarus",
"Belgium",
"Belize",
"Benin",
"Bermuda",
"Bhutan",
"Bolivia",
"Bosnia",
"Botswana",
"Brazil",
"British Virgin Islands",
"Brunei",
"Bulgaria",
"Burkina Faso",
"Burundi",
"Cabo Verde",
"Cambodia",
"Cameroon",
"Canada",
"Caribbean Netherlands",
"Cayman Islands",
"Central African Republic",
"Chad",
"Channel Islands",
"Chile",
"China",
"Colombia",
"Comoros",
"Congo",
"Cook Islands",
"Costa Rica",
"Croatia",
"Cuba",
"Curaçao",
"Cyprus",
"Czechia",
"Côte d'Ivoire",
"DRC",
"Denmark",
"Diamond Princess",
"Djibouti",
"Dominica",
"Dominican Republic",
"Ecuador",
"Egypt",
"El Salvador",
"Equatorial Guinea",
"Eritrea",
"Estonia",
"Ethiopia",
"Falkland Islands (Malvinas)",
"Faroe Islands",
"Fiji",
"Finland",
"France",
"French Guiana",
"French Polynesia",
"Gabon",
"Gambia",
"Georgia",
"Germany",
"Ghana",
"Gibraltar",
"Greece",
"Greenland",
"Grenada",
"Guadeloupe",
"Guatemala",
"Guinea",
"Guinea-Bissau",
"Guyana",
"Haiti",
"Holy See (Vatican City State)",
"Honduras",
"Hong Kong",
"Hungary",
"Iceland",
"India",
"Indonesia",
"Iran",
"Iraq",
"Ireland",
"Isle of Man",
"Israel",
"Italy",
"Jamaica",
"Japan",
"Jordan",
"Kazakhstan",
"Kenya",
"Kiribati",
"Kuwait",
"Kyrgyzstan",
"Lao People's Democratic Republic",
"Latvia",
"Lebanon",
"Lesotho",
"Liberia",
"Libyan Arab Jamahiriya",
"Liechtenstein",
"Lithuania",
"Luxembourg",
"MS Zaandam",
"Macao",
"Macedonia",
"Madagascar",
"Malawi",
"Malaysia",
"Maldives",
"Mali",
"Malta",
"Marshall Islands",
"Martinique",
"Mauritania",
"Mauritius",
"Mayotte",
"Mexico",
"Micronesia",
"Moldova",
"Monaco",
"Mongolia",
"Montenegro",
"Montserrat",
"Morocco",
"Mozambique",
"Myanmar",
"N. Korea",
"Namibia",
"Nauru",
"Nepal",
"Netherlands",
"New Caledonia",
"New Zealand",
"Nicaragua",
"Niger",
"Nigeria",
"Niue",
"Norway",
"Oman",
"Pakistan",
"Palau",
"Palestine",
"Panama",
"Papua New Guinea",
"Paraguay",
"Peru",
"Philippines",
"Poland",
"Portugal",
"Qatar",
"Romania",
"Russia",
"Rwanda",
"Réunion",
"S. Korea",
"Saint Helena",
"Saint Kitts and Nevis",
"Saint Lucia",
"Saint Martin",
"Saint Pierre Miquelon",
"Saint Vincent and the Grenadines",
"Samoa",
"San Marino",
"Sao Tome and Principe",
"Saudi Arabia",
"Senegal",
"Serbia",
"Seychelles",
"Sierra Leone",
"Singapore",
"Sint Maarten",
"Slovakia",
"Slovenia",
"Solomon Islands",
"Somalia",
"South Africa",
"South Sudan",
"Spain",
"Sri Lanka",
"St. Barth",
"Sudan",
"Suriname",
"Swaziland",
"Sweden",
"Switzerland",
"Syrian Arab Republic",
"Taiwan",
"Tajikistan",
"Tanzania",
"Thailand",
"Timor-Leste",
"Togo",
"Tonga",
"Trinidad and Tobago",
"Tunisia",
"Turkey",
"Turks and Caicos Islands",
"Tuvalu",
"UAE",
"UK",
"USA",
"Uganda",
"Ukraine",
"Uruguay",
"Uzbekistan",
"Vanuatu",
"Venezuela",
"Vietnam",
"Wallis and Futuna",
"Western Sahara",
"Yemen",
"Zambia",
"Zimbabwe",
]
def autocomplete_country(argument: str) -> list[app_commands.Choice]:
"""Autocompletion for country names"""
argument = argument.lower()
global_autocomplete = ["Global"]
return [
app_commands.Choice(name=country, value=country)
for country in global_autocomplete + country_list
if argument in country.lower()
]

View File

@ -1,6 +1,6 @@
import logging
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from typing import AsyncGenerator, Optional
from aiohttp import ClientResponse, ClientSession, ContentTypeError
@ -19,10 +19,10 @@ def request_successful(response: ClientResponse) -> bool:
@asynccontextmanager
async def ensure_get(
http_session: ClientSession, endpoint: str, *, log_exceptions: bool = True
http_session: ClientSession, endpoint: str, *, params: Optional[dict] = None, log_exceptions: bool = True
) -> AsyncGenerator[dict, None]:
"""Context manager that automatically raises an exception if a GET-request fails"""
async with http_session.get(endpoint) as response:
async with http_session.get(endpoint, params=params) as response:
try:
content = await response.json()
except ContentTypeError:

View File

@ -3,7 +3,6 @@ alembic==1.8.0
asyncpg==0.25.0
beautifulsoup4==4.11.1
discord.py==2.0.1
git+https://github.com/Rapptz/discord-ext-menus@8686b5d
environs==9.5.0
feedparser==6.0.10
ics==0.7.2