From adcf94c66e8af1d1add8e1d0efbe18dbd8f00389 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 19 Jul 2022 23:35:41 +0200 Subject: [PATCH] Tests for birthday commands, overwrite existing birthdays --- .flake8 | 11 +++-- database/crud/birthdays.py | 16 ++++++- database/models.py | 2 +- tests/conftest.py | 6 +++ tests/test_database/conftest.py | 28 ++++++++++- .../test_database/test_crud/test_birthdays.py | 47 +++++++++++++++++++ .../test_database/test_crud/test_currency.py | 10 ---- .../test_crud/test_ufora_announcements.py | 1 - .../test_converters/test_numbers.py | 1 + 9 files changed, 103 insertions(+), 19 deletions(-) create mode 100644 tests/test_database/test_crud/test_birthdays.py diff --git a/.flake8 b/.flake8 index cab8ba8..1707912 100644 --- a/.flake8 +++ b/.flake8 @@ -8,7 +8,6 @@ exclude = __pycache__, alembic, htmlcov, - tests, venv # Disable rules that we don't care about (or conflict with others) extend-ignore = @@ -30,10 +29,14 @@ ignore-decorators=overrides max-line-length = 120 # Disable some rules for entire files per-file-ignores = - # Missing __all__, main isn't supposed to be imported + # DALL000: Missing __all__, main isn't supposed to be imported main.py: DALL000, - # Missing __all__, Cogs aren't modules + # DALL000: Missing __all__, Cogs aren't modules ./didier/cogs/*: DALL000, + # DALL000: Missing __all__, tests aren't supposed to be imported + # S101: Use of assert, this is the point of tests + ./tests/*: DALL000 S101, + # D103: Missing docstring in public function # All of the colours methods are just oneliners to create a colour, # there's no point adding docstrings (function names are enough) - ./didier/utils/discord/colours.py: D103 + ./didier/utils/discord/colours.py: D103, diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 6ff714d..99ea2db 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -4,14 +4,26 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from database.crud import users from database.models import Birthday __all__ = ["add_birthday", "get_birthday_for_user"] async def add_birthday(session: AsyncSession, user_id: int, birthday: date): - """Add a user's birthday into the database""" - bd = Birthday(user_id=user_id, birthday=birthday) + """Add a user's birthday into the database + + If already present, overwrites the existing one + """ + user = await users.get_or_add(session, user_id) + + if user.birthday is not None: + bd = user.birthday + await session.refresh(bd) + bd.birthday = birthday + else: + bd = Birthday(user_id=user_id, birthday=birthday) + session.add(bd) await session.commit() diff --git a/database/models.py b/database/models.py index d906406..74aa5fc 100644 --- a/database/models.py +++ b/database/models.py @@ -162,7 +162,7 @@ class User(Base): bank: Bank = relationship( "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - birthday: Birthday = relationship( + birthday: Optional[Birthday] = relationship( "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) nightly_data: NightlyData = relationship( diff --git a/tests/conftest.py b/tests/conftest.py index 95b44db..8530de5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,10 @@ from didier import Didier @pytest.fixture(scope="session", autouse=True) def event_loop() -> Generator: + """Fixture to change the event loop + + This fixes a lot of headaches during async tests + """ loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @@ -33,6 +37,7 @@ async def tables(): @pytest.fixture async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: """Fixture to create a session for every test + Rollbacks the transaction afterwards so that the future tests start with a clean database """ connection = await engine.connect() @@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: @pytest.fixture def mock_client() -> Didier: """Fixture to get a mock Didier instance + The mock uses 0 as the id """ mock_client = MagicMock() diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index de1e939..8bc765c 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -3,7 +3,33 @@ import datetime import pytest from sqlalchemy.ext.asyncio import AsyncSession -from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias +from database.crud import users +from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User + + +@pytest.fixture(scope="session") +def test_user_id() -> int: + """User id used when creating the debug user + + Fixture is useful when comparing, fetching data, ... + """ + return 1 + + +@pytest.fixture +async def user(database_session: AsyncSession, test_user_id) -> User: + """Fixture to create a user""" + _user = await users.get_or_add(database_session, test_user_id) + await database_session.refresh(_user) + return _user + + +@pytest.fixture +async def bank(database_session: AsyncSession, user: User) -> Bank: + """Fixture to fetch the test user's bank""" + _bank = user.bank + await database_session.refresh(_bank) + return _bank @pytest.fixture diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py new file mode 100644 index 0000000..96b924c --- /dev/null +++ b/tests/test_database/test_crud/test_birthdays.py @@ -0,0 +1,47 @@ +from datetime import datetime, timedelta + +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import birthdays as crud +from database.models import User + + +async def test_add_birthday_not_present(database_session: AsyncSession, user: User): + """Test setting a user's birthday when it doesn't exist yet""" + assert user.birthday is None + + bd_date = datetime.today().date() + await crud.add_birthday(database_session, user.user_id, bd_date) + await database_session.refresh(user) + assert user.birthday is not None + assert user.birthday.birthday.date() == bd_date + + +async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): + """Test that setting a user's birthday when it already exists overwrites it""" + bd_date = datetime.today().date() + await crud.add_birthday(database_session, user.user_id, bd_date) + await database_session.refresh(user) + assert user.birthday is not None + + new_bd_date = bd_date + timedelta(weeks=1) + await crud.add_birthday(database_session, user.user_id, new_bd_date) + await database_session.refresh(user) + assert user.birthday.birthday.date() == new_bd_date + + +async def test_get_birthday_exists(database_session: AsyncSession, user: User): + """Test getting a user's birthday when it exists""" + bd_date = datetime.today().date() + await crud.add_birthday(database_session, user.user_id, bd_date) + await database_session.refresh(user) + + bd = await crud.get_birthday_for_user(database_session, user.user_id) + assert bd is not None + assert bd.birthday.date() == bd_date + + +async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): + """Test getting a user's birthday when it doesn't exist""" + bd = await crud.get_birthday_for_user(database_session, user.user_id) + assert bd is None diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index f996bf9..1f0a163 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions from database.models import Bank -DEBUG_USER_ID = 1 - - -@pytest.fixture -async def bank(database_session: AsyncSession) -> Bank: - _bank = await crud.get_bank(database_session, DEBUG_USER_ID) - await database_session.refresh(_bank) - return _bank - - async def test_add_dinks(database_session: AsyncSession, bank: Bank): """Test adding dinks to an account""" assert bank.dinks == 0 diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index b2385a2..4e6fc47 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -1,6 +1,5 @@ import datetime -import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_announcements as crud diff --git a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py index ed88692..75ab401 100644 --- a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py +++ b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py @@ -44,6 +44,7 @@ def test_abbreviated_no_number(): def test_abbreviated_float_floors(): """Test abbreviated_number for a float that is longer than the unit + Example: 5.3k is 5300, but 5.3001k is 5300.1 """