diff --git a/database/crud/currency.py b/database/crud/currency.py index 3ef4de0..612b969 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Union from sqlalchemy.ext.asyncio import AsyncSession @@ -17,6 +18,23 @@ async def get_bank(session: AsyncSession, user_id: int) -> Bank: return user.bank +async def invest(session: AsyncSession, user_id: int, amount: Union[str, int]) -> int: + """Invest all your Dinks""" + bank = await get_bank(session, user_id) + if amount == "all": + amount = bank.dinks + + amount = int(amount) + + bank.dinks -= amount + bank.invested += amount + + session.add(bank) + await session.commit() + + return amount + + async def add_dinks(session: AsyncSession, user_id: int, amount: int): """Increase the Dinks counter for a user""" bank = await get_bank(session, user_id) @@ -44,9 +62,9 @@ async def claim_nightly(session: AsyncSession, user_id: int): await session.commit() -async def upgrade_capacity(database_session: AsyncSession, user_id: int) -> int: +async def upgrade_capacity(session: AsyncSession, user_id: int) -> int: """Upgrade capacity level""" - bank = await get_bank(database_session, user_id) + bank = await get_bank(session, user_id) upgrade_price = capacity_upgrade_price(bank.capacity_level) # Can't afford this upgrade @@ -56,15 +74,15 @@ async def upgrade_capacity(database_session: AsyncSession, user_id: int) -> int: bank.dinks -= upgrade_price bank.capacity_level += 1 - database_session.add(bank) - await database_session.commit() + session.add(bank) + await session.commit() return bank.capacity_level -async def upgrade_interest(database_session: AsyncSession, user_id: int) -> int: +async def upgrade_interest(session: AsyncSession, user_id: int) -> int: """Upgrade interest level""" - bank = await get_bank(database_session, user_id) + bank = await get_bank(session, user_id) upgrade_price = interest_upgrade_price(bank.interest_level) # Can't afford this upgrade @@ -74,15 +92,15 @@ async def upgrade_interest(database_session: AsyncSession, user_id: int) -> int: bank.dinks -= upgrade_price bank.interest_level += 1 - database_session.add(bank) - await database_session.commit() + session.add(bank) + await session.commit() return bank.interest_level -async def upgrade_rob(database_session: AsyncSession, user_id: int) -> int: +async def upgrade_rob(session: AsyncSession, user_id: int) -> int: """Upgrade rob level""" - bank = await get_bank(database_session, user_id) + bank = await get_bank(session, user_id) upgrade_price = rob_upgrade_price(bank.rob_level) # Can't afford this upgrade @@ -92,7 +110,7 @@ async def upgrade_rob(database_session: AsyncSession, user_id: int) -> int: bank.dinks -= upgrade_price bank.rob_level += 1 - database_session.add(bank) - await database_session.commit() + session.add(bank) + await session.commit() return bank.rob_level diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index d4910ee..bb46bb0 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -113,9 +113,25 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) + @commands.command(name="Invest") + async def invest(self, ctx: commands.Context, amount: abbreviated_number): # type: ignore + """Invest a given amount of Didier Dinks""" + amount = typing.cast(typing.Union[str, int], amount) + + async with self.client.db_session as session: + invested = await crud.invest(session, ctx.author.id, amount) + plural = pluralize("Didier Dink", invested) + + if invested == 0: + await ctx.reply("Je hebt geen Didier Dinks om te investeren.", mention_author=False) + else: + await ctx.reply( + f"**{ctx.author.display_name}** heeft **{invested}** {plural} geïnvesteerd.", mention_author=False + ) + @commands.hybrid_command(name="nightly") async def nightly(self, ctx: commands.Context): - """Claim nightly Dinks""" + """Claim nightly Didier Dinks""" async with self.client.db_session as session: try: await crud.claim_nightly(session, ctx.author.id) diff --git a/didier/utils/discord/converters/numbers.py b/didier/utils/discord/converters/numbers.py index 8019fa5..f184ca8 100644 --- a/didier/utils/discord/converters/numbers.py +++ b/didier/utils/discord/converters/numbers.py @@ -1,16 +1,18 @@ import math -from typing import Optional - +from typing import Optional, Union __all__ = ["abbreviated_number"] -def abbreviated_number(argument: str) -> int: +def abbreviated_number(argument: str) -> Union[str, int]: """Custom converter to allow numbers to be abbreviated Examples: 515k 4m """ + if argument.lower() == "all": + return "all" + if not argument: raise ValueError