diff --git a/alembic/versions/0d03c226d881_initial_currency_models.py b/alembic/versions/0d03c226d881_initial_currency_models.py index 7478410..feec2c1 100644 --- a/alembic/versions/0d03c226d881_initial_currency_models.py +++ b/alembic/versions/0d03c226d881_initial_currency_models.py @@ -37,7 +37,7 @@ def upgrade() -> None: "nightly_data", sa.Column("nightly_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True), + sa.Column("last_nightly", sa.Date, nullable=True), sa.Column("count", sa.Integer(), server_default="0", nullable=False), sa.ForeignKeyConstraint( ["user_id"], diff --git a/alembic/versions/1716bfecf684_add_birthdays.py b/alembic/versions/1716bfecf684_add_birthdays.py index 9065993..5f01615 100644 --- a/alembic/versions/1716bfecf684_add_birthdays.py +++ b/alembic/versions/1716bfecf684_add_birthdays.py @@ -22,7 +22,7 @@ def upgrade() -> None: "birthdays", sa.Column("birthday_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("birthday", sa.DateTime(), nullable=False), + sa.Column("birthday", sa.Date, nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["users.user_id"], diff --git a/alembic/versions/346b408c362a_create_tasks.py b/alembic/versions/346b408c362a_create_tasks.py new file mode 100644 index 0000000..25f1530 --- /dev/null +++ b/alembic/versions/346b408c362a_create_tasks.py @@ -0,0 +1,35 @@ +"""Create tasks + +Revision ID: 346b408c362a +Revises: 1716bfecf684 +Create Date: 2022-07-23 19:41:07.029482 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "346b408c362a" +down_revision = "1716bfecf684" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "tasks", + sa.Column("task_id", sa.Integer(), nullable=False), + sa.Column("task", sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype"), nullable=False), + sa.Column("previous_run", sa.DateTime(), nullable=True), + sa.PrimaryKeyConstraint("task_id"), + sa.UniqueConstraint("task"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("tasks") + # ### end Alembic commands ### diff --git a/alembic/versions/4ec79dd5b191_initial_migration.py b/alembic/versions/4ec79dd5b191_initial_migration.py index 186e280..2bf8362 100644 --- a/alembic/versions/4ec79dd5b191_initial_migration.py +++ b/alembic/versions/4ec79dd5b191_initial_migration.py @@ -1,16 +1,16 @@ """Initial migration Revision ID: 4ec79dd5b191 -Revises: +Revises: Create Date: 2022-06-19 00:31:58.384360 """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = '4ec79dd5b191' +revision = "4ec79dd5b191" down_revision = None branch_labels = None depends_on = None @@ -18,37 +18,46 @@ depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('ufora_courses', - sa.Column('course_id', sa.Integer(), nullable=False), - sa.Column('name', sa.Text(), nullable=False), - sa.Column('code', sa.Text(), nullable=False), - sa.Column('year', sa.Integer(), nullable=False), - sa.Column('log_announcements', sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint('course_id'), - sa.UniqueConstraint('code'), - sa.UniqueConstraint('name') + op.create_table( + "ufora_courses", + sa.Column("course_id", sa.Integer(), nullable=False), + sa.Column("name", sa.Text(), nullable=False), + sa.Column("code", sa.Text(), nullable=False), + sa.Column("year", sa.Integer(), nullable=False), + sa.Column("log_announcements", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("course_id"), + sa.UniqueConstraint("code"), + sa.UniqueConstraint("name"), ) - op.create_table('ufora_announcements', - sa.Column('announcement_id', sa.Integer(), nullable=False), - sa.Column('course_id', sa.Integer(), nullable=True), - sa.Column('publication_date', sa.DateTime(timezone=True), nullable=True), - sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ), - sa.PrimaryKeyConstraint('announcement_id') + op.create_table( + "ufora_announcements", + sa.Column("announcement_id", sa.Integer(), nullable=False), + sa.Column("course_id", sa.Integer(), nullable=True), + sa.Column("publication_date", sa.Date, nullable=True), + sa.ForeignKeyConstraint( + ["course_id"], + ["ufora_courses.course_id"], + ), + sa.PrimaryKeyConstraint("announcement_id"), ) - op.create_table('ufora_course_aliases', - sa.Column('alias_id', sa.Integer(), nullable=False), - sa.Column('alias', sa.Text(), nullable=False), - sa.Column('course_id', sa.Integer(), nullable=True), - sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ), - sa.PrimaryKeyConstraint('alias_id'), - sa.UniqueConstraint('alias') + op.create_table( + "ufora_course_aliases", + sa.Column("alias_id", sa.Integer(), nullable=False), + sa.Column("alias", sa.Text(), nullable=False), + sa.Column("course_id", sa.Integer(), nullable=True), + sa.ForeignKeyConstraint( + ["course_id"], + ["ufora_courses.course_id"], + ), + sa.PrimaryKeyConstraint("alias_id"), + sa.UniqueConstraint("alias"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('ufora_course_aliases') - op.drop_table('ufora_announcements') - op.drop_table('ufora_courses') + op.drop_table("ufora_course_aliases") + op.drop_table("ufora_announcements") + op.drop_table("ufora_courses") # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 99ea2db..6b52ef3 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -1,3 +1,4 @@ +import datetime from datetime import date from typing import Optional @@ -32,3 +33,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional """Find a user's birthday""" statement = select(Birthday).where(Birthday.user_id == user_id) return (await session.execute(statement)).scalar_one_or_none() + + +async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]: + """Get all birthdays that happen on a given day""" diff --git a/database/crud/currency.py b/database/crud/currency.py index f72b653..dea303e 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -71,7 +71,7 @@ async def claim_nightly(session: AsyncSession, user_id: int): now = datetime.now() - if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): + if nightly_data.last_nightly is not None and nightly_data.last_nightly == now.date(): raise exceptions.DoubleNightly bank = await get_bank(session, user_id) diff --git a/database/crud/tasks.py b/database/crud/tasks.py new file mode 100644 index 0000000..dd1a607 --- /dev/null +++ b/database/crud/tasks.py @@ -0,0 +1,32 @@ +import datetime +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.enums import TaskType +from database.models import Task +from database.utils.datetime import LOCAL_TIMEZONE + +__all__ = ["get_task_by_enum", "set_last_task_execution_time"] + + +async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]: + """Get a task by its enum value, if it exists + + Returns None if the task does not exist + """ + statement = select(Task).where(Task.task == task) + return (await session.execute(statement)).scalar_one_or_none() + + +async def set_last_task_execution_time(session: AsyncSession, task: TaskType): + """Set the last time a specific task was executed""" + _task = await get_task_by_enum(session, task) + + if _task is None: + _task = Task(task=task) + + _task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE) + session.add(_task) + await session.commit() diff --git a/database/enums.py b/database/enums.py new file mode 100644 index 0000000..3f75130 --- /dev/null +++ b/database/enums.py @@ -0,0 +1,13 @@ +import enum + +__all__ = ["TaskType"] + + +# There is a bug in typeshed that causes an incorrect PyCharm warning +# https://github.com/python/typeshed/issues/8286 +# noinspection PyArgumentList +class TaskType(enum.IntEnum): + """Enum for the different types of tasks""" + + BIRTHDAYS = enum.auto() + UFORA_ANNOUNCEMENTS = enum.auto() diff --git a/database/models.py b/database/models.py index 74aa5fc..f0fb6e4 100644 --- a/database/models.py +++ b/database/models.py @@ -1,11 +1,23 @@ from __future__ import annotations -from datetime import datetime +from datetime import date, datetime from typing import Optional -from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + Date, + DateTime, + Enum, + ForeignKey, + Integer, + Text, +) from sqlalchemy.orm import declarative_base, relationship +from database import enums + Base = declarative_base() @@ -17,6 +29,7 @@ __all__ = [ "CustomCommandAlias", "DadJoke", "NightlyData", + "Task", "UforaAnnouncement", "UforaCourse", "UforaCourseAlias", @@ -54,7 +67,7 @@ class Birthday(Base): birthday_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - birthday: datetime = Column(DateTime, nullable=False) + birthday: date = Column(Date, nullable=False) user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") @@ -103,12 +116,22 @@ class NightlyData(Base): nightly_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True) + last_nightly: Optional[date] = Column(Date, nullable=True) count: int = Column(Integer, server_default="0", nullable=False) user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") +class Task(Base): + """A Didier task""" + + __tablename__ = "tasks" + + task_id: int = Column(Integer, primary_key=True) + task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True) + previous_run: datetime = Column(DateTime(timezone=True), nullable=True) + + class UforaCourse(Base): """A course on Ufora""" @@ -147,7 +170,7 @@ class UforaAnnouncement(Base): announcement_id: int = Column(Integer, primary_key=True) course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) - publication_date: datetime = Column(DateTime(timezone=True)) + publication_date: date = Column(Date) course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") diff --git a/database/utils/datetime.py b/database/utils/datetime.py new file mode 100644 index 0000000..8450e84 --- /dev/null +++ b/database/utils/datetime.py @@ -0,0 +1,5 @@ +import zoneinfo + +__all__ = ["LOCAL_TIMEZONE"] + +LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index f43df42..9a93fdf 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -8,7 +8,7 @@ from database.crud import custom_commands from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier -from didier.data.flags.owner import EditCustomFlags +from didier.utils.discord.flags.owner import EditCustomFlags from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index b8488ae..bcd6e8a 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,17 +1,32 @@ +import datetime import traceback from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error import settings +from database import enums from database.crud.ufora_announcements import remove_old_announcements from didier import Didier from didier.data.embeds.ufora.announcements import fetch_ufora_announcements +from didier.decorators.tasks import timed_task +from didier.utils.discord.checks import is_owner +from didier.utils.types.datetime import LOCAL_TIMEZONE + +# datetime.time()-instances for when every task should run +DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) +SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) class Tasks(commands.Cog): - """Task loops that run periodically""" + """Task loops that run periodically + + Preferably these would use the new `time`-kwarg, but these don't run + on startup, which is not ideal. This means we still have to run them every hour + in order to never miss anything if Didier goes offline by coincidence + """ client: Didier + _tasks: dict[str, tasks.Loop] def __init__(self, client: Didier): self.client = client @@ -21,7 +36,41 @@ class Tasks(commands.Cog): self.pull_ufora_announcements.start() self.remove_old_ufora_announcements.start() + # Start all tasks + self.check_birthdays.start() + + self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements} + + @commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True) + @commands.check(is_owner) + async def tasks_group(self, ctx: commands.Context): + """Command group for Task-related commands + + Invoking the group itself shows the time until the next iteration + """ + raise NotImplementedError() + + @tasks_group.command(name="Force", case_insensitive=True) + async def force_task(self, ctx: commands.Context, name: str): + """Command to force-run a task without waiting for the run time""" + name = name.lower() + if name not in self._tasks: + return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False) + + task = self._tasks[name] + await task() + + @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) + @timed_task(enums.TaskType.BIRTHDAYS) + async def check_birthdays(self): + """Check if it's currently anyone's birthday""" + + @check_birthdays.before_loop + async def _before_check_birthdays(self): + await self.client.wait_until_ready() + @tasks.loop(minutes=10) + @timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS) async def pull_ufora_announcements(self): """Task that checks for new Ufora announcements & logs them in a channel""" # In theory this shouldn't happen but just to please Mypy @@ -37,23 +86,20 @@ class Tasks(commands.Cog): @pull_ufora_announcements.before_loop async def _before_ufora_announcements(self): - """Don't try to get announcements if the bot isn't ready yet""" await self.client.wait_until_ready() - @pull_ufora_announcements.error - async def _on_announcements_error(self, error: BaseException): - """Error handler for the Ufora Announcements task""" - print("".join(traceback.format_exception(type(error), error, error.__traceback__))) - @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" async with self.client.db_session as session: await remove_old_announcements(session) - @remove_old_ufora_announcements.before_loop - async def _before_remove_old_ufora_announcements(self): - await self.client.wait_until_ready() + @check_birthdays.error + @pull_ufora_announcements.error + @remove_old_ufora_announcements.error + async def _on_tasks_error(self, error: BaseException): + """Error handler for all tasks""" + print("".join(traceback.format_exception(type(error), error, error.__traceback__))) async def setup(client: Didier): diff --git a/didier/decorators/__init__.py b/didier/decorators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/decorators/tasks.py b/didier/decorators/tasks.py new file mode 100644 index 0000000..36b4f89 --- /dev/null +++ b/didier/decorators/tasks.py @@ -0,0 +1,28 @@ +from __future__ import annotations + +import functools +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from didier.cogs.tasks import Tasks + +from database import enums +from database.crud.tasks import set_last_task_execution_time + +__all__ = ["timed_task"] + + +def timed_task(task: enums.TaskType): + """Decorator to log the last execution time of a task""" + + def _decorator(func): + @functools.wraps(func) + async def _wrapper(tasks_cog: Tasks, *args, **kwargs): + await func(tasks_cog, *args, **kwargs) + + async with tasks_cog.client.db_session as session: + await set_last_task_execution_time(session, task) + + return _wrapper + + return _decorator diff --git a/didier/data/flags/__init__.py b/didier/utils/discord/flags/__init__.py similarity index 100% rename from didier/data/flags/__init__.py rename to didier/utils/discord/flags/__init__.py diff --git a/didier/data/flags/owner.py b/didier/utils/discord/flags/owner.py similarity index 80% rename from didier/data/flags/owner.py rename to didier/utils/discord/flags/owner.py index 5ff75a9..282957c 100644 --- a/didier/data/flags/owner.py +++ b/didier/utils/discord/flags/owner.py @@ -1,6 +1,6 @@ from typing import Optional -from didier.data.flags import PosixFlags +from didier.utils.discord.flags import PosixFlags __all__ = ["EditCustomFlags"] diff --git a/didier/data/flags/posix.py b/didier/utils/discord/flags/posix.py similarity index 100% rename from didier/data/flags/posix.py rename to didier/utils/discord/flags/posix.py diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index 91c3583..a34b21f 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -1,6 +1,10 @@ import datetime +import zoneinfo -__all__ = ["int_to_weekday", "str_to_date"] +__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"] + + +LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this