mirror of
https://github.com/stijndcl/didier.git
synced 2026-04-07 15:48:29 +02:00
Add support for lazy loading of user fields
This commit is contained in:
parent
66997b7556
commit
393cc9c891
3 changed files with 41 additions and 6 deletions
|
|
@ -4,11 +4,12 @@ from typing import Optional
|
|||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.crud import users
|
||||
from database.models import Birthday
|
||||
from database.models import Birthday, User
|
||||
|
||||
__all__ = ["add_birthday", "get_birthday_for_user"]
|
||||
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||
|
||||
|
||||
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
||||
|
|
@ -16,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
|||
|
||||
If already present, overwrites the existing one
|
||||
"""
|
||||
user = await users.get_or_add(session, user_id)
|
||||
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
|
||||
|
||||
if user.birthday is not None:
|
||||
bd = user.birthday
|
||||
|
|
@ -35,5 +36,7 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
|
|||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_birthdays_on_day(session: AsyncSession, day: datetime.datetime) -> list[Birthday]:
|
||||
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
||||
"""Get all birthdays that happen on a given day"""
|
||||
statement = select(Birthday).where(Birthday.birthday == day)
|
||||
return list((await session.execute(statement)).scalars())
|
||||
|
|
|
|||
|
|
@ -10,12 +10,16 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
async def get_or_add(session: AsyncSession, user_id: int) -> User:
|
||||
async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
|
||||
"""Get a user's profile
|
||||
|
||||
If it doesn't exist yet, create it (along with all linked datastructures)
|
||||
"""
|
||||
statement = select(User).where(User.user_id == user_id)
|
||||
if options is None:
|
||||
options = []
|
||||
|
||||
statement = select(User).where(User.user_id == user_id).options(*options)
|
||||
|
||||
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
# User exists
|
||||
|
|
@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User:
|
|||
session.add(user)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
return user
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue