mirror of https://github.com/stijndcl/didier
Compare commits
2 Commits
016d87bcea
...
adcf94c66e
| Author | SHA1 | Date |
|---|---|---|
|
|
adcf94c66e | |
|
|
f49f32d2e9 |
11
.flake8
11
.flake8
|
|
@ -8,7 +8,6 @@ exclude =
|
|||
__pycache__,
|
||||
alembic,
|
||||
htmlcov,
|
||||
tests,
|
||||
venv
|
||||
# Disable rules that we don't care about (or conflict with others)
|
||||
extend-ignore =
|
||||
|
|
@ -30,10 +29,14 @@ ignore-decorators=overrides
|
|||
max-line-length = 120
|
||||
# Disable some rules for entire files
|
||||
per-file-ignores =
|
||||
# Missing __all__, main isn't supposed to be imported
|
||||
# DALL000: Missing __all__, main isn't supposed to be imported
|
||||
main.py: DALL000,
|
||||
# Missing __all__, Cogs aren't modules
|
||||
# DALL000: Missing __all__, Cogs aren't modules
|
||||
./didier/cogs/*: DALL000,
|
||||
# DALL000: Missing __all__, tests aren't supposed to be imported
|
||||
# S101: Use of assert, this is the point of tests
|
||||
./tests/*: DALL000 S101,
|
||||
# D103: Missing docstring in public function
|
||||
# All of the colours methods are just oneliners to create a colour,
|
||||
# there's no point adding docstrings (function names are enough)
|
||||
./didier/utils/discord/colours.py: D103
|
||||
./didier/utils/discord/colours.py: D103,
|
||||
|
|
|
|||
|
|
@ -34,6 +34,8 @@ repos:
|
|||
rev: 4.0.1
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: ^(alembic|.github)
|
||||
args: [--config, .flake8]
|
||||
additional_dependencies:
|
||||
- "flake8-bandit"
|
||||
- "flake8-bugbear"
|
||||
|
|
|
|||
|
|
@ -0,0 +1,38 @@
|
|||
"""Add birthdays
|
||||
|
||||
Revision ID: 1716bfecf684
|
||||
Revises: 581ae6511b98
|
||||
Create Date: 2022-07-19 21:46:42.796349
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1716bfecf684"
|
||||
down_revision = "581ae6511b98"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"birthdays",
|
||||
sa.Column("birthday_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("birthday", sa.DateTime(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("birthday_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("birthdays")
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -0,0 +1,34 @@
|
|||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users
|
||||
from database.models import Birthday
|
||||
|
||||
__all__ = ["add_birthday", "get_birthday_for_user"]
|
||||
|
||||
|
||||
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
||||
"""Add a user's birthday into the database
|
||||
|
||||
If already present, overwrites the existing one
|
||||
"""
|
||||
user = await users.get_or_add(session, user_id)
|
||||
|
||||
if user.birthday is not None:
|
||||
bd = user.birthday
|
||||
await session.refresh(bd)
|
||||
bd.birthday = birthday
|
||||
else:
|
||||
bd = Birthday(user_id=user_id, birthday=birthday)
|
||||
|
||||
session.add(bd)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional[Birthday]:
|
||||
"""Find a user's birthday"""
|
||||
statement = select(Birthday).where(Birthday.user_id == user_id)
|
||||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
|
@ -12,6 +12,7 @@ Base = declarative_base()
|
|||
__all__ = [
|
||||
"Base",
|
||||
"Bank",
|
||||
"Birthday",
|
||||
"CustomCommand",
|
||||
"CustomCommandAlias",
|
||||
"DadJoke",
|
||||
|
|
@ -46,6 +47,18 @@ class Bank(Base):
|
|||
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
|
||||
|
||||
|
||||
class Birthday(Base):
|
||||
"""A user's birthday"""
|
||||
|
||||
__tablename__ = "birthdays"
|
||||
|
||||
birthday_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
birthday: datetime = Column(DateTime, nullable=False)
|
||||
|
||||
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
|
||||
|
||||
|
||||
class CustomCommand(Base):
|
||||
"""Custom commands to fill the hole Dyno couldn't"""
|
||||
|
||||
|
|
@ -149,6 +162,9 @@ class User(Base):
|
|||
bank: Bank = relationship(
|
||||
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
birthday: Optional[Birthday] = relationship(
|
||||
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
nightly_data: NightlyData = relationship(
|
||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,49 @@
|
|||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
from database.crud import birthdays
|
||||
from didier import Didier
|
||||
from didier.utils.types.datetime import str_to_date
|
||||
from didier.utils.types.string import leading
|
||||
|
||||
|
||||
class Discord(commands.Cog):
|
||||
"""Cog for commands related to Discord, servers, and members"""
|
||||
|
||||
client: Didier
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.group(name="Birthday", aliases=["Bd", "Birthdays"], case_insensitive=True, invoke_without_command=True)
|
||||
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
||||
"""Command to check the birthday of a user"""
|
||||
user_id = (user and user.id) or ctx.author.id
|
||||
async with self.client.db_session as session:
|
||||
birthday = await birthdays.get_birthday_for_user(session, user_id)
|
||||
|
||||
name = "Jouw" if user is None else f"{user.display_name}'s"
|
||||
|
||||
if birthday is None:
|
||||
return await ctx.reply(f"{name} verjaardag zit niet in de database.", mention_author=False)
|
||||
|
||||
day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month))
|
||||
|
||||
return await ctx.reply(f"{name} verjaardag staat ingesteld op **{day}/{month}**.", mention_author=False)
|
||||
|
||||
@birthday.command(name="Set", aliases=["Config"])
|
||||
async def birthday_set(self, ctx: commands.Context, date_str: str):
|
||||
"""Command to set your birthday"""
|
||||
try:
|
||||
date = str_to_date(date_str)
|
||||
except ValueError:
|
||||
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
||||
|
||||
async with self.client.db_session as session:
|
||||
await birthdays.add_birthday(session, ctx.author.id, date)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(Discord(client))
|
||||
|
|
@ -1,6 +1,13 @@
|
|||
__all__ = ["int_to_weekday"]
|
||||
import datetime
|
||||
|
||||
__all__ = ["int_to_weekday", "str_to_date"]
|
||||
|
||||
|
||||
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
|
||||
"""Get the Dutch name of a weekday from the number"""
|
||||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
||||
|
||||
|
||||
def str_to_date(date_str: str) -> datetime.date:
|
||||
"""Turn a string into a DD/MM/YYYY date"""
|
||||
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()
|
||||
|
|
|
|||
|
|
@ -42,6 +42,12 @@ docker-compose up -d db-pytest
|
|||
# Starting Didier
|
||||
python3 main.py
|
||||
|
||||
# Running database migrations
|
||||
alembic upgrade head
|
||||
|
||||
# Creating a new database migration
|
||||
alembic revision --autogenerate -m "Revision message here"
|
||||
|
||||
# Running tests
|
||||
pytest
|
||||
|
||||
|
|
|
|||
|
|
@ -12,6 +12,10 @@ from didier import Didier
|
|||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
def event_loop() -> Generator:
|
||||
"""Fixture to change the event loop
|
||||
|
||||
This fixes a lot of headaches during async tests
|
||||
"""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
|
@ -33,6 +37,7 @@ async def tables():
|
|||
@pytest.fixture
|
||||
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Fixture to create a session for every test
|
||||
|
||||
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
||||
"""
|
||||
connection = await engine.connect()
|
||||
|
|
@ -52,6 +57,7 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
|||
@pytest.fixture
|
||||
def mock_client() -> Didier:
|
||||
"""Fixture to get a mock Didier instance
|
||||
|
||||
The mock uses 0 as the id
|
||||
"""
|
||||
mock_client = MagicMock()
|
||||
|
|
|
|||
|
|
@ -3,7 +3,33 @@ import datetime
|
|||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias
|
||||
from database.crud import users
|
||||
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def test_user_id() -> int:
|
||||
"""User id used when creating the debug user
|
||||
|
||||
Fixture is useful when comparing, fetching data, ...
|
||||
"""
|
||||
return 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def user(database_session: AsyncSession, test_user_id) -> User:
|
||||
"""Fixture to create a user"""
|
||||
_user = await users.get_or_add(database_session, test_user_id)
|
||||
await database_session.refresh(_user)
|
||||
return _user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bank(database_session: AsyncSession, user: User) -> Bank:
|
||||
"""Fixture to fetch the test user's bank"""
|
||||
_bank = user.bank
|
||||
await database_session.refresh(_bank)
|
||||
return _bank
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -0,0 +1,47 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import birthdays as crud
|
||||
from database.models import User
|
||||
|
||||
|
||||
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
|
||||
"""Test setting a user's birthday when it doesn't exist yet"""
|
||||
assert user.birthday is None
|
||||
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
assert user.birthday is not None
|
||||
assert user.birthday.birthday.date() == bd_date
|
||||
|
||||
|
||||
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
|
||||
"""Test that setting a user's birthday when it already exists overwrites it"""
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
assert user.birthday is not None
|
||||
|
||||
new_bd_date = bd_date + timedelta(weeks=1)
|
||||
await crud.add_birthday(database_session, user.user_id, new_bd_date)
|
||||
await database_session.refresh(user)
|
||||
assert user.birthday.birthday.date() == new_bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
|
||||
"""Test getting a user's birthday when it exists"""
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
|
||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||
assert bd is not None
|
||||
assert bd.birthday.date() == bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
|
||||
"""Test getting a user's birthday when it doesn't exist"""
|
||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||
assert bd is None
|
||||
|
|
@ -6,16 +6,6 @@ from database.exceptions import currency as exceptions
|
|||
from database.models import Bank
|
||||
|
||||
|
||||
DEBUG_USER_ID = 1
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bank(database_session: AsyncSession) -> Bank:
|
||||
_bank = await crud.get_bank(database_session, DEBUG_USER_ID)
|
||||
await database_session.refresh(_bank)
|
||||
return _bank
|
||||
|
||||
|
||||
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
|
||||
"""Test adding dinks to an account"""
|
||||
assert bank.dinks == 0
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_announcements as crud
|
||||
|
|
|
|||
|
|
@ -44,6 +44,7 @@ def test_abbreviated_no_number():
|
|||
|
||||
def test_abbreviated_float_floors():
|
||||
"""Test abbreviated_number for a float that is longer than the unit
|
||||
|
||||
Example:
|
||||
5.3k is 5300, but 5.3001k is 5300.1
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue