Compare commits

...

4 Commits

27 changed files with 402 additions and 64 deletions

View File

@ -37,7 +37,7 @@ def upgrade() -> None:
"nightly_data",
sa.Column("nightly_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True),
sa.Column("last_nightly", sa.Date, nullable=True),
sa.Column("count", sa.Integer(), server_default="0", nullable=False),
sa.ForeignKeyConstraint(
["user_id"],

View File

@ -22,7 +22,7 @@ def upgrade() -> None:
"birthdays",
sa.Column("birthday_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("birthday", sa.DateTime(), nullable=False),
sa.Column("birthday", sa.Date, nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],

View File

@ -0,0 +1,36 @@
"""Create tasks
Revision ID: 346b408c362a
Revises: 1716bfecf684
Create Date: 2022-07-23 19:41:07.029482
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "346b408c362a"
down_revision = "1716bfecf684"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"tasks",
sa.Column("task_id", sa.Integer(), nullable=False),
sa.Column("task", sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype"), nullable=False),
sa.Column("previous_run", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("task_id"),
sa.UniqueConstraint("task"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("tasks")
sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype").drop(op.get_bind())
# ### end Alembic commands ###

View File

@ -1,16 +1,16 @@
"""Initial migration
Revision ID: 4ec79dd5b191
Revises:
Revises:
Create Date: 2022-06-19 00:31:58.384360
"""
from alembic import op
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = '4ec79dd5b191'
revision = "4ec79dd5b191"
down_revision = None
branch_labels = None
depends_on = None
@ -18,37 +18,46 @@ depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('ufora_courses',
sa.Column('course_id', sa.Integer(), nullable=False),
sa.Column('name', sa.Text(), nullable=False),
sa.Column('code', sa.Text(), nullable=False),
sa.Column('year', sa.Integer(), nullable=False),
sa.Column('log_announcements', sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint('course_id'),
sa.UniqueConstraint('code'),
sa.UniqueConstraint('name')
op.create_table(
"ufora_courses",
sa.Column("course_id", sa.Integer(), nullable=False),
sa.Column("name", sa.Text(), nullable=False),
sa.Column("code", sa.Text(), nullable=False),
sa.Column("year", sa.Integer(), nullable=False),
sa.Column("log_announcements", sa.Boolean(), nullable=False),
sa.PrimaryKeyConstraint("course_id"),
sa.UniqueConstraint("code"),
sa.UniqueConstraint("name"),
)
op.create_table('ufora_announcements',
sa.Column('announcement_id', sa.Integer(), nullable=False),
sa.Column('course_id', sa.Integer(), nullable=True),
sa.Column('publication_date', sa.DateTime(timezone=True), nullable=True),
sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ),
sa.PrimaryKeyConstraint('announcement_id')
op.create_table(
"ufora_announcements",
sa.Column("announcement_id", sa.Integer(), nullable=False),
sa.Column("course_id", sa.Integer(), nullable=True),
sa.Column("publication_date", sa.Date, nullable=True),
sa.ForeignKeyConstraint(
["course_id"],
["ufora_courses.course_id"],
),
sa.PrimaryKeyConstraint("announcement_id"),
)
op.create_table('ufora_course_aliases',
sa.Column('alias_id', sa.Integer(), nullable=False),
sa.Column('alias', sa.Text(), nullable=False),
sa.Column('course_id', sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ),
sa.PrimaryKeyConstraint('alias_id'),
sa.UniqueConstraint('alias')
op.create_table(
"ufora_course_aliases",
sa.Column("alias_id", sa.Integer(), nullable=False),
sa.Column("alias", sa.Text(), nullable=False),
sa.Column("course_id", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(
["course_id"],
["ufora_courses.course_id"],
),
sa.PrimaryKeyConstraint("alias_id"),
sa.UniqueConstraint("alias"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('ufora_course_aliases')
op.drop_table('ufora_announcements')
op.drop_table('ufora_courses')
op.drop_table("ufora_course_aliases")
op.drop_table("ufora_announcements")
op.drop_table("ufora_courses")
# ### end Alembic commands ###

View File

@ -1,13 +1,15 @@
import datetime
from datetime import date
from typing import Optional
from sqlalchemy import select
from sqlalchemy import extract, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from database.crud import users
from database.models import Birthday
from database.models import Birthday, User
__all__ = ["add_birthday", "get_birthday_for_user"]
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
@ -15,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
If already present, overwrites the existing one
"""
user = await users.get_or_add(session, user_id)
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
if user.birthday is not None:
bd = user.birthday
@ -32,3 +34,12 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
"""Find a user's birthday"""
statement = select(Birthday).where(Birthday.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
"""Get all birthdays that happen on a given day"""
days = extract("day", Birthday.birthday)
months = extract("month", Birthday.birthday)
statement = select(Birthday).where((days == day.day) & (months == day.month))
return list((await session.execute(statement)).scalars())

View File

@ -1,4 +1,4 @@
from datetime import datetime
from datetime import date
from typing import Union
from sqlalchemy.ext.asyncio import AsyncSession
@ -69,9 +69,9 @@ async def claim_nightly(session: AsyncSession, user_id: int):
"""Claim daily Dinks"""
nightly_data = await get_nightly_data(session, user_id)
now = datetime.now()
now = date.today()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
if nightly_data.last_nightly is not None and nightly_data.last_nightly == now:
raise exceptions.DoubleNightly
bank = await get_bank(session, user_id)

View File

@ -0,0 +1,32 @@
import datetime
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.enums import TaskType
from database.models import Task
from database.utils.datetime import LOCAL_TIMEZONE
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]:
"""Get a task by its enum value, if it exists
Returns None if the task does not exist
"""
statement = select(Task).where(Task.task == task)
return (await session.execute(statement)).scalar_one_or_none()
async def set_last_task_execution_time(session: AsyncSession, task: TaskType):
"""Set the last time a specific task was executed"""
_task = await get_task_by_enum(session, task)
if _task is None:
_task = Task(task=task)
_task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE)
session.add(_task)
await session.commit()

View File

@ -10,12 +10,16 @@ __all__ = [
]
async def get_or_add(session: AsyncSession, user_id: int) -> User:
async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
"""Get a user's profile
If it doesn't exist yet, create it (along with all linked datastructures)
"""
statement = select(User).where(User.user_id == user_id)
if options is None:
options = []
statement = select(User).where(User.user_id == user_id).options(*options)
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
# User exists
@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User:
session.add(user)
await session.commit()
await session.refresh(user)
return user

13
database/enums.py 100644
View File

@ -0,0 +1,13 @@
import enum
__all__ = ["TaskType"]
# There is a bug in typeshed that causes an incorrect PyCharm warning
# https://github.com/python/typeshed/issues/8286
# noinspection PyArgumentList
class TaskType(enum.IntEnum):
"""Enum for the different types of tasks"""
BIRTHDAYS = enum.auto()
UFORA_ANNOUNCEMENTS = enum.auto()

View File

@ -1,11 +1,23 @@
from __future__ import annotations
from datetime import datetime
from datetime import date, datetime
from typing import Optional
from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Date,
DateTime,
Enum,
ForeignKey,
Integer,
Text,
)
from sqlalchemy.orm import declarative_base, relationship
from database import enums
Base = declarative_base()
@ -17,6 +29,7 @@ __all__ = [
"CustomCommandAlias",
"DadJoke",
"NightlyData",
"Task",
"UforaAnnouncement",
"UforaCourse",
"UforaCourseAlias",
@ -54,7 +67,7 @@ class Birthday(Base):
birthday_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
birthday: datetime = Column(DateTime, nullable=False)
birthday: date = Column(Date, nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
@ -103,12 +116,22 @@ class NightlyData(Base):
nightly_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True)
last_nightly: Optional[date] = Column(Date, nullable=True)
count: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class Task(Base):
"""A Didier task"""
__tablename__ = "tasks"
task_id: int = Column(Integer, primary_key=True)
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True)
previous_run: datetime = Column(DateTime(timezone=True), nullable=True)
class UforaCourse(Base):
"""A course on Ufora"""
@ -147,7 +170,7 @@ class UforaAnnouncement(Base):
announcement_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
publication_date: datetime = Column(DateTime(timezone=True))
publication_date: date = Column(Date)
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")

View File

@ -0,0 +1,5 @@
import zoneinfo
__all__ = ["LOCAL_TIMEZONE"]
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")

View File

@ -35,7 +35,13 @@ class Discord(commands.Cog):
async def birthday_set(self, ctx: commands.Context, date_str: str):
"""Command to set your birthday"""
try:
date = str_to_date(date_str)
default_year = 2001
date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
# If no year was passed, make it 2001 by default
if date_str.count("/") == 1:
date.replace(year=default_year)
except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)

View File

@ -8,7 +8,7 @@ from database.crud import custom_commands
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from didier import Didier
from didier.data.flags.owner import EditCustomFlags
from didier.utils.discord.flags.owner import EditCustomFlags
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand

View File

@ -1,27 +1,90 @@
import datetime
import traceback
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
import settings
from database import enums
from database.crud.birthdays import get_birthdays_on_day
from database.crud.ufora_announcements import remove_old_announcements
from didier import Didier
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
from didier.decorators.tasks import timed_task
from didier.utils.discord.checks import is_owner
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
# datetime.time()-instances for when every task should run
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
class Tasks(commands.Cog):
"""Task loops that run periodically"""
"""Task loops that run periodically
Preferably these would use the new `time`-kwarg, but these don't run
on startup, which is not ideal. This means we still have to run them every hour
in order to never miss anything if Didier goes offline by coincidence
"""
client: Didier
_tasks: dict[str, tasks.Loop]
def __init__(self, client: Didier):
self.client = client
# Only check birthdays if there's a channel to send it to
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
self.check_birthdays.start()
# Only pull announcements if a token was provided
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
self.pull_ufora_announcements.start()
self.remove_old_ufora_announcements.start()
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
@commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
@commands.check(is_owner)
async def tasks_group(self, ctx: commands.Context):
"""Command group for Task-related commands
Invoking the group itself shows the time until the next iteration
"""
raise NotImplementedError()
@tasks_group.command(name="Force", case_insensitive=True)
async def force_task(self, ctx: commands.Context, name: str):
"""Command to force-run a task without waiting for the run time"""
name = name.lower()
if name not in self._tasks:
return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False)
task = self._tasks[name]
await task()
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
@timed_task(enums.TaskType.BIRTHDAYS)
async def check_birthdays(self):
"""Check if it's currently anyone's birthday"""
now = tz_aware_now().date()
async with self.client.db_session as session:
birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
if channel is None:
return await self.client.log_error("Unable to find channel for birthday announcements")
for birthday in birthdays:
user = self.client.get_user(birthday.user_id)
# TODO more messages?
await channel.send(f"Gelukkig verjaardag {user.mention}!")
@check_birthdays.before_loop
async def _before_check_birthdays(self):
await self.client.wait_until_ready()
@tasks.loop(minutes=10)
@timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS)
async def pull_ufora_announcements(self):
"""Task that checks for new Ufora announcements & logs them in a channel"""
# In theory this shouldn't happen but just to please Mypy
@ -37,23 +100,20 @@ class Tasks(commands.Cog):
@pull_ufora_announcements.before_loop
async def _before_ufora_announcements(self):
"""Don't try to get announcements if the bot isn't ready yet"""
await self.client.wait_until_ready()
@pull_ufora_announcements.error
async def _on_announcements_error(self, error: BaseException):
"""Error handler for the Ufora Announcements task"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
@tasks.loop(hours=24)
async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day"""
async with self.client.db_session as session:
await remove_old_announcements(session)
@remove_old_ufora_announcements.before_loop
async def _before_remove_old_ufora_announcements(self):
await self.client.wait_until_ready()
@check_birthdays.error
@pull_ufora_announcements.error
@remove_old_ufora_announcements.error
async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks"""
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
async def setup(client: Didier):

View File

View File

@ -0,0 +1,28 @@
from __future__ import annotations
import functools
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from didier.cogs.tasks import Tasks
from database import enums
from database.crud.tasks import set_last_task_execution_time
__all__ = ["timed_task"]
def timed_task(task: enums.TaskType):
"""Decorator to log the last execution time of a task"""
def _decorator(func):
@functools.wraps(func)
async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
await func(tasks_cog, *args, **kwargs)
async with tasks_cog.client.db_session as session:
await set_last_task_execution_time(session, task)
return _wrapper
return _decorator

View File

@ -1,3 +1,4 @@
import logging
import os
import discord
@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix
__all__ = ["Didier"]
logger = logging.getLogger(__name__)
class Didier(commands.Bot):
"""DIDIER <3"""
database_caches: CacheManager
error_channel: discord.abc.Messageable
initial_extensions: tuple[str, ...] = ()
http_session: ClientSession
@ -60,6 +65,12 @@ class Didier(commands.Bot):
# Create aiohttp session
self.http_session = ClientSession()
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
else:
self.error_channel = self.get_user(self.owner_id)
async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others"""
for extension in self.initial_extensions:
@ -101,6 +112,13 @@ class Didier(commands.Bot):
"""Add an X to a message"""
await message.add_reaction("")
async def log_error(self, message: str, log_to_discord: bool = True):
"""Send an error message to the logs, and optionally the configured channel"""
logger.error(message)
if log_to_discord:
# TODO pretty embed
await self.error_channel.send(message)
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)

