From 61128dda92d5a66773e8950559be2a922a75e8fc Mon Sep 17 00:00:00 2001 From: stijndcl Date: Sun, 3 Jul 2022 19:26:30 +0200 Subject: [PATCH] Write some tests for currency crud --- database/crud/currency.py | 16 ++-- didier/cogs/currency.py | 4 +- didier/utils/discord/converters/numbers.py | 6 +- .../test_database/test_crud/test_currency.py | 81 +++++++++++++++++++ .../test_converters/test_numbers.py | 15 ++++ 5 files changed, 112 insertions(+), 10 deletions(-) diff --git a/database/crud/currency.py b/database/crud/currency.py index 612b969..8e7f2e6 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.exceptions import currency as exceptions -from database.models import Bank +from database.models import Bank, NightlyData from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price @@ -18,13 +18,20 @@ async def get_bank(session: AsyncSession, user_id: int) -> Bank: return user.bank +async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData: + """Get a user's nightly info""" + user = await users.get_or_add(session, user_id) + return user.nightly_data + + async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int: """Invest all your Dinks""" bank = await get_bank(session, user_id) if amount == "all": amount = bank.dinks - amount = int(amount) + # Don't allow investing more dinks than you own + amount = min(bank.dinks, int(amount)) bank.dinks -= amount bank.invested += amount @@ -45,15 +52,14 @@ async def add_dinks(session: AsyncSession, user_id: int, amount: int): async def claim_nightly(session: AsyncSession, user_id: int): """Claim daily Dinks""" - user = await users.get_or_add(session, user_id) - nightly_data = user.nightly_data + nightly_data = await get_nightly_data(session, user_id) now = datetime.now() if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): raise exceptions.DoubleNightly - bank = user.bank + bank = await get_bank(session, user_id) bank.dinks += NIGHTLY_AMOUNT nightly_data.last_nightly = now diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index bb46bb0..b30c283 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -5,10 +5,10 @@ from discord.ext import commands from database.crud import currency as crud from database.exceptions.currency import DoubleNightly, NotEnoughDinks +from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price from didier import Didier from didier.utils.discord.checks import is_owner from didier.utils.discord.converters import abbreviated_number -from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price from didier.utils.types.string import pluralize @@ -113,7 +113,7 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) - @commands.command(name="Invest") + @commands.command(name="Invest", aliases=["Deposit", "Dep"]) async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore """Invest a given amount of Didier Dinks""" amount = typing.cast(typing.Union[str, int], amount) diff --git a/didier/utils/discord/converters/numbers.py b/didier/utils/discord/converters/numbers.py index f184ca8..443f3e7 100644 --- a/didier/utils/discord/converters/numbers.py +++ b/didier/utils/discord/converters/numbers.py @@ -10,12 +10,12 @@ def abbreviated_number(argument: str) -> Union[str, int]: 515k 4m """ - if argument.lower() == "all": - return "all" - if not argument: raise ValueError + if argument.lower() == "all": + return "all" + if argument.isdecimal(): return int(argument) diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index e69de29..f996bf9 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -0,0 +1,81 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import currency as crud +from database.exceptions import currency as exceptions +from database.models import Bank + + +DEBUG_USER_ID = 1 + + +@pytest.fixture +async def bank(database_session: AsyncSession) -> Bank: + _bank = await crud.get_bank(database_session, DEBUG_USER_ID) + await database_session.refresh(_bank) + return _bank + + +async def test_add_dinks(database_session: AsyncSession, bank: Bank): + """Test adding dinks to an account""" + assert bank.dinks == 0 + await crud.add_dinks(database_session, bank.user_id, 10) + await database_session.refresh(bank) + assert bank.dinks == 10 + + +async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): + """Test claiming nightlies when it hasn't been done yet""" + await crud.claim_nightly(database_session, bank.user_id) + await database_session.refresh(bank) + assert bank.dinks == crud.NIGHTLY_AMOUNT + + +async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): + """Test claiming nightlies twice in a day""" + await crud.claim_nightly(database_session, bank.user_id) + + with pytest.raises(exceptions.DoubleNightly): + await crud.claim_nightly(database_session, bank.user_id) + + await database_session.refresh(bank) + assert bank.dinks == crud.NIGHTLY_AMOUNT + + +async def test_invest(database_session: AsyncSession, bank: Bank): + """Test investing some Dinks""" + bank.dinks = 100 + database_session.add(bank) + await database_session.commit() + + await crud.invest(database_session, bank.user_id, 20) + await database_session.refresh(bank) + + assert bank.dinks == 80 + assert bank.invested == 20 + + +async def test_invest_all(database_session: AsyncSession, bank: Bank): + """Test investing all dinks""" + bank.dinks = 100 + database_session.add(bank) + await database_session.commit() + + await crud.invest(database_session, bank.user_id, "all") + await database_session.refresh(bank) + + assert bank.dinks == 0 + assert bank.invested == 100 + + +async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank): + """Test investing more Dinks than you own""" + bank.dinks = 100 + database_session.add(bank) + await database_session.commit() + + await crud.invest(database_session, bank.user_id, 200) + await database_session.refresh(bank) + + assert bank.dinks == 0 + assert bank.invested == 100 diff --git a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py index 3efc9f4..ed88692 100644 --- a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py +++ b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py @@ -48,3 +48,18 @@ def test_abbreviated_float_floors(): 5.3k is 5300, but 5.3001k is 5300.1 """ assert numbers.abbreviated_number("5.3001k") == 5300 + + +def test_abbreviated_all(): + """Test abbreviated_number for the 'all' argument""" + assert numbers.abbreviated_number("all") == "all" + assert numbers.abbreviated_number("ALL") == "all" + + +def test_abbreviated_empty(): + """Test abbreviated_number for empty arguments""" + with pytest.raises(ValueError): + numbers.abbreviated_number("") + + with pytest.raises(ValueError): + numbers.abbreviated_number(None)