Compare commits

...

2 Commits

Author SHA1 Message Date
stijndcl 61128dda92 Write some tests for currency crud 2022-07-03 19:26:30 +02:00
stijndcl 8da0eb0b2a Investing 2022-07-03 18:35:30 +02:00
5 changed files with 159 additions and 21 deletions

View File

@ -1,10 +1,11 @@
from datetime import datetime
from typing import Union
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.exceptions import currency as exceptions
from database.models import Bank
from database.models import Bank, NightlyData
from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price
@ -17,6 +18,30 @@ async def get_bank(session: AsyncSession, user_id: int) -> Bank:
return user.bank
async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData:
"""Get a user's nightly info"""
user = await users.get_or_add(session, user_id)
return user.nightly_data
async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int:
"""Invest all your Dinks"""
bank = await get_bank(session, user_id)
if amount == "all":
amount = bank.dinks
# Don't allow investing more dinks than you own
amount = min(bank.dinks, int(amount))
bank.dinks -= amount
bank.invested += amount
session.add(bank)
await session.commit()
return amount
async def add_dinks(session: AsyncSession, user_id: int, amount: int):
"""Increase the Dinks counter for a user"""
bank = await get_bank(session, user_id)
@ -27,15 +52,14 @@ async def add_dinks(session: AsyncSession, user_id: int, amount: int):
async def claim_nightly(session: AsyncSession, user_id: int):
"""Claim daily Dinks"""
user = await users.get_or_add(session, user_id)
nightly_data = user.nightly_data
nightly_data = await get_nightly_data(session, user_id)
now = datetime.now()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
raise exceptions.DoubleNightly
bank = user.bank
bank = await get_bank(session, user_id)
bank.dinks += NIGHTLY_AMOUNT
nightly_data.last_nightly = now
@ -44,9 +68,9 @@ async def claim_nightly(session: AsyncSession, user_id: int):
await session.commit()
async def upgrade_capacity(database_session: AsyncSession, user_id: int) -> int:
async def upgrade_capacity(session: AsyncSession, user_id: int) -> int:
"""Upgrade capacity level"""
bank = await get_bank(database_session, user_id)
bank = await get_bank(session, user_id)
upgrade_price = capacity_upgrade_price(bank.capacity_level)
# Can't afford this upgrade
@ -56,15 +80,15 @@ async def upgrade_capacity(database_session: AsyncSession, user_id: int) -> int:
bank.dinks -= upgrade_price
bank.capacity_level += 1
database_session.add(bank)
await database_session.commit()
session.add(bank)
await session.commit()
return bank.capacity_level
async def upgrade_interest(database_session: AsyncSession, user_id: int) -> int:
async def upgrade_interest(session: AsyncSession, user_id: int) -> int:
"""Upgrade interest level"""
bank = await get_bank(database_session, user_id)
bank = await get_bank(session, user_id)
upgrade_price = interest_upgrade_price(bank.interest_level)
# Can't afford this upgrade
@ -74,15 +98,15 @@ async def upgrade_interest(database_session: AsyncSession, user_id: int) -> int:
bank.dinks -= upgrade_price
bank.interest_level += 1
database_session.add(bank)
await database_session.commit()
session.add(bank)
await session.commit()
return bank.interest_level
async def upgrade_rob(database_session: AsyncSession, user_id: int) -> int:
async def upgrade_rob(session: AsyncSession, user_id: int) -> int:
"""Upgrade rob level"""
bank = await get_bank(database_session, user_id)
bank = await get_bank(session, user_id)
upgrade_price = rob_upgrade_price(bank.rob_level)
# Can't afford this upgrade
@ -92,7 +116,7 @@ async def upgrade_rob(database_session: AsyncSession, user_id: int) -> int:
bank.dinks -= upgrade_price
bank.rob_level += 1
database_session.add(bank)
await database_session.commit()
session.add(bank)
await session.commit()
return bank.rob_level

View File

