mirror of https://github.com/stijndcl/didier
				
				
				
			
						commit
						cb0b4a419e
					
				|  | @ -0,0 +1,51 @@ | |||
| """Initial currency models | ||||
| 
 | ||||
| Revision ID: 0d03c226d881 | ||||
| Revises: b2d511552a1f | ||||
| Create Date: 2022-06-30 20:02:27.284759 | ||||
| 
 | ||||
| """ | ||||
| from alembic import op | ||||
| import sqlalchemy as sa | ||||
| 
 | ||||
| 
 | ||||
| # revision identifiers, used by Alembic. | ||||
| revision = '0d03c226d881' | ||||
| down_revision = 'b2d511552a1f' | ||||
| branch_labels = None | ||||
| depends_on = None | ||||
| 
 | ||||
| 
 | ||||
| def upgrade() -> None: | ||||
|     # ### commands auto generated by Alembic - please adjust! ### | ||||
|     op.create_table('users', | ||||
|     sa.Column('user_id', sa.BigInteger(), nullable=False), | ||||
|     sa.PrimaryKeyConstraint('user_id') | ||||
|     ) | ||||
|     op.create_table('bank', | ||||
|     sa.Column('bank_id', sa.Integer(), nullable=False), | ||||
|     sa.Column('user_id', sa.BigInteger(), nullable=True), | ||||
|     sa.Column('dinks', sa.BigInteger(), nullable=False), | ||||
|     sa.Column('interest_level', sa.Integer(), nullable=False), | ||||
|     sa.Column('capacity_level', sa.Integer(), nullable=False), | ||||
|     sa.Column('rob_level', sa.Integer(), nullable=False), | ||||
|     sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), | ||||
|     sa.PrimaryKeyConstraint('bank_id') | ||||
|     ) | ||||
|     op.create_table('nightly_data', | ||||
|     sa.Column('nightly_id', sa.Integer(), nullable=False), | ||||
|     sa.Column('user_id', sa.BigInteger(), nullable=True), | ||||
|     sa.Column('last_nightly', sa.DateTime(timezone=True), nullable=True), | ||||
|     sa.Column('count', sa.Integer(), nullable=False), | ||||
|     sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), | ||||
|     sa.PrimaryKeyConstraint('nightly_id') | ||||
|     ) | ||||
|     # ### end Alembic commands ### | ||||
| 
 | ||||
| 
 | ||||
| def downgrade() -> None: | ||||
|     # ### commands auto generated by Alembic - please adjust! ### | ||||
|     op.drop_table('nightly_data') | ||||
|     op.drop_table('bank') | ||||
|     op.drop_table('users') | ||||
|     # ### end Alembic commands ### | ||||
|  | @ -0,0 +1,43 @@ | |||
| from datetime import datetime | ||||
| 
 | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.crud import users | ||||
| from database.exceptions import currency as exceptions | ||||
| from database.models import Bank | ||||
| 
 | ||||
| 
 | ||||
| NIGHTLY_AMOUNT = 420 | ||||
| 
 | ||||
| 
 | ||||
| async def get_bank(session: AsyncSession, user_id: int) -> Bank: | ||||
|     """Get a user's bank info""" | ||||
|     user = await users.get_or_add(session, user_id) | ||||
|     return user.bank | ||||
| 
 | ||||
| 
 | ||||
| async def add_dinks(session: AsyncSession, user_id: int, amount: int): | ||||
|     """Increase the Dinks counter for a user""" | ||||
|     bank = await get_bank(session, user_id) | ||||
|     bank.dinks += amount | ||||
|     session.add(bank) | ||||
|     await session.commit() | ||||
| 
 | ||||
| 
 | ||||
| async def claim_nightly(session: AsyncSession, user_id: int): | ||||
|     """Claim daily Dinks""" | ||||
|     user = await users.get_or_add(session, user_id) | ||||
|     nightly_data = user.nightly_data | ||||
| 
 | ||||
|     now = datetime.now() | ||||
| 
 | ||||
|     if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): | ||||
|         raise exceptions.DoubleNightly | ||||
| 
 | ||||
|     bank = user.bank | ||||
|     bank.dinks += NIGHTLY_AMOUNT | ||||
|     nightly_data.last_nightly = now | ||||
| 
 | ||||
|     session.add(bank) | ||||
|     session.add(nightly_data) | ||||
|     await session.commit() | ||||
|  | @ -0,0 +1,37 @@ | |||
| from typing import Optional | ||||
| 
 | ||||
| from sqlalchemy import select | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.models import User, Bank, NightlyData | ||||
| 
 | ||||
| 
 | ||||
| async def get_or_add(session: AsyncSession, user_id: int) -> User: | ||||
|     """Get a user's profile | ||||
|     If it doesn't exist yet, create it (along with all linked datastructures). | ||||
|     """ | ||||
|     statement = select(User).where(User.user_id == user_id) | ||||
|     user: Optional[User] = (await session.execute(statement)).scalar_one_or_none() | ||||
| 
 | ||||
|     # User exists | ||||
|     if user is not None: | ||||
|         return user | ||||
| 
 | ||||
|     # Create new user | ||||
|     user = User(user_id=user_id) | ||||
|     session.add(user) | ||||
|     await session.commit() | ||||
| 
 | ||||
|     # Add bank & nightly info | ||||
|     bank = Bank(user_id=user_id) | ||||
|     nightly_data = NightlyData(user_id=user_id) | ||||
|     user.bank = bank | ||||
|     user.nightly_data = nightly_data | ||||
| 
 | ||||
|     session.add(bank) | ||||
|     session.add(nightly_data) | ||||
|     session.add(user) | ||||
| 
 | ||||
|     await session.commit() | ||||
| 
 | ||||
|     return user | ||||
|  | @ -0,0 +1,2 @@ | |||
| class DoubleNightly(Exception): | ||||
|     """Exception raised when claiming nightlies multiple times per day""" | ||||
|  | @ -1,13 +1,36 @@ | |||
| from __future__ import annotations | ||||
| 
 | ||||
| from datetime import datetime | ||||
| from typing import Optional | ||||
| 
 | ||||
| from sqlalchemy import Column, Integer, Text, ForeignKey, Boolean, DateTime | ||||
| from sqlalchemy import BigInteger, Column, Integer, Text, ForeignKey, Boolean, DateTime | ||||
| from sqlalchemy.orm import declarative_base, relationship | ||||
| 
 | ||||
| Base = declarative_base() | ||||
| 
 | ||||
| 
 | ||||
| class Bank(Base): | ||||
|     """A user's currency information""" | ||||
| 
 | ||||
|     __tablename__ = "bank" | ||||
| 
 | ||||
|     bank_id: int = Column(Integer, primary_key=True) | ||||
|     user_id: int = Column(BigInteger, ForeignKey("users.user_id")) | ||||
| 
 | ||||
|     dinks: int = Column(BigInteger, default=0, nullable=False) | ||||
| 
 | ||||
|     # Interest rate | ||||
|     interest_level: int = Column(Integer, default=1, nullable=False) | ||||
| 
 | ||||
|     # Maximum amount that can be stored in the bank | ||||
|     capacity_level: int = Column(Integer, default=1, nullable=False) | ||||
| 
 | ||||
|     # Maximum amount that can be robbed | ||||
|     rob_level: int = Column(Integer, default=1, nullable=False) | ||||
| 
 | ||||
|     user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") | ||||
| 
 | ||||
| 
 | ||||
| class CustomCommand(Base): | ||||
|     """Custom commands to fill the hole Dyno couldn't""" | ||||
| 
 | ||||
|  | @ -36,6 +59,19 @@ class CustomCommandAlias(Base): | |||
|     command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") | ||||
| 
 | ||||
| 
 | ||||
| class NightlyData(Base): | ||||
|     """Data for a user's Nightly stats""" | ||||
| 
 | ||||
|     __tablename__ = "nightly_data" | ||||
| 
 | ||||
|     nightly_id: int = Column(Integer, primary_key=True) | ||||
|     user_id: int = Column(BigInteger, ForeignKey("users.user_id")) | ||||
|     last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True) | ||||
|     count: int = Column(Integer, default=0, nullable=False) | ||||
| 
 | ||||
|     user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") | ||||
| 
 | ||||
| 
 | ||||
| class UforaCourse(Base): | ||||
|     """A course on Ufora""" | ||||
| 
 | ||||
|  | @ -72,8 +108,23 @@ class UforaAnnouncement(Base): | |||
| 
 | ||||
|     __tablename__ = "ufora_announcements" | ||||
| 
 | ||||
|     announcement_id = Column(Integer, primary_key=True) | ||||
|     course_id = Column(Integer, ForeignKey("ufora_courses.course_id")) | ||||
|     announcement_id: int = Column(Integer, primary_key=True) | ||||
|     course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) | ||||
|     publication_date: datetime = Column(DateTime(timezone=True)) | ||||
| 
 | ||||
|     course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") | ||||
| 
 | ||||
| 
 | ||||
| class User(Base): | ||||
|     """A Didier user""" | ||||
| 
 | ||||
|     __tablename__ = "users" | ||||
| 
 | ||||
|     user_id: int = Column(BigInteger, primary_key=True) | ||||
| 
 | ||||
|     bank: Bank = relationship( | ||||
|         "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" | ||||
|     ) | ||||
|     nightly_data: NightlyData = relationship( | ||||
|         "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" | ||||
|     ) | ||||
|  |  | |||
|  | @ -0,0 +1,64 @@ | |||
| import typing | ||||
| 
 | ||||
| import discord | ||||
| from discord.ext import commands | ||||
| 
 | ||||
| from database.crud import currency as crud | ||||
| from database.exceptions.currency import DoubleNightly | ||||
| from didier import Didier | ||||
| from didier.utils.discord.checks import is_owner | ||||
| from didier.utils.discord.converters import abbreviated_number | ||||
| from didier.utils.types.string import pluralize | ||||
| 
 | ||||
| 
 | ||||
| class Currency(commands.Cog): | ||||
|     """Everything Dinks-related""" | ||||
| 
 | ||||
|     client: Didier | ||||
| 
 | ||||
|     def __init__(self, client: Didier): | ||||
|         super().__init__() | ||||
|         self.client = client | ||||
| 
 | ||||
|     @commands.command(name="Award") | ||||
|     @commands.check(is_owner) | ||||
|     async def award(self, ctx: commands.Context, user: discord.User, amount: abbreviated_number):  # type: ignore | ||||
|         """Award a user a given amount of Didier Dinks""" | ||||
|         amount = typing.cast(int, amount) | ||||
| 
 | ||||
|         async with self.client.db_session as session: | ||||
|             await crud.add_dinks(session, user.id, amount) | ||||
|             plural = pluralize("Didier Dink", amount) | ||||
|             await ctx.reply( | ||||
|                 f"**{ctx.author.display_name}** heeft **{user.display_name}** **{amount}** {plural} geschonken.", | ||||
|                 mention_author=False, | ||||
|             ) | ||||
| 
 | ||||
|     @commands.hybrid_group(name="bank", case_insensitive=True, invoke_without_command=True) | ||||
|     async def bank(self, ctx: commands.Context): | ||||
|         """Show your Didier Bank information""" | ||||
|         async with self.client.db_session as session: | ||||
|             await crud.get_bank(session, ctx.author.id) | ||||
| 
 | ||||
|     @commands.hybrid_command(name="dinks") | ||||
|     async def dinks(self, ctx: commands.Context): | ||||
|         """Check your Didier Dinks""" | ||||
|         async with self.client.db_session as session: | ||||
|             bank = await crud.get_bank(session, ctx.author.id) | ||||
|             plural = pluralize("Didier Dink", bank.dinks) | ||||
|             await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) | ||||
| 
 | ||||
|     @commands.hybrid_command(name="nightly") | ||||
|     async def nightly(self, ctx: commands.Context): | ||||
|         """Claim nightly Dinks""" | ||||
|         async with self.client.db_session as session: | ||||
|             try: | ||||
|                 await crud.claim_nightly(session, ctx.author.id) | ||||
|                 await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") | ||||
|             except DoubleNightly: | ||||
|                 await ctx.reply("Je hebt je nightly al geclaimd vandaag.", mention_author=False, ephemeral=True) | ||||
| 
 | ||||
| 
 | ||||
| async def setup(client: Didier): | ||||
|     """Load the cog""" | ||||
|     await client.add_cog(Currency(client)) | ||||
|  | @ -1,7 +1,7 @@ | |||
| from discord.ext import commands | ||||
| 
 | ||||
| 
 | ||||
| class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): | ||||
| class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"):  # type: ignore | ||||
|     """Base class to add POSIX-like flags to commands | ||||
| 
 | ||||
|     Example usage: | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| import traceback | ||||
| import typing | ||||
| 
 | ||||
| import discord | ||||
| 
 | ||||
|  | @ -23,7 +24,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"): | |||
| 
 | ||||
|     async def on_submit(self, interaction: discord.Interaction): | ||||
|         async with self.client.db_session as session: | ||||
|             command = await create_command(session, self.name.value, self.response.value) | ||||
|             command = await create_command(session, str(self.name.value), str(self.response.value)) | ||||
| 
 | ||||
|         await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) | ||||
| 
 | ||||
|  | @ -49,7 +50,6 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): | |||
|         self.original_name = name | ||||
|         self.client = client | ||||
| 
 | ||||
|         # TODO find a way to access these items | ||||
|         self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name)) | ||||
|         self.add_item( | ||||
|             discord.ui.TextInput( | ||||
|  | @ -58,8 +58,11 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): | |||
|         ) | ||||
| 
 | ||||
|     async def on_submit(self, interaction: discord.Interaction): | ||||
|         name_field = typing.cast(discord.ui.TextInput, self.children[0]) | ||||
|         response_field = typing.cast(discord.ui.TextInput, self.children[1]) | ||||
| 
 | ||||
|         async with self.client.db_session as session: | ||||
|             await edit_command(session, self.original_name, self.name.value, self.response.value) | ||||
|             await edit_command(session, self.original_name, name_field.value, response_field.value) | ||||
| 
 | ||||
|         await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) | ||||
| 
 | ||||
|  |  | |||
|  | @ -0,0 +1 @@ | |||
| from .message_commands import is_owner | ||||
|  | @ -0,0 +1,6 @@ | |||
| from discord.ext import commands | ||||
| 
 | ||||
| 
 | ||||
| async def is_owner(ctx: commands.Context) -> bool: | ||||
|     """Check that a command is being invoked by the owner of the bot""" | ||||
|     return await ctx.bot.is_owner(ctx.author) | ||||
|  | @ -0,0 +1 @@ | |||
| from .numbers import * | ||||
|  | @ -0,0 +1,46 @@ | |||
| import math | ||||
| from typing import Optional | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["abbreviated_number"] | ||||
| 
 | ||||
| 
 | ||||
| def abbreviated_number(argument: str) -> int: | ||||
|     """Custom converter to allow numbers to be abbreviated | ||||
|     Examples: | ||||
|         515k | ||||
|         4m | ||||
|     """ | ||||
|     if not argument: | ||||
|         raise ValueError | ||||
| 
 | ||||
|     if argument.isdecimal(): | ||||
|         return int(argument) | ||||
| 
 | ||||
|     units = {"k": 3, "m": 6, "b": 9, "t": 12} | ||||
| 
 | ||||
|     # Get the unit if there is one, then chop it off | ||||
|     value: Optional[int] = None | ||||
|     if not argument[-1].isdigit(): | ||||
|         if argument[-1].lower() not in units: | ||||
|             raise ValueError | ||||
| 
 | ||||
|         unit = argument[-1].lower() | ||||
|         value = units.get(unit) | ||||
|         argument = argument[:-1] | ||||
| 
 | ||||
|     # [int][unit] | ||||
|     if "." not in argument and value is not None: | ||||
|         return int(argument) * (10**value) | ||||
| 
 | ||||
|     # [float][unit] | ||||
|     if "." in argument: | ||||
|         # Floats themselves are not supported | ||||
|         if value is None: | ||||
|             raise ValueError | ||||
| 
 | ||||
|         as_float = float(argument) | ||||
|         return math.floor(as_float * (10**value)) | ||||
| 
 | ||||
|     # Unparseable | ||||
|     raise ValueError | ||||
|  | @ -20,3 +20,11 @@ def leading(character: str, string: str, target_length: Optional[int] = 2) -> st | |||
|     frequency = math.ceil((target_length - len(string)) / len(character)) | ||||
| 
 | ||||
|     return (frequency * character) + string | ||||
| 
 | ||||
| 
 | ||||
| def pluralize(word: str, amount: int, plural_form: Optional[str] = None) -> str: | ||||
|     """Turn a word into plural""" | ||||
|     if amount == 1: | ||||
|         return word | ||||
| 
 | ||||
|     return plural_form or (word + "s") | ||||
|  |  | |||
|  | @ -0,0 +1,25 @@ | |||
| from sqlalchemy import select | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.crud import users as crud | ||||
| from database.models import User | ||||
| 
 | ||||
| 
 | ||||
| async def test_get_or_add_non_existing(database_session: AsyncSession): | ||||
|     """Test get_or_add for a user that doesn't exist""" | ||||
|     await crud.get_or_add(database_session, 1) | ||||
|     statement = select(User) | ||||
|     res = (await database_session.execute(statement)).scalars().all() | ||||
| 
 | ||||
|     assert len(res) == 1 | ||||
|     assert res[0].bank is not None | ||||
|     assert res[0].nightly_data is not None | ||||
| 
 | ||||
| 
 | ||||
| async def test_get_or_add_existing(database_session: AsyncSession): | ||||
|     """Test get_or_add for a user that does exist""" | ||||
|     user = await crud.get_or_add(database_session, 1) | ||||
|     bank = user.bank | ||||
| 
 | ||||
|     assert await crud.get_or_add(database_session, 1) == user | ||||
|     assert (await crud.get_or_add(database_session, 1)).bank == bank | ||||
|  | @ -0,0 +1,50 @@ | |||
| import pytest | ||||
| 
 | ||||
| from didier.utils.discord.converters import numbers | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_int(): | ||||
|     """Test abbreviated_number for a regular int""" | ||||
|     assert numbers.abbreviated_number("500") == 500 | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_float_errors(): | ||||
|     """Test abbreviated_number for a float""" | ||||
|     with pytest.raises(ValueError): | ||||
|         numbers.abbreviated_number("5.4") | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_int_unit(): | ||||
|     """Test abbreviated_number for an int combined with a unit""" | ||||
|     assert numbers.abbreviated_number("20k") == 20000 | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_int_unknown_unit(): | ||||
|     """Test abbreviated_number for an int combined with an unknown unit""" | ||||
|     with pytest.raises(ValueError): | ||||
|         numbers.abbreviated_number("20p") | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_float_unit(): | ||||
|     """Test abbreviated_number for a float combined with a unit""" | ||||
|     assert numbers.abbreviated_number("20.5k") == 20500 | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_float_unknown_unit(): | ||||
|     """Test abbreviated_number for a float combined with an unknown unit""" | ||||
|     with pytest.raises(ValueError): | ||||
|         numbers.abbreviated_number("20.5p") | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_no_number(): | ||||
|     """Test abbreviated_number for unparseable content""" | ||||
|     with pytest.raises(ValueError): | ||||
|         numbers.abbreviated_number("didier") | ||||
| 
 | ||||
| 
 | ||||
| def test_abbreviated_float_floors(): | ||||
|     """Test abbreviated_number for a float that is longer than the unit | ||||
|     Example: | ||||
|         5.3k is 5300, but 5.3001k is 5300.1 | ||||
|     """ | ||||
|     assert numbers.abbreviated_number("5.3001k") == 5300 | ||||
		Loading…
	
		Reference in New Issue