Write some tests for currency crud

pull/119/head
stijndcl 2022-07-03 19:26:30 +02:00
parent 8da0eb0b2a
commit 61128dda92
5 changed files with 112 additions and 10 deletions

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users from database.crud import users
from database.exceptions import currency as exceptions from database.exceptions import currency as exceptions
from database.models import Bank from database.models import Bank, NightlyData
from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price
@ -18,13 +18,20 @@ async def get_bank(session: AsyncSession, user_id: int) -> Bank:
return user.bank return user.bank
async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData:
"""Get a user's nightly info"""
user = await users.get_or_add(session, user_id)
return user.nightly_data
async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int: async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int:
"""Invest all your Dinks""" """Invest all your Dinks"""
bank = await get_bank(session, user_id) bank = await get_bank(session, user_id)
if amount == "all": if amount == "all":
amount = bank.dinks amount = bank.dinks
amount = int(amount) # Don't allow investing more dinks than you own
amount = min(bank.dinks, int(amount))
bank.dinks -= amount bank.dinks -= amount
bank.invested += amount bank.invested += amount
@ -45,15 +52,14 @@ async def add_dinks(session: AsyncSession, user_id: int, amount: int):
async def claim_nightly(session: AsyncSession, user_id: int): async def claim_nightly(session: AsyncSession, user_id: int):
"""Claim daily Dinks""" """Claim daily Dinks"""
user = await users.get_or_add(session, user_id) nightly_data = await get_nightly_data(session, user_id)
nightly_data = user.nightly_data
now = datetime.now() now = datetime.now()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
raise exceptions.DoubleNightly raise exceptions.DoubleNightly
bank = user.bank bank = await get_bank(session, user_id)
bank.dinks += NIGHTLY_AMOUNT bank.dinks += NIGHTLY_AMOUNT
nightly_data.last_nightly = now nightly_data.last_nightly = now

View File

@ -5,10 +5,10 @@ from discord.ext import commands
from database.crud import currency as crud from database.crud import currency as crud
from database.exceptions.currency import DoubleNightly, NotEnoughDinks from database.exceptions.currency import DoubleNightly, NotEnoughDinks
from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price
from didier import Didier from didier import Didier
from didier.utils.discord.checks import is_owner from didier.utils.discord.checks import is_owner
from didier.utils.discord.converters import abbreviated_number from didier.utils.discord.converters import abbreviated_number
from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price
from didier.utils.types.string import pluralize from didier.utils.types.string import pluralize
@ -113,7 +113,7 @@ class Currency(commands.Cog):
plural = pluralize("Didier Dink", bank.dinks) plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
@commands.command(name="Invest") @commands.command(name="Invest", aliases=["Deposit", "Dep"])
async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore
"""Invest a given amount of Didier Dinks""" """Invest a given amount of Didier Dinks"""
amount = typing.cast(typing.Union[str, int], amount) amount = typing.cast(typing.Union[str, int], amount)

View File

@ -10,12 +10,12 @@ def abbreviated_number(argument: str) -> Union[str, int]:
515k 515k
4m 4m
""" """
if argument.lower() == "all":
return "all"
if not argument: if not argument:
raise ValueError raise ValueError
if argument.lower() == "all":
return "all"
if argument.isdecimal(): if argument.isdecimal():
return int(argument) return int(argument)

View File

@ -0,0 +1,81 @@
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import currency as crud
from database.exceptions import currency as exceptions
from database.models import Bank
DEBUG_USER_ID = 1
@pytest.fixture
async def bank(database_session: AsyncSession) -> Bank:
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
await database_session.refresh(_bank)
return _bank
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
"""Test adding dinks to an account"""
assert bank.dinks == 0
await crud.add_dinks(database_session, bank.user_id, 10)
await database_session.refresh(bank)
assert bank.dinks == 10
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank):
"""Test claiming nightlies when it hasn't been done yet"""
await crud.claim_nightly(database_session, bank.user_id)
await database_session.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank):
"""Test claiming nightlies twice in a day"""
await crud.claim_nightly(database_session, bank.user_id)
with pytest.raises(exceptions.DoubleNightly):
await crud.claim_nightly(database_session, bank.user_id)
await database_session.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT
async def test_invest(database_session: AsyncSession, bank: Bank):
"""Test investing some Dinks"""
bank.dinks = 100
database_session.add(bank)
await database_session.commit()
await crud.invest(database_session, bank.user_id, 20)
await database_session.refresh(bank)
assert bank.dinks == 80
assert bank.invested == 20
async def test_invest_all(database_session: AsyncSession, bank: Bank):
"""Test investing all dinks"""
bank.dinks = 100
database_session.add(bank)
await database_session.commit()
await crud.invest(database_session, bank.user_id, "all")
await database_session.refresh(bank)
assert bank.dinks == 0
assert bank.invested == 100
async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank):
"""Test investing more Dinks than you own"""
bank.dinks = 100
database_session.add(bank)
await database_session.commit()
await crud.invest(database_session, bank.user_id, 200)
await database_session.refresh(bank)
assert bank.dinks == 0
assert bank.invested == 100

View File

@ -48,3 +48,18 @@ def test_abbreviated_float_floors():
5.3k is 5300, but 5.3001k is 5300.1 5.3k is 5300, but 5.3001k is 5300.1
""" """
assert numbers.abbreviated_number("5.3001k") == 5300 assert numbers.abbreviated_number("5.3001k") == 5300
def test_abbreviated_all():
"""Test abbreviated_number for the 'all' argument"""
assert numbers.abbreviated_number("all") == "all"
assert numbers.abbreviated_number("ALL") == "all"
def test_abbreviated_empty():
"""Test abbreviated_number for empty arguments"""
with pytest.raises(ValueError):
numbers.abbreviated_number("")
with pytest.raises(ValueError):
numbers.abbreviated_number(None)