2022-06-30 21:17:48 +02:00
|
|
|
from typing import Optional
|
|
|
|
|
|
|
|
from sqlalchemy import select
|
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
|
2024-03-01 14:18:58 +01:00
|
|
|
from database.schemas import Bank, BankSavings, NightlyData, User
|
2022-07-11 22:23:38 +02:00
|
|
|
|
|
|
|
__all__ = [
|
2022-08-29 20:49:29 +02:00
|
|
|
"get_or_add_user",
|
2022-07-11 22:23:38 +02:00
|
|
|
]
|
2022-06-30 21:17:48 +02:00
|
|
|
|
|
|
|
|
2022-08-29 20:49:29 +02:00
|
|
|
async def get_or_add_user(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
|
2022-06-30 21:17:48 +02:00
|
|
|
"""Get a user's profile
|
2022-07-13 22:54:16 +02:00
|
|
|
|
|
|
|
If it doesn't exist yet, create it (along with all linked datastructures)
|
2022-06-30 21:17:48 +02:00
|
|
|
"""
|
2022-07-23 22:34:03 +02:00
|
|
|
if options is None:
|
|
|
|
options = []
|
|
|
|
|
|
|
|
statement = select(User).where(User.user_id == user_id).options(*options)
|
|
|
|
|
2022-06-30 21:17:48 +02:00
|
|
|
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
|
|
|
|
|
|
|
|
# User exists
|
|
|
|
if user is not None:
|
|
|
|
return user
|
|
|
|
|
|
|
|
# Create new user
|
|
|
|
user = User(user_id=user_id)
|
|
|
|
session.add(user)
|
|
|
|
await session.commit()
|
|
|
|
|
|
|
|
# Add bank & nightly info
|
|
|
|
bank = Bank(user_id=user_id)
|
|
|
|
nightly_data = NightlyData(user_id=user_id)
|
|
|
|
user.bank = bank
|
|
|
|
user.nightly_data = nightly_data
|
|
|
|
|
2024-03-01 14:18:58 +01:00
|
|
|
savings = BankSavings(user_id=user_id)
|
|
|
|
user.savings = savings
|
|
|
|
|
2022-06-30 21:17:48 +02:00
|
|
|
session.add(bank)
|
|
|
|
session.add(nightly_data)
|
2024-03-01 14:18:58 +01:00
|
|
|
session.add(savings)
|
2022-06-30 21:17:48 +02:00
|
|
|
session.add(user)
|
|
|
|
|
|
|
|
await session.commit()
|
2022-07-23 22:34:03 +02:00
|
|
|
await session.refresh(user)
|
2022-06-30 21:17:48 +02:00
|
|
|
|
|
|
|
return user
|