diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 6380e4b..10eabb9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -34,6 +34,8 @@ repos: rev: 4.0.1 hooks: - id: flake8 + exclude: ^(alembic|.github) + args: [--config, .flake8] additional_dependencies: - "flake8-bandit" - "flake8-bugbear" diff --git a/alembic/versions/1716bfecf684_add_birthdays.py b/alembic/versions/1716bfecf684_add_birthdays.py new file mode 100644 index 0000000..9065993 --- /dev/null +++ b/alembic/versions/1716bfecf684_add_birthdays.py @@ -0,0 +1,38 @@ +"""Add birthdays + +Revision ID: 1716bfecf684 +Revises: 581ae6511b98 +Create Date: 2022-07-19 21:46:42.796349 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "1716bfecf684" +down_revision = "581ae6511b98" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "birthdays", + sa.Column("birthday_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("birthday", sa.DateTime(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("birthday_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("birthdays") + # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py new file mode 100644 index 0000000..6ff714d --- /dev/null +++ b/database/crud/birthdays.py @@ -0,0 +1,22 @@ +from datetime import date +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import Birthday + +__all__ = ["add_birthday", "get_birthday_for_user"] + + +async def add_birthday(session: AsyncSession, user_id: int, birthday: date): + """Add a user's birthday into the database""" + bd = Birthday(user_id=user_id, birthday=birthday) + session.add(bd) + await session.commit() + + +async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]: + """Find a user's birthday""" + statement = select(Birthday).where(Birthday.user_id == user_id) + return (await session.execute(statement)).scalar_one_or_none() diff --git a/database/models.py b/database/models.py index d663204..d906406 100644 --- a/database/models.py +++ b/database/models.py @@ -12,6 +12,7 @@ Base = declarative_base() __all__ = [ "Base", "Bank", + "Birthday", "CustomCommand", "CustomCommandAlias", "DadJoke", @@ -46,6 +47,18 @@ class Bank(Base): user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") +class Birthday(Base): + """A user's birthday""" + + __tablename__ = "birthdays" + + birthday_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + birthday: datetime = Column(DateTime, nullable=False) + + user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") + + class CustomCommand(Base): """Custom commands to fill the hole Dyno couldn't""" @@ -149,6 +162,9 @@ class User(Base): bank: Bank = relationship( "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) + birthday: Birthday = relationship( + "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + ) nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py new file mode 100644 index 0000000..1152922 --- /dev/null +++ b/didier/cogs/discord.py @@ -0,0 +1,49 @@ +import discord +from discord.ext import commands + +from database.crud import birthdays +from didier import Didier +from didier.utils.types.datetime import str_to_date +from didier.utils.types.string import leading + + +class Discord(commands.Cog): + """Cog for commands related to Discord, servers, and members""" + + client: Didier + + def __init__(self, client: Didier): + self.client = client + + @commands.group(name="Birthday", aliases=["Bd", "Birthdays"], case_insensitive=True, invoke_without_command=True) + async def birthday(self, ctx: commands.Context, user: discord.User = None): + """Command to check the birthday of a user""" + user_id = (user and user.id) or ctx.author.id + async with self.client.db_session as session: + birthday = await birthdays.get_birthday_for_user(session, user_id) + + name = "Jouw" if user is None else f"{user.display_name}'s" + + if birthday is None: + return await ctx.reply(f"{name} verjaardag zit niet in de database.", mention_author=False) + + day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) + + return await ctx.reply(f"{name} verjaardag staat ingesteld op **{day}/{month}**.", mention_author=False) + + @birthday.command(name="Set", aliases=["Config"]) + async def birthday_set(self, ctx: commands.Context, date_str: str): + """Command to set your birthday""" + try: + date = str_to_date(date_str) + except ValueError: + return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) + + async with self.client.db_session as session: + await birthdays.add_birthday(session, ctx.author.id, date) + await self.client.confirm_message(ctx.message) + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(Discord(client)) diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index 3701b4f..91c3583 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -1,6 +1,13 @@ -__all__ = ["int_to_weekday"] +import datetime + +__all__ = ["int_to_weekday", "str_to_date"] def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this """Get the Dutch name of a weekday from the number""" return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] + + +def str_to_date(date_str: str) -> datetime.date: + """Turn a string into a DD/MM/YYYY date""" + return datetime.datetime.strptime(date_str, "%d/%m/%Y").date() diff --git a/readme.md b/readme.md index 4fd4143..dae8347 100644 --- a/readme.md +++ b/readme.md @@ -42,6 +42,12 @@ docker-compose up -d db-pytest # Starting Didier python3 main.py +# Running database migrations +alembic upgrade head + +# Creating a new database migration +alembic revision --autogenerate -m "Revision message here" + # Running tests pytest