View File

@ -1,6 +1,6 @@
from typing import Optional
from didier.data.flags import PosixFlags
from didier.utils.discord.flags import PosixFlags
__all__ = ["EditCustomFlags"]

View File

@ -1,6 +1,11 @@
import datetime
import zoneinfo
__all__ = ["int_to_weekday", "str_to_date"]
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"]
from typing import Union
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
@ -8,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
def str_to_date(date_str: str) -> datetime.date:
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
"""Turn a string into a DD/MM/YYYY date"""
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()
# Allow passing multiple formats in a list
if isinstance(formats, str):
formats = [formats]
for format_str in formats:
try:
return datetime.datetime.strptime(date_str, format_str).date()
except ValueError:
continue
raise ValueError
def tz_aware_now() -> datetime.datetime:
"""Get the current date & time, but timezone-aware"""
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE)

View File

@ -1,5 +1,6 @@
black==22.3.0
coverage[toml]==6.4.1
freezegun==1.2.1
mypy==0.961
pre-commit==2.20.0
pytest==7.1.2

View File

@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None)
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
"""API Keys"""

View File

@ -23,7 +23,7 @@ def event_loop() -> Generator:
@pytest.fixture(scope="session")
async def tables():
"""Initialize a database before the tests, and then tear it down again
"""Fixture to initialize a database before the tests, and then tear it down again
Checks that the migrations were successful by asserting that we are currently
on the latest migration

