From 4587a49311e099702ebf43170e1b9f8b8bb00111 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 30 Jun 2022 19:42:48 +0200 Subject: [PATCH 1/7] Create database models --- .../5f3a11a80e69_initial_currency_models.py | 51 +++++++++++++++++ database/models.py | 56 ++++++++++++++++++- 2 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 alembic/versions/5f3a11a80e69_initial_currency_models.py diff --git a/alembic/versions/5f3a11a80e69_initial_currency_models.py b/alembic/versions/5f3a11a80e69_initial_currency_models.py new file mode 100644 index 0000000..96981b9 --- /dev/null +++ b/alembic/versions/5f3a11a80e69_initial_currency_models.py @@ -0,0 +1,51 @@ +"""Initial currency models + +Revision ID: 5f3a11a80e69 +Revises: b2d511552a1f +Create Date: 2022-06-30 19:40:02.701336 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5f3a11a80e69' +down_revision = 'b2d511552a1f' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('users', + sa.Column('user_id', sa.BigInteger(), nullable=False), + sa.Column('dinks', sa.BigInteger(), nullable=False), + sa.PrimaryKeyConstraint('user_id') + ) + op.create_table('bank', + sa.Column('bank_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.BigInteger(), nullable=True), + sa.Column('interest_level', sa.Integer(), nullable=False), + sa.Column('capacity_level', sa.Integer(), nullable=False), + sa.Column('rob_level', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), + sa.PrimaryKeyConstraint('bank_id') + ) + op.create_table('nightly_data', + sa.Column('nightly_id', sa.Integer(), nullable=False), + sa.Column('user_id', sa.BigInteger(), nullable=True), + sa.Column('last_nightly', sa.DateTime(timezone=True), nullable=True), + sa.Column('count', sa.Integer(), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), + sa.PrimaryKeyConstraint('nightly_id') + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table('nightly_data') + op.drop_table('bank') + op.drop_table('users') + # ### end Alembic commands ### diff --git a/database/models.py b/database/models.py index 1c28d37..02b16a5 100644 --- a/database/models.py +++ b/database/models.py @@ -1,13 +1,34 @@ from __future__ import annotations from datetime import datetime +from typing import Optional -from sqlalchemy import Column, Integer, Text, ForeignKey, Boolean, DateTime +from sqlalchemy import BigInteger, Column, Integer, Text, ForeignKey, Boolean, DateTime from sqlalchemy.orm import declarative_base, relationship Base = declarative_base() +class Bank(Base): + """A user's currency information""" + + __tablename__ = "bank" + + bank_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + + # Interest rate + interest_level: int = Column(Integer, default=1, nullable=False) + + # Maximum amount that can be stored in the bank + capacity_level: int = Column(Integer, default=1, nullable=False) + + # Maximum amount that can be robbed + rob_level: int = Column(Integer, default=1, nullable=False) + + user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") + + class CustomCommand(Base): """Custom commands to fill the hole Dyno couldn't""" @@ -36,6 +57,19 @@ class CustomCommandAlias(Base): command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") +class NightlyData(Base): + """Data for a user's Nightly stats""" + + __tablename__ = "nightly_data" + + nightly_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True) + count: int = Column(Integer, default=0, nullable=False) + + user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") + + class UforaCourse(Base): """A course on Ufora""" @@ -72,8 +106,24 @@ class UforaAnnouncement(Base): __tablename__ = "ufora_announcements" - announcement_id = Column(Integer, primary_key=True) - course_id = Column(Integer, ForeignKey("ufora_courses.course_id")) + announcement_id: int = Column(Integer, primary_key=True) + course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) publication_date: datetime = Column(DateTime(timezone=True)) course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") + + +class User(Base): + """A Didier user""" + + __tablename__ = "users" + + user_id: int = Column(BigInteger, primary_key=True) + dinks: int = Column(BigInteger, default=0, nullable=False) + + bank: Bank = relationship( + "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + ) + nightly_data: NightlyData = relationship( + "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + ) From 032b636b025b56f1510b337cfd186b799962ac25 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 30 Jun 2022 21:17:48 +0200 Subject: [PATCH 2/7] Nightly, bank, award & dinks --- ...> 0d03c226d881_initial_currency_models.py} | 8 +-- database/crud/currency.py | 43 +++++++++++++++ database/crud/users.py | 37 +++++++++++++ database/exceptions/currency.py | 2 + database/models.py | 3 +- didier/cogs/currency.py | 54 +++++++++++++++++++ didier/utils/discord/checks/__init__.py | 1 + .../utils/discord/checks/message_commands.py | 6 +++ didier/utils/types/string.py | 8 +++ 9 files changed, 157 insertions(+), 5 deletions(-) rename alembic/versions/{5f3a11a80e69_initial_currency_models.py => 0d03c226d881_initial_currency_models.py} (94%) create mode 100644 database/crud/currency.py create mode 100644 database/crud/users.py create mode 100644 database/exceptions/currency.py create mode 100644 didier/cogs/currency.py diff --git a/alembic/versions/5f3a11a80e69_initial_currency_models.py b/alembic/versions/0d03c226d881_initial_currency_models.py similarity index 94% rename from alembic/versions/5f3a11a80e69_initial_currency_models.py rename to alembic/versions/0d03c226d881_initial_currency_models.py index 96981b9..45a5e26 100644 --- a/alembic/versions/5f3a11a80e69_initial_currency_models.py +++ b/alembic/versions/0d03c226d881_initial_currency_models.py @@ -1,8 +1,8 @@ """Initial currency models -Revision ID: 5f3a11a80e69 +Revision ID: 0d03c226d881 Revises: b2d511552a1f -Create Date: 2022-06-30 19:40:02.701336 +Create Date: 2022-06-30 20:02:27.284759 """ from alembic import op @@ -10,7 +10,7 @@ import sqlalchemy as sa # revision identifiers, used by Alembic. -revision = '5f3a11a80e69' +revision = '0d03c226d881' down_revision = 'b2d511552a1f' branch_labels = None depends_on = None @@ -20,12 +20,12 @@ def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.create_table('users', sa.Column('user_id', sa.BigInteger(), nullable=False), - sa.Column('dinks', sa.BigInteger(), nullable=False), sa.PrimaryKeyConstraint('user_id') ) op.create_table('bank', sa.Column('bank_id', sa.Integer(), nullable=False), sa.Column('user_id', sa.BigInteger(), nullable=True), + sa.Column('dinks', sa.BigInteger(), nullable=False), sa.Column('interest_level', sa.Integer(), nullable=False), sa.Column('capacity_level', sa.Integer(), nullable=False), sa.Column('rob_level', sa.Integer(), nullable=False), diff --git a/database/crud/currency.py b/database/crud/currency.py new file mode 100644 index 0000000..9fe21e0 --- /dev/null +++ b/database/crud/currency.py @@ -0,0 +1,43 @@ +from datetime import datetime + +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import users +from database.exceptions import currency as exceptions +from database.models import Bank + + +NIGHTLY_AMOUNT = 420 + + +async def get_bank(session: AsyncSession, user_id: int) -> Bank: + """Get a user's bank info""" + user = await users.get_or_add(session, user_id) + return user.bank + + +async def add_dinks(session: AsyncSession, user_id: int, amount: int): + """Increase the Dinks counter for a user""" + bank = await get_bank(session, user_id) + bank.dinks += amount + session.add(bank) + await session.commit() + + +async def claim_nightly(session: AsyncSession, user_id: int): + """Claim daily Dinks""" + user = await users.get_or_add(session, user_id) + nightly_data = user.nightly_data + + now = datetime.now() + + if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): + raise exceptions.DoubleNightly + + bank = user.bank + bank.dinks += NIGHTLY_AMOUNT + nightly_data.last_nightly = now + + session.add(bank) + session.add(nightly_data) + await session.commit() diff --git a/database/crud/users.py b/database/crud/users.py new file mode 100644 index 0000000..794d951 --- /dev/null +++ b/database/crud/users.py @@ -0,0 +1,37 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import User, Bank, NightlyData + + +async def get_or_add(session: AsyncSession, user_id: int) -> User: + """Get a user's profile + If it doesn't exist yet, create it (along with all linked datastructures). + """ + statement = select(User).where(User.user_id == user_id) + user: Optional[User] = (await session.execute(statement)).scalar_one_or_none() + + # User exists + if user is not None: + return user + + # Create new user + user = User(user_id=user_id) + session.add(user) + await session.commit() + + # Add bank & nightly info + bank = Bank(user_id=user_id) + nightly_data = NightlyData(user_id=user_id) + user.bank = bank + user.nightly_data = nightly_data + + session.add(bank) + session.add(nightly_data) + session.add(user) + + await session.commit() + + return user diff --git a/database/exceptions/currency.py b/database/exceptions/currency.py new file mode 100644 index 0000000..6eab7a1 --- /dev/null +++ b/database/exceptions/currency.py @@ -0,0 +1,2 @@ +class DoubleNightly(Exception): + """Exception raised when claiming nightlies multiple times per day""" diff --git a/database/models.py b/database/models.py index 02b16a5..0bc6370 100644 --- a/database/models.py +++ b/database/models.py @@ -17,6 +17,8 @@ class Bank(Base): bank_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + dinks: int = Column(BigInteger, default=0, nullable=False) + # Interest rate interest_level: int = Column(Integer, default=1, nullable=False) @@ -119,7 +121,6 @@ class User(Base): __tablename__ = "users" user_id: int = Column(BigInteger, primary_key=True) - dinks: int = Column(BigInteger, default=0, nullable=False) bank: Bank = relationship( "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py new file mode 100644 index 0000000..193a9ef --- /dev/null +++ b/didier/cogs/currency.py @@ -0,0 +1,54 @@ +import discord +from discord.ext import commands + +from didier import Didier + +from database.crud import currency as crud +from database.exceptions.currency import DoubleNightly +from didier.utils.discord.checks import is_owner +from didier.utils.types.string import pluralize + + +class Currency(commands.Cog): + """Everything Dinks-related""" + + client: Didier + + def __init__(self, client: Didier): + super().__init__() + self.client = client + + @commands.command(name="Award") + @commands.check(is_owner) + async def award(self, ctx: commands.Context, user: discord.User, amount: int): + async with self.client.db_session as session: + await crud.add_dinks(session, user.id, amount) + await self.client.confirm_message(ctx.message) + + @commands.hybrid_command(name="bank") + async def bank(self, ctx: commands.Context): + """Show your Didier Bank information""" + async with self.client.db_session as session: + await crud.get_bank(session, ctx.author.id) + + @commands.hybrid_command(name="dinks") + async def dinks(self, ctx: commands.Context): + """Check your Didier Dinks""" + async with self.client.db_session as session: + bank = await crud.get_bank(session, ctx.author.id) + plural = pluralize("Didier Dink", bank.dinks) + await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) + + @commands.hybrid_command(name="nightly") + async def nightly(self, ctx: commands.Context): + """Claim nightly Dinks""" + async with self.client.db_session as session: + try: + await crud.claim_nightly(session, ctx.author.id) + await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") + except DoubleNightly: + await ctx.reply("Je hebt je nightly al geclaimd vandaag.", mention_author=False, ephemeral=True) + + +async def setup(client: Didier): + await client.add_cog(Currency(client)) diff --git a/didier/utils/discord/checks/__init__.py b/didier/utils/discord/checks/__init__.py index e69de29..7332b42 100644 --- a/didier/utils/discord/checks/__init__.py +++ b/didier/utils/discord/checks/__init__.py @@ -0,0 +1 @@ +from .message_commands import is_owner diff --git a/didier/utils/discord/checks/message_commands.py b/didier/utils/discord/checks/message_commands.py index e69de29..a57b315 100644 --- a/didier/utils/discord/checks/message_commands.py +++ b/didier/utils/discord/checks/message_commands.py @@ -0,0 +1,6 @@ +from discord.ext import commands + + +async def is_owner(ctx: commands.Context) -> bool: + """Check that a command is being invoked by the owner of the bot""" + return await ctx.bot.is_owner(ctx.author) diff --git a/didier/utils/types/string.py b/didier/utils/types/string.py index eddfff8..0255877 100644 --- a/didier/utils/types/string.py +++ b/didier/utils/types/string.py @@ -20,3 +20,11 @@ def leading(character: str, string: str, target_length: Optional[int] = 2) -> st frequency = math.ceil((target_length - len(string)) / len(character)) return (frequency * character) + string + + +def pluralize(word: str, amount: int, plural_form: Optional[str] = None) -> str: + """Turn a word into plural""" + if amount == 1: + return word + + return plural_form or (word + "s") From bec893bd2064c72fe005deb1ec60139bf9a236e8 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 30 Jun 2022 21:33:37 +0200 Subject: [PATCH 3/7] Add tests for users crud --- didier/cogs/currency.py | 9 +++++-- .../test_database/test_crud/test_currency.py | 0 tests/test_database/test_crud/test_users.py | 25 +++++++++++++++++++ 3 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 tests/test_database/test_crud/test_currency.py create mode 100644 tests/test_database/test_crud/test_users.py diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 193a9ef..3283ed4 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -1,10 +1,9 @@ import discord from discord.ext import commands -from didier import Didier - from database.crud import currency as crud from database.exceptions.currency import DoubleNightly +from didier import Didier from didier.utils.discord.checks import is_owner from didier.utils.types.string import pluralize @@ -21,8 +20,13 @@ class Currency(commands.Cog): @commands.command(name="Award") @commands.check(is_owner) async def award(self, ctx: commands.Context, user: discord.User, amount: int): + """Award a user a given amount of Didier Dinks""" async with self.client.db_session as session: await crud.add_dinks(session, user.id, amount) + plural = pluralize("Didier Dink", amount) + await ctx.reply( + f"**{ctx.author.display_name}** heeft **{user.display_name}** **{amount}** {plural} geschonken." + ) await self.client.confirm_message(ctx.message) @commands.hybrid_command(name="bank") @@ -51,4 +55,5 @@ class Currency(commands.Cog): async def setup(client: Didier): + """Load the cog""" await client.add_cog(Currency(client)) diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py new file mode 100644 index 0000000..08b4c81 --- /dev/null +++ b/tests/test_database/test_crud/test_users.py @@ -0,0 +1,25 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import users as crud +from database.models import User + + +async def test_get_or_add_non_existing(database_session: AsyncSession): + """Test get_or_add for a user that doesn't exist""" + await crud.get_or_add(database_session, 1) + statement = select(User) + res = (await database_session.execute(statement)).scalars().all() + + assert len(res) == 1 + assert res[0].bank is not None + assert res[0].nightly_data is not None + + +async def test_get_or_add_existing(database_session: AsyncSession): + """Test get_or_add for a user that does exist""" + user = await crud.get_or_add(database_session, 1) + bank = user.bank + + assert await crud.get_or_add(database_session, 1) == user + assert (await crud.get_or_add(database_session, 1)).bank == bank From bd63f80a7d5e908ccb6a0b5addbce231b85ffbaa Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 30 Jun 2022 21:39:13 +0200 Subject: [PATCH 4/7] Editing custom commands --- didier/data/modals/custom_commands.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/didier/data/modals/custom_commands.py b/didier/data/modals/custom_commands.py index 3cd908f..74e71e6 100644 --- a/didier/data/modals/custom_commands.py +++ b/didier/data/modals/custom_commands.py @@ -1,4 +1,5 @@ import traceback +import typing import discord @@ -49,7 +50,6 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): self.original_name = name self.client = client - # TODO find a way to access these items self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name)) self.add_item( discord.ui.TextInput( @@ -58,8 +58,11 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): ) async def on_submit(self, interaction: discord.Interaction): + name_field = typing.cast(discord.ui.TextInput, self.children[0]) + response_field = typing.cast(discord.ui.TextInput, self.children[1]) + async with self.client.db_session as session: - await edit_command(session, self.original_name, self.name.value, self.response.value) + await edit_command(session, self.original_name, name_field.value, response_field.value) await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) From fd72bb17749f089c055f052821ed01a9ac0476ff Mon Sep 17 00:00:00 2001 From: stijndcl Date: Thu, 30 Jun 2022 21:49:45 +0200 Subject: [PATCH 5/7] Typing --- didier/data/flags/posix.py | 2 +- didier/data/modals/custom_commands.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/didier/data/flags/posix.py b/didier/data/flags/posix.py index 582fd4e..8747165 100644 --- a/didier/data/flags/posix.py +++ b/didier/data/flags/posix.py @@ -1,7 +1,7 @@ from discord.ext import commands -class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): +class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): # type: ignore """Base class to add POSIX-like flags to commands Example usage: diff --git a/didier/data/modals/custom_commands.py b/didier/data/modals/custom_commands.py index 74e71e6..a54f690 100644 --- a/didier/data/modals/custom_commands.py +++ b/didier/data/modals/custom_commands.py @@ -24,7 +24,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"): async def on_submit(self, interaction: discord.Interaction): async with self.client.db_session as session: - command = await create_command(session, self.name.value, self.response.value) + command = await create_command(session, str(self.name.value), str(self.response.value)) await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) From 96916d2abdd2a9246e51def8e8e6b5a0c175c2b5 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 1 Jul 2022 12:43:41 +0200 Subject: [PATCH 6/7] Re-create & test number converter --- didier/utils/discord/converters/__init__.py | 0 didier/utils/discord/converters/numbers.py | 42 ++++++++++++++++ .../test_discord/test_converters/__init__.py | 0 .../test_converters/test_numbers.py | 50 +++++++++++++++++++ 4 files changed, 92 insertions(+) create mode 100644 didier/utils/discord/converters/__init__.py create mode 100644 didier/utils/discord/converters/numbers.py create mode 100644 tests/test_didier/test_utils/test_discord/test_converters/__init__.py create mode 100644 tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py diff --git a/didier/utils/discord/converters/__init__.py b/didier/utils/discord/converters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/didier/utils/discord/converters/numbers.py b/didier/utils/discord/converters/numbers.py new file mode 100644 index 0000000..bd0f432 --- /dev/null +++ b/didier/utils/discord/converters/numbers.py @@ -0,0 +1,42 @@ +import math +from typing import Optional + + +def abbreviated_number(argument: str) -> int: + """Custom converter to allow numbers to be abbreviated + Examples: + 515k + 4m + """ + if not argument: + raise ValueError + + if argument.isdecimal(): + return int(argument) + + units = {"k": 3, "m": 6, "b": 9, "t": 12} + + # Get the unit if there is one, then chop it off + unit: Optional[str] = None + if not argument[-1].isdigit(): + if argument[-1].lower() not in units: + raise ValueError + + unit = argument[-1].lower() + argument = argument[:-1] + + # [int][unit] + if "." not in argument and unit is not None: + return int(argument) * (10 ** units.get(unit)) + + # [float][unit] + if "." in argument: + # Floats themselves are not supported + if unit is None: + raise ValueError + + as_float = float(argument) + return math.floor(as_float * (10 ** units.get(unit))) + + # Unparseable + raise ValueError diff --git a/tests/test_didier/test_utils/test_discord/test_converters/__init__.py b/tests/test_didier/test_utils/test_discord/test_converters/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py new file mode 100644 index 0000000..3efc9f4 --- /dev/null +++ b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py @@ -0,0 +1,50 @@ +import pytest + +from didier.utils.discord.converters import numbers + + +def test_abbreviated_int(): + """Test abbreviated_number for a regular int""" + assert numbers.abbreviated_number("500") == 500 + + +def test_abbreviated_float_errors(): + """Test abbreviated_number for a float""" + with pytest.raises(ValueError): + numbers.abbreviated_number("5.4") + + +def test_abbreviated_int_unit(): + """Test abbreviated_number for an int combined with a unit""" + assert numbers.abbreviated_number("20k") == 20000 + + +def test_abbreviated_int_unknown_unit(): + """Test abbreviated_number for an int combined with an unknown unit""" + with pytest.raises(ValueError): + numbers.abbreviated_number("20p") + + +def test_abbreviated_float_unit(): + """Test abbreviated_number for a float combined with a unit""" + assert numbers.abbreviated_number("20.5k") == 20500 + + +def test_abbreviated_float_unknown_unit(): + """Test abbreviated_number for a float combined with an unknown unit""" + with pytest.raises(ValueError): + numbers.abbreviated_number("20.5p") + + +def test_abbreviated_no_number(): + """Test abbreviated_number for unparseable content""" + with pytest.raises(ValueError): + numbers.abbreviated_number("didier") + + +def test_abbreviated_float_floors(): + """Test abbreviated_number for a float that is longer than the unit + Example: + 5.3k is 5300, but 5.3001k is 5300.1 + """ + assert numbers.abbreviated_number("5.3001k") == 5300 From c294bc8da50fdb1ba24db10a21b3457bfcf1e00f Mon Sep 17 00:00:00 2001 From: stijndcl Date: Fri, 1 Jul 2022 14:05:00 +0200 Subject: [PATCH 7/7] Use abbreviated numbers in award --- didier/cogs/currency.py | 13 +++++++++---- didier/utils/discord/converters/__init__.py | 1 + didier/utils/discord/converters/numbers.py | 14 +++++++++----- 3 files changed, 19 insertions(+), 9 deletions(-) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 3283ed4..606a910 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -1,3 +1,5 @@ +import typing + import discord from discord.ext import commands @@ -5,6 +7,7 @@ from database.crud import currency as crud from database.exceptions.currency import DoubleNightly from didier import Didier from didier.utils.discord.checks import is_owner +from didier.utils.discord.converters import abbreviated_number from didier.utils.types.string import pluralize @@ -19,17 +22,19 @@ class Currency(commands.Cog): @commands.command(name="Award") @commands.check(is_owner) - async def award(self, ctx: commands.Context, user: discord.User, amount: int): + async def award(self, ctx: commands.Context, user: discord.User, amount: abbreviated_number): # type: ignore """Award a user a given amount of Didier Dinks""" + amount = typing.cast(int, amount) + async with self.client.db_session as session: await crud.add_dinks(session, user.id, amount) plural = pluralize("Didier Dink", amount) await ctx.reply( - f"**{ctx.author.display_name}** heeft **{user.display_name}** **{amount}** {plural} geschonken." + f"**{ctx.author.display_name}** heeft **{user.display_name}** **{amount}** {plural} geschonken.", + mention_author=False, ) - await self.client.confirm_message(ctx.message) - @commands.hybrid_command(name="bank") + @commands.hybrid_group(name="bank", case_insensitive=True, invoke_without_command=True) async def bank(self, ctx: commands.Context): """Show your Didier Bank information""" async with self.client.db_session as session: diff --git a/didier/utils/discord/converters/__init__.py b/didier/utils/discord/converters/__init__.py index e69de29..3f47753 100644 --- a/didier/utils/discord/converters/__init__.py +++ b/didier/utils/discord/converters/__init__.py @@ -0,0 +1 @@ +from .numbers import * diff --git a/didier/utils/discord/converters/numbers.py b/didier/utils/discord/converters/numbers.py index bd0f432..8019fa5 100644 --- a/didier/utils/discord/converters/numbers.py +++ b/didier/utils/discord/converters/numbers.py @@ -2,6 +2,9 @@ import math from typing import Optional +__all__ = ["abbreviated_number"] + + def abbreviated_number(argument: str) -> int: """Custom converter to allow numbers to be abbreviated Examples: @@ -17,26 +20,27 @@ def abbreviated_number(argument: str) -> int: units = {"k": 3, "m": 6, "b": 9, "t": 12} # Get the unit if there is one, then chop it off - unit: Optional[str] = None + value: Optional[int] = None if not argument[-1].isdigit(): if argument[-1].lower() not in units: raise ValueError unit = argument[-1].lower() + value = units.get(unit) argument = argument[:-1] # [int][unit] - if "." not in argument and unit is not None: - return int(argument) * (10 ** units.get(unit)) + if "." not in argument and value is not None: + return int(argument) * (10**value) # [float][unit] if "." in argument: # Floats themselves are not supported - if unit is None: + if value is None: raise ValueError as_float = float(argument) - return math.floor(as_float * (10 ** units.get(unit))) + return math.floor(as_float * (10**value)) # Unparseable raise ValueError