didier/database/crud/birthdays.py

46 lines
1.5 KiB
Python
Raw Permalink Normal View History

import datetime
2022-07-19 22:58:59 +02:00
from datetime import date
from typing import Optional
from sqlalchemy import extract, select
2022-07-19 22:58:59 +02:00
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
2022-07-19 22:58:59 +02:00
from database.crud import users
2022-08-29 20:24:42 +02:00
from database.schemas import Birthday, User
2022-07-19 22:58:59 +02:00
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
2022-07-19 22:58:59 +02:00
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
"""Add a user's birthday into the database
If already present, overwrites the existing one
"""
2022-08-29 20:49:29 +02:00
user = await users.get_or_add_user(session, user_id, options=[selectinload(User.birthday)])
if user.birthday is not None:
bd = user.birthday
await session.refresh(bd)
bd.birthday = birthday
else:
bd = Birthday(user_id=user_id, birthday=birthday)
2022-07-19 22:58:59 +02:00
session.add(bd)
await session.commit()
async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]:
"""Find a user's birthday"""
statement = select(Birthday).where(Birthday.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
"""Get all birthdays that happen on a given day"""
days = extract("day", Birthday.birthday)
months = extract("month", Birthday.birthday)
statement = select(Birthday).where((days == day.day) & (months == day.month))
2022-07-24 01:49:52 +02:00
return list((await session.execute(statement)).scalars().all())