mirror of https://github.com/stijndcl/didier
Tests for birthday commands, overwrite existing birthdays
parent
f49f32d2e9
commit
adcf94c66e
11
.flake8
11
.flake8
|
@ -8,7 +8,6 @@ exclude =
|
||||||
__pycache__,
|
__pycache__,
|
||||||
alembic,
|
alembic,
|
||||||
htmlcov,
|
htmlcov,
|
||||||
tests,
|
|
||||||
venv
|
venv
|
||||||
# Disable rules that we don't care about (or conflict with others)
|
# Disable rules that we don't care about (or conflict with others)
|
||||||
extend-ignore =
|
extend-ignore =
|
||||||
|
@ -30,10 +29,14 @@ ignore-decorators=overrides
|
||||||
max-line-length = 120
|
max-line-length = 120
|
||||||
# Disable some rules for entire files
|
# Disable some rules for entire files
|
||||||
per-file-ignores =
|
per-file-ignores =
|
||||||
# Missing __all__, main isn't supposed to be imported
|
# DALL000: Missing __all__, main isn't supposed to be imported
|
||||||
main.py: DALL000,
|
main.py: DALL000,
|
||||||
# Missing __all__, Cogs aren't modules
|
# DALL000: Missing __all__, Cogs aren't modules
|
||||||
./didier/cogs/*: DALL000,
|
./didier/cogs/*: DALL000,
|
||||||
|
# DALL000: Missing __all__, tests aren't supposed to be imported
|
||||||
|
# S101: Use of assert, this is the point of tests
|
||||||
|
./tests/*: DALL000 S101,
|
||||||
|
# D103: Missing docstring in public function
|
||||||
# All of the colours methods are just oneliners to create a colour,
|
# All of the colours methods are just oneliners to create a colour,
|
||||||
# there's no point adding docstrings (function names are enough)
|
# there's no point adding docstrings (function names are enough)
|
||||||
./didier/utils/discord/colours.py: D103
|
./didier/utils/discord/colours.py: D103,
|
||||||
|
|
|
@ -4,14 +4,26 @@ from typing import Optional
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import users
|
||||||
from database.models import Birthday
|
from database.models import Birthday
|
||||||
|
|
||||||
__all__ = ["add_birthday", "get_birthday_for_user"]
|
__all__ = ["add_birthday", "get_birthday_for_user"]
|
||||||
|
|
||||||
|
|
||||||
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
||||||
"""Add a user's birthday into the database"""
|
"""Add a user's birthday into the database
|
||||||
bd = Birthday(user_id=user_id, birthday=birthday)
|
|
||||||
|
If already present, overwrites the existing one
|
||||||
|
"""
|
||||||
|
user = await users.get_or_add(session, user_id)
|
||||||
|
|
||||||
|
if user.birthday is not None:
|
||||||
|
bd = user.birthday
|
||||||
|
await session.refresh(bd)
|
||||||
|
bd.birthday = birthday
|
||||||
|
else:
|
||||||
|
bd = Birthday(user_id=user_id, birthday=birthday)
|
||||||
|
|
||||||
session.add(bd)
|
session.add(bd)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
|
|
|
@ -162,7 +162,7 @@ class User(Base):
|
||||||
bank: Bank = relationship(
|
bank: Bank = relationship(
|
||||||
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
birthday: Birthday = relationship(
|
birthday: Optional[Birthday] = relationship(
|
||||||
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
nightly_data: NightlyData = relationship(
|
nightly_data: NightlyData = relationship(
|
||||||
|
|
|
@ -12,6 +12,10 @@ from didier import Didier
|
||||||
|
|
||||||
@pytest.fixture(scope="session", autouse=True)
|
@pytest.fixture(scope="session", autouse=True)
|
||||||
def event_loop() -> Generator:
|
def event_loop() -> Generator:
|
||||||
|
"""Fixture to change the event loop
|
||||||
|
|
||||||
|
This fixes a lot of headaches during async tests
|
||||||
|
"""
|
||||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||||
yield loop
|
yield loop
|
||||||
loop.close()
|
loop.close()
|
||||||
|
@ -33,6 +37,7 @@ async def tables():
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||||
"""Fixture to create a session for every test
|
"""Fixture to create a session for every test
|
||||||
|
|
||||||
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
||||||
"""
|
"""
|
||||||
connection = await engine.connect()
|
connection = await engine.connect()
|
||||||
|
@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_client() -> Didier:
|
def mock_client() -> Didier:
|
||||||
"""Fixture to get a mock Didier instance
|
"""Fixture to get a mock Didier instance
|
||||||
|
|
||||||
The mock uses 0 as the id
|
The mock uses 0 as the id
|
||||||
"""
|
"""
|
||||||
mock_client = MagicMock()
|
mock_client = MagicMock()
|
||||||
|
|
|
@ -3,7 +3,33 @@ import datetime
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias
|
from database.crud import users
|
||||||
|
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def test_user_id() -> int:
|
||||||
|
"""User id used when creating the debug user
|
||||||
|
|
||||||
|
Fixture is useful when comparing, fetching data, ...
|
||||||
|
"""
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def user(database_session: AsyncSession, test_user_id) -> User:
|
||||||
|
"""Fixture to create a user"""
|
||||||
|
_user = await users.get_or_add(database_session, test_user_id)
|
||||||
|
await database_session.refresh(_user)
|
||||||
|
return _user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def bank(database_session: AsyncSession, user: User) -> Bank:
|
||||||
|
"""Fixture to fetch the test user's bank"""
|
||||||
|
_bank = user.bank
|
||||||
|
await database_session.refresh(_bank)
|
||||||
|
return _bank
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import birthdays as crud
|
||||||
|
from database.models import User
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
|
||||||
|
"""Test setting a user's birthday when it doesn't exist yet"""
|
||||||
|
assert user.birthday is None
|
||||||
|
|
||||||
|
bd_date = datetime.today().date()
|
||||||
|
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||||
|
await database_session.refresh(user)
|
||||||
|
assert user.birthday is not None
|
||||||
|
assert user.birthday.birthday.date() == bd_date
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
|
||||||
|
"""Test that setting a user's birthday when it already exists overwrites it"""
|
||||||
|
bd_date = datetime.today().date()
|
||||||
|
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||||
|
await database_session.refresh(user)
|
||||||
|
assert user.birthday is not None
|
||||||
|
|
||||||
|
new_bd_date = bd_date + timedelta(weeks=1)
|
||||||
|
await crud.add_birthday(database_session, user.user_id, new_bd_date)
|
||||||
|
await database_session.refresh(user)
|
||||||
|
assert user.birthday.birthday.date() == new_bd_date
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
|
||||||
|
"""Test getting a user's birthday when it exists"""
|
||||||
|
bd_date = datetime.today().date()
|
||||||
|
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||||
|
await database_session.refresh(user)
|
||||||
|
|
||||||
|
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||||
|
assert bd is not None
|
||||||
|
assert bd.birthday.date() == bd_date
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
|
||||||
|
"""Test getting a user's birthday when it doesn't exist"""
|
||||||
|
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||||
|
assert bd is None
|
|
@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions
|
||||||
from database.models import Bank
|
from database.models import Bank
|
||||||
|
|
||||||
|
|
||||||
DEBUG_USER_ID = 1
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def bank(database_session: AsyncSession) -> Bank:
|
|
||||||
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
|
|
||||||
await database_session.refresh(_bank)
|
|
||||||
return _bank
|
|
||||||
|
|
||||||
|
|
||||||
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
|
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
|
||||||
"""Test adding dinks to an account"""
|
"""Test adding dinks to an account"""
|
||||||
assert bank.dinks == 0
|
assert bank.dinks == 0
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import pytest
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import ufora_announcements as crud
|
from database.crud import ufora_announcements as crud
|
||||||
|
|
|
@ -44,6 +44,7 @@ def test_abbreviated_no_number():
|
||||||
|
|
||||||
def test_abbreviated_float_floors():
|
def test_abbreviated_float_floors():
|
||||||
"""Test abbreviated_number for a float that is longer than the unit
|
"""Test abbreviated_number for a float that is longer than the unit
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
5.3k is 5300, but 5.3001k is 5300.1
|
5.3k is 5300, but 5.3001k is 5300.1
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue