Send daily birthday notifications, add more settings & configs, fix small bugs in database

pull/125/head
stijndcl 2022-07-23 23:21:32 +02:00
parent 393cc9c891
commit da0e60ac4f
8 changed files with 105 additions and 12 deletions

View File

@ -2,7 +2,7 @@ import datetime
from datetime import date
from typing import Optional
from sqlalchemy import select
from sqlalchemy import extract, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
@ -38,5 +38,8 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
"""Get all birthdays that happen on a given day"""
statement = select(Birthday).where(Birthday.birthday == day)
days = extract("day", Birthday.birthday)
months = extract("month", Birthday.birthday)
statement = select(Birthday).where((days == day.day) & (months == day.month))
return list((await session.execute(statement)).scalars())

View File

@ -35,7 +35,13 @@ class Discord(commands.Cog):
async def birthday_set(self, ctx: commands.Context, date_str: str):
"""Command to set your birthday"""
try:
date = str_to_date(date_str)
default_year = 2001
date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
# If no year was passed, make it 2001 by default
if date_str.count("/") == 1:
date.replace(year=default_year)
except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)

View File

@ -5,12 +5,13 @@ from discord.ext import commands, tasks # type: ignore # Strange & incorrect My
import settings
from database import enums
from database.crud.birthdays import get_birthdays_on_day
from database.crud.ufora_announcements import remove_old_announcements
from didier import Didier
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
from didier.decorators.tasks import timed_task
from didier.utils.discord.checks import is_owner
from didier.utils.types.datetime import LOCAL_TIMEZONE
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
# datetime.time()-instances for when every task should run
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
@ -31,17 +32,18 @@ class Tasks(commands.Cog):
def __init__(self, client: Didier):
self.client = client
# Only check birthdays if there's a channel to send it to
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
self.check_birthdays.start()
# Only pull announcements if a token was provided
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
self.pull_ufora_announcements.start()
self.remove_old_ufora_announcements.start()
# Start all tasks
self.check_birthdays.start()
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
@commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True)
@commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
@commands.check(is_owner)
async def tasks_group(self, ctx: commands.Context):
"""Command group for Task-related commands
@ -64,6 +66,18 @@ class Tasks(commands.Cog):
@timed_task(enums.TaskType.BIRTHDAYS)
async def check_birthdays(self):
"""Check if it's currently anyone's birthday"""
now = tz_aware_now().date()
async with self.client.db_session as session:
birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
if channel is None:
return await self.client.log_error("Unable to find channel for birthday announcements")
for birthday in birthdays:
user = self.client.get_user(birthday.user_id)
# TODO more messages?
await channel.send(f"Gelukkig verjaardag {user.mention}!")
@check_birthdays.before_loop
async def _before_check_birthdays(self):

View File

@ -1,3 +1,4 @@
import logging
import os
import discord
@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix
__all__ = ["Didier"]
logger = logging.getLogger(__name__)
class Didier(commands.Bot):
"""DIDIER <3"""
database_caches: CacheManager
error_channel: discord.abc.Messageable
initial_extensions: tuple[str, ...] = ()
http_session: ClientSession
@ -60,6 +65,12 @@ class Didier(commands.Bot):
# Create aiohttp session
self.http_session = ClientSession()
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
else:
self.error_channel = self.get_user(self.owner_id)
async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others"""
for extension in self.initial_extensions:
@ -101,6 +112,13 @@ class Didier(commands.Bot):
"""Add an X to a message"""
await message.add_reaction("")
async def log_error(self, message: str, log_to_discord: bool = True):
"""Send an error message to the logs, and optionally the configured channel"""
logger.error(message)
if log_to_discord:
# TODO pretty embed
await self.error_channel.send(message)
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)

View File

@ -1,8 +1,9 @@
import datetime
import zoneinfo
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"]
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"]
from typing import Union
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
@ -12,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
def str_to_date(date_str: str) -> datetime.date:
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
"""Turn a string into a DD/MM/YYYY date"""
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()
# Allow passing multiple formats in a list
if isinstance(formats, str):
formats = [formats]
for format_str in formats:
try:
return datetime.datetime.strptime(date_str, format_str).date()
except ValueError:
continue
raise ValueError
def tz_aware_now() -> datetime.datetime:
"""Get the current date & time, but timezone-aware"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE)

View File

@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None)
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
"""API Keys"""

View File

@ -52,7 +52,7 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use
@freeze_time("2022/07/23")
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
"""Test getting all birthdays on a given day"""
await crud.add_birthday(database_session, user.user_id, datetime.today())
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
user_2 = await users.get_or_add(database_session, user.user_id + 1)
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))

View File

@ -0,0 +1,34 @@
import datetime
import pytest
from didier.utils.types.datetime import str_to_date
def test_str_to_date_single_valid():
"""Test parsing a string for a single possibility (default)"""
result = str_to_date("23/11/2001")
assert result == datetime.date(year=2001, month=11, day=23)
def test_str_to_date_single_invalid():
"""Test parsing a string for an invalid string"""
# Invalid format
with pytest.raises(ValueError):
str_to_date("23/11/01")
# Invalid date
with pytest.raises(ValueError):
str_to_date("69/42/0")
def test_str_to_date_multiple_valid():
"""Test parsing a string for multiple possibilities"""
result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"])
assert result == datetime.date(year=2001, month=11, day=23)
def test_str_to_date_multiple_invalid():
"""Test parsing a string for multiple possibilities when none are valid"""
with pytest.raises(ValueError):
str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])