Add birthday task, change migrations to use date instead of datetime

This commit is contained in:
stijndcl 2022-07-23 20:35:42 +02:00
parent adcf94c66e
commit 8bc0f1fa7a
18 changed files with 249 additions and 49 deletions

View file

@ -1,3 +1,4 @@
import datetime
from datetime import date
from typing import Optional
@ -32,3 +33,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
"""Find a user's birthday"""
statement = select(Birthday).where(Birthday.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]:
"""Get all birthdays that happen on a given day"""

View file

@ -71,7 +71,7 @@ async def claim_nightly(session: AsyncSession, user_id: int):
now = datetime.now()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
if nightly_data.last_nightly is not None and nightly_data.last_nightly == now.date():
raise exceptions.DoubleNightly
bank = await get_bank(session, user_id)

32
database/crud/tasks.py Normal file
View file

@ -0,0 +1,32 @@
import datetime
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.enums import TaskType
from database.models import Task
from database.utils.datetime import LOCAL_TIMEZONE
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]:
"""Get a task by its enum value, if it exists
Returns None if the task does not exist
"""
statement = select(Task).where(Task.task == task)
return (await session.execute(statement)).scalar_one_or_none()
async def set_last_task_execution_time(session: AsyncSession, task: TaskType):
"""Set the last time a specific task was executed"""
_task = await get_task_by_enum(session, task)
if _task is None:
_task = Task(task=task)
_task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE)
session.add(_task)
await session.commit()

13
database/enums.py Normal file
View file

@ -0,0 +1,13 @@
import enum
__all__ = ["TaskType"]
# There is a bug in typeshed that causes an incorrect PyCharm warning
# https://github.com/python/typeshed/issues/8286
# noinspection PyArgumentList
class TaskType(enum.IntEnum):
"""Enum for the different types of tasks"""
BIRTHDAYS = enum.auto()
UFORA_ANNOUNCEMENTS = enum.auto()

View file

@ -1,11 +1,23 @@
from __future__ import annotations
from datetime import datetime
from datetime import date, datetime
from typing import Optional
from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Date,
DateTime,
Enum,
ForeignKey,
Integer,
Text,
)
from sqlalchemy.orm import declarative_base, relationship
from database import enums
Base = declarative_base()
@ -17,6 +29,7 @@ __all__ = [
"CustomCommandAlias",
"DadJoke",
"NightlyData",
"Task",
"UforaAnnouncement",
"UforaCourse",
"UforaCourseAlias",
@ -54,7 +67,7 @@ class Birthday(Base):
birthday_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
birthday: datetime = Column(DateTime, nullable=False)
birthday: date = Column(Date, nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
@ -103,12 +116,22 @@ class NightlyData(Base):
nightly_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True)
last_nightly: Optional[date] = Column(Date, nullable=True)
count: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class Task(Base):
"""A Didier task"""
__tablename__ = "tasks"
task_id: int = Column(Integer, primary_key=True)
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True)
previous_run: datetime = Column(DateTime(timezone=True), nullable=True)
class UforaCourse(Base):
"""A course on Ufora"""
@ -147,7 +170,7 @@ class UforaAnnouncement(Base):
announcement_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
publication_date: datetime = Column(DateTime(timezone=True))
publication_date: date = Column(Date)
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")

View file

@ -0,0 +1,5 @@
import zoneinfo
__all__ = ["LOCAL_TIMEZONE"]
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")