View File

@ -1,8 +1,10 @@
from datetime import datetime, timedelta
from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud
from database.crud import users
from database.models import User
@ -14,7 +16,7 @@ async def test_add_birthday_not_present(database_session: AsyncSession, user: Us
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
assert user.birthday.birthday.date() == bd_date
assert user.birthday.birthday == bd_date
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
@ -27,7 +29,7 @@ async def test_add_birthday_overwrite(database_session: AsyncSession, user: User
new_bd_date = bd_date + timedelta(weeks=1)
await crud.add_birthday(database_session, user.user_id, new_bd_date)
await database_session.refresh(user)
assert user.birthday.birthday.date() == new_bd_date
assert user.birthday.birthday == new_bd_date
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
@ -38,10 +40,35 @@ async def test_get_birthday_exists(database_session: AsyncSession, user: User):
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is not None
assert bd.birthday.date() == bd_date
assert bd.birthday == bd_date
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it doesn't exist"""
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is None
@freeze_time("2022/07/23")
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
"""Test getting all birthdays on a given day"""
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
user_2 = await users.get_or_add(database_session, user.user_id + 1)
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
assert len(birthdays) == 1
assert birthdays[0].user_id == user.user_id
@freeze_time("2022/07/23")
async def test_get_birthdays_none_present(database_session: AsyncSession):
"""Test getting all birthdays when there are none"""
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
assert len(birthdays) == 0
# Add a random birthday that is not today
await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1))
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
assert len(birthdays) == 0

View File

@ -1,4 +1,7 @@
import datetime
import pytest
from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import currency as crud
@ -14,13 +17,18 @@ async def test_add_dinks(database_session: AsyncSession, bank: Bank):
assert bank.dinks == 10
@freeze_time("2022/07/23")
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank):
"""Test claiming nightlies when it hasn't been done yet"""
await crud.claim_nightly(database_session, bank.user_id)
await database_session.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT
nightly_data = await crud.get_nightly_data(database_session, bank.user_id)
assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23)
@freeze_time("2022/07/23")
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank):
"""Test claiming nightlies twice in a day"""
await crud.claim_nightly(database_session, bank.user_id)

View File

@ -0,0 +1,34 @@
import datetime
import pytest
from didier.utils.types.datetime import str_to_date
def test_str_to_date_single_valid():
"""Test parsing a string for a single possibility (default)"""
result = str_to_date("23/11/2001")
assert result == datetime.date(year=2001, month=11, day=23)
def test_str_to_date_single_invalid():
"""Test parsing a string for an invalid string"""
# Invalid format
with pytest.raises(ValueError):
str_to_date("23/11/01")
# Invalid date
with pytest.raises(ValueError):
str_to_date("69/42/0")
def test_str_to_date_multiple_valid():
"""Test parsing a string for multiple possibilities"""
result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"])
assert result == datetime.date(year=2001, month=11, day=23)
def test_str_to_date_multiple_invalid():
"""Test parsing a string for multiple possibilities when none are valid"""
with pytest.raises(ValueError):
str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])