mirror of https://github.com/stijndcl/didier
Add support for lazy loading of user fields
parent
66997b7556
commit
393cc9c891
|
@ -4,11 +4,12 @@ from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.models import Birthday
|
from database.models import Birthday, User
|
||||||
|
|
||||||
__all__ = ["add_birthday", "get_birthday_for_user"]
|
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||||
|
|
||||||
|
|
||||||
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
||||||
|
@ -16,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
||||||
|
|
||||||
If already present, overwrites the existing one
|
If already present, overwrites the existing one
|
||||||
"""
|
"""
|
||||||
user = await users.get_or_add(session, user_id)
|
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
|
||||||
|
|
||||||
if user.birthday is not None:
|
if user.birthday is not None:
|
||||||
bd = user.birthday
|
bd = user.birthday
|
||||||
|
@ -35,5 +36,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
|
||||||
return (await session.execute(statement)).scalar_one_or_none()
|
return (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]:
|
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
||||||
"""Get all birthdays that happen on a given day"""
|
"""Get all birthdays that happen on a given day"""
|
||||||
|
statement = select(Birthday).where(Birthday.birthday == day)
|
||||||
|
return list((await session.execute(statement)).scalars())
|
||||||
|
|
|
@ -10,12 +10,16 @@ __all__ = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def get_or_add(session: AsyncSession, user_id: int) -> User:
|
async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
|
||||||
"""Get a user's profile
|
"""Get a user's profile
|
||||||
|
|
||||||
If it doesn't exist yet, create it (along with all linked datastructures)
|
If it doesn't exist yet, create it (along with all linked datastructures)
|
||||||
"""
|
"""
|
||||||
statement = select(User).where(User.user_id == user_id)
|
if options is None:
|
||||||
|
options = []
|
||||||
|
|
||||||
|
statement = select(User).where(User.user_id == user_id).options(*options)
|
||||||
|
|
||||||
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
|
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
# User exists
|
# User exists
|
||||||
|
@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User:
|
||||||
session.add(user)
|
session.add(user)
|
||||||
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from freezegun import freeze_time
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import birthdays as crud
|
from database.crud import birthdays as crud
|
||||||
|
from database.crud import users
|
||||||
from database.models import User
|
from database.models import User
|
||||||
|
|
||||||
|
|
||||||
|
@ -45,3 +47,28 @@ async def test_get_birthday_not_exists(database_session: AsyncSession, user: Use
|
||||||
"""Test getting a user's birthday when it doesn't exist"""
|
"""Test getting a user's birthday when it doesn't exist"""
|
||||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||||
assert bd is None
|
assert bd is None
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2022/07/23")
|
||||||
|
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
|
||||||
|
"""Test getting all birthdays on a given day"""
|
||||||
|
await crud.add_birthday(database_session, user.user_id, datetime.today())
|
||||||
|
|
||||||
|
user_2 = await users.get_or_add(database_session, user.user_id + 1)
|
||||||
|
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||||
|
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||||
|
assert len(birthdays) == 1
|
||||||
|
assert birthdays[0].user_id == user.user_id
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2022/07/23")
|
||||||
|
async def test_get_birthdays_none_present(database_session: AsyncSession):
|
||||||
|
"""Test getting all birthdays when there are none"""
|
||||||
|
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||||
|
assert len(birthdays) == 0
|
||||||
|
|
||||||
|
# Add a random birthday that is not today
|
||||||
|
await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1))
|
||||||
|
|
||||||
|
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||||
|
assert len(birthdays) == 0
|
||||||
|
|
Loading…
Reference in New Issue