Compare commits

...

12 Commits

Author SHA1 Message Date
stijndcl c294bc8da5 Use abbreviated numbers in award 2022-07-01 16:06:12 +02:00
stijndcl 96916d2abd Re-create & test number converter 2022-07-01 16:06:12 +02:00
stijndcl fd72bb1774 Typing 2022-07-01 16:06:12 +02:00
stijndcl bd63f80a7d Editing custom commands 2022-07-01 16:06:12 +02:00
stijndcl bec893bd20 Add tests for users crud 2022-07-01 16:06:12 +02:00
stijndcl 032b636b02 Nightly, bank, award & dinks 2022-07-01 16:06:12 +02:00
stijndcl 4587a49311 Create database models 2022-07-01 16:06:12 +02:00
stijndcl 9552c38a70 Fix typo in toml file 2022-07-01 16:06:03 +02:00
Stijn De Clercq 76f1ba3543
Merge pull request #115 from stijndcl/codecov
Add CodeCov & increase coverage
2022-07-01 16:00:55 +02:00
stijndcl c95b7ed58f Remove discord stuff from tests 2022-07-01 15:59:33 +02:00
stijndcl 27d074d760 Increase coverage 2022-07-01 15:46:56 +02:00
stijndcl 9d04d62b1c Add codecov 2022-07-01 14:25:15 +02:00
25 changed files with 524 additions and 11 deletions

View File

@ -54,7 +54,11 @@ jobs:
- name: Install dependencies
run: pip3 install -r requirements.txt -r requirements-dev.txt
- name: Run Pytest
run: pytest tests
run: |
coverage run -m pytest
coverage xml
- name: Upload coverage report to CodeCov
uses: codecov/codecov-action@v3
linting:
needs: [dependencies]
runs-on: ubuntu-latest

View File

@ -0,0 +1,51 @@
"""Initial currency models
Revision ID: 0d03c226d881
Revises: b2d511552a1f
Create Date: 2022-06-30 20:02:27.284759
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '0d03c226d881'
down_revision = 'b2d511552a1f'
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('users',
sa.Column('user_id', sa.BigInteger(), nullable=False),
sa.PrimaryKeyConstraint('user_id')
)
op.create_table('bank',
sa.Column('bank_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.BigInteger(), nullable=True),
sa.Column('dinks', sa.BigInteger(), nullable=False),
sa.Column('interest_level', sa.Integer(), nullable=False),
sa.Column('capacity_level', sa.Integer(), nullable=False),
sa.Column('rob_level', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ),
sa.PrimaryKeyConstraint('bank_id')
)
op.create_table('nightly_data',
sa.Column('nightly_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.BigInteger(), nullable=True),
sa.Column('last_nightly', sa.DateTime(timezone=True), nullable=True),
sa.Column('count', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ),
sa.PrimaryKeyConstraint('nightly_id')
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('nightly_data')
op.drop_table('bank')
op.drop_table('users')
# ### end Alembic commands ###

14
codecov.yaml 100644
View File

@ -0,0 +1,14 @@
comment:
layout: "reach, diff, flags, files"
behavior: default
require_changes: false # if true: only post the comment if coverage changes
require_base: no # [yes :: must have a base report to post]
require_head: yes # [yes :: must have a head report to post]
coverage:
round: down
precision: 5
ignore:
- "./tests/*"
- "./didier/cogs/*" # Cogs can't really be tested properly

View File

@ -0,0 +1,43 @@
from datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.exceptions import currency as exceptions
from database.models import Bank
NIGHTLY_AMOUNT = 420
async def get_bank(session: AsyncSession, user_id: int) -> Bank:
"""Get a user's bank info"""
user = await users.get_or_add(session, user_id)
return user.bank
async def add_dinks(session: AsyncSession, user_id: int, amount: int):
"""Increase the Dinks counter for a user"""
bank = await get_bank(session, user_id)
bank.dinks += amount
session.add(bank)
await session.commit()
async def claim_nightly(session: AsyncSession, user_id: int):
"""Claim daily Dinks"""
user = await users.get_or_add(session, user_id)
nightly_data = user.nightly_data
now = datetime.now()
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
raise exceptions.DoubleNightly
bank = user.bank
bank.dinks += NIGHTLY_AMOUNT
nightly_data.last_nightly = now
session.add(bank)
session.add(nightly_data)
await session.commit()

View File

@ -0,0 +1,37 @@
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import User, Bank, NightlyData
async def get_or_add(session: AsyncSession, user_id: int) -> User:
"""Get a user's profile
If it doesn't exist yet, create it (along with all linked datastructures).
"""
statement = select(User).where(User.user_id == user_id)
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
# User exists
if user is not None:
return user
# Create new user
user = User(user_id=user_id)
session.add(user)
await session.commit()
# Add bank & nightly info
bank = Bank(user_id=user_id)
nightly_data = NightlyData(user_id=user_id)
user.bank = bank
user.nightly_data = nightly_data
session.add(bank)
session.add(nightly_data)
session.add(user)
await session.commit()
return user

View File

@ -0,0 +1,2 @@
class DoubleNightly(Exception):
"""Exception raised when claiming nightlies multiple times per day"""

View File

@ -1,13 +1,36 @@
from __future__ import annotations
from datetime import datetime
from typing import Optional
from sqlalchemy import Column, Integer, Text, ForeignKey, Boolean, DateTime
from sqlalchemy import BigInteger, Column, Integer, Text, ForeignKey, Boolean, DateTime
from sqlalchemy.orm import declarative_base, relationship
Base = declarative_base()
class Bank(Base):
"""A user's currency information"""
__tablename__ = "bank"
bank_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
dinks: int = Column(BigInteger, default=0, nullable=False)
# Interest rate
interest_level: int = Column(Integer, default=1, nullable=False)
# Maximum amount that can be stored in the bank
capacity_level: int = Column(Integer, default=1, nullable=False)
# Maximum amount that can be robbed
rob_level: int = Column(Integer, default=1, nullable=False)
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't"""
@ -36,6 +59,19 @@ class CustomCommandAlias(Base):
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
class NightlyData(Base):
"""Data for a user's Nightly stats"""
__tablename__ = "nightly_data"
nightly_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True)
count: int = Column(Integer, default=0, nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
class UforaCourse(Base):
"""A course on Ufora"""
@ -72,8 +108,23 @@ class UforaAnnouncement(Base):
__tablename__ = "ufora_announcements"
announcement_id = Column(Integer, primary_key=True)
course_id = Column(Integer, ForeignKey("ufora_courses.course_id"))
announcement_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
publication_date: datetime = Column(DateTime(timezone=True))
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")
class User(Base):
"""A Didier user"""
__tablename__ = "users"
user_id: int = Column(BigInteger, primary_key=True)
bank: Bank = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)

View File

@ -0,0 +1,64 @@
import typing
import discord
from discord.ext import commands
from database.crud import currency as crud
from database.exceptions.currency import DoubleNightly
from didier import Didier
from didier.utils.discord.checks import is_owner
from didier.utils.discord.converters import abbreviated_number
from didier.utils.types.string import pluralize
class Currency(commands.Cog):
"""Everything Dinks-related"""
client: Didier
def __init__(self, client: Didier):
super().__init__()
self.client = client
@commands.command(name="Award")
@commands.check(is_owner)
async def award(self, ctx: commands.Context, user: discord.User, amount: abbreviated_number): # type: ignore
"""Award a user a given amount of Didier Dinks"""
amount = typing.cast(int, amount)
async with self.client.db_session as session:
await crud.add_dinks(session, user.id, amount)
plural = pluralize("Didier Dink", amount)
await ctx.reply(
f"**{ctx.author.display_name}** heeft **{user.display_name}** **{amount}** {plural} geschonken.",
mention_author=False,
)
@commands.hybrid_group(name="bank", case_insensitive=True, invoke_without_command=True)
async def bank(self, ctx: commands.Context):
"""Show your Didier Bank information"""
async with self.client.db_session as session:
await crud.get_bank(session, ctx.author.id)
@commands.hybrid_command(name="dinks")
async def dinks(self, ctx: commands.Context):
"""Check your Didier Dinks"""
async with self.client.db_session as session:
bank = await crud.get_bank(session, ctx.author.id)
plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
@commands.hybrid_command(name="nightly")
async def nightly(self, ctx: commands.Context):
"""Claim nightly Dinks"""
async with self.client.db_session as session:
try:
await crud.claim_nightly(session, ctx.author.id)
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")
except DoubleNightly:
await ctx.reply("Je hebt je nightly al geclaimd vandaag.", mention_author=False, ephemeral=True)
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Currency(client))

View File

@ -1,7 +1,7 @@
from discord.ext import commands
class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"):
class PosixFlags(commands.FlagConverter, delimiter=" ", prefix="--"): # type: ignore
"""Base class to add POSIX-like flags to commands
Example usage:

View File

@ -1,4 +1,5 @@
import traceback
import typing
import discord
@ -23,7 +24,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session:
command = await create_command(session, self.name.value, self.response.value)
command = await create_command(session, str(self.name.value), str(self.response.value))
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
@ -49,7 +50,6 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
self.original_name = name
self.client = client
# TODO find a way to access these items
self.add_item(discord.ui.TextInput(label="Name", placeholder="Didier", default=name))
self.add_item(
discord.ui.TextInput(
@ -58,8 +58,11 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
)
async def on_submit(self, interaction: discord.Interaction):
name_field = typing.cast(discord.ui.TextInput, self.children[0])
response_field = typing.cast(discord.ui.TextInput, self.children[1])
async with self.client.db_session as session:
await edit_command(session, self.original_name, self.name.value, self.response.value)
await edit_command(session, self.original_name, name_field.value, response_field.value)
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)

View File

@ -0,0 +1 @@
from .message_commands import is_owner

View File

@ -0,0 +1,6 @@
from discord.ext import commands
async def is_owner(ctx: commands.Context) -> bool:
"""Check that a command is being invoked by the owner of the bot"""
return await ctx.bot.is_owner(ctx.author)

View File

@ -0,0 +1 @@
from .numbers import *

View File

@ -0,0 +1,46 @@
import math
from typing import Optional
__all__ = ["abbreviated_number"]
def abbreviated_number(argument: str) -> int:
"""Custom converter to allow numbers to be abbreviated
Examples:
515k
4m
"""
if not argument:
raise ValueError
if argument.isdecimal():
return int(argument)
units = {"k": 3, "m": 6, "b": 9, "t": 12}
# Get the unit if there is one, then chop it off
value: Optional[int] = None
if not argument[-1].isdigit():
if argument[-1].lower() not in units:
raise ValueError
unit = argument[-1].lower()
value = units.get(unit)
argument = argument[:-1]
# [int][unit]
if "." not in argument and value is not None:
return int(argument) * (10**value)
# [float][unit]
if "." in argument:
# Floats themselves are not supported
if value is None:
raise ValueError
as_float = float(argument)
return math.floor(as_float * (10**value))
# Unparseable
raise ValueError

View File

@ -1,3 +1,3 @@
def int_to_weekday(number: int) -> str:
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
"""Get the Dutch name of a weekday from the number"""
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]

View File

@ -1,9 +1,10 @@
import math
from typing import Optional
def leading(character: str, string: str, target_length: Optional[int] = 2) -> str:
"""Add a leading [character] to [string] to make it length [target_length]
Pass None to target length to always do it, no matter the length
Pass None to target length to always do it (once), no matter the length
"""
# Cast to string just in case
string = str(string)
@ -16,6 +17,14 @@ def leading(character: str, string: str, target_length: Optional[int] = 2) -> st
if len(string) >= target_length:
return string
frequency = (target_length - len(string)) // len(character)
frequency = math.ceil((target_length - len(string)) / len(character))
return (frequency * character) + string
def pluralize(word: str, amount: int, plural_form: Optional[str] = None) -> str:
"""Turn a word into plural"""
if amount == 1:
return word
return plural_form or (word + "s")

View File

@ -1,6 +1,22 @@
[tool.black]
line-length = 120
[tool.coverage.run]
concurrency = [
"greenlet"
]
source = [
"didier",
"database"
]
omit = [
"./tests/*",
"./database/migrations.py",
"./didier/cogs/*",
"./didier/didier.py",
"./didier/data/*"
]
[tool.mypy]
plugins = [
"sqlalchemy.ext.mypy.plugin"

View File

@ -1,4 +1,5 @@
black==22.3.0
coverage[toml]==6.4.1
mypy==0.961
pylint==2.14.1
pytest==7.1.2

View File

@ -0,0 +1,67 @@
import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_announcements as crud
from database.models import UforaAnnouncement, UforaCourse
@pytest.fixture
async def course(database_session: AsyncSession) -> UforaCourse:
"""Fixture to create a course"""
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
database_session.add(course)
await database_session.commit()
return course
@pytest.fixture
async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement:
"""Fixture to create an announcement"""
announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now())
database_session.add(announcement)
await database_session.commit()
return announcement
async def test_get_courses_with_announcements_none(database_session: AsyncSession):
"""Test getting all courses with announcements when there are none"""
results = await crud.get_courses_with_announcements(database_session)
assert len(results) == 0
async def test_get_courses_with_announcements(database_session: AsyncSession):
"""Test getting all courses with announcements"""
course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True)
course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False)
database_session.add_all([course_1, course_2])
await database_session.commit()
results = await crud.get_courses_with_announcements(database_session)
assert len(results) == 1
assert results[0] == course_1
async def test_create_new_announcement(course: UforaCourse, database_session: AsyncSession):
"""Test creating a new announcement"""
await crud.create_new_announcement(database_session, 1, course=course, publication_date=datetime.datetime.now())
await database_session.refresh(course)
assert len(course.announcements) == 1
async def test_remove_old_announcements(announcement: UforaAnnouncement, database_session: AsyncSession):
"""Test removing all stale announcements"""
course = announcement.course
announcement.publication_date -= datetime.timedelta(weeks=2)
announcement_2 = UforaAnnouncement(course_id=announcement.course_id, publication_date=datetime.datetime.now())
database_session.add_all([announcement, announcement_2])
await database_session.commit()
await database_session.refresh(course)
assert len(course.announcements) == 2
await crud.remove_old_announcements(database_session)
await database_session.refresh(course)
assert len(course.announcements) == 1
assert announcement_2.course.announcements[0] == announcement_2

View File

@ -0,0 +1,25 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users as crud
from database.models import User
async def test_get_or_add_non_existing(database_session: AsyncSession):
"""Test get_or_add for a user that doesn't exist"""
await crud.get_or_add(database_session, 1)
statement = select(User)
res = (await database_session.execute(statement)).scalars().all()
assert len(res) == 1
assert res[0].bank is not None
assert res[0].nightly_data is not None
async def test_get_or_add_existing(database_session: AsyncSession):
"""Test get_or_add for a user that does exist"""
user = await crud.get_or_add(database_session, 1)
bank = user.bank
assert await crud.get_or_add(database_session, 1) == user
assert (await crud.get_or_add(database_session, 1)).bank == bank

View File

@ -0,0 +1,50 @@
import pytest
from didier.utils.discord.converters import numbers
def test_abbreviated_int():
"""Test abbreviated_number for a regular int"""
assert numbers.abbreviated_number("500") == 500
def test_abbreviated_float_errors():
"""Test abbreviated_number for a float"""
with pytest.raises(ValueError):
numbers.abbreviated_number("5.4")
def test_abbreviated_int_unit():
"""Test abbreviated_number for an int combined with a unit"""
assert numbers.abbreviated_number("20k") == 20000
def test_abbreviated_int_unknown_unit():
"""Test abbreviated_number for an int combined with an unknown unit"""
with pytest.raises(ValueError):
numbers.abbreviated_number("20p")
def test_abbreviated_float_unit():
"""Test abbreviated_number for a float combined with a unit"""
assert numbers.abbreviated_number("20.5k") == 20500
def test_abbreviated_float_unknown_unit():
"""Test abbreviated_number for a float combined with an unknown unit"""
with pytest.raises(ValueError):
numbers.abbreviated_number("20.5p")
def test_abbreviated_no_number():
"""Test abbreviated_number for unparseable content"""
with pytest.raises(ValueError):
numbers.abbreviated_number("didier")
def test_abbreviated_float_floors():
"""Test abbreviated_number for a float that is longer than the unit
Example:
5.3k is 5300, but 5.3001k is 5300.1
"""
assert numbers.abbreviated_number("5.3001k") == 5300

View File

@ -0,0 +1,22 @@
from didier.utils.types.string import leading
def test_leading():
"""Test leading() when it actually does something"""
assert leading("0", "5") == "05"
assert leading("0", "5", target_length=3) == "005"
def test_leading_not_necessary():
"""Test leading() when the input is already long enough"""
assert leading("0", "05") == "05"
def test_leading_no_exact():
"""Test leading() when adding would bring you over the required length"""
assert leading("ab", "c", target_length=6) == "abababc"
def test_leading_no_target_length():
"""Test leading() when target_length is None"""
assert leading("0", "05", target_length=None) == "005"