diff --git a/.flake8 b/.flake8 index cab8ba8..1707912 100644 --- a/.flake8 +++ b/.flake8 @@ -8,7 +8,6 @@ exclude = __pycache__, alembic, htmlcov, - tests, venv # Disable rules that we don't care about (or conflict with others) extend-ignore = @@ -30,10 +29,14 @@ ignore-decorators=overrides max-line-length = 120 # Disable some rules for entire files per-file-ignores = - # Missing __all__, main isn't supposed to be imported + # DALL000: Missing __all__, main isn't supposed to be imported main.py: DALL000, - # Missing __all__, Cogs aren't modules + # DALL000: Missing __all__, Cogs aren't modules ./didier/cogs/*: DALL000, + # DALL000: Missing __all__, tests aren't supposed to be imported + # S101: Use of assert, this is the point of tests + ./tests/*: DALL000 S101, + # D103: Missing docstring in public function # All of the colours methods are just oneliners to create a colour, # there's no point adding docstrings (function names are enough) - ./didier/utils/discord/colours.py: D103 + ./didier/utils/discord/colours.py: D103, diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6380e4b..10eabb9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,8 @@ repos: rev: 4.0.1 hooks: - id: flake8 + exclude: ^(alembic|.github) + args: [--config, .flake8] additional_dependencies: - "flake8-bandit" - "flake8-bugbear" diff --git a/alembic/versions/1716bfecf684_add_birthdays.py b/alembic/versions/1716bfecf684_add_birthdays.py new file mode 100644 index 0000000..9065993 --- /dev/null +++ b/alembic/versions/1716bfecf684_add_birthdays.py @@ -0,0 +1,38 @@ +"""Add birthdays + +Revision ID: 1716bfecf684 +Revises: 581ae6511b98 +Create Date: 2022-07-19 21:46:42.796349 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "1716bfecf684" +down_revision = "581ae6511b98" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "birthdays", + sa.Column("birthday_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("birthday", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("birthday_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("birthdays") + # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py new file mode 100644 index 0000000..99ea2db --- /dev/null +++ b/database/crud/birthdays.py @@ -0,0 +1,34 @@ +from datetime import date +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import users +from database.models import Birthday + +__all__ = ["add_birthday", "get_birthday_for_user"] + + +async def add_birthday(session: AsyncSession, user_id: int, birthday: date): + """Add a user's birthday into the database + + If already present, overwrites the existing one + """ + user = await users.get_or_add(session, user_id) + + if user.birthday is not None: + bd = user.birthday + await session.refresh(bd) + bd.birthday = birthday + else: + bd = Birthday(user_id=user_id, birthday=birthday) + + session.add(bd) + await session.commit() + + +async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]: + """Find a user's birthday""" + statement = select(Birthday).where(Birthday.user_id == user_id) + return (await session.execute(statement)).scalar_one_or_none() diff --git a/database/models.py b/database/models.py index d663204..74aa5fc 100644 --- a/database/models.py +++ b/database/models.py @@ -12,6 +12,7 @@ Base = declarative_base() __all__ = [ "Base", "Bank", + "Birthday", "CustomCommand", "CustomCommandAlias", "DadJoke", @@ -46,6 +47,18 @@ class Bank(Base): user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") +class Birthday(Base): + """A user's birthday""" + + __tablename__ = "birthdays" + + birthday_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + birthday: datetime = Column(DateTime, nullable=False) + + user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") + + class CustomCommand(Base): """Custom commands to fill the hole Dyno couldn't""" @@ -149,6 +162,9 @@ class User(Base): bank: Bank = relationship( "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) + birthday: Optional[Birthday] = relationship( + "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + ) nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py new file mode 100644 index 0000000..1152922 --- /dev/null +++ b/didier/cogs/discord.py @@ -0,0 +1,49 @@ +import discord +from discord.ext import commands + +from database.crud import birthdays +from didier import Didier +from didier.utils.types.datetime import str_to_date +from didier.utils.types.string import leading + + +class Discord(commands.Cog): + """Cog for commands related to Discord, servers, and members""" + + client: Didier + + def __init__(self, client: Didier): + self.client = client + + @commands.group(name="Birthday", aliases=["Bd", "Birthdays"], case_insensitive=True, invoke_without_command=True) + async def birthday(self, ctx: commands.Context, user: discord.User = None): + """Command to check the birthday of a user""" + user_id = (user and user.id) or ctx.author.id + async with self.client.db_session as session: + birthday = await birthdays.get_birthday_for_user(session, user_id) + + name = "Jouw" if user is None else f"{user.display_name}'s" + + if birthday is None: + return await ctx.reply(f"{name} verjaardag zit niet in de database.", mention_author=False) + + day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) + + return await ctx.reply(f"{name} verjaardag staat ingesteld op **{day}/{month}**.", mention_author=False) + + @birthday.command(name="Set", aliases=["Config"]) + async def birthday_set(self, ctx: commands.Context, date_str: str): + """Command to set your birthday""" + try: + date = str_to_date(date_str) + except ValueError: + return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) + + async with self.client.db_session as session: + await birthdays.add_birthday(session, ctx.author.id, date) + await self.client.confirm_message(ctx.message) + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(Discord(client)) diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index 3701b4f..91c3583 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -1,6 +1,13 @@ -__all__ = ["int_to_weekday"] +import datetime + +__all__ = ["int_to_weekday", "str_to_date"] def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this """Get the Dutch name of a weekday from the number""" return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] + + +def str_to_date(date_str: str) -> datetime.date: + """Turn a string into a DD/MM/YYYY date""" + return datetime.datetime.strptime(date_str, "%d/%m/%Y").date() diff --git a/readme.md b/readme.md index 4fd4143..dae8347 100644 --- a/readme.md +++ b/readme.md @@ -42,6 +42,12 @@ docker-compose up -d db-pytest # Starting Didier python3 main.py +# Running database migrations +alembic upgrade head + +# Creating a new database migration +alembic revision --autogenerate -m "Revision message here" + # Running tests pytest diff --git a/tests/conftest.py b/tests/conftest.py index 95b44db..8530de5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,10 @@ from didier import Didier @pytest.fixture(scope="session", autouse=True) def event_loop() -> Generator: + """Fixture to change the event loop + + This fixes a lot of headaches during async tests + """ loop = asyncio.get_event_loop_policy().new_event_loop() yield loop loop.close() @@ -33,6 +37,7 @@ async def tables(): @pytest.fixture async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: """Fixture to create a session for every test + Rollbacks the transaction afterwards so that the future tests start with a clean database """ connection = await engine.connect() @@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: @pytest.fixture def mock_client() -> Didier: """Fixture to get a mock Didier instance + The mock uses 0 as the id """ mock_client = MagicMock() diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index de1e939..8bc765c 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -3,7 +3,33 @@ import datetime import pytest from sqlalchemy.ext.asyncio import AsyncSession -from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias +from database.crud import users +from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User + + +@pytest.fixture(scope="session") +def test_user_id() -> int: + """User id used when creating the debug user + + Fixture is useful when comparing, fetching data, ... + """ + return 1 + + +@pytest.fixture +async def user(database_session: AsyncSession, test_user_id) -> User: + """Fixture to create a user""" + _user = await users.get_or_add(database_session, test_user_id) + await database_session.refresh(_user) + return _user + + +@pytest.fixture +async def bank(database_session: AsyncSession, user: User) -> Bank: + """Fixture to fetch the test user's bank""" + _bank = user.bank + await database_session.refresh(_bank) + return _bank @pytest.fixture diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py new file mode 100644 index 0000000..96b924c --- /dev/null +++ b/tests/test_database/test_crud/test_birthdays.py @@ -0,0 +1,47 @@ +from datetime import datetime, timedelta + +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import birthdays as crud +from database.models import User + + +async def test_add_birthday_not_present(database_session: AsyncSession, user: User): + """Test setting a user's birthday when it doesn't exist yet""" + assert user.birthday is None + + bd_date = datetime.today().date() + await crud.add_birthday(database_session, user.user_id, bd_date) + await database_session.refresh(user) + assert user.birthday is not None + assert user.birthday.birthday.date() == bd_date + + +async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): + """Test that setting a user's birthday when it already exists overwrites it""" + bd_date = datetime.today().date() + await crud.add_birthday(database_session, user.user_id, bd_date) + await database_session.refresh(user) + assert user.birthday is not None + + new_bd_date = bd_date + timedelta(weeks=1) + await crud.add_birthday(database_session, user.user_id, new_bd_date) + await database_session.refresh(user) + assert user.birthday.birthday.date() == new_bd_date + + +async def test_get_birthday_exists(database_session: AsyncSession, user: User): + """Test getting a user's birthday when it exists""" + bd_date = datetime.today().date() + await crud.add_birthday(database_session, user.user_id, bd_date) + await database_session.refresh(user) + + bd = await crud.get_birthday_for_user(database_session, user.user_id) + assert bd is not None + assert bd.birthday.date() == bd_date + + +async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): + """Test getting a user's birthday when it doesn't exist""" + bd = await crud.get_birthday_for_user(database_session, user.user_id) + assert bd is None diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index f996bf9..1f0a163 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions from database.models import Bank -DEBUG_USER_ID = 1 - - -@pytest.fixture -async def bank(database_session: AsyncSession) -> Bank: - _bank = await crud.get_bank(database_session, DEBUG_USER_ID) - await database_session.refresh(_bank) - return _bank - - async def test_add_dinks(database_session: AsyncSession, bank: Bank): """Test adding dinks to an account""" assert bank.dinks == 0 diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index b2385a2..4e6fc47 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -1,6 +1,5 @@ import datetime -import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_announcements as crud diff --git a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py index ed88692..75ab401 100644 --- a/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py +++ b/tests/test_didier/test_utils/test_discord/test_converters/test_numbers.py @@ -44,6 +44,7 @@ def test_abbreviated_no_number(): def test_abbreviated_float_floors(): """Test abbreviated_number for a float that is longer than the unit + Example: 5.3k is 5300, but 5.3001k is 5300.1 """