mirror of https://github.com/stijndcl/didier
Send daily birthday notifications, add more settings & configs, fix small bugs in database
parent
393cc9c891
commit
da0e60ac4f
|
@ -2,7 +2,7 @@ import datetime
|
||||||
from datetime import date
|
from datetime import date
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import extract, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
|
@ -38,5 +38,8 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
|
||||||
|
|
||||||
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
||||||
"""Get all birthdays that happen on a given day"""
|
"""Get all birthdays that happen on a given day"""
|
||||||
statement = select(Birthday).where(Birthday.birthday == day)
|
days = extract("day", Birthday.birthday)
|
||||||
|
months = extract("month", Birthday.birthday)
|
||||||
|
|
||||||
|
statement = select(Birthday).where((days == day.day) & (months == day.month))
|
||||||
return list((await session.execute(statement)).scalars())
|
return list((await session.execute(statement)).scalars())
|
||||||
|
|
|
@ -35,7 +35,13 @@ class Discord(commands.Cog):
|
||||||
async def birthday_set(self, ctx: commands.Context, date_str: str):
|
async def birthday_set(self, ctx: commands.Context, date_str: str):
|
||||||
"""Command to set your birthday"""
|
"""Command to set your birthday"""
|
||||||
try:
|
try:
|
||||||
date = str_to_date(date_str)
|
default_year = 2001
|
||||||
|
date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
|
||||||
|
|
||||||
|
# If no year was passed, make it 2001 by default
|
||||||
|
if date_str.count("/") == 1:
|
||||||
|
date.replace(year=default_year)
|
||||||
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
||||||
|
|
||||||
|
|
|
@ -5,12 +5,13 @@ from discord.ext import commands, tasks # type: ignore # Strange & incorrect My
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database import enums
|
from database import enums
|
||||||
|
from database.crud.birthdays import get_birthdays_on_day
|
||||||
from database.crud.ufora_announcements import remove_old_announcements
|
from database.crud.ufora_announcements import remove_old_announcements
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
||||||
from didier.decorators.tasks import timed_task
|
from didier.decorators.tasks import timed_task
|
||||||
from didier.utils.discord.checks import is_owner
|
from didier.utils.discord.checks import is_owner
|
||||||
from didier.utils.types.datetime import LOCAL_TIMEZONE
|
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
|
||||||
|
|
||||||
# datetime.time()-instances for when every task should run
|
# datetime.time()-instances for when every task should run
|
||||||
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
|
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||||
|
@ -31,17 +32,18 @@ class Tasks(commands.Cog):
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
|
# Only check birthdays if there's a channel to send it to
|
||||||
|
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
|
||||||
|
self.check_birthdays.start()
|
||||||
|
|
||||||
# Only pull announcements if a token was provided
|
# Only pull announcements if a token was provided
|
||||||
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
|
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
|
||||||
self.pull_ufora_announcements.start()
|
self.pull_ufora_announcements.start()
|
||||||
self.remove_old_ufora_announcements.start()
|
self.remove_old_ufora_announcements.start()
|
||||||
|
|
||||||
# Start all tasks
|
|
||||||
self.check_birthdays.start()
|
|
||||||
|
|
||||||
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
|
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
|
||||||
|
|
||||||
@commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True)
|
@commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
|
||||||
@commands.check(is_owner)
|
@commands.check(is_owner)
|
||||||
async def tasks_group(self, ctx: commands.Context):
|
async def tasks_group(self, ctx: commands.Context):
|
||||||
"""Command group for Task-related commands
|
"""Command group for Task-related commands
|
||||||
|
@ -64,6 +66,18 @@ class Tasks(commands.Cog):
|
||||||
@timed_task(enums.TaskType.BIRTHDAYS)
|
@timed_task(enums.TaskType.BIRTHDAYS)
|
||||||
async def check_birthdays(self):
|
async def check_birthdays(self):
|
||||||
"""Check if it's currently anyone's birthday"""
|
"""Check if it's currently anyone's birthday"""
|
||||||
|
now = tz_aware_now().date()
|
||||||
|
async with self.client.db_session as session:
|
||||||
|
birthdays = await get_birthdays_on_day(session, now)
|
||||||
|
|
||||||
|
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
||||||
|
if channel is None:
|
||||||
|
return await self.client.log_error("Unable to find channel for birthday announcements")
|
||||||
|
|
||||||
|
for birthday in birthdays:
|
||||||
|
user = self.client.get_user(birthday.user_id)
|
||||||
|
# TODO more messages?
|
||||||
|
await channel.send(f"Gelukkig verjaardag {user.mention}!")
|
||||||
|
|
||||||
@check_birthdays.before_loop
|
@check_birthdays.before_loop
|
||||||
async def _before_check_birthdays(self):
|
async def _before_check_birthdays(self):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix
|
||||||
__all__ = ["Didier"]
|
__all__ = ["Didier"]
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class Didier(commands.Bot):
|
class Didier(commands.Bot):
|
||||||
"""DIDIER <3"""
|
"""DIDIER <3"""
|
||||||
|
|
||||||
database_caches: CacheManager
|
database_caches: CacheManager
|
||||||
|
error_channel: discord.abc.Messageable
|
||||||
initial_extensions: tuple[str, ...] = ()
|
initial_extensions: tuple[str, ...] = ()
|
||||||
http_session: ClientSession
|
http_session: ClientSession
|
||||||
|
|
||||||
|
@ -60,6 +65,12 @@ class Didier(commands.Bot):
|
||||||
# Create aiohttp session
|
# Create aiohttp session
|
||||||
self.http_session = ClientSession()
|
self.http_session = ClientSession()
|
||||||
|
|
||||||
|
# Configure channel to send errors to
|
||||||
|
if settings.ERRORS_CHANNEL is not None:
|
||||||
|
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||||
|
else:
|
||||||
|
self.error_channel = self.get_user(self.owner_id)
|
||||||
|
|
||||||
async def _load_initial_extensions(self):
|
async def _load_initial_extensions(self):
|
||||||
"""Load all extensions that should be loaded before the others"""
|
"""Load all extensions that should be loaded before the others"""
|
||||||
for extension in self.initial_extensions:
|
for extension in self.initial_extensions:
|
||||||
|
@ -101,6 +112,13 @@ class Didier(commands.Bot):
|
||||||
"""Add an X to a message"""
|
"""Add an X to a message"""
|
||||||
await message.add_reaction("❌")
|
await message.add_reaction("❌")
|
||||||
|
|
||||||
|
async def log_error(self, message: str, log_to_discord: bool = True):
|
||||||
|
"""Send an error message to the logs, and optionally the configured channel"""
|
||||||
|
logger.error(message)
|
||||||
|
if log_to_discord:
|
||||||
|
# TODO pretty embed
|
||||||
|
await self.error_channel.send(message)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""Event triggered when the bot is ready"""
|
"""Event triggered when the bot is ready"""
|
||||||
print(settings.DISCORD_READY_MESSAGE)
|
print(settings.DISCORD_READY_MESSAGE)
|
||||||
|
|
|
@ -1,8 +1,9 @@
|
||||||
import datetime
|
import datetime
|
||||||
import zoneinfo
|
import zoneinfo
|
||||||
|
|
||||||
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"]
|
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"]
|
||||||
|
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||||
|
|
||||||
|
@ -12,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr
|
||||||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
||||||
|
|
||||||
|
|
||||||
def str_to_date(date_str: str) -> datetime.date:
|
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
|
||||||
"""Turn a string into a DD/MM/YYYY date"""
|
"""Turn a string into a DD/MM/YYYY date"""
|
||||||
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()
|
# Allow passing multiple formats in a list
|
||||||
|
if isinstance(formats, str):
|
||||||
|
formats = [formats]
|
||||||
|
|
||||||
|
for format_str in formats:
|
||||||
|
try:
|
||||||
|
return datetime.datetime.strptime(date_str, format_str).date()
|
||||||
|
except ValueError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
|
def tz_aware_now() -> datetime.datetime:
|
||||||
|
"""Get the current date & time, but timezone-aware"""
|
||||||
|
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE)
|
||||||
|
|
|
@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie
|
||||||
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
|
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
|
||||||
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
|
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
|
||||||
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
|
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
|
||||||
|
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
|
||||||
|
ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None)
|
||||||
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
||||||
|
|
||||||
"""API Keys"""
|
"""API Keys"""
|
||||||
|
|
|
@ -52,7 +52,7 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use
|
||||||
@freeze_time("2022/07/23")
|
@freeze_time("2022/07/23")
|
||||||
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
|
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
|
||||||
"""Test getting all birthdays on a given day"""
|
"""Test getting all birthdays on a given day"""
|
||||||
await crud.add_birthday(database_session, user.user_id, datetime.today())
|
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
|
||||||
|
|
||||||
user_2 = await users.get_or_add(database_session, user.user_id + 1)
|
user_2 = await users.get_or_add(database_session, user.user_id + 1)
|
||||||
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||||
|
|
|
@ -0,0 +1,34 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from didier.utils.types.datetime import str_to_date
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_to_date_single_valid():
|
||||||
|
"""Test parsing a string for a single possibility (default)"""
|
||||||
|
result = str_to_date("23/11/2001")
|
||||||
|
assert result == datetime.date(year=2001, month=11, day=23)
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_to_date_single_invalid():
|
||||||
|
"""Test parsing a string for an invalid string"""
|
||||||
|
# Invalid format
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
str_to_date("23/11/01")
|
||||||
|
|
||||||
|
# Invalid date
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
str_to_date("69/42/0")
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_to_date_multiple_valid():
|
||||||
|
"""Test parsing a string for multiple possibilities"""
|
||||||
|
result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"])
|
||||||
|
assert result == datetime.date(year=2001, month=11, day=23)
|
||||||
|
|
||||||
|
|
||||||
|
def test_str_to_date_multiple_invalid():
|
||||||
|
"""Test parsing a string for multiple possibilities when none are valid"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])
|
Loading…
Reference in New Issue