mirror of https://github.com/stijndcl/didier
				
				
				
			Tests for birthday commands, overwrite existing birthdays
							parent
							
								
									f49f32d2e9
								
							
						
					
					
						commit
						adcf94c66e
					
				
							
								
								
									
										11
									
								
								.flake8
								
								
								
								
							
							
						
						
									
										11
									
								
								.flake8
								
								
								
								
							| 
						 | 
				
			
			@ -8,7 +8,6 @@ exclude =
 | 
			
		|||
    __pycache__,
 | 
			
		||||
    alembic,
 | 
			
		||||
    htmlcov,
 | 
			
		||||
    tests,
 | 
			
		||||
    venv
 | 
			
		||||
# Disable rules that we don't care about (or conflict with others)
 | 
			
		||||
extend-ignore =
 | 
			
		||||
| 
						 | 
				
			
			@ -30,10 +29,14 @@ ignore-decorators=overrides
 | 
			
		|||
max-line-length = 120
 | 
			
		||||
# Disable some rules for entire files
 | 
			
		||||
per-file-ignores =
 | 
			
		||||
    # Missing __all__, main isn't supposed to be imported
 | 
			
		||||
    # DALL000: Missing __all__, main isn't supposed to be imported
 | 
			
		||||
    main.py: DALL000,
 | 
			
		||||
    # Missing __all__, Cogs aren't modules
 | 
			
		||||
    # DALL000: Missing __all__, Cogs aren't modules
 | 
			
		||||
    ./didier/cogs/*: DALL000,
 | 
			
		||||
    # DALL000: Missing __all__, tests aren't supposed to be imported
 | 
			
		||||
    # S101: Use of assert, this is the point of tests
 | 
			
		||||
    ./tests/*: DALL000 S101,
 | 
			
		||||
    # D103: Missing docstring in public function
 | 
			
		||||
    # All of the colours methods are just oneliners to create a colour,
 | 
			
		||||
    # there's no point adding docstrings (function names are enough)
 | 
			
		||||
    ./didier/utils/discord/colours.py: D103
 | 
			
		||||
    ./didier/utils/discord/colours.py: D103,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,14 +4,26 @@ from typing import Optional
 | 
			
		|||
from sqlalchemy import select
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
 | 
			
		||||
from database.crud import users
 | 
			
		||||
from database.models import Birthday
 | 
			
		||||
 | 
			
		||||
__all__ = ["add_birthday", "get_birthday_for_user"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
 | 
			
		||||
    """Add a user's birthday into the database"""
 | 
			
		||||
    """Add a user's birthday into the database
 | 
			
		||||
 | 
			
		||||
    If already present, overwrites the existing one
 | 
			
		||||
    """
 | 
			
		||||
    user = await users.get_or_add(session, user_id)
 | 
			
		||||
 | 
			
		||||
    if user.birthday is not None:
 | 
			
		||||
        bd = user.birthday
 | 
			
		||||
        await session.refresh(bd)
 | 
			
		||||
        bd.birthday = birthday
 | 
			
		||||
    else:
 | 
			
		||||
        bd = Birthday(user_id=user_id, birthday=birthday)
 | 
			
		||||
 | 
			
		||||
    session.add(bd)
 | 
			
		||||
    await session.commit()
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -162,7 +162,7 @@ class User(Base):
 | 
			
		|||
    bank: Bank = relationship(
 | 
			
		||||
        "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
 | 
			
		||||
    )
 | 
			
		||||
    birthday: Birthday = relationship(
 | 
			
		||||
    birthday: Optional[Birthday] = relationship(
 | 
			
		||||
        "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
 | 
			
		||||
    )
 | 
			
		||||
    nightly_data: NightlyData = relationship(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -12,6 +12,10 @@ from didier import Didier
 | 
			
		|||
 | 
			
		||||
@pytest.fixture(scope="session", autouse=True)
 | 
			
		||||
def event_loop() -> Generator:
 | 
			
		||||
    """Fixture to change the event loop
 | 
			
		||||
 | 
			
		||||
    This fixes a lot of headaches during async tests
 | 
			
		||||
    """
 | 
			
		||||
    loop = asyncio.get_event_loop_policy().new_event_loop()
 | 
			
		||||
    yield loop
 | 
			
		||||
    loop.close()
 | 
			
		||||
| 
						 | 
				
			
			@ -33,6 +37,7 @@ async def tables():
 | 
			
		|||
@pytest.fixture
 | 
			
		||||
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
 | 
			
		||||
    """Fixture to create a session for every test
 | 
			
		||||
 | 
			
		||||
    Rollbacks the transaction afterwards so that the future tests start with a clean database
 | 
			
		||||
    """
 | 
			
		||||
    connection = await engine.connect()
 | 
			
		||||
| 
						 | 
				
			
			@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
 | 
			
		|||
@pytest.fixture
 | 
			
		||||
def mock_client() -> Didier:
 | 
			
		||||
    """Fixture to get a mock Didier instance
 | 
			
		||||
 | 
			
		||||
    The mock uses 0 as the id
 | 
			
		||||
    """
 | 
			
		||||
    mock_client = MagicMock()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,7 +3,33 @@ import datetime
 | 
			
		|||
import pytest
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
 | 
			
		||||
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias
 | 
			
		||||
from database.crud import users
 | 
			
		||||
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture(scope="session")
 | 
			
		||||
def test_user_id() -> int:
 | 
			
		||||
    """User id used when creating the debug user
 | 
			
		||||
 | 
			
		||||
    Fixture is useful when comparing, fetching data, ...
 | 
			
		||||
    """
 | 
			
		||||
    return 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
async def user(database_session: AsyncSession, test_user_id) -> User:
 | 
			
		||||
    """Fixture to create a user"""
 | 
			
		||||
    _user = await users.get_or_add(database_session, test_user_id)
 | 
			
		||||
    await database_session.refresh(_user)
 | 
			
		||||
    return _user
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
async def bank(database_session: AsyncSession, user: User) -> Bank:
 | 
			
		||||
    """Fixture to fetch the test user's bank"""
 | 
			
		||||
    _bank = user.bank
 | 
			
		||||
    await database_session.refresh(_bank)
 | 
			
		||||
    return _bank
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,47 @@
 | 
			
		|||
from datetime import datetime, timedelta
 | 
			
		||||
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
 | 
			
		||||
from database.crud import birthdays as crud
 | 
			
		||||
from database.models import User
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
 | 
			
		||||
    """Test setting a user's birthday when it doesn't exist yet"""
 | 
			
		||||
    assert user.birthday is None
 | 
			
		||||
 | 
			
		||||
    bd_date = datetime.today().date()
 | 
			
		||||
    await crud.add_birthday(database_session, user.user_id, bd_date)
 | 
			
		||||
    await database_session.refresh(user)
 | 
			
		||||
    assert user.birthday is not None
 | 
			
		||||
    assert user.birthday.birthday.date() == bd_date
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
 | 
			
		||||
    """Test that setting a user's birthday when it already exists overwrites it"""
 | 
			
		||||
    bd_date = datetime.today().date()
 | 
			
		||||
    await crud.add_birthday(database_session, user.user_id, bd_date)
 | 
			
		||||
    await database_session.refresh(user)
 | 
			
		||||
    assert user.birthday is not None
 | 
			
		||||
 | 
			
		||||
    new_bd_date = bd_date + timedelta(weeks=1)
 | 
			
		||||
    await crud.add_birthday(database_session, user.user_id, new_bd_date)
 | 
			
		||||
    await database_session.refresh(user)
 | 
			
		||||
    assert user.birthday.birthday.date() == new_bd_date
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
 | 
			
		||||
    """Test getting a user's birthday when it exists"""
 | 
			
		||||
    bd_date = datetime.today().date()
 | 
			
		||||
    await crud.add_birthday(database_session, user.user_id, bd_date)
 | 
			
		||||
    await database_session.refresh(user)
 | 
			
		||||
 | 
			
		||||
    bd = await crud.get_birthday_for_user(database_session, user.user_id)
 | 
			
		||||
    assert bd is not None
 | 
			
		||||
    assert bd.birthday.date() == bd_date
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
 | 
			
		||||
    """Test getting a user's birthday when it doesn't exist"""
 | 
			
		||||
    bd = await crud.get_birthday_for_user(database_session, user.user_id)
 | 
			
		||||
    assert bd is None
 | 
			
		||||
| 
						 | 
				
			
			@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions
 | 
			
		|||
from database.models import Bank
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DEBUG_USER_ID = 1
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
async def bank(database_session: AsyncSession) -> Bank:
 | 
			
		||||
    _bank = await crud.get_bank(database_session, DEBUG_USER_ID)
 | 
			
		||||
    await database_session.refresh(_bank)
 | 
			
		||||
    return _bank
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
 | 
			
		||||
    """Test adding dinks to an account"""
 | 
			
		||||
    assert bank.dinks == 0
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,5 @@
 | 
			
		|||
import datetime
 | 
			
		||||
 | 
			
		||||
import pytest
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
 | 
			
		||||
from database.crud import ufora_announcements as crud
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -44,6 +44,7 @@ def test_abbreviated_no_number():
 | 
			
		|||
 | 
			
		||||
def test_abbreviated_float_floors():
 | 
			
		||||
    """Test abbreviated_number for a float that is longer than the unit
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
        5.3k is 5300, but 5.3001k is 5300.1
 | 
			
		||||
    """
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue