Merge pull request #118 from stijndcl/currency-v3

Start of currency v3
pull/119/head
Stijn De Clercq 2022-07-02 00:00:21 +02:00 committed by GitHub
commit cb0b4a419e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 395 additions and 7 deletions

View File

@ -0,0 +1,51 @@
"""Initial currency models
Revision ID: 0d03c226d881
Revises: b2d511552a1f
Create Date: 2022-06-30 20:02:27.284759
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '0d03c226d881'
down_revision = 'b2d511552a1f'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('user_id', sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint('user_id')
)
op.create_table('bank',
sa.Column('bank_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.BigInteger(), nullable=True),
sa.Column('dinks', sa.BigInteger(), nullable=False),
sa.Column('interest_level', sa.Integer(), nullable=False),
sa.Column('capacity_level', sa.Integer(), nullable=False),
sa.Column('rob_level', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ),
sa.PrimaryKeyConstraint('bank_id')
)
op.create_table('nightly_data',
sa.Column('nightly_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.BigInteger(), nullable=True),
sa.Column('last_nightly', sa.DateTime(timezone=True), nullable=True),
sa.Column('count', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ),
sa.PrimaryKeyConstraint('nightly_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('nightly_data')
op.drop_table('bank')
op.drop_table('users')
# ### end Alembic commands ###

View File

@ -0,0 +1,43 @@
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.exceptions import currency as exceptions
from database.models import Bank
NIGHTLY_AMOUNT = 420
async def get_bank(session: AsyncSession, user_id: int) -> Bank:
"""Get a user's bank info"""
user = await users.get_or_add(session, user_id)
return user.bank
async def add_dinks(session: AsyncSession, user_id: int, amount: int):
"""Increase the Dinks counter for a user"""
bank = await get_bank(session, user_id)
bank.dinks += amount
session.add(bank)
await session.commit()
async def claim_nightly(session: AsyncSession, user_id: int):
"""Claim daily Dinks"""
user = await users.get_or_add(session, user_id)
nightly_data = user.nightly_data
now = datetime.now()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
raise exceptions.DoubleNightly
bank = user.bank
bank.dinks += NIGHTLY_AMOUNT
nightly_data.last_nightly = now
session.add(bank)
session.add(nightly_data)
await session.commit()

View File

@ -0,0 +1,37 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import User, Bank, NightlyData
async def get_or_add(session: AsyncSession, user_id: int) -> User:
"""Get a user's profile
If it doesn't exist yet, create it (along with all linked datastructures).
"""
statement = select(User).where(User.user_id == user_id)
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
# User exists
if user is not None:
return user
# Create new user
user = User(user_id=user_id)
session.add(user)
await session.commit()
# Add bank & nightly info
bank = Bank(user_id=user_id)
nightly_data = NightlyData(user_id=user_id)
user.bank = bank
user.nightly_data = nightly_data
session.add(bank)
session.add(nightly_data)
session.add(user)
await session.commit()
return user

View File

@ -0,0 +1,2 @@
class DoubleNightly(Exception):
"""Exception raised when claiming nightlies multiple times per day"""

View File

@ -1,13 +1,36 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, Text, ForeignKey, Boolean, DateTime
from sqlalchemy import BigInteger, Column, Integer, Text, ForeignKey, Boolean, DateTime
from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class Bank(Base):
"""A user's currency information"""
__tablename__ = "bank"
bank_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
dinks: int = Column(BigInteger, default=0, nullable=False)
# Interest rate
interest_level: int = Column(Integer, default=1, nullable=False)
# Maximum amount that can be stored in the bank
capacity_level: int = Column(Integer, default=1, nullable=False)
# Maximum amount that can be robbed
rob_level: int = Column(Integer, default=1, nullable=False)
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't"""
@ -36,6 +59,19 @@ class CustomCommandAlias(Base):
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
class NightlyData(Base):
"""Data for a user's Nightly stats"""
__tablename__ = "nightly_data"
nightly_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True)
count: int = Column(Integer, default=0, nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class UforaCourse(Base):
"""A course on Ufora"""
@ -72,8 +108,23 @@ class UforaAnnouncement(Base):
__tablename__ = "ufora_announcements"
announcement_id = Column(Integer, primary_key=True)
course_id = Column(Integer, ForeignKey("ufora_courses.course_id"))
announcement_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
publication_date: datetime = Column(DateTime(timezone=True))
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")
class User(Base):
"""A Didier user"""
__tablename__ = "users"
user_id: int = Column(BigInteger, primary_key=True)
bank: Bank = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)

View File

@ -0,0 +1,64 @@
import typing
import discord
from discord.ext import commands
from database.crud import currency as crud
from database.exceptions.currency import DoubleNightly
from didier import Didier
from didier.utils.discord.checks import is_owner
from didier.utils.discord.converters import abbreviated_number
from didier.utils.types.string import pluralize
class Currency(commands.Cog):
"""Everything Dinks-related"""
client: Didier
def __init__(self, client: Didier):
super().__init__()
self.client = client
@commands.command(name="Award")
@commands.check(is_owner)
async def award(self, ctx: commands.Context, user: discord.User, amount: abbreviated_number): # type: ignore
"""Award a user a given amount of Didier Dinks"""
amount = typing.cast(int, amount)
async with self.client.db_session as session:
await crud.add_dinks(session, user.id, amount)
plural = pluralize("Didier Dink", amount)
await ctx.reply(
f"**{ctx.author.display_name}** heeft **{user.display_name}** **{amount}** {plural} geschonken.",
mention_author=False,
)
@commands.hybrid_group(name="bank", case_insensitive=True, invoke_without_command=True)
async def bank(self, ctx: commands.Context):
"""Show your Didier Bank information"""
async with self.client.db_session as session:
await crud.get_bank(session, ctx.author.id)
@commands.hybrid_command(name="dinks")
async def dinks(self, ctx: commands.Context):
"""Check your Didier Dinks"""
async with self.client.db_session as session:
bank = await crud.get_bank(session, ctx.author.id)
plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
@commands.hybrid_command(name="nightly")
async def nightly(self, ctx: commands.Context):
"""Claim nightly Dinks"""
async with self.client.db_session as session:
try:
await crud.claim_nightly(session, ctx.author.id)
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")
except DoubleNightly:
await ctx.reply("Je hebt je nightly al geclaimd vandaag.", mention_author=False, ephemeral=True)
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Currency(client))

View File

@ -1,7 +1,7 @@
from discord.ext import commands
class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"):
class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): # type: ignore
"""Base class to add POSIX-like flags to commands
Example usage:

View File

@ -1,4 +1,5 @@
import traceback
import typing
import discord
@ -23,7 +24,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session:
command = await create_command(session, self.name.value, self.response.value)
command = await create_command(session, str(self.name.value), str(self.response.value))
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
@ -49,7 +50,6 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
self.original_name = name
self.client = client
# TODO find a way to access these items
self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name))
self.add_item(
discord.ui.TextInput(
@ -58,8 +58,11 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
)
async def on_submit(self, interaction: discord.Interaction):
name_field = typing.cast(discord.ui.TextInput, self.children[0])
response_field = typing.cast(discord.ui.TextInput, self.children[1])
async with self.client.db_session as session:
await edit_command(session, self.original_name, self.name.value, self.response.value)
await edit_command(session, self.original_name, name_field.value, response_field.value)
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)

View File

@ -0,0 +1 @@
from .message_commands import is_owner

View File

@ -0,0 +1,6 @@
from discord.ext import commands
async def is_owner(ctx: commands.Context) -> bool:
"""Check that a command is being invoked by the owner of the bot"""
return await ctx.bot.is_owner(ctx.author)

View File

@ -0,0 +1 @@
from .numbers import *

View File

@ -0,0 +1,46 @@
import math
from typing import Optional
__all__ = ["abbreviated_number"]
def abbreviated_number(argument: str) -> int:
"""Custom converter to allow numbers to be abbreviated
Examples:
515k
4m
"""
if not argument:
raise ValueError
if argument.isdecimal():
return int(argument)
units = {"k": 3, "m": 6, "b": 9, "t": 12}
# Get the unit if there is one, then chop it off
value: Optional[int] = None
if not argument[-1].isdigit():
if argument[-1].lower() not in units:
raise ValueError
unit = argument[-1].lower()
value = units.get(unit)
argument = argument[:-1]
# [int][unit]
if "." not in argument and value is not None:
return int(argument) * (10**value)
# [float][unit]
if "." in argument:
# Floats themselves are not supported
if value is None:
raise ValueError
as_float = float(argument)
return math.floor(as_float * (10**value))
# Unparseable
raise ValueError

View File

@ -20,3 +20,11 @@ def leading(character: str, string: str, target_length: Optional[int] = 2) -> st
frequency = math.ceil((target_length - len(string)) / len(character))
return (frequency * character) + string
def pluralize(word: str, amount: int, plural_form: Optional[str] = None) -> str:
"""Turn a word into plural"""
if amount == 1:
return word
return plural_form or (word + "s")

View File

@ -0,0 +1,25 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users as crud
from database.models import User
async def test_get_or_add_non_existing(database_session: AsyncSession):
"""Test get_or_add for a user that doesn't exist"""
await crud.get_or_add(database_session, 1)
statement = select(User)
res = (await database_session.execute(statement)).scalars().all()
assert len(res) == 1
assert res[0].bank is not None
assert res[0].nightly_data is not None
async def test_get_or_add_existing(database_session: AsyncSession):
"""Test get_or_add for a user that does exist"""
user = await crud.get_or_add(database_session, 1)
bank = user.bank
assert await crud.get_or_add(database_session, 1) == user
assert (await crud.get_or_add(database_session, 1)).bank == bank

View File

@ -0,0 +1,50 @@
import pytest
from didier.utils.discord.converters import numbers
def test_abbreviated_int():
"""Test abbreviated_number for a regular int"""
assert numbers.abbreviated_number("500") == 500
def test_abbreviated_float_errors():
"""Test abbreviated_number for a float"""
with pytest.raises(ValueError):
numbers.abbreviated_number("5.4")
def test_abbreviated_int_unit():
"""Test abbreviated_number for an int combined with a unit"""
assert numbers.abbreviated_number("20k") == 20000
def test_abbreviated_int_unknown_unit():
"""Test abbreviated_number for an int combined with an unknown unit"""
with pytest.raises(ValueError):
numbers.abbreviated_number("20p")
def test_abbreviated_float_unit():
"""Test abbreviated_number for a float combined with a unit"""
assert numbers.abbreviated_number("20.5k") == 20500
def test_abbreviated_float_unknown_unit():
"""Test abbreviated_number for a float combined with an unknown unit"""
with pytest.raises(ValueError):
numbers.abbreviated_number("20.5p")
def test_abbreviated_no_number():
"""Test abbreviated_number for unparseable content"""
with pytest.raises(ValueError):
numbers.abbreviated_number("didier")
def test_abbreviated_float_floors():
"""Test abbreviated_number for a float that is longer than the unit
Example:
5.3k is 5300, but 5.3001k is 5300.1
"""
assert numbers.abbreviated_number("5.3001k") == 5300