Compare commits

...

2 Commits

Author SHA1 Message Date
stijndcl adcf94c66e Tests for birthday commands, overwrite existing birthdays 2022-07-19 23:35:41 +02:00
stijndcl f49f32d2e9 Add birthday commands 2022-07-19 22:58:59 +02:00
14 changed files with 241 additions and 17 deletions

11
.flake8
View File

@ -8,7 +8,6 @@ exclude =
__pycache__, __pycache__,
alembic, alembic,
htmlcov, htmlcov,
tests,
venv venv
# Disable rules that we don't care about (or conflict with others) # Disable rules that we don't care about (or conflict with others)
extend-ignore = extend-ignore =
@ -30,10 +29,14 @@ ignore-decorators=overrides
max-line-length = 120 max-line-length = 120
# Disable some rules for entire files # Disable some rules for entire files
per-file-ignores = per-file-ignores =
# Missing __all__, main isn't supposed to be imported # DALL000: Missing __all__, main isn't supposed to be imported
main.py: DALL000, main.py: DALL000,
# Missing __all__, Cogs aren't modules # DALL000: Missing __all__, Cogs aren't modules
./didier/cogs/*: DALL000, ./didier/cogs/*: DALL000,
# DALL000: Missing __all__, tests aren't supposed to be imported
# S101: Use of assert, this is the point of tests
./tests/*: DALL000 S101,
# D103: Missing docstring in public function
# All of the colours methods are just oneliners to create a colour, # All of the colours methods are just oneliners to create a colour,
# there's no point adding docstrings (function names are enough) # there's no point adding docstrings (function names are enough)
./didier/utils/discord/colours.py: D103 ./didier/utils/discord/colours.py: D103,

View File

@ -34,6 +34,8 @@ repos:
rev: 4.0.1 rev: 4.0.1
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^(alembic|.github)
args: [--config, .flake8]
additional_dependencies: additional_dependencies:
- "flake8-bandit" - "flake8-bandit"
- "flake8-bugbear" - "flake8-bugbear"

View File

@ -0,0 +1,38 @@
"""Add birthdays
Revision ID: 1716bfecf684
Revises: 581ae6511b98
Create Date: 2022-07-19 21:46:42.796349
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "1716bfecf684"
down_revision = "581ae6511b98"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"birthdays",
sa.Column("birthday_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("birthday", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("birthday_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("birthdays")
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
from datetime import date
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.models import Birthday
__all__ = ["add_birthday", "get_birthday_for_user"]
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
"""Add a user's birthday into the database
If already present, overwrites the existing one
"""
user = await users.get_or_add(session, user_id)
if user.birthday is not None:
bd = user.birthday
await session.refresh(bd)
bd.birthday = birthday
else:
bd = Birthday(user_id=user_id, birthday=birthday)
session.add(bd)
await session.commit()
async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]:
"""Find a user's birthday"""
statement = select(Birthday).where(Birthday.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()

View File

@ -12,6 +12,7 @@ Base = declarative_base()
__all__ = [ __all__ = [
"Base", "Base",
"Bank", "Bank",
"Birthday",
"CustomCommand", "CustomCommand",
"CustomCommandAlias", "CustomCommandAlias",
"DadJoke", "DadJoke",
@ -46,6 +47,18 @@ class Bank(Base):
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
class Birthday(Base):
"""A user's birthday"""
__tablename__ = "birthdays"
birthday_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
birthday: datetime = Column(DateTime, nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
class CustomCommand(Base): class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't""" """Custom commands to fill the hole Dyno couldn't"""
@ -149,6 +162,9 @@ class User(Base):
bank: Bank = relationship( bank: Bank = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
birthday: Optional[Birthday] = relationship(
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship( nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )

View File

@ -0,0 +1,49 @@
import discord
from discord.ext import commands
from database.crud import birthdays
from didier import Didier
from didier.utils.types.datetime import str_to_date
from didier.utils.types.string import leading
class Discord(commands.Cog):
"""Cog for commands related to Discord, servers, and members"""
client: Didier
def __init__(self, client: Didier):
self.client = client
@commands.group(name="Birthday", aliases=["Bd", "Birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: discord.User = None):
"""Command to check the birthday of a user"""
user_id = (user and user.id) or ctx.author.id
async with self.client.db_session as session:
birthday = await birthdays.get_birthday_for_user(session, user_id)
name = "Jouw" if user is None else f"{user.display_name}'s"
if birthday is None:
return await ctx.reply(f"{name} verjaardag zit niet in de database.", mention_author=False)
day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month))
return await ctx.reply(f"{name} verjaardag staat ingesteld op **{day}/{month}**.", mention_author=False)
@birthday.command(name="Set", aliases=["Config"])
async def birthday_set(self, ctx: commands.Context, date_str: str):
"""Command to set your birthday"""
try:
date = str_to_date(date_str)
except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
async with self.client.db_session as session:
await birthdays.add_birthday(session, ctx.author.id, date)
await self.client.confirm_message(ctx.message)
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Discord(client))

View File

@ -1,6 +1,13 @@
__all__ = ["int_to_weekday"] import datetime
__all__ = ["int_to_weekday", "str_to_date"]
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
"""Get the Dutch name of a weekday from the number""" """Get the Dutch name of a weekday from the number"""
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
def str_to_date(date_str: str) -> datetime.date:
"""Turn a string into a DD/MM/YYYY date"""
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()

View File

@ -42,6 +42,12 @@ docker-compose up -d db-pytest
# Starting Didier # Starting Didier
python3 main.py python3 main.py
# Running database migrations
alembic upgrade head
# Creating a new database migration
alembic revision --autogenerate -m "Revision message here"
# Running tests # Running tests
pytest pytest

View File

@ -12,6 +12,10 @@ from didier import Didier
@pytest.fixture(scope="session", autouse=True) @pytest.fixture(scope="session", autouse=True)
def event_loop() -> Generator: def event_loop() -> Generator:
"""Fixture to change the event loop
This fixes a lot of headaches during async tests
"""
loop = asyncio.get_event_loop_policy().new_event_loop() loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop yield loop
loop.close() loop.close()
@ -33,6 +37,7 @@ async def tables():
@pytest.fixture @pytest.fixture
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
"""Fixture to create a session for every test """Fixture to create a session for every test
Rollbacks the transaction afterwards so that the future tests start with a clean database Rollbacks the transaction afterwards so that the future tests start with a clean database
""" """
connection = await engine.connect() connection = await engine.connect()
@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
@pytest.fixture @pytest.fixture
def mock_client() -> Didier: def mock_client() -> Didier:
"""Fixture to get a mock Didier instance """Fixture to get a mock Didier instance
The mock uses 0 as the id The mock uses 0 as the id
""" """
mock_client = MagicMock() mock_client = MagicMock()

View File

@ -3,7 +3,33 @@ import datetime
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias from database.crud import users
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
@pytest.fixture(scope="session")
def test_user_id() -> int:
"""User id used when creating the debug user
Fixture is useful when comparing, fetching data, ...
"""
return 1
@pytest.fixture
async def user(database_session: AsyncSession, test_user_id) -> User:
"""Fixture to create a user"""
_user = await users.get_or_add(database_session, test_user_id)
await database_session.refresh(_user)
return _user
@pytest.fixture
async def bank(database_session: AsyncSession, user: User) -> Bank:
"""Fixture to fetch the test user's bank"""
_bank = user.bank
await database_session.refresh(_bank)
return _bank
@pytest.fixture @pytest.fixture

View File

@ -0,0 +1,47 @@
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud
from database.models import User
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
"""Test setting a user's birthday when it doesn't exist yet"""
assert user.birthday is None
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
assert user.birthday.birthday.date() == bd_date
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
"""Test that setting a user's birthday when it already exists overwrites it"""
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
new_bd_date = bd_date + timedelta(weeks=1)
await crud.add_birthday(database_session, user.user_id, new_bd_date)
await database_session.refresh(user)
assert user.birthday.birthday.date() == new_bd_date
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it exists"""
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is not None
assert bd.birthday.date() == bd_date
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it doesn't exist"""
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is None

View File

@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions
from database.models import Bank from database.models import Bank
DEBUG_USER_ID = 1
@pytest.fixture
async def bank(database_session: AsyncSession) -> Bank:
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
await database_session.refresh(_bank)
return _bank
async def test_add_dinks(database_session: AsyncSession, bank: Bank): async def test_add_dinks(database_session: AsyncSession, bank: Bank):
"""Test adding dinks to an account""" """Test adding dinks to an account"""
assert bank.dinks == 0 assert bank.dinks == 0

View File

@ -1,6 +1,5 @@
import datetime import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud

View File

@ -44,6 +44,7 @@ def test_abbreviated_no_number():
def test_abbreviated_float_floors(): def test_abbreviated_float_floors():
"""Test abbreviated_number for a float that is longer than the unit """Test abbreviated_number for a float that is longer than the unit
Example: Example:
5.3k is 5300, but 5.3001k is 5300.1 5.3k is 5300, but 5.3001k is 5300.1
""" """