Compare commits

...

2 Commits

Author SHA1 Message Date
stijndcl adcf94c66e Tests for birthday commands, overwrite existing birthdays 2022-07-19 23:35:41 +02:00
stijndcl f49f32d2e9 Add birthday commands 2022-07-19 22:58:59 +02:00
14 changed files with 241 additions and 17 deletions

11
.flake8
View File

@ -8,7 +8,6 @@ exclude =
__pycache__,
alembic,
htmlcov,
tests,
venv
# Disable rules that we don't care about (or conflict with others)
extend-ignore =
@ -30,10 +29,14 @@ ignore-decorators=overrides
max-line-length = 120
# Disable some rules for entire files
per-file-ignores =
# Missing __all__, main isn't supposed to be imported
# DALL000: Missing __all__, main isn't supposed to be imported
main.py: DALL000,
# Missing __all__, Cogs aren't modules
# DALL000: Missing __all__, Cogs aren't modules
./didier/cogs/*: DALL000,
# DALL000: Missing __all__, tests aren't supposed to be imported
# S101: Use of assert, this is the point of tests
./tests/*: DALL000 S101,
# D103: Missing docstring in public function
# All of the colours methods are just oneliners to create a colour,
# there's no point adding docstrings (function names are enough)
./didier/utils/discord/colours.py: D103
./didier/utils/discord/colours.py: D103,

View File

@ -34,6 +34,8 @@ repos:
rev: 4.0.1
hooks:
- id: flake8
exclude: ^(alembic|.github)
args: [--config, .flake8]
additional_dependencies:
- "flake8-bandit"
- "flake8-bugbear"

View File

@ -0,0 +1,38 @@
"""Add birthdays
Revision ID: 1716bfecf684
Revises: 581ae6511b98
Create Date: 2022-07-19 21:46:42.796349
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "1716bfecf684"
down_revision = "581ae6511b98"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"birthdays",
sa.Column("birthday_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("birthday", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("birthday_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("birthdays")
# ### end Alembic commands ###

View File

@ -0,0 +1,34 @@
from datetime import date
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.models import Birthday
__all__ = ["add_birthday", "get_birthday_for_user"]
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
"""Add a user's birthday into the database
If already present, overwrites the existing one
"""
user = await users.get_or_add(session, user_id)
if user.birthday is not None:
bd = user.birthday
await session.refresh(bd)
bd.birthday = birthday
else:
bd = Birthday(user_id=user_id, birthday=birthday)
session.add(bd)
await session.commit()
async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]:
"""Find a user's birthday"""
statement = select(Birthday).where(Birthday.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()

View File

@ -12,6 +12,7 @@ Base = declarative_base()
__all__ = [
"Base",
"Bank",
"Birthday",
"CustomCommand",
"CustomCommandAlias",
"DadJoke",
@ -46,6 +47,18 @@ class Bank(Base):
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
class Birthday(Base):
"""A user's birthday"""
__tablename__ = "birthdays"
birthday_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
birthday: datetime = Column(DateTime, nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't"""
@ -149,6 +162,9 @@ class User(Base):
bank: Bank = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
birthday: Optional[Birthday] = relationship(
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)

View File

@ -0,0 +1,49 @@
import discord
from discord.ext import commands
from database.crud import birthdays
from didier import Didier
from didier.utils.types.datetime import str_to_date
from didier.utils.types.string import leading
class Discord(commands.Cog):
"""Cog for commands related to Discord, servers, and members"""
client: Didier
def __init__(self, client: Didier):
self.client = client
@commands.group(name="Birthday", aliases=["Bd", "Birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: discord.User = None):
"""Command to check the birthday of a user"""
user_id = (user and user.id) or ctx.author.id
async with self.client.db_session as session:
birthday = await birthdays.get_birthday_for_user(session, user_id)
name = "Jouw" if user is None else f"{user.display_name}'s"
if birthday is None:
return await ctx.reply(f"{name} verjaardag zit niet in de database.", mention_author=False)
day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month))
return await ctx.reply(f"{name} verjaardag staat ingesteld op **{day}/{month}**.", mention_author=False)
@birthday.command(name="Set", aliases=["Config"])
async def birthday_set(self, ctx: commands.Context, date_str: str):
"""Command to set your birthday"""
try:
date = str_to_date(date_str)
except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
async with self.client.db_session as session:
await birthdays.add_birthday(session, ctx.author.id, date)
await self.client.confirm_message(ctx.message)
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Discord(client))

View File

@ -1,6 +1,13 @@
__all__ = ["int_to_weekday"]
import datetime
__all__ = ["int_to_weekday", "str_to_date"]
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
"""Get the Dutch name of a weekday from the number"""
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
def str_to_date(date_str: str) -> datetime.date:
"""Turn a string into a DD/MM/YYYY date"""
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()

View File

@ -42,6 +42,12 @@ docker-compose up -d db-pytest
# Starting Didier
python3 main.py
# Running database migrations
alembic upgrade head
# Creating a new database migration
alembic revision --autogenerate -m "Revision message here"
# Running tests
pytest

View File

@ -12,6 +12,10 @@ from didier import Didier
@pytest.fixture(scope="session", autouse=True)
def event_loop() -> Generator:
"""Fixture to change the event loop
This fixes a lot of headaches during async tests
"""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
@ -33,6 +37,7 @@ async def tables():
@pytest.fixture
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
"""Fixture to create a session for every test
Rollbacks the transaction afterwards so that the future tests start with a clean database
"""
connection = await engine.connect()
@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
@pytest.fixture
def mock_client() -> Didier:
"""Fixture to get a mock Didier instance
The mock uses 0 as the id
"""
mock_client = MagicMock()

View File

@ -3,7 +3,33 @@ import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias
from database.crud import users
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
@pytest.fixture(scope="session")
def test_user_id() -> int:
"""User id used when creating the debug user
Fixture is useful when comparing, fetching data, ...
"""
return 1
@pytest.fixture
async def user(database_session: AsyncSession, test_user_id) -> User:
"""Fixture to create a user"""
_user = await users.get_or_add(database_session, test_user_id)
await database_session.refresh(_user)
return _user
@pytest.fixture
async def bank(database_session: AsyncSession, user: User) -> Bank:
"""Fixture to fetch the test user's bank"""
_bank = user.bank
await database_session.refresh(_bank)
return _bank
@pytest.fixture

View File

@ -0,0 +1,47 @@
from datetime import datetime, timedelta
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud
from database.models import User
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
"""Test setting a user's birthday when it doesn't exist yet"""
assert user.birthday is None
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
assert user.birthday.birthday.date() == bd_date
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
"""Test that setting a user's birthday when it already exists overwrites it"""
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
assert user.birthday is not None
new_bd_date = bd_date + timedelta(weeks=1)
await crud.add_birthday(database_session, user.user_id, new_bd_date)
await database_session.refresh(user)
assert user.birthday.birthday.date() == new_bd_date
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it exists"""
bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date)
await database_session.refresh(user)
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is not None
assert bd.birthday.date() == bd_date
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
"""Test getting a user's birthday when it doesn't exist"""
bd = await crud.get_birthday_for_user(database_session, user.user_id)
assert bd is None

View File

@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions
from database.models import Bank
DEBUG_USER_ID = 1
@pytest.fixture
async def bank(database_session: AsyncSession) -> Bank:
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
await database_session.refresh(_bank)
return _bank
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
"""Test adding dinks to an account"""
assert bank.dinks == 0

View File

@ -1,6 +1,5 @@
import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_announcements as crud

View File

@ -44,6 +44,7 @@ def test_abbreviated_no_number():
def test_abbreviated_float_floors():
"""Test abbreviated_number for a float that is longer than the unit
Example:
5.3k is 5300, but 5.3001k is 5300.1
"""