2022-07-23 20:35:42 +02:00
|
|
|
import datetime
|
2022-07-19 22:58:59 +02:00
|
|
|
from datetime import date
|
|
|
|
from typing import Optional
|
|
|
|
|
2022-07-23 23:21:32 +02:00
|
|
|
from sqlalchemy import extract, select
|
2022-07-19 22:58:59 +02:00
|
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
2022-07-23 22:34:03 +02:00
|
|
|
from sqlalchemy.orm import selectinload
|
2022-07-19 22:58:59 +02:00
|
|
|
|
2022-07-19 23:35:41 +02:00
|
|
|
from database.crud import users
|
2022-07-23 22:34:03 +02:00
|
|
|
from database.models import Birthday, User
|
2022-07-19 22:58:59 +02:00
|
|
|
|
2022-07-23 22:34:03 +02:00
|
|
|
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
2022-07-19 22:58:59 +02:00
|
|
|
|
|
|
|
|
|
|
|
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
2022-07-19 23:35:41 +02:00
|
|
|
"""Add a user's birthday into the database
|
|
|
|
|
|
|
|
If already present, overwrites the existing one
|
|
|
|
"""
|
2022-07-23 22:34:03 +02:00
|
|
|
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
|
2022-07-19 23:35:41 +02:00
|
|
|
|
|
|
|
if user.birthday is not None:
|
|
|
|
bd = user.birthday
|
|
|
|
await session.refresh(bd)
|
|
|
|
bd.birthday = birthday
|
|
|
|
else:
|
|
|
|
bd = Birthday(user_id=user_id, birthday=birthday)
|
|
|
|
|
2022-07-19 22:58:59 +02:00
|
|
|
session.add(bd)
|
|
|
|
await session.commit()
|
|
|
|
|
|
|
|
|
|
|
|
async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]:
|
|
|
|
"""Find a user's birthday"""
|
|
|
|
statement = select(Birthday).where(Birthday.user_id == user_id)
|
|
|
|
return (await session.execute(statement)).scalar_one_or_none()
|
2022-07-23 20:35:42 +02:00
|
|
|
|
|
|
|
|
2022-07-23 22:34:03 +02:00
|
|
|
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
2022-07-23 20:35:42 +02:00
|
|
|
"""Get all birthdays that happen on a given day"""
|
2022-07-23 23:21:32 +02:00
|
|
|
days = extract("day", Birthday.birthday)
|
|
|
|
months = extract("month", Birthday.birthday)
|
|
|
|
|
|
|
|
statement = select(Birthday).where((days == day.day) & (months == day.month))
|
2022-07-24 01:49:52 +02:00
|
|
|
return list((await session.execute(statement)).scalars().all())
|