Send daily birthday notifications, add more settings & configs, fix small bugs in database

pull/125/head
stijndcl 2022-07-23 23:21:32 +02:00
parent 393cc9c891
commit da0e60ac4f
8 changed files with 105 additions and 12 deletions

View File

@ -2,7 +2,7 @@ import datetime
from datetime import date from datetime import date
from typing import Optional from typing import Optional
from sqlalchemy import select from sqlalchemy import extract, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
@ -38,5 +38,8 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]: async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
"""Get all birthdays that happen on a given day""" """Get all birthdays that happen on a given day"""
statement = select(Birthday).where(Birthday.birthday == day) days = extract("day", Birthday.birthday)
months = extract("month", Birthday.birthday)
statement = select(Birthday).where((days == day.day) & (months == day.month))
return list((await session.execute(statement)).scalars()) return list((await session.execute(statement)).scalars())

View File

@ -35,7 +35,13 @@ class Discord(commands.Cog):
async def birthday_set(self, ctx: commands.Context, date_str: str): async def birthday_set(self, ctx: commands.Context, date_str: str):
"""Command to set your birthday""" """Command to set your birthday"""
try: try:
date = str_to_date(date_str) default_year = 2001
date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
# If no year was passed, make it 2001 by default
if date_str.count("/") == 1:
date.replace(year=default_year)
except ValueError: except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)

View File

@ -5,12 +5,13 @@ from discord.ext import commands, tasks # type: ignore # Strange & incorrect My
import settings import settings
from database import enums from database import enums
from database.crud.birthdays import get_birthdays_on_day
from database.crud.ufora_announcements import remove_old_announcements from database.crud.ufora_announcements import remove_old_announcements
from didier import Didier from didier import Didier
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
from didier.decorators.tasks import timed_task from didier.decorators.tasks import timed_task
from didier.utils.discord.checks import is_owner from didier.utils.discord.checks import is_owner
from didier.utils.types.datetime import LOCAL_TIMEZONE from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
# datetime.time()-instances for when every task should run # datetime.time()-instances for when every task should run
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
@ -31,17 +32,18 @@ class Tasks(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
# Only check birthdays if there's a channel to send it to
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
self.check_birthdays.start()
# Only pull announcements if a token was provided # Only pull announcements if a token was provided
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None: if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
self.pull_ufora_announcements.start() self.pull_ufora_announcements.start()
self.remove_old_ufora_announcements.start() self.remove_old_ufora_announcements.start()
# Start all tasks
self.check_birthdays.start()
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements} self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
@commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True) @commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
@commands.check(is_owner) @commands.check(is_owner)
async def tasks_group(self, ctx: commands.Context): async def tasks_group(self, ctx: commands.Context):
"""Command group for Task-related commands """Command group for Task-related commands
@ -64,6 +66,18 @@ class Tasks(commands.Cog):
@timed_task(enums.TaskType.BIRTHDAYS) @timed_task(enums.TaskType.BIRTHDAYS)
async def check_birthdays(self): async def check_birthdays(self):
"""Check if it's currently anyone's birthday""" """Check if it's currently anyone's birthday"""
now = tz_aware_now().date()
async with self.client.db_session as session:
birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
if channel is None:
return await self.client.log_error("Unable to find channel for birthday announcements")
for birthday in birthdays:
user = self.client.get_user(birthday.user_id)
# TODO more messages?
await channel.send(f"Gelukkig verjaardag {user.mention}!")
@check_birthdays.before_loop @check_birthdays.before_loop
async def _before_check_birthdays(self): async def _before_check_birthdays(self):

View File

@ -1,3 +1,4 @@
import logging
import os import os
import discord import discord
@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix
__all__ = ["Didier"] __all__ = ["Didier"]
logger = logging.getLogger(__name__)
class Didier(commands.Bot): class Didier(commands.Bot):
"""DIDIER <3""" """DIDIER <3"""
database_caches: CacheManager database_caches: CacheManager
error_channel: discord.abc.Messageable
initial_extensions: tuple[str, ...] = () initial_extensions: tuple[str, ...] = ()
http_session: ClientSession http_session: ClientSession
@ -60,6 +65,12 @@ class Didier(commands.Bot):
# Create aiohttp session # Create aiohttp session
self.http_session = ClientSession() self.http_session = ClientSession()
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
else:
self.error_channel = self.get_user(self.owner_id)
async def _load_initial_extensions(self): async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others""" """Load all extensions that should be loaded before the others"""
for extension in self.initial_extensions: for extension in self.initial_extensions:
@ -101,6 +112,13 @@ class Didier(commands.Bot):
"""Add an X to a message""" """Add an X to a message"""
await message.add_reaction("") await message.add_reaction("")
async def log_error(self, message: str, log_to_discord: bool = True):
"""Send an error message to the logs, and optionally the configured channel"""
logger.error(message)
if log_to_discord:
# TODO pretty embed
await self.error_channel.send(message)
async def on_ready(self): async def on_ready(self):
"""Event triggered when the bot is ready""" """Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE) print(settings.DISCORD_READY_MESSAGE)

View File

@ -1,8 +1,9 @@
import datetime import datetime
import zoneinfo import zoneinfo
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"] __all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"]
from typing import Union
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
@ -12,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
def str_to_date(date_str: str) -> datetime.date: def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
"""Turn a string into a DD/MM/YYYY date""" """Turn a string into a DD/MM/YYYY date"""
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date() # Allow passing multiple formats in a list
if isinstance(formats, str):
formats = [formats]
for format_str in formats:
try:
return datetime.datetime.strptime(date_str, format_str).date()
except ValueError:
continue
raise ValueError
def tz_aware_now() -> datetime.datetime:
"""Get the current date & time, but timezone-aware"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE)

View File

@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int) DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>") DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?") DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None)
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
"""API Keys""" """API Keys"""

View File

@ -52,7 +52,7 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use
@freeze_time("2022/07/23") @freeze_time("2022/07/23")
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
"""Test getting all birthdays on a given day""" """Test getting all birthdays on a given day"""
await crud.add_birthday(database_session, user.user_id, datetime.today()) await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
user_2 = await users.get_or_add(database_session, user.user_id + 1) user_2 = await users.get_or_add(database_session, user.user_id + 1)
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))

View File

@ -0,0 +1,34 @@
import datetime
import pytest
from didier.utils.types.datetime import str_to_date
def test_str_to_date_single_valid():
"""Test parsing a string for a single possibility (default)"""
result = str_to_date("23/11/2001")
assert result == datetime.date(year=2001, month=11, day=23)
def test_str_to_date_single_invalid():
"""Test parsing a string for an invalid string"""
# Invalid format
with pytest.raises(ValueError):
str_to_date("23/11/01")
# Invalid date
with pytest.raises(ValueError):
str_to_date("69/42/0")
def test_str_to_date_multiple_valid():
"""Test parsing a string for multiple possibilities"""
result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"])
assert result == datetime.date(year=2001, month=11, day=23)
def test_str_to_date_multiple_invalid():
"""Test parsing a string for multiple possibilities when none are valid"""
with pytest.raises(ValueError):
str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])