@ -5,10 +5,10 @@ from discord.ext import commands
from database.crud import currency as crud
from database.exceptions.currency import DoubleNightly, NotEnoughDinks
from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price
from didier import Didier
from didier.utils.discord.checks import is_owner
from didier.utils.discord.converters import abbreviated_number
from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price
from didier.utils.types.string import pluralize
@ -113,9 +113,25 @@ class Currency(commands.Cog):
plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
@commands.command(name="Invest", aliases=["Deposit", "Dep"])
async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore
"""Invest a given amount of Didier Dinks"""
amount = typing.cast(typing.Union[str, int], amount)
async with self.client.db_session as session:
invested = await crud.invest(session, ctx.author.id, amount)
plural = pluralize("Didier Dink", invested)
if invested == 0:
await ctx.reply("Je hebt geen Didier Dinks om te investeren.", mention_author=False)
else:
await ctx.reply(
f"**{ctx.author.display_name}** heeft **{invested}** {plural} geïnvesteerd.", mention_author=False
)
@commands.hybrid_command(name="nightly")
async def nightly(self, ctx: commands.Context):
"""Claim nightly Dinks"""
"""Claim nightly Didier Dinks"""
async with self.client.db_session as session:
try:
await crud.claim_nightly(session, ctx.author.id)

View File

@ -1,11 +1,10 @@
import math
from typing import Optional
from typing import Optional, Union
__all__ = ["abbreviated_number"]
def abbreviated_number(argument: str) -> int:
def abbreviated_number(argument: str) -> Union[str, int]:
"""Custom converter to allow numbers to be abbreviated
Examples:
515k
@ -14,6 +13,9 @@ def abbreviated_number(argument: str) -> int:
if not argument:
raise ValueError
if argument.lower() == "all":
return "all"
if argument.isdecimal():
return int(argument)

View File

@ -0,0 +1,81 @@
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import currency as crud
from database.exceptions import currency as exceptions
from database.models import Bank
DEBUG_USER_ID = 1
@pytest.fixture
async def bank(database_session: AsyncSession) -> Bank:
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
await database_session.refresh(_bank)
return _bank
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
"""Test adding dinks to an account"""
assert bank.dinks == 0
await crud.add_dinks(database_session, bank.user_id, 10)
await database_session.refresh(bank)
assert bank.dinks == 10
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank):
"""Test claiming nightlies when it hasn't been done yet"""
await crud.claim_nightly(database_session, bank.user_id)
await database_session.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank):
"""Test claiming nightlies twice in a day"""
await crud.claim_nightly(database_session, bank.user_id)
with pytest.raises(exceptions.DoubleNightly):
await crud.claim_nightly(database_session, bank.user_id)
await database_session.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT
async def test_invest(database_session: AsyncSession, bank: Bank):
"""Test investing some Dinks"""
bank.dinks = 100
database_session.add(bank)
await database_session.commit()
await crud.invest(database_session, bank.user_id, 20)
await database_session.refresh(bank)
assert bank.dinks == 80
assert bank.invested == 20
async def test_invest_all(database_session: AsyncSession, bank: Bank):
"""Test investing all dinks"""
bank.dinks = 100
database_session.add(bank)
await database_session.commit()
await crud.invest(database_session, bank.user_id, "all")
await database_session.refresh(bank)
assert bank.dinks == 0
assert bank.invested == 100
async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank):
"""Test investing more Dinks than you own"""
bank.dinks = 100
database_session.add(bank)
await database_session.commit()
await crud.invest(database_session, bank.user_id, 200)
await database_session.refresh(bank)
assert bank.dinks == 0
assert bank.invested == 100

View File

@ -48,3 +48,18 @@ def test_abbreviated_float_floors():
5.3k is 5300, but 5.3001k is 5300.1
"""
assert numbers.abbreviated_number("5.3001k") == 5300
def test_abbreviated_all():
"""Test abbreviated_number for the 'all' argument"""
assert numbers.abbreviated_number("all") == "all"
assert numbers.abbreviated_number("ALL") == "all"
def test_abbreviated_empty():
"""Test abbreviated_number for empty arguments"""
with pytest.raises(ValueError):
numbers.abbreviated_number("")
with pytest.raises(ValueError):
numbers.abbreviated_number(None)