From 8bc0f1fa7aaa3945c3ccd96cd5d1c6e8e92346f3 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 23 Jul 2022 20:35:42 +0200 Subject: [PATCH 1/4] Add birthday task, change migrations to use date instead of datetime --- .../0d03c226d881_initial_currency_models.py | 2 +- .../versions/1716bfecf684_add_birthdays.py | 2 +- alembic/versions/346b408c362a_create_tasks.py | 35 ++++++++++ .../4ec79dd5b191_initial_migration.py | 65 ++++++++++-------- database/crud/birthdays.py | 5 ++ database/crud/currency.py | 2 +- database/crud/tasks.py | 32 +++++++++ database/enums.py | 13 ++++ database/models.py | 33 ++++++++-- database/utils/datetime.py | 5 ++ didier/cogs/owner.py | 2 +- didier/cogs/tasks.py | 66 ++++++++++++++++--- didier/decorators/__init__.py | 0 didier/decorators/tasks.py | 28 ++++++++ .../{data => utils/discord}/flags/__init__.py | 0 didier/{data => utils/discord}/flags/owner.py | 2 +- didier/{data => utils/discord}/flags/posix.py | 0 didier/utils/types/datetime.py | 6 +- 18 files changed, 249 insertions(+), 49 deletions(-) create mode 100644 alembic/versions/346b408c362a_create_tasks.py create mode 100644 database/crud/tasks.py create mode 100644 database/enums.py create mode 100644 database/utils/datetime.py create mode 100644 didier/decorators/__init__.py create mode 100644 didier/decorators/tasks.py rename didier/{data => utils/discord}/flags/__init__.py (100%) rename didier/{data => utils/discord}/flags/owner.py (80%) rename didier/{data => utils/discord}/flags/posix.py (100%) diff --git a/alembic/versions/0d03c226d881_initial_currency_models.py b/alembic/versions/0d03c226d881_initial_currency_models.py index 7478410..feec2c1 100644 --- a/alembic/versions/0d03c226d881_initial_currency_models.py +++ b/alembic/versions/0d03c226d881_initial_currency_models.py @@ -37,7 +37,7 @@ def upgrade() -> None: "nightly_data", sa.Column("nightly_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_nightly", sa.Date, nullable=True), sa.Column("count", sa.Integer(), server_default="0", nullable=False), sa.ForeignKeyConstraint( ["user_id"], diff --git a/alembic/versions/1716bfecf684_add_birthdays.py b/alembic/versions/1716bfecf684_add_birthdays.py index 9065993..5f01615 100644 --- a/alembic/versions/1716bfecf684_add_birthdays.py +++ b/alembic/versions/1716bfecf684_add_birthdays.py @@ -22,7 +22,7 @@ def upgrade() -> None: "birthdays", sa.Column("birthday_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("birthday", sa.DateTime(), nullable=False), + sa.Column("birthday", sa.Date, nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["users.user_id"], diff --git a/alembic/versions/346b408c362a_create_tasks.py b/alembic/versions/346b408c362a_create_tasks.py new file mode 100644 index 0000000..25f1530 --- /dev/null +++ b/alembic/versions/346b408c362a_create_tasks.py @@ -0,0 +1,35 @@ +"""Create tasks + +Revision ID: 346b408c362a +Revises: 1716bfecf684 +Create Date: 2022-07-23 19:41:07.029482 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "346b408c362a" +down_revision = "1716bfecf684" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "tasks", + sa.Column("task_id", sa.Integer(), nullable=False), + sa.Column("task", sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype"), nullable=False), + sa.Column("previous_run", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("task_id"), + sa.UniqueConstraint("task"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("tasks") + # ### end Alembic commands ### diff --git a/alembic/versions/4ec79dd5b191_initial_migration.py b/alembic/versions/4ec79dd5b191_initial_migration.py index 186e280..2bf8362 100644 --- a/alembic/versions/4ec79dd5b191_initial_migration.py +++ b/alembic/versions/4ec79dd5b191_initial_migration.py @@ -1,16 +1,16 @@ """Initial migration Revision ID: 4ec79dd5b191 -Revises: +Revises: Create Date: 2022-06-19 00:31:58.384360 """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = '4ec79dd5b191' +revision = "4ec79dd5b191" down_revision = None branch_labels = None depends_on = None @@ -18,37 +18,46 @@ depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('ufora_courses', - sa.Column('course_id', sa.Integer(), nullable=False), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('code', sa.Text(), nullable=False), - sa.Column('year', sa.Integer(), nullable=False), - sa.Column('log_announcements', sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint('course_id'), - sa.UniqueConstraint('code'), - sa.UniqueConstraint('name') + op.create_table( + "ufora_courses", + sa.Column("course_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("code", sa.Text(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("log_announcements", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("course_id"), + sa.UniqueConstraint("code"), + sa.UniqueConstraint("name"), ) - op.create_table('ufora_announcements', - sa.Column('announcement_id', sa.Integer(), nullable=False), - sa.Column('course_id', sa.Integer(), nullable=True), - sa.Column('publication_date', sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ), - sa.PrimaryKeyConstraint('announcement_id') + op.create_table( + "ufora_announcements", + sa.Column("announcement_id", sa.Integer(), nullable=False), + sa.Column("course_id", sa.Integer(), nullable=True), + sa.Column("publication_date", sa.Date, nullable=True), + sa.ForeignKeyConstraint( + ["course_id"], + ["ufora_courses.course_id"], + ), + sa.PrimaryKeyConstraint("announcement_id"), ) - op.create_table('ufora_course_aliases', - sa.Column('alias_id', sa.Integer(), nullable=False), - sa.Column('alias', sa.Text(), nullable=False), - sa.Column('course_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ), - sa.PrimaryKeyConstraint('alias_id'), - sa.UniqueConstraint('alias') + op.create_table( + "ufora_course_aliases", + sa.Column("alias_id", sa.Integer(), nullable=False), + sa.Column("alias", sa.Text(), nullable=False), + sa.Column("course_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["course_id"], + ["ufora_courses.course_id"], + ), + sa.PrimaryKeyConstraint("alias_id"), + sa.UniqueConstraint("alias"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('ufora_course_aliases') - op.drop_table('ufora_announcements') - op.drop_table('ufora_courses') + op.drop_table("ufora_course_aliases") + op.drop_table("ufora_announcements") + op.drop_table("ufora_courses") # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 99ea2db..6b52ef3 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -1,3 +1,4 @@ +import datetime from datetime import date from typing import Optional @@ -32,3 +33,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional """Find a user's birthday""" statement = select(Birthday).where(Birthday.user_id == user_id) return (await session.execute(statement)).scalar_one_or_none() + + +async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]: + """Get all birthdays that happen on a given day""" diff --git a/database/crud/currency.py b/database/crud/currency.py index f72b653..dea303e 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -71,7 +71,7 @@ async def claim_nightly(session: AsyncSession, user_id: int): now = datetime.now() - if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): + if nightly_data.last_nightly is not None and nightly_data.last_nightly == now.date(): raise exceptions.DoubleNightly bank = await get_bank(session, user_id) diff --git a/database/crud/tasks.py b/database/crud/tasks.py new file mode 100644 index 0000000..dd1a607 --- /dev/null +++ b/database/crud/tasks.py @@ -0,0 +1,32 @@ +import datetime +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.enums import TaskType +from database.models import Task +from database.utils.datetime import LOCAL_TIMEZONE + +__all__ = ["get_task_by_enum", "set_last_task_execution_time"] + + +async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]: + """Get a task by its enum value, if it exists + + Returns None if the task does not exist + """ + statement = select(Task).where(Task.task == task) + return (await session.execute(statement)).scalar_one_or_none() + + +async def set_last_task_execution_time(session: AsyncSession, task: TaskType): + """Set the last time a specific task was executed""" + _task = await get_task_by_enum(session, task) + + if _task is None: + _task = Task(task=task) + + _task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE) + session.add(_task) + await session.commit() diff --git a/database/enums.py b/database/enums.py new file mode 100644 index 0000000..3f75130 --- /dev/null +++ b/database/enums.py @@ -0,0 +1,13 @@ +import enum + +__all__ = ["TaskType"] + + +# There is a bug in typeshed that causes an incorrect PyCharm warning +# https://github.com/python/typeshed/issues/8286 +# noinspection PyArgumentList +class TaskType(enum.IntEnum): + """Enum for the different types of tasks""" + + BIRTHDAYS = enum.auto() + UFORA_ANNOUNCEMENTS = enum.auto() diff --git a/database/models.py b/database/models.py index 74aa5fc..f0fb6e4 100644 --- a/database/models.py +++ b/database/models.py @@ -1,11 +1,23 @@ from __future__ import annotations -from datetime import datetime +from datetime import date, datetime from typing import Optional -from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + Date, + DateTime, + Enum, + ForeignKey, + Integer, + Text, +) from sqlalchemy.orm import declarative_base, relationship +from database import enums + Base = declarative_base() @@ -17,6 +29,7 @@ __all__ = [ "CustomCommandAlias", "DadJoke", "NightlyData", + "Task", "UforaAnnouncement", "UforaCourse", "UforaCourseAlias", @@ -54,7 +67,7 @@ class Birthday(Base): birthday_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - birthday: datetime = Column(DateTime, nullable=False) + birthday: date = Column(Date, nullable=False) user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") @@ -103,12 +116,22 @@ class NightlyData(Base): nightly_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True) + last_nightly: Optional[date] = Column(Date, nullable=True) count: int = Column(Integer, server_default="0", nullable=False) user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") +class Task(Base): + """A Didier task""" + + __tablename__ = "tasks" + + task_id: int = Column(Integer, primary_key=True) + task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True) + previous_run: datetime = Column(DateTime(timezone=True), nullable=True) + + class UforaCourse(Base): """A course on Ufora""" @@ -147,7 +170,7 @@ class UforaAnnouncement(Base): announcement_id: int = Column(Integer, primary_key=True) course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) - publication_date: datetime = Column(DateTime(timezone=True)) + publication_date: date = Column(Date) course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") diff --git a/database/utils/datetime.py b/database/utils/datetime.py new file mode 100644 index 0000000..8450e84 --- /dev/null +++ b/database/utils/datetime.py @@ -0,0 +1,5 @@ +import zoneinfo + +__all__ = ["LOCAL_TIMEZONE"] + +LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index f43df42..9a93fdf 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -8,7 +8,7 @@ from database.crud import custom_commands from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier -from didier.data.flags.owner import EditCustomFlags +from didier.utils.discord.flags.owner import EditCustomFlags from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index b8488ae..bcd6e8a 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,17 +1,32 @@ +import datetime import traceback from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error import settings +from database import enums from database.crud.ufora_announcements import remove_old_announcements from didier import Didier from didier.data.embeds.ufora.announcements import fetch_ufora_announcements +from didier.decorators.tasks import timed_task +from didier.utils.discord.checks import is_owner +from didier.utils.types.datetime import LOCAL_TIMEZONE + +# datetime.time()-instances for when every task should run +DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) +SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) class Tasks(commands.Cog): - """Task loops that run periodically""" + """Task loops that run periodically + + Preferably these would use the new `time`-kwarg, but these don't run + on startup, which is not ideal. This means we still have to run them every hour + in order to never miss anything if Didier goes offline by coincidence + """ client: Didier + _tasks: dict[str, tasks.Loop] def __init__(self, client: Didier): self.client = client @@ -21,7 +36,41 @@ class Tasks(commands.Cog): self.pull_ufora_announcements.start() self.remove_old_ufora_announcements.start() + # Start all tasks + self.check_birthdays.start() + + self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements} + + @commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True) + @commands.check(is_owner) + async def tasks_group(self, ctx: commands.Context): + """Command group for Task-related commands + + Invoking the group itself shows the time until the next iteration + """ + raise NotImplementedError() + + @tasks_group.command(name="Force", case_insensitive=True) + async def force_task(self, ctx: commands.Context, name: str): + """Command to force-run a task without waiting for the run time""" + name = name.lower() + if name not in self._tasks: + return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False) + + task = self._tasks[name] + await task() + + @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) + @timed_task(enums.TaskType.BIRTHDAYS) + async def check_birthdays(self): + """Check if it's currently anyone's birthday""" + + @check_birthdays.before_loop + async def _before_check_birthdays(self): + await self.client.wait_until_ready() + @tasks.loop(minutes=10) + @timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS) async def pull_ufora_announcements(self): """Task that checks for new Ufora announcements & logs them in a channel""" # In theory this shouldn't happen but just to please Mypy @@ -37,23 +86,20 @@ class Tasks(commands.Cog): @pull_ufora_announcements.before_loop async def _before_ufora_announcements(self): - """Don't try to get announcements if the bot isn't ready yet""" await self.client.wait_until_ready() - @pull_ufora_announcements.error - async def _on_announcements_error(self, error: BaseException): - """Error handler for the Ufora Announcements task""" - print("".join(traceback.format_exception(type(error), error, error.__traceback__))) - @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" async with self.client.db_session as session: await remove_old_announcements(session) - @remove_old_ufora_announcements.before_loop - async def _before_remove_old_ufora_announcements(self): - await self.client.wait_until_ready() + @check_birthdays.error + @pull_ufora_announcements.error + @remove_old_ufora_announcements.error + async def _on_tasks_error(self, error: BaseException): + """Error handler for all tasks""" + print("".join(traceback.format_exception(type(error), error, error.__traceback__))) async def setup(client: Didier): diff --git a/didier/decorators/__init__.py b/didier/decorators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/decorators/tasks.py b/didier/decorators/tasks.py new file mode 100644 index 0000000..36b4f89 --- /dev/null +++ b/didier/decorators/tasks.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from didier.cogs.tasks import Tasks + +from database import enums +from database.crud.tasks import set_last_task_execution_time + +__all__ = ["timed_task"] + + +def timed_task(task: enums.TaskType): + """Decorator to log the last execution time of a task""" + + def _decorator(func): + @functools.wraps(func) + async def _wrapper(tasks_cog: Tasks, *args, **kwargs): + await func(tasks_cog, *args, **kwargs) + + async with tasks_cog.client.db_session as session: + await set_last_task_execution_time(session, task) + + return _wrapper + + return _decorator diff --git a/didier/data/flags/__init__.py b/didier/utils/discord/flags/__init__.py similarity index 100% rename from didier/data/flags/__init__.py rename to didier/utils/discord/flags/__init__.py diff --git a/didier/data/flags/owner.py b/didier/utils/discord/flags/owner.py similarity index 80% rename from didier/data/flags/owner.py rename to didier/utils/discord/flags/owner.py index 5ff75a9..282957c 100644 --- a/didier/data/flags/owner.py +++ b/didier/utils/discord/flags/owner.py @@ -1,6 +1,6 @@ from typing import Optional -from didier.data.flags import PosixFlags +from didier.utils.discord.flags import PosixFlags __all__ = ["EditCustomFlags"] diff --git a/didier/data/flags/posix.py b/didier/utils/discord/flags/posix.py similarity index 100% rename from didier/data/flags/posix.py rename to didier/utils/discord/flags/posix.py diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index 91c3583..a34b21f 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -1,6 +1,10 @@ import datetime +import zoneinfo -__all__ = ["int_to_weekday", "str_to_date"] +__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"] + + +LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this From 66997b7556530f3fa7f6138133075a780138a2d4 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 23 Jul 2022 20:59:02 +0200 Subject: [PATCH 2/4] Fix broken migration --- alembic/versions/346b408c362a_create_tasks.py | 1 + database/crud/currency.py | 6 +++--- requirements-dev.txt | 1 + tests/conftest.py | 2 +- tests/test_database/test_crud/test_birthdays.py | 6 +++--- tests/test_database/test_crud/test_currency.py | 8 ++++++++ 6 files changed, 17 insertions(+), 7 deletions(-) diff --git a/alembic/versions/346b408c362a_create_tasks.py b/alembic/versions/346b408c362a_create_tasks.py index 25f1530..f6efeeb 100644 --- a/alembic/versions/346b408c362a_create_tasks.py +++ b/alembic/versions/346b408c362a_create_tasks.py @@ -32,4 +32,5 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("tasks") + sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype").drop(op.get_bind()) # ### end Alembic commands ### diff --git a/database/crud/currency.py b/database/crud/currency.py index dea303e..1bb2d11 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -1,4 +1,4 @@ -from datetime import datetime +from datetime import date from typing import Union from sqlalchemy.ext.asyncio import AsyncSession @@ -69,9 +69,9 @@ async def claim_nightly(session: AsyncSession, user_id: int): """Claim daily Dinks""" nightly_data = await get_nightly_data(session, user_id) - now = datetime.now() + now = date.today() - if nightly_data.last_nightly is not None and nightly_data.last_nightly == now.date(): + if nightly_data.last_nightly is not None and nightly_data.last_nightly == now: raise exceptions.DoubleNightly bank = await get_bank(session, user_id) diff --git a/requirements-dev.txt b/requirements-dev.txt index d82fde3..64b8467 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ black==22.3.0 coverage[toml]==6.4.1 +freezegun==1.2.1 mypy==0.961 pre-commit==2.20.0 pytest==7.1.2 diff --git a/tests/conftest.py b/tests/conftest.py index 8530de5..2e425ef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def event_loop() -> Generator: @pytest.fixture(scope="session") async def tables(): - """Initialize a database before the tests, and then tear it down again + """Fixture to initialize a database before the tests, and then tear it down again Checks that the migrations were successful by asserting that we are currently on the latest migration diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 96b924c..ba90a3a 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -14,7 +14,7 @@ async def test_add_birthday_not_present(database_session: AsyncSession, user: Us await crud.add_birthday(database_session, user.user_id, bd_date) await database_session.refresh(user) assert user.birthday is not None - assert user.birthday.birthday.date() == bd_date + assert user.birthday.birthday == bd_date async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): @@ -27,7 +27,7 @@ async def test_add_birthday_overwrite(database_session: AsyncSession, user: User new_bd_date = bd_date + timedelta(weeks=1) await crud.add_birthday(database_session, user.user_id, new_bd_date) await database_session.refresh(user) - assert user.birthday.birthday.date() == new_bd_date + assert user.birthday.birthday == new_bd_date async def test_get_birthday_exists(database_session: AsyncSession, user: User): @@ -38,7 +38,7 @@ async def test_get_birthday_exists(database_session: AsyncSession, user: User): bd = await crud.get_birthday_for_user(database_session, user.user_id) assert bd is not None - assert bd.birthday.date() == bd_date + assert bd.birthday == bd_date async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index 1f0a163..a2eeec8 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -1,4 +1,7 @@ +import datetime + import pytest +from freezegun import freeze_time from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud @@ -14,13 +17,18 @@ async def test_add_dinks(database_session: AsyncSession, bank: Bank): assert bank.dinks == 10 +@freeze_time("2022/07/23") async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): """Test claiming nightlies when it hasn't been done yet""" await crud.claim_nightly(database_session, bank.user_id) await database_session.refresh(bank) assert bank.dinks == crud.NIGHTLY_AMOUNT + nightly_data = await crud.get_nightly_data(database_session, bank.user_id) + assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23) + +@freeze_time("2022/07/23") async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): """Test claiming nightlies twice in a day""" await crud.claim_nightly(database_session, bank.user_id) From 393cc9c891230a2ccddc37c23b794d346343c571 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 23 Jul 2022 22:34:03 +0200 Subject: [PATCH 3/4] Add support for lazy loading of user fields --- database/crud/birthdays.py | 11 +++++--- database/crud/users.py | 9 +++++-- .../test_database/test_crud/test_birthdays.py | 27 +++++++++++++++++++ 3 files changed, 41 insertions(+), 6 deletions(-) diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 6b52ef3..8300d49 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -4,11 +4,12 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from database.crud import users -from database.models import Birthday +from database.models import Birthday, User -__all__ = ["add_birthday", "get_birthday_for_user"] +__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] async def add_birthday(session: AsyncSession, user_id: int, birthday: date): @@ -16,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date): If already present, overwrites the existing one """ - user = await users.get_or_add(session, user_id) + user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) if user.birthday is not None: bd = user.birthday @@ -35,5 +36,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional return (await session.execute(statement)).scalar_one_or_none() -async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]: +async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]: """Get all birthdays that happen on a given day""" + statement = select(Birthday).where(Birthday.birthday == day) + return list((await session.execute(statement)).scalars()) diff --git a/database/crud/users.py b/database/crud/users.py index 57c5029..ba3011d 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -10,12 +10,16 @@ __all__ = [ ] -async def get_or_add(session: AsyncSession, user_id: int) -> User: +async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: """Get a user's profile If it doesn't exist yet, create it (along with all linked datastructures) """ - statement = select(User).where(User.user_id == user_id) + if options is None: + options = [] + + statement = select(User).where(User.user_id == user_id).options(*options) + user: Optional[User] = (await session.execute(statement)).scalar_one_or_none() # User exists @@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User: session.add(user) await session.commit() + await session.refresh(user) return user diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index ba90a3a..5d40914 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -1,8 +1,10 @@ from datetime import datetime, timedelta +from freezegun import freeze_time from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud +from database.crud import users from database.models import User @@ -45,3 +47,28 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use """Test getting a user's birthday when it doesn't exist""" bd = await crud.get_birthday_for_user(database_session, user.user_id) assert bd is None + + +@freeze_time("2022/07/23") +async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): + """Test getting all birthdays on a given day""" + await crud.add_birthday(database_session, user.user_id, datetime.today()) + + user_2 = await users.get_or_add(database_session, user.user_id + 1) + await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) + birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + assert len(birthdays) == 1 + assert birthdays[0].user_id == user.user_id + + +@freeze_time("2022/07/23") +async def test_get_birthdays_none_present(database_session: AsyncSession): + """Test getting all birthdays when there are none""" + birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + assert len(birthdays) == 0 + + # Add a random birthday that is not today + await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1)) + + birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + assert len(birthdays) == 0 From da0e60ac4fede16a2cc97f66662aaecbc6547af8 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sat, 23 Jul 2022 23:21:32 +0200 Subject: [PATCH 4/4] Send daily birthday notifications, add more settings & configs, fix small bugs in database --- database/crud/birthdays.py | 7 ++-- didier/cogs/discord.py | 8 ++++- didier/cogs/tasks.py | 24 ++++++++++--- didier/didier.py | 18 ++++++++++ didier/utils/types/datetime.py | 22 ++++++++++-- settings.py | 2 ++ .../test_database/test_crud/test_birthdays.py | 2 +- .../test_utils/test_types/test_datetime.py | 34 +++++++++++++++++++ 8 files changed, 105 insertions(+), 12 deletions(-) create mode 100644 tests/test_didier/test_utils/test_types/test_datetime.py diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 8300d49..054d4c5 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -2,7 +2,7 @@ import datetime from datetime import date from typing import Optional -from sqlalchemy import select +from sqlalchemy import extract, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -38,5 +38,8 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]: """Get all birthdays that happen on a given day""" - statement = select(Birthday).where(Birthday.birthday == day) + days = extract("day", Birthday.birthday) + months = extract("month", Birthday.birthday) + + statement = select(Birthday).where((days == day.day) & (months == day.month)) return list((await session.execute(statement)).scalars()) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index 1152922..a73ead3 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -35,7 +35,13 @@ class Discord(commands.Cog): async def birthday_set(self, ctx: commands.Context, date_str: str): """Command to set your birthday""" try: - date = str_to_date(date_str) + default_year = 2001 + date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"]) + + # If no year was passed, make it 2001 by default + if date_str.count("/") == 1: + date.replace(year=default_year) + except ValueError: return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index bcd6e8a..88bcb6e 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -5,12 +5,13 @@ from discord.ext import commands, tasks # type: ignore # Strange & incorrect My import settings from database import enums +from database.crud.birthdays import get_birthdays_on_day from database.crud.ufora_announcements import remove_old_announcements from didier import Didier from didier.data.embeds.ufora.announcements import fetch_ufora_announcements from didier.decorators.tasks import timed_task from didier.utils.discord.checks import is_owner -from didier.utils.types.datetime import LOCAL_TIMEZONE +from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now # datetime.time()-instances for when every task should run DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) @@ -31,17 +32,18 @@ class Tasks(commands.Cog): def __init__(self, client: Didier): self.client = client + # Only check birthdays if there's a channel to send it to + if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None: + self.check_birthdays.start() + # Only pull announcements if a token was provided if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None: self.pull_ufora_announcements.start() self.remove_old_ufora_announcements.start() - # Start all tasks - self.check_birthdays.start() - self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements} - @commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True) + @commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True) @commands.check(is_owner) async def tasks_group(self, ctx: commands.Context): """Command group for Task-related commands @@ -64,6 +66,18 @@ class Tasks(commands.Cog): @timed_task(enums.TaskType.BIRTHDAYS) async def check_birthdays(self): """Check if it's currently anyone's birthday""" + now = tz_aware_now().date() + async with self.client.db_session as session: + birthdays = await get_birthdays_on_day(session, now) + + channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) + if channel is None: + return await self.client.log_error("Unable to find channel for birthday announcements") + + for birthday in birthdays: + user = self.client.get_user(birthday.user_id) + # TODO more messages? + await channel.send(f"Gelukkig verjaardag {user.mention}!") @check_birthdays.before_loop async def _before_check_birthdays(self): diff --git a/didier/didier.py b/didier/didier.py index 9fef227..2f07372 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -1,3 +1,4 @@ +import logging import os import discord @@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix __all__ = ["Didier"] +logger = logging.getLogger(__name__) + + class Didier(commands.Bot): """DIDIER <3""" database_caches: CacheManager + error_channel: discord.abc.Messageable initial_extensions: tuple[str, ...] = () http_session: ClientSession @@ -60,6 +65,12 @@ class Didier(commands.Bot): # Create aiohttp session self.http_session = ClientSession() + # Configure channel to send errors to + if settings.ERRORS_CHANNEL is not None: + self.error_channel = self.get_channel(settings.ERRORS_CHANNEL) + else: + self.error_channel = self.get_user(self.owner_id) + async def _load_initial_extensions(self): """Load all extensions that should be loaded before the others""" for extension in self.initial_extensions: @@ -101,6 +112,13 @@ class Didier(commands.Bot): """Add an X to a message""" await message.add_reaction("❌") + async def log_error(self, message: str, log_to_discord: bool = True): + """Send an error message to the logs, and optionally the configured channel""" + logger.error(message) + if log_to_discord: + # TODO pretty embed + await self.error_channel.send(message) + async def on_ready(self): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index a34b21f..42f58a9 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -1,8 +1,9 @@ import datetime import zoneinfo -__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"] +__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"] +from typing import Union LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") @@ -12,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] -def str_to_date(date_str: str) -> datetime.date: +def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date: """Turn a string into a DD/MM/YYYY date""" - return datetime.datetime.strptime(date_str, "%d/%m/%Y").date() + # Allow passing multiple formats in a list + if isinstance(formats, str): + formats = [formats] + + for format_str in formats: + try: + return datetime.datetime.strptime(date_str, format_str).date() + except ValueError: + continue + + raise ValueError + + +def tz_aware_now() -> datetime.datetime: + """Get the current date & time, but timezone-aware""" + return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE) diff --git a/settings.py b/settings.py index bb2296f..d03b48b 100644 --- a/settings.py +++ b/settings.py @@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int) DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>") DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?") +BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None) +ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None) UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) """API Keys""" diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 5d40914..544e5b0 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -52,7 +52,7 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use @freeze_time("2022/07/23") async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): """Test getting all birthdays on a given day""" - await crud.add_birthday(database_session, user.user_id, datetime.today()) + await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001)) user_2 = await users.get_or_add(database_session, user.user_id + 1) await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) diff --git a/tests/test_didier/test_utils/test_types/test_datetime.py b/tests/test_didier/test_utils/test_types/test_datetime.py new file mode 100644 index 0000000..ecb1973 --- /dev/null +++ b/tests/test_didier/test_utils/test_types/test_datetime.py @@ -0,0 +1,34 @@ +import datetime + +import pytest + +from didier.utils.types.datetime import str_to_date + + +def test_str_to_date_single_valid(): + """Test parsing a string for a single possibility (default)""" + result = str_to_date("23/11/2001") + assert result == datetime.date(year=2001, month=11, day=23) + + +def test_str_to_date_single_invalid(): + """Test parsing a string for an invalid string""" + # Invalid format + with pytest.raises(ValueError): + str_to_date("23/11/01") + + # Invalid date + with pytest.raises(ValueError): + str_to_date("69/42/0") + + +def test_str_to_date_multiple_valid(): + """Test parsing a string for multiple possibilities""" + result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"]) + assert result == datetime.date(year=2001, month=11, day=23) + + +def test_str_to_date_multiple_invalid(): + """Test parsing a string for multiple possibilities when none are valid""" + with pytest.raises(ValueError): + str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])