mirror of https://github.com/stijndcl/didier
Write some tests for currency crud
parent
8da0eb0b2a
commit
61128dda92
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.exceptions import currency as exceptions
|
from database.exceptions import currency as exceptions
|
||||||
from database.models import Bank
|
from database.models import Bank, NightlyData
|
||||||
from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price
|
from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,13 +18,20 @@ async def get_bank(session: AsyncSession, user_id: int) -> Bank:
|
||||||
return user.bank
|
return user.bank
|
||||||
|
|
||||||
|
|
||||||
|
async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData:
|
||||||
|
"""Get a user's nightly info"""
|
||||||
|
user = await users.get_or_add(session, user_id)
|
||||||
|
return user.nightly_data
|
||||||
|
|
||||||
|
|
||||||
async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int:
|
async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int:
|
||||||
"""Invest all your Dinks"""
|
"""Invest all your Dinks"""
|
||||||
bank = await get_bank(session, user_id)
|
bank = await get_bank(session, user_id)
|
||||||
if amount == "all":
|
if amount == "all":
|
||||||
amount = bank.dinks
|
amount = bank.dinks
|
||||||
|
|
||||||
amount = int(amount)
|
# Don't allow investing more dinks than you own
|
||||||
|
amount = min(bank.dinks, int(amount))
|
||||||
|
|
||||||
bank.dinks -= amount
|
bank.dinks -= amount
|
||||||
bank.invested += amount
|
bank.invested += amount
|
||||||
|
@ -45,15 +52,14 @@ async def add_dinks(session: AsyncSession, user_id: int, amount: int):
|
||||||
|
|
||||||
async def claim_nightly(session: AsyncSession, user_id: int):
|
async def claim_nightly(session: AsyncSession, user_id: int):
|
||||||
"""Claim daily Dinks"""
|
"""Claim daily Dinks"""
|
||||||
user = await users.get_or_add(session, user_id)
|
nightly_data = await get_nightly_data(session, user_id)
|
||||||
nightly_data = user.nightly_data
|
|
||||||
|
|
||||||
now = datetime.now()
|
now = datetime.now()
|
||||||
|
|
||||||
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
|
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
|
||||||
raise exceptions.DoubleNightly
|
raise exceptions.DoubleNightly
|
||||||
|
|
||||||
bank = user.bank
|
bank = await get_bank(session, user_id)
|
||||||
bank.dinks += NIGHTLY_AMOUNT
|
bank.dinks += NIGHTLY_AMOUNT
|
||||||
nightly_data.last_nightly = now
|
nightly_data.last_nightly = now
|
||||||
|
|
||||||
|
|
|
@ -5,10 +5,10 @@ from discord.ext import commands
|
||||||
|
|
||||||
from database.crud import currency as crud
|
from database.crud import currency as crud
|
||||||
from database.exceptions.currency import DoubleNightly, NotEnoughDinks
|
from database.exceptions.currency import DoubleNightly, NotEnoughDinks
|
||||||
|
from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.utils.discord.checks import is_owner
|
from didier.utils.discord.checks import is_owner
|
||||||
from didier.utils.discord.converters import abbreviated_number
|
from didier.utils.discord.converters import abbreviated_number
|
||||||
from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price
|
|
||||||
from didier.utils.types.string import pluralize
|
from didier.utils.types.string import pluralize
|
||||||
|
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ class Currency(commands.Cog):
|
||||||
plural = pluralize("Didier Dink", bank.dinks)
|
plural = pluralize("Didier Dink", bank.dinks)
|
||||||
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
|
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
|
||||||
|
|
||||||
@commands.command(name="Invest")
|
@commands.command(name="Invest", aliases=["Deposit", "Dep"])
|
||||||
async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore
|
async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore
|
||||||
"""Invest a given amount of Didier Dinks"""
|
"""Invest a given amount of Didier Dinks"""
|
||||||
amount = typing.cast(typing.Union[str, int], amount)
|
amount = typing.cast(typing.Union[str, int], amount)
|
||||||
|
|
|
@ -10,12 +10,12 @@ def abbreviated_number(argument: str) -> Union[str, int]:
|
||||||
515k
|
515k
|
||||||
4m
|
4m
|
||||||
"""
|
"""
|
||||||
if argument.lower() == "all":
|
|
||||||
return "all"
|
|
||||||
|
|
||||||
if not argument:
|
if not argument:
|
||||||
raise ValueError
|
raise ValueError
|
||||||
|
|
||||||
|
if argument.lower() == "all":
|
||||||
|
return "all"
|
||||||
|
|
||||||
if argument.isdecimal():
|
if argument.isdecimal():
|
||||||
return int(argument)
|
return int(argument)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import currency as crud
|
||||||
|
from database.exceptions import currency as exceptions
|
||||||
|
from database.models import Bank
|
||||||
|
|
||||||
|
|
||||||
|
DEBUG_USER_ID = 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def bank(database_session: AsyncSession) -> Bank:
|
||||||
|
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
|
||||||
|
await database_session.refresh(_bank)
|
||||||
|
return _bank
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
|
||||||
|
"""Test adding dinks to an account"""
|
||||||
|
assert bank.dinks == 0
|
||||||
|
await crud.add_dinks(database_session, bank.user_id, 10)
|
||||||
|
await database_session.refresh(bank)
|
||||||
|
assert bank.dinks == 10
|
||||||
|
|
||||||
|
|
||||||
|
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank):
|
||||||
|
"""Test claiming nightlies when it hasn't been done yet"""
|
||||||
|
await crud.claim_nightly(database_session, bank.user_id)
|
||||||
|
await database_session.refresh(bank)
|
||||||
|
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||||
|
|
||||||
|
|
||||||
|
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank):
|
||||||
|
"""Test claiming nightlies twice in a day"""
|
||||||
|
await crud.claim_nightly(database_session, bank.user_id)
|
||||||
|
|
||||||
|
with pytest.raises(exceptions.DoubleNightly):
|
||||||
|
await crud.claim_nightly(database_session, bank.user_id)
|
||||||
|
|
||||||
|
await database_session.refresh(bank)
|
||||||
|
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invest(database_session: AsyncSession, bank: Bank):
|
||||||
|
"""Test investing some Dinks"""
|
||||||
|
bank.dinks = 100
|
||||||
|
database_session.add(bank)
|
||||||
|
await database_session.commit()
|
||||||
|
|
||||||
|
await crud.invest(database_session, bank.user_id, 20)
|
||||||
|
await database_session.refresh(bank)
|
||||||
|
|
||||||
|
assert bank.dinks == 80
|
||||||
|
assert bank.invested == 20
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invest_all(database_session: AsyncSession, bank: Bank):
|
||||||
|
"""Test investing all dinks"""
|
||||||
|
bank.dinks = 100
|
||||||
|
database_session.add(bank)
|
||||||
|
await database_session.commit()
|
||||||
|
|
||||||
|
await crud.invest(database_session, bank.user_id, "all")
|
||||||
|
await database_session.refresh(bank)
|
||||||
|
|
||||||
|
assert bank.dinks == 0
|
||||||
|
assert bank.invested == 100
|
||||||
|
|
||||||
|
|
||||||
|
async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank):
|
||||||
|
"""Test investing more Dinks than you own"""
|
||||||
|
bank.dinks = 100
|
||||||
|
database_session.add(bank)
|
||||||
|
await database_session.commit()
|
||||||
|
|
||||||
|
await crud.invest(database_session, bank.user_id, 200)
|
||||||
|
await database_session.refresh(bank)
|
||||||
|
|
||||||
|
assert bank.dinks == 0
|
||||||
|
assert bank.invested == 100
|
|
@ -48,3 +48,18 @@ def test_abbreviated_float_floors():
|
||||||
5.3k is 5300, but 5.3001k is 5300.1
|
5.3k is 5300, but 5.3001k is 5300.1
|
||||||
"""
|
"""
|
||||||
assert numbers.abbreviated_number("5.3001k") == 5300
|
assert numbers.abbreviated_number("5.3001k") == 5300
|
||||||
|
|
||||||
|
|
||||||
|
def test_abbreviated_all():
|
||||||
|
"""Test abbreviated_number for the 'all' argument"""
|
||||||
|
assert numbers.abbreviated_number("all") == "all"
|
||||||
|
assert numbers.abbreviated_number("ALL") == "all"
|
||||||
|
|
||||||
|
|
||||||
|
def test_abbreviated_empty():
|
||||||
|
"""Test abbreviated_number for empty arguments"""
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
numbers.abbreviated_number("")
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
numbers.abbreviated_number(None)
|
||||||
|
|
Loading…
Reference in New Issue