Tests for birthday commands, overwrite existing birthdays

pull/125/head
stijndcl 2022-07-19 23:35:41 +02:00
parent f49f32d2e9
commit adcf94c66e
9 changed files with 103 additions and 19 deletions

11
.flake8
View File

@ -8,7 +8,6 @@ exclude =
__pycache__, __pycache__,
alembic, alembic,
htmlcov, htmlcov,
tests,
venv venv
# Disable rules that we don't care about (or conflict with others) # Disable rules that we don't care about (or conflict with others)
extend-ignore = extend-ignore =
@ -30,10 +29,14 @@ ignore-decorators=overrides
max-line-length = 120 max-line-length = 120
# Disable some rules for entire files # Disable some rules for entire files
per-file-ignores = per-file-ignores =
# Missing __all__, main isn't supposed to be imported # DALL000: Missing __all__, main isn't supposed to be imported
main.py: DALL000, main.py: DALL000,
# Missing __all__, Cogs aren't modules # DALL000: Missing __all__, Cogs aren't modules
./didier/cogs/*: DALL000, ./didier/cogs/*: DALL000,
# DALL000: Missing __all__, tests aren't supposed to be imported
# S101: Use of assert, this is the point of tests
./tests/*: DALL000 S101,
# D103: Missing docstring in public function
# All of the colours methods are just oneliners to create a colour, # All of the colours methods are just oneliners to create a colour,
# there's no point adding docstrings (function names are enough) # there's no point adding docstrings (function names are enough)
./didier/utils/discord/colours.py: D103 ./didier/utils/discord/colours.py: D103,

View File

@ -4,14 +4,26 @@ from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.models import Birthday from database.models import Birthday
__all__ = ["add_birthday", "get_birthday_for_user"] __all__ = ["add_birthday", "get_birthday_for_user"]
async def add_birthday(session: AsyncSession, user_id: int, birthday: date): async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
"""Add a user's birthday into the database""" """Add a user's birthday into the database
bd = Birthday(user_id=user_id, birthday=birthday)
If already present, overwrites the existing one
"""
user = await users.get_or_add(session, user_id)
if user.birthday is not None:
bd = user.birthday
await session.refresh(bd)
bd.birthday = birthday
else:
bd = Birthday(user_id=user_id, birthday=birthday)
session.add(bd) session.add(bd)
await session.commit() await session.commit()

View File

@ -162,7 +162,7 @@ class User(Base):
bank: Bank = relationship( bank: Bank = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
birthday: Birthday = relationship( birthday: Optional[Birthday] = relationship(
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
nightly_data: NightlyData = relationship( nightly_data: NightlyData = relationship(

View File

@ -12,6 +12,10 @@ from didier import Didier
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def event_loop() -> Generator: def event_loop() -> Generator:
"""Fixture to change the event loop
This fixes a lot of headaches during async tests
"""
loop = asyncio.get_event_loop_policy().new_event_loop() loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop yield loop
loop.close() loop.close()
@ -33,6 +37,7 @@ async def tables():
@pytest.fixture @pytest.fixture
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
"""Fixture to create a session for every test """Fixture to create a session for every test
Rollbacks the transaction afterwards so that the future tests start with a clean database Rollbacks the transaction afterwards so that the future tests start with a clean database
""" """
connection = await engine.connect() connection = await engine.connect()
@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
@pytest.fixture @pytest.fixture
def mock_client() -> Didier: def mock_client() -> Didier:
"""Fixture to get a mock Didier instance """Fixture to get a mock Didier instance
The mock uses 0 as the id The mock uses 0 as the id
""" """
mock_client = MagicMock() mock_client = MagicMock()

View File

@ -3,7 +3,33 @@ import datetime
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias from database.crud import users
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
@pytest.fixture(scope="session")
def test_user_id() -> int:
"""User id used when creating the debug user
Fixture is useful when comparing, fetching data, ...
"""
return 1
@pytest.fixture
async def user(database_session: AsyncSession, test_user_id) -> User:
"""Fixture to create a user"""
_user = await users.get_or_add(database_session, test_user_id)
await database_session.refresh(_user)
return _user
@pytest.fixture
async def bank(database_session: AsyncSession, user: User) -> Bank:
"""Fixture to fetch the test user's bank"""
_bank = user.bank
await database_session.refresh(_bank)
return _bank
@pytest.fixture @pytest.fixture

View File

@ -0,0 +1,47 @@
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud
from database.models import User
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
"""Test setting a user's birthday when it doesn't exist yet"""
assert user.birthday is None
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
assert user.birthday.birthday.date() == bd_date
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
"""Test that setting a user's birthday when it already exists overwrites it"""
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
new_bd_date = bd_date + timedelta(weeks=1)
await crud.add_birthday(database_session, user.user_id, new_bd_date)
await database_session.refresh(user)
assert user.birthday.birthday.date() == new_bd_date
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it exists"""
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is not None
assert bd.birthday.date() == bd_date
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it doesn't exist"""
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is None

View File

@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions
from database.models import Bank from database.models import Bank
DEBUG_USER_ID = 1
@pytest.fixture
async def bank(database_session: AsyncSession) -> Bank:
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
await database_session.refresh(_bank)
return _bank
async def test_add_dinks(database_session: AsyncSession, bank: Bank): async def test_add_dinks(database_session: AsyncSession, bank: Bank):
"""Test adding dinks to an account""" """Test adding dinks to an account"""
assert bank.dinks == 0 assert bank.dinks == 0

View File

@ -1,6 +1,5 @@
import datetime import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud

View File

@ -44,6 +44,7 @@ def test_abbreviated_no_number():
def test_abbreviated_float_floors(): def test_abbreviated_float_floors():
"""Test abbreviated_number for a float that is longer than the unit """Test abbreviated_number for a float that is longer than the unit
Example: Example:
5.3k is 5300, but 5.3001k is 5300.1 5.3k is 5300, but 5.3001k is 5300.1
""" """