mirror of https://github.com/stijndcl/didier
Send daily birthday notifications, add more settings & configs, fix small bugs in database
parent
393cc9c891
commit
da0e60ac4f
|
@ -2,7 +2,7 @@ import datetime
|
|||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import extract, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
|
@ -38,5 +38,8 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
|
|||
|
||||
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
||||
"""Get all birthdays that happen on a given day"""
|
||||
statement = select(Birthday).where(Birthday.birthday == day)
|
||||
days = extract("day", Birthday.birthday)
|
||||
months = extract("month", Birthday.birthday)
|
||||
|
||||
statement = select(Birthday).where((days == day.day) & (months == day.month))
|
||||
return list((await session.execute(statement)).scalars())
|
||||
|
|
|
@ -35,7 +35,13 @@ class Discord(commands.Cog):
|
|||
async def birthday_set(self, ctx: commands.Context, date_str: str):
|
||||
"""Command to set your birthday"""
|
||||
try:
|
||||
date = str_to_date(date_str)
|
||||
default_year = 2001
|
||||
date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
|
||||
|
||||
# If no year was passed, make it 2001 by default
|
||||
if date_str.count("/") == 1:
|
||||
date.replace(year=default_year)
|
||||
|
||||
except ValueError:
|
||||
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
||||
|
||||
|
|
|
@ -5,12 +5,13 @@ from discord.ext import commands, tasks # type: ignore # Strange & incorrect My
|
|||
|
||||
import settings
|
||||
from database import enums
|
||||
from database.crud.birthdays import get_birthdays_on_day
|
||||
from database.crud.ufora_announcements import remove_old_announcements
|
||||
from didier import Didier
|
||||
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
||||
from didier.decorators.tasks import timed_task
|
||||
from didier.utils.discord.checks import is_owner
|
||||
from didier.utils.types.datetime import LOCAL_TIMEZONE
|
||||
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
|
||||
|
||||
# datetime.time()-instances for when every task should run
|
||||
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||
|
@ -31,17 +32,18 @@ class Tasks(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
# Only check birthdays if there's a channel to send it to
|
||||
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
|
||||
self.check_birthdays.start()
|
||||
|
||||
# Only pull announcements if a token was provided
|
||||
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
|
||||
self.pull_ufora_announcements.start()
|
||||
self.remove_old_ufora_announcements.start()
|
||||
|
||||
# Start all tasks
|
||||
self.check_birthdays.start()
|
||||
|
||||
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
|
||||
|
||||
@commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True)
|
||||
@commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
|
||||
@commands.check(is_owner)
|
||||
async def tasks_group(self, ctx: commands.Context):
|
||||
"""Command group for Task-related commands
|
||||
|
@ -64,6 +66,18 @@ class Tasks(commands.Cog):
|
|||
@timed_task(enums.TaskType.BIRTHDAYS)
|
||||
async def check_birthdays(self):
|
||||
"""Check if it's currently anyone's birthday"""
|
||||
now = tz_aware_now().date()
|
||||
async with self.client.db_session as session:
|
||||
birthdays = await get_birthdays_on_day(session, now)
|
||||
|
||||
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
||||
if channel is None:
|
||||
return await self.client.log_error("Unable to find channel for birthday announcements")
|
||||
|
||||
for birthday in birthdays:
|
||||
user = self.client.get_user(birthday.user_id)
|
||||
# TODO more messages?
|
||||
await channel.send(f"Gelukkig verjaardag {user.mention}!")
|
||||
|
||||
@check_birthdays.before_loop
|
||||
async def _before_check_birthdays(self):
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import discord
|
||||
|
@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix
|
|||
__all__ = ["Didier"]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Didier(commands.Bot):
|
||||
"""DIDIER <3"""
|
||||
|
||||
database_caches: CacheManager
|
||||
error_channel: discord.abc.Messageable
|
||||
initial_extensions: tuple[str, ...] = ()
|
||||
http_session: ClientSession
|
||||
|
||||
|
@ -60,6 +65,12 @@ class Didier(commands.Bot):
|
|||
# Create aiohttp session
|
||||
self.http_session = ClientSession()
|
||||
|
||||
# Configure channel to send errors to
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
else:
|
||||
self.error_channel = self.get_user(self.owner_id)
|
||||
|
||||
async def _load_initial_extensions(self):
|
||||
"""Load all extensions that should be loaded before the others"""
|
||||
for extension in self.initial_extensions:
|
||||
|
@ -101,6 +112,13 @@ class Didier(commands.Bot):
|
|||
"""Add an X to a message"""
|
||||
await message.add_reaction("❌")
|
||||
|
||||
async def log_error(self, message: str, log_to_discord: bool = True):
|
||||
"""Send an error message to the logs, and optionally the configured channel"""
|
||||
logger.error(message)
|
||||
if log_to_discord:
|
||||
# TODO pretty embed
|
||||
await self.error_channel.send(message)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import datetime
|
||||
import zoneinfo
|
||||
|
||||
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"]
|
||||
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"]
|
||||
|
||||
from typing import Union
|
||||
|
||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||
|
||||
|
@ -12,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr
|
|||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
||||
|
||||
|
||||
def str_to_date(date_str: str) -> datetime.date:
|
||||
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
|
||||
"""Turn a string into a DD/MM/YYYY date"""
|
||||
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()
|
||||
# Allow passing multiple formats in a list
|
||||
if isinstance(formats, str):
|
||||
formats = [formats]
|
||||
|
||||
for format_str in formats:
|
||||
try:
|
||||
return datetime.datetime.strptime(date_str, format_str).date()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
def tz_aware_now() -> datetime.datetime:
|
||||
"""Get the current date & time, but timezone-aware"""
|
||||
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE)
|
||||
|
|
|
@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie
|
|||
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
|
||||
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
|
||||
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
|
||||
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
|
||||
ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None)
|
||||
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
||||
|
||||
"""API Keys"""
|
||||
|
|
|
@ -52,7 +52,7 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use
|
|||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
|
||||
"""Test getting all birthdays on a given day"""
|
||||
await crud.add_birthday(database_session, user.user_id, datetime.today())
|
||||
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
|
||||
|
||||
user_2 = await users.get_or_add(database_session, user.user_id + 1)
|
||||
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||
|
|
|
@ -0,0 +1,34 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from didier.utils.types.datetime import str_to_date
|
||||
|
||||
|
||||
def test_str_to_date_single_valid():
|
||||
"""Test parsing a string for a single possibility (default)"""
|
||||
result = str_to_date("23/11/2001")
|
||||
assert result == datetime.date(year=2001, month=11, day=23)
|
||||
|
||||
|
||||
def test_str_to_date_single_invalid():
|
||||
"""Test parsing a string for an invalid string"""
|
||||
# Invalid format
|
||||
with pytest.raises(ValueError):
|
||||
str_to_date("23/11/01")
|
||||
|
||||
# Invalid date
|
||||
with pytest.raises(ValueError):
|
||||
str_to_date("69/42/0")
|
||||
|
||||
|
||||
def test_str_to_date_multiple_valid():
|
||||
"""Test parsing a string for multiple possibilities"""
|
||||
result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"])
|
||||
assert result == datetime.date(year=2001, month=11, day=23)
|
||||
|
||||
|
||||
def test_str_to_date_multiple_invalid():
|
||||
"""Test parsing a string for multiple possibilities when none are valid"""
|
||||
with pytest.raises(ValueError):
|
||||
str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])
|
Loading…
Reference in New Issue