import datetime from datetime import date from typing import Optional from sqlalchemy import extract, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from database.crud import users from database.schemas.relational import Birthday, User __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] async def add_birthday(session: AsyncSession, user_id: int, birthday: date): """Add a user's birthday into the database If already present, overwrites the existing one """ user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) if user.birthday is not None: bd = user.birthday await session.refresh(bd) bd.birthday = birthday else: bd = Birthday(user_id=user_id, birthday=birthday) session.add(bd) await session.commit() async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]: """Find a user's birthday""" statement = select(Birthday).where(Birthday.user_id == user_id) return (await session.execute(statement)).scalar_one_or_none() async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]: """Get all birthdays that happen on a given day""" days = extract("day", Birthday.birthday) months = extract("month", Birthday.birthday) statement = select(Birthday).where((days == day.day) & (months == day.month)) return list((await session.execute(statement)).scalars().all())