Add birthday task, change migrations to use date instead of datetime

pull/125/head
stijndcl 2022-07-23 20:35:42 +02:00
parent adcf94c66e
commit 8bc0f1fa7a
18 changed files with 249 additions and 49 deletions

View File

@ -37,7 +37,7 @@ def upgrade() -> None:
"nightly_data",
sa.Column("nightly_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_nightly", sa.Date, nullable=True),
sa.Column("count", sa.Integer(), server_default="0", nullable=False),
sa.ForeignKeyConstraint(
["user_id"],

View File

@ -22,7 +22,7 @@ def upgrade() -> None:
"birthdays",
sa.Column("birthday_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("birthday", sa.DateTime(), nullable=False),
sa.Column("birthday", sa.Date, nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],

View File

@ -0,0 +1,35 @@
"""Create tasks
Revision ID: 346b408c362a
Revises: 1716bfecf684
Create Date: 2022-07-23 19:41:07.029482
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "346b408c362a"
down_revision = "1716bfecf684"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"tasks",
sa.Column("task_id", sa.Integer(), nullable=False),
sa.Column("task", sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype"), nullable=False),
sa.Column("previous_run", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("task_id"),
sa.UniqueConstraint("task"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("tasks")
# ### end Alembic commands ###

View File

@ -1,16 +1,16 @@
"""Initial migration
Revision ID: 4ec79dd5b191
Revises:
Revises:
Create Date: 2022-06-19 00:31:58.384360
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '4ec79dd5b191'
revision = "4ec79dd5b191"
down_revision = None
branch_labels = None
depends_on = None
@ -18,37 +18,46 @@ depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('ufora_courses',
sa.Column('course_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('code', sa.Text(), nullable=False),
sa.Column('year', sa.Integer(), nullable=False),
sa.Column('log_announcements', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('course_id'),
sa.UniqueConstraint('code'),
sa.UniqueConstraint('name')
op.create_table(
"ufora_courses",
sa.Column("course_id", sa.Integer(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("code", sa.Text(), nullable=False),
sa.Column("year", sa.Integer(), nullable=False),
sa.Column("log_announcements", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("course_id"),
sa.UniqueConstraint("code"),
sa.UniqueConstraint("name"),
)
op.create_table('ufora_announcements',
sa.Column('announcement_id', sa.Integer(), nullable=False),
sa.Column('course_id', sa.Integer(), nullable=True),
sa.Column('publication_date', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ),
sa.PrimaryKeyConstraint('announcement_id')
op.create_table(
"ufora_announcements",
sa.Column("announcement_id", sa.Integer(), nullable=False),
sa.Column("course_id", sa.Integer(), nullable=True),
sa.Column("publication_date", sa.Date, nullable=True),
sa.ForeignKeyConstraint(
["course_id"],
["ufora_courses.course_id"],
),
sa.PrimaryKeyConstraint("announcement_id"),
)
op.create_table('ufora_course_aliases',
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias', sa.Text(), nullable=False),
sa.Column('course_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ),
sa.PrimaryKeyConstraint('alias_id'),
sa.UniqueConstraint('alias')
op.create_table(
"ufora_course_aliases",
sa.Column("alias_id", sa.Integer(), nullable=False),
sa.Column("alias", sa.Text(), nullable=False),
sa.Column("course_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["course_id"],
["ufora_courses.course_id"],
),
sa.PrimaryKeyConstraint("alias_id"),
sa.UniqueConstraint("alias"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('ufora_course_aliases')
op.drop_table('ufora_announcements')
op.drop_table('ufora_courses')
op.drop_table("ufora_course_aliases")
op.drop_table("ufora_announcements")
op.drop_table("ufora_courses")
# ### end Alembic commands ###

View File

@ -1,3 +1,4 @@
import datetime
from datetime import date
from typing import Optional
@ -32,3 +33,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
"""Find a user's birthday"""
statement = select(Birthday).where(Birthday.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]:
"""Get all birthdays that happen on a given day"""

View File

@ -71,7 +71,7 @@ async def claim_nightly(session: AsyncSession, user_id: int):
now = datetime.now()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
if nightly_data.last_nightly is not None and nightly_data.last_nightly == now.date():
raise exceptions.DoubleNightly
bank = await get_bank(session, user_id)

View File

@ -0,0 +1,32 @@
import datetime
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.enums import TaskType
from database.models import Task
from database.utils.datetime import LOCAL_TIMEZONE
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]:
"""Get a task by its enum value, if it exists
Returns None if the task does not exist
"""
statement = select(Task).where(Task.task == task)
return (await session.execute(statement)).scalar_one_or_none()
async def set_last_task_execution_time(session: AsyncSession, task: TaskType):
"""Set the last time a specific task was executed"""
_task = await get_task_by_enum(session, task)
if _task is None:
_task = Task(task=task)
_task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE)
session.add(_task)
await session.commit()

13
database/enums.py 100644
View File

@ -0,0 +1,13 @@
import enum
__all__ = ["TaskType"]
# There is a bug in typeshed that causes an incorrect PyCharm warning
# https://github.com/python/typeshed/issues/8286
# noinspection PyArgumentList
class TaskType(enum.IntEnum):
"""Enum for the different types of tasks"""
BIRTHDAYS = enum.auto()
UFORA_ANNOUNCEMENTS = enum.auto()

View File

@ -1,11 +1,23 @@
from __future__ import annotations
from datetime import datetime
from datetime import date, datetime
from typing import Optional
from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Date,
DateTime,
Enum,
ForeignKey,
Integer,
Text,
)
from sqlalchemy.orm import declarative_base, relationship
from database import enums
Base = declarative_base()
@ -17,6 +29,7 @@ __all__ = [
"CustomCommandAlias",
"DadJoke",
"NightlyData",
"Task",
"UforaAnnouncement",
"UforaCourse",
"UforaCourseAlias",
@ -54,7 +67,7 @@ class Birthday(Base):
birthday_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
birthday: datetime = Column(DateTime, nullable=False)
birthday: date = Column(Date, nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
@ -103,12 +116,22 @@ class NightlyData(Base):
nightly_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True)
last_nightly: Optional[date] = Column(Date, nullable=True)
count: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class Task(Base):
"""A Didier task"""
__tablename__ = "tasks"
task_id: int = Column(Integer, primary_key=True)
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True)
previous_run: datetime = Column(DateTime(timezone=True), nullable=True)
class UforaCourse(Base):
"""A course on Ufora"""
@ -147,7 +170,7 @@ class UforaAnnouncement(Base):
announcement_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
publication_date: datetime = Column(DateTime(timezone=True))
publication_date: date = Column(Date)
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")

View File

@ -0,0 +1,5 @@
import zoneinfo
__all__ = ["LOCAL_TIMEZONE"]
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")

View File

@ -8,7 +8,7 @@ from database.crud import custom_commands
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from didier import Didier
from didier.data.flags.owner import EditCustomFlags
from didier.utils.discord.flags.owner import EditCustomFlags
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand

View File

@ -1,17 +1,32 @@
import datetime
import traceback
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
import settings
from database import enums
from database.crud.ufora_announcements import remove_old_announcements
from didier import Didier
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
from didier.decorators.tasks import timed_task
from didier.utils.discord.checks import is_owner
from didier.utils.types.datetime import LOCAL_TIMEZONE
# datetime.time()-instances for when every task should run
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
class Tasks(commands.Cog):
"""Task loops that run periodically"""
"""Task loops that run periodically
Preferably these would use the new `time`-kwarg, but these don't run
on startup, which is not ideal. This means we still have to run them every hour
in order to never miss anything if Didier goes offline by coincidence
"""
client: Didier
_tasks: dict[str, tasks.Loop]
def __init__(self, client: Didier):
self.client = client
@ -21,7 +36,41 @@ class Tasks(commands.Cog):
self.pull_ufora_announcements.start()
self.remove_old_ufora_announcements.start()
# Start all tasks
self.check_birthdays.start()
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
@commands.group(name="Tasks", case_insensitive=True, invoke_without_command=True)
@commands.check(is_owner)
async def tasks_group(self, ctx: commands.Context):
"""Command group for Task-related commands
Invoking the group itself shows the time until the next iteration
"""
raise NotImplementedError()
@tasks_group.command(name="Force", case_insensitive=True)
async def force_task(self, ctx: commands.Context, name: str):
"""Command to force-run a task without waiting for the run time"""
name = name.lower()
if name not in self._tasks:
return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False)
task = self._tasks[name]
await task()
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
@timed_task(enums.TaskType.BIRTHDAYS)
async def check_birthdays(self):
"""Check if it's currently anyone's birthday"""
@check_birthdays.before_loop
async def _before_check_birthdays(self):
await self.client.wait_until_ready()
@tasks.loop(minutes=10)
@timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS)
async def pull_ufora_announcements(self):
"""Task that checks for new Ufora announcements & logs them in a channel"""
# In theory this shouldn't happen but just to please Mypy
@ -37,23 +86,20 @@ class Tasks(commands.Cog):
@pull_ufora_announcements.before_loop
async def _before_ufora_announcements(self):
"""Don't try to get announcements if the bot isn't ready yet"""
await self.client.wait_until_ready()
@pull_ufora_announcements.error
async def _on_announcements_error(self, error: BaseException):
"""Error handler for the Ufora Announcements task"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
@tasks.loop(hours=24)
async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day"""
async with self.client.db_session as session:
await remove_old_announcements(session)
@remove_old_ufora_announcements.before_loop
async def _before_remove_old_ufora_announcements(self):
await self.client.wait_until_ready()
@check_birthdays.error
@pull_ufora_announcements.error
@remove_old_ufora_announcements.error
async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
async def setup(client: Didier):

View File

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import functools
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from didier.cogs.tasks import Tasks
from database import enums
from database.crud.tasks import set_last_task_execution_time
__all__ = ["timed_task"]
def timed_task(task: enums.TaskType):
"""Decorator to log the last execution time of a task"""
def _decorator(func):
@functools.wraps(func)
async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
await func(tasks_cog, *args, **kwargs)
async with tasks_cog.client.db_session as session:
await set_last_task_execution_time(session, task)
return _wrapper
return _decorator

View File

@ -1,6 +1,6 @@
from typing import Optional
from didier.data.flags import PosixFlags
from didier.utils.discord.flags import PosixFlags
__all__ = ["EditCustomFlags"]

View File

@ -1,6 +1,10 @@
import datetime
import zoneinfo
__all__ = ["int_to_weekday", "str_to_date"]
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date"]
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this