Add support for lazy loading of user fields

pull/125/head
stijndcl 2022-07-23 22:34:03 +02:00
parent 66997b7556
commit 393cc9c891
3 changed files with 41 additions and 6 deletions

View File

@ -4,11 +4,12 @@ from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from database.crud import users from database.crud import users
from database.models import Birthday from database.models import Birthday, User
__all__ = ["add_birthday", "get_birthday_for_user"] __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
async def add_birthday(session: AsyncSession, user_id: int, birthday: date): async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
@ -16,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
If already present, overwrites the existing one If already present, overwrites the existing one
""" """
user = await users.get_or_add(session, user_id) user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
if user.birthday is not None: if user.birthday is not None:
bd = user.birthday bd = user.birthday
@ -35,5 +36,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
return (await session.execute(statement)).scalar_one_or_none() return (await session.execute(statement)).scalar_one_or_none()
async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]: async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
"""Get all birthdays that happen on a given day""" """Get all birthdays that happen on a given day"""
statement = select(Birthday).where(Birthday.birthday == day)
return list((await session.execute(statement)).scalars())

View File

@ -10,12 +10,16 @@ __all__ = [
] ]
async def get_or_add(session: AsyncSession, user_id: int) -> User: async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
"""Get a user's profile """Get a user's profile
If it doesn't exist yet, create it (along with all linked datastructures) If it doesn't exist yet, create it (along with all linked datastructures)
""" """
statement = select(User).where(User.user_id == user_id) if options is None:
options = []
statement = select(User).where(User.user_id == user_id).options(*options)
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none() user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
# User exists # User exists
@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User:
session.add(user) session.add(user)
await session.commit() await session.commit()
await session.refresh(user)
return user return user

View File

@ -1,8 +1,10 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud from database.crud import birthdays as crud
from database.crud import users
from database.models import User from database.models import User
@ -45,3 +47,28 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use
"""Test getting a user's birthday when it doesn't exist""" """Test getting a user's birthday when it doesn't exist"""
bd = await crud.get_birthday_for_user(database_session, user.user_id) bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is None assert bd is None
@freeze_time("2022/07/23")
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
"""Test getting all birthdays on a given day"""
await crud.add_birthday(database_session, user.user_id, datetime.today())
user_2 = await users.get_or_add(database_session, user.user_id + 1)
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
assert len(birthdays) == 1
assert birthdays[0].user_id == user.user_id
@freeze_time("2022/07/23")
async def test_get_birthdays_none_present(database_session: AsyncSession):
"""Test getting all birthdays when there are none"""
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
assert len(birthdays) == 0
# Add a random birthday that is not today
await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1))
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
assert len(birthdays) == 0