diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 6b52ef3..8300d49 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -4,11 +4,12 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from database.crud import users -from database.models import Birthday +from database.models import Birthday, User -__all__ = ["add_birthday", "get_birthday_for_user"] +__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] async def add_birthday(session: AsyncSession, user_id: int, birthday: date): @@ -16,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date): If already present, overwrites the existing one """ - user = await users.get_or_add(session, user_id) + user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) if user.birthday is not None: bd = user.birthday @@ -35,5 +36,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional return (await session.execute(statement)).scalar_one_or_none() -async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]: +async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]: """Get all birthdays that happen on a given day""" + statement = select(Birthday).where(Birthday.birthday == day) + return list((await session.execute(statement)).scalars()) diff --git a/database/crud/users.py b/database/crud/users.py index 57c5029..ba3011d 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -10,12 +10,16 @@ __all__ = [ ] -async def get_or_add(session: AsyncSession, user_id: int) -> User: +async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: """Get a user's profile If it doesn't exist yet, create it (along with all linked datastructures) """ - statement = select(User).where(User.user_id == user_id) + if options is None: + options = [] + + statement = select(User).where(User.user_id == user_id).options(*options) + user: Optional[User] = (await session.execute(statement)).scalar_one_or_none() # User exists @@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User: session.add(user) await session.commit() + await session.refresh(user) return user diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index ba90a3a..5d40914 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -1,8 +1,10 @@ from datetime import datetime, timedelta +from freezegun import freeze_time from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud +from database.crud import users from database.models import User @@ -45,3 +47,28 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use """Test getting a user's birthday when it doesn't exist""" bd = await crud.get_birthday_for_user(database_session, user.user_id) assert bd is None + + +@freeze_time("2022/07/23") +async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): + """Test getting all birthdays on a given day""" + await crud.add_birthday(database_session, user.user_id, datetime.today()) + + user_2 = await users.get_or_add(database_session, user.user_id + 1) + await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) + birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + assert len(birthdays) == 1 + assert birthdays[0].user_id == user.user_id + + +@freeze_time("2022/07/23") +async def test_get_birthdays_none_present(database_session: AsyncSession): + """Test getting all birthdays when there are none""" + birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + assert len(birthdays) == 0 + + # Add a random birthday that is not today + await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1)) + + birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + assert len(birthdays) == 0