diff --git a/database/crud/currency.py b/database/crud/currency.py index 8e7f2e6..3ef4de0 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -1,11 +1,10 @@ from datetime import datetime -from typing import Union from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.exceptions import currency as exceptions -from database.models import Bank, NightlyData +from database.models import Bank from database.utils.math.currency import rob_upgrade_price, interest_upgrade_price, capacity_upgrade_price @@ -18,30 +17,6 @@ async def get_bank(session: AsyncSession, user_id: int) -> Bank: return user.bank -async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData: - """Get a user's nightly info""" - user = await users.get_or_add(session, user_id) - return user.nightly_data - - -async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int: - """Invest all your Dinks""" - bank = await get_bank(session, user_id) - if amount == "all": - amount = bank.dinks - - # Don't allow investing more dinks than you own - amount = min(bank.dinks, int(amount)) - - bank.dinks -= amount - bank.invested += amount - - session.add(bank) - await session.commit() - - return amount - - async def add_dinks(session: AsyncSession, user_id: int, amount: int): """Increase the Dinks counter for a user""" bank = await get_bank(session, user_id) @@ -52,14 +27,15 @@ async def add_dinks(session: AsyncSession, user_id: int, amount: int): async def claim_nightly(session: AsyncSession, user_id: int): """Claim daily Dinks""" - nightly_data = await get_nightly_data(session, user_id) + user = await users.get_or_add(session, user_id) + nightly_data = user.nightly_data now = datetime.now() if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): raise exceptions.DoubleNightly - bank = await get_bank(session, user_id) + bank = user.bank bank.dinks += NIGHTLY_AMOUNT nightly_data.last_nightly = now @@ -68,9 +44,9 @@ async def claim_nightly(session: AsyncSession, user_id: int): await session.commit() -async def upgrade_capacity(session: AsyncSession, user_id: int) -> int: +async def upgrade_capacity(database_session: AsyncSession, user_id: int) -> int: """Upgrade capacity level""" - bank = await get_bank(session, user_id) + bank = await get_bank(database_session, user_id) upgrade_price = capacity_upgrade_price(bank.capacity_level) # Can't afford this upgrade @@ -80,15 +56,15 @@ async def upgrade_capacity(session: AsyncSession, user_id: int) -> int: bank.dinks -= upgrade_price bank.capacity_level += 1 - session.add(bank) - await session.commit() + database_session.add(bank) + await database_session.commit() return bank.capacity_level -async def upgrade_interest(session: AsyncSession, user_id: int) -> int: +async def upgrade_interest(database_session: AsyncSession, user_id: int) -> int: """Upgrade interest level""" - bank = await get_bank(session, user_id) + bank = await get_bank(database_session, user_id) upgrade_price = interest_upgrade_price(bank.interest_level) # Can't afford this upgrade @@ -98,15 +74,15 @@ async def upgrade_interest(session: AsyncSession, user_id: int) -> int: bank.dinks -= upgrade_price bank.interest_level += 1 - session.add(bank) - await session.commit() + database_session.add(bank) + await database_session.commit() return bank.interest_level -async def upgrade_rob(session: AsyncSession, user_id: int) -> int: +async def upgrade_rob(database_session: AsyncSession, user_id: int) -> int: """Upgrade rob level""" - bank = await get_bank(session, user_id) + bank = await get_bank(database_session, user_id) upgrade_price = rob_upgrade_price(bank.rob_level) # Can't afford this upgrade @@ -116,7 +92,7 @@ async def upgrade_rob(session: AsyncSession, user_id: int) -> int: bank.dinks -= upgrade_price bank.rob_level += 1 - session.add(bank) - await session.commit() + database_session.add(bank) + await database_session.commit() return bank.rob_level diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index b30c283..d4910ee 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -5,10 +5,10 @@ from discord.ext import commands from database.crud import currency as crud from database.exceptions.currency import DoubleNightly, NotEnoughDinks -from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price from didier import Didier from didier.utils.discord.checks import is_owner from didier.utils.discord.converters import abbreviated_number +from database.utils.math.currency import capacity_upgrade_price, interest_upgrade_price, rob_upgrade_price from didier.utils.types.string import pluralize @@ -113,25 +113,9 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) - @commands.command(name="Invest", aliases=["Deposit", "Dep"]) - async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore - """Invest a given amount of Didier Dinks""" - amount = typing.cast(typing.Union[str, int], amount) - - async with self.client.db_session as session: - invested = await crud.invest(session, ctx.author.id, amount) - plural = pluralize("Didier Dink", invested) - - if invested == 0: - await ctx.reply("Je hebt geen Didier Dinks om te investeren.", mention_author=False) - else: - await ctx.reply( - f"**{ctx.author.display_name}** heeft **{invested}** {plural} geïnvesteerd.", mention_author=False - ) - @commands.hybrid_command(name="nightly") async def nightly(self, ctx: commands.Context): - """Claim nightly Didier Dinks""" + """Claim nightly Dinks""" async with self.client.db_session as session: try: await crud.claim_nightly(session, ctx.author.id) diff --git a/didier/utils/discord/converters/numbers.py b/didier/utils/discord/converters/numbers.py index 443f3e7..8019fa5 100644 --- a/didier/utils/discord/converters/numbers.py +++ b/didier/utils/discord/converters/numbers.py @@ -1,10 +1,11 @@ import math -from typing import Optional, Union +from typing import Optional + __all__ = ["abbreviated_number"] -def abbreviated_number(argument: str) -> Union[str, int]: +def abbreviated_number(argument: str) -> int: """Custom converter to allow numbers to be abbreviated Examples: 515k @@ -13,9 +14,6 @@ def abbreviated_number(argument: str) -> Union[str, int]: if not argument: raise ValueError - if argument.lower() == "all": - return "all" - if argument.isdecimal(): return int(argument) diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index f996bf9..e69de29 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -1,81 +0,0 @@ -import pytest -from sqlalchemy.ext.asyncio import AsyncSession - -from database.crud import currency as crud -from database.exceptions import currency as exceptions -from database.models import Bank - - -DEBUG_USER_ID = 1 - - -@pytest.fixture -async def bank(database_session: AsyncSession) -> Bank: - _bank = await crud.get_bank(database_session, DEBUG_USER_ID) - await database_session.refresh(_bank) - return _bank - - -async def test_add_dinks(database_session: AsyncSession, bank: Bank): - """Test adding dinks to an account""" - assert bank.dinks == 0 - await crud.add_dinks(database_session, bank.user_id, 10) - await database_session.refresh(bank) - assert bank.dinks == 10 - - -async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): - """Test claiming nightlies when it hasn't been done yet""" - await crud.claim_nightly(database_session, bank.user_id) - await database_session.refresh(bank) - assert bank.dinks == crud.NIGHTLY_AMOUNT - - -async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): - """Test claiming nightlies twice in a day""" - await crud.claim_nightly(database_session, bank.user_id) - - with pytest.raises(exceptions.DoubleNightly): - await crud.claim_nightly(database_session, bank.user_id) - - await database_session.refresh(bank) - assert bank.dinks == crud.NIGHTLY_AMOUNT - - -async def test_invest(database_session: AsyncSession, bank: Bank): - """Test investing some Dinks""" - bank.dinks = 100 - database_session.add(bank) - await database_session.commit() - - await crud.invest(database_session, bank.user_id, 20) - await database_session.refresh(bank) - - assert bank.dinks == 80 - assert bank.invested == 20 - - -async def test_invest_all(database_session: AsyncSession, bank: Bank): - """Test investing all dinks""" - bank.dinks = 100 - database_session.add(bank) - await database_session.commit() - - await crud.invest(database_session, bank.user_id, "all") - await database_session.refresh(bank) - - assert bank.dinks == 0 - assert bank.invested == 100 - - -async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank): - """Test investing more Dinks than you own""" - bank.dinks = 100 - database_session.add(bank) - await database_session.commit() - - await crud.invest(database_session, bank.user_id, 200) - await database_session.refresh(bank) - - assert bank.dinks == 0 - assert bank.invested == 100 diff --git a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py index ed88692..3efc9f4 100644 --- a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py +++ b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py @@ -48,18 +48,3 @@ def test_abbreviated_float_floors(): 5.3k is 5300, but 5.3001k is 5300.1 """ assert numbers.abbreviated_number("5.3001k") == 5300 - - -def test_abbreviated_all(): - """Test abbreviated_number for the 'all' argument""" - assert numbers.abbreviated_number("all") == "all" - assert numbers.abbreviated_number("ALL") == "all" - - -def test_abbreviated_empty(): - """Test abbreviated_number for empty arguments""" - with pytest.raises(ValueError): - numbers.abbreviated_number("") - - with pytest.raises(ValueError): - numbers.abbreviated_number(None)