diff --git a/alembic/versions/0d03c226d881_initial_currency_models.py b/alembic/versions/0d03c226d881_initial_currency_models.py index feec2c1..7478410 100644 --- a/alembic/versions/0d03c226d881_initial_currency_models.py +++ b/alembic/versions/0d03c226d881_initial_currency_models.py @@ -37,7 +37,7 @@ def upgrade() -> None: "nightly_data", sa.Column("nightly_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("last_nightly", sa.Date, nullable=True), + sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True), sa.Column("count", sa.Integer(), server_default="0", nullable=False), sa.ForeignKeyConstraint( ["user_id"], diff --git a/alembic/versions/1716bfecf684_add_birthdays.py b/alembic/versions/1716bfecf684_add_birthdays.py index 5f01615..9065993 100644 --- a/alembic/versions/1716bfecf684_add_birthdays.py +++ b/alembic/versions/1716bfecf684_add_birthdays.py @@ -22,7 +22,7 @@ def upgrade() -> None: "birthdays", sa.Column("birthday_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("birthday", sa.Date, nullable=False), + sa.Column("birthday", sa.DateTime(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["users.user_id"], diff --git a/alembic/versions/346b408c362a_create_tasks.py b/alembic/versions/346b408c362a_create_tasks.py deleted file mode 100644 index f6efeeb..0000000 --- a/alembic/versions/346b408c362a_create_tasks.py +++ /dev/null @@ -1,36 +0,0 @@ -"""Create tasks - -Revision ID: 346b408c362a -Revises: 1716bfecf684 -Create Date: 2022-07-23 19:41:07.029482 - -""" -import sqlalchemy as sa - -from alembic import op - -# revision identifiers, used by Alembic. -revision = "346b408c362a" -down_revision = "1716bfecf684" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "tasks", - sa.Column("task_id", sa.Integer(), nullable=False), - sa.Column("task", sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype"), nullable=False), - sa.Column("previous_run", sa.DateTime(), nullable=True), - sa.PrimaryKeyConstraint("task_id"), - sa.UniqueConstraint("task"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("tasks") - sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype").drop(op.get_bind()) - # ### end Alembic commands ### diff --git a/alembic/versions/4ec79dd5b191_initial_migration.py b/alembic/versions/4ec79dd5b191_initial_migration.py index 2bf8362..186e280 100644 --- a/alembic/versions/4ec79dd5b191_initial_migration.py +++ b/alembic/versions/4ec79dd5b191_initial_migration.py @@ -1,16 +1,16 @@ """Initial migration Revision ID: 4ec79dd5b191 -Revises: +Revises: Create Date: 2022-06-19 00:31:58.384360 """ +from alembic import op import sqlalchemy as sa -from alembic import op # revision identifiers, used by Alembic. -revision = "4ec79dd5b191" +revision = '4ec79dd5b191' down_revision = None branch_labels = None depends_on = None @@ -18,46 +18,37 @@ depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "ufora_courses", - sa.Column("course_id", sa.Integer(), nullable=False), - sa.Column("name", sa.Text(), nullable=False), - sa.Column("code", sa.Text(), nullable=False), - sa.Column("year", sa.Integer(), nullable=False), - sa.Column("log_announcements", sa.Boolean(), nullable=False), - sa.PrimaryKeyConstraint("course_id"), - sa.UniqueConstraint("code"), - sa.UniqueConstraint("name"), + op.create_table('ufora_courses', + sa.Column('course_id', sa.Integer(), nullable=False), + sa.Column('name', sa.Text(), nullable=False), + sa.Column('code', sa.Text(), nullable=False), + sa.Column('year', sa.Integer(), nullable=False), + sa.Column('log_announcements', sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint('course_id'), + sa.UniqueConstraint('code'), + sa.UniqueConstraint('name') ) - op.create_table( - "ufora_announcements", - sa.Column("announcement_id", sa.Integer(), nullable=False), - sa.Column("course_id", sa.Integer(), nullable=True), - sa.Column("publication_date", sa.Date, nullable=True), - sa.ForeignKeyConstraint( - ["course_id"], - ["ufora_courses.course_id"], - ), - sa.PrimaryKeyConstraint("announcement_id"), + op.create_table('ufora_announcements', + sa.Column('announcement_id', sa.Integer(), nullable=False), + sa.Column('course_id', sa.Integer(), nullable=True), + sa.Column('publication_date', sa.DateTime(timezone=True), nullable=True), + sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ), + sa.PrimaryKeyConstraint('announcement_id') ) - op.create_table( - "ufora_course_aliases", - sa.Column("alias_id", sa.Integer(), nullable=False), - sa.Column("alias", sa.Text(), nullable=False), - sa.Column("course_id", sa.Integer(), nullable=True), - sa.ForeignKeyConstraint( - ["course_id"], - ["ufora_courses.course_id"], - ), - sa.PrimaryKeyConstraint("alias_id"), - sa.UniqueConstraint("alias"), + op.create_table('ufora_course_aliases', + sa.Column('alias_id', sa.Integer(), nullable=False), + sa.Column('alias', sa.Text(), nullable=False), + sa.Column('course_id', sa.Integer(), nullable=True), + sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ), + sa.PrimaryKeyConstraint('alias_id'), + sa.UniqueConstraint('alias') ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("ufora_course_aliases") - op.drop_table("ufora_announcements") - op.drop_table("ufora_courses") + op.drop_table('ufora_course_aliases') + op.drop_table('ufora_announcements') + op.drop_table('ufora_courses') # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 054d4c5..99ea2db 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -1,15 +1,13 @@ -import datetime from datetime import date from typing import Optional -from sqlalchemy import extract, select +from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload from database.crud import users -from database.models import Birthday, User +from database.models import Birthday -__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] +__all__ = ["add_birthday", "get_birthday_for_user"] async def add_birthday(session: AsyncSession, user_id: int, birthday: date): @@ -17,7 +15,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date): If already present, overwrites the existing one """ - user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) + user = await users.get_or_add(session, user_id) if user.birthday is not None: bd = user.birthday @@ -34,12 +32,3 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional """Find a user's birthday""" statement = select(Birthday).where(Birthday.user_id == user_id) return (await session.execute(statement)).scalar_one_or_none() - - -async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]: - """Get all birthdays that happen on a given day""" - days = extract("day", Birthday.birthday) - months = extract("month", Birthday.birthday) - - statement = select(Birthday).where((days == day.day) & (months == day.month)) - return list((await session.execute(statement)).scalars()) diff --git a/database/crud/currency.py b/database/crud/currency.py index 1bb2d11..f72b653 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -1,4 +1,4 @@ -from datetime import date +from datetime import datetime from typing import Union from sqlalchemy.ext.asyncio import AsyncSession @@ -69,9 +69,9 @@ async def claim_nightly(session: AsyncSession, user_id: int): """Claim daily Dinks""" nightly_data = await get_nightly_data(session, user_id) - now = date.today() + now = datetime.now() - if nightly_data.last_nightly is not None and nightly_data.last_nightly == now: + if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date(): raise exceptions.DoubleNightly bank = await get_bank(session, user_id) diff --git a/database/crud/tasks.py b/database/crud/tasks.py deleted file mode 100644 index dd1a607..0000000 --- a/database/crud/tasks.py +++ /dev/null @@ -1,32 +0,0 @@ -import datetime -from typing import Optional - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from database.enums import TaskType -from database.models import Task -from database.utils.datetime import LOCAL_TIMEZONE - -__all__ = ["get_task_by_enum", "set_last_task_execution_time"] - - -async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]: - """Get a task by its enum value, if it exists - - Returns None if the task does not exist - """ - statement = select(Task).where(Task.task == task) - return (await session.execute(statement)).scalar_one_or_none() - - -async def set_last_task_execution_time(session: AsyncSession, task: TaskType): - """Set the last time a specific task was executed""" - _task = await get_task_by_enum(session, task) - - if _task is None: - _task = Task(task=task) - - _task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE) - session.add(_task) - await session.commit() diff --git a/database/crud/users.py b/database/crud/users.py index ba3011d..57c5029 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -10,16 +10,12 @@ __all__ = [ ] -async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: +async def get_or_add(session: AsyncSession, user_id: int) -> User: """Get a user's profile If it doesn't exist yet, create it (along with all linked datastructures) """ - if options is None: - options = [] - - statement = select(User).where(User.user_id == user_id).options(*options) - + statement = select(User).where(User.user_id == user_id) user: Optional[User] = (await session.execute(statement)).scalar_one_or_none() # User exists @@ -42,6 +38,5 @@ async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[l session.add(user) await session.commit() - await session.refresh(user) return user diff --git a/database/enums.py b/database/enums.py deleted file mode 100644 index 3f75130..0000000 --- a/database/enums.py +++ /dev/null @@ -1,13 +0,0 @@ -import enum - -__all__ = ["TaskType"] - - -# There is a bug in typeshed that causes an incorrect PyCharm warning -# https://github.com/python/typeshed/issues/8286 -# noinspection PyArgumentList -class TaskType(enum.IntEnum): - """Enum for the different types of tasks""" - - BIRTHDAYS = enum.auto() - UFORA_ANNOUNCEMENTS = enum.auto() diff --git a/database/models.py b/database/models.py index f0fb6e4..74aa5fc 100644 --- a/database/models.py +++ b/database/models.py @@ -1,23 +1,11 @@ from __future__ import annotations -from datetime import date, datetime +from datetime import datetime from typing import Optional -from sqlalchemy import ( - BigInteger, - Boolean, - Column, - Date, - DateTime, - Enum, - ForeignKey, - Integer, - Text, -) +from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text from sqlalchemy.orm import declarative_base, relationship -from database import enums - Base = declarative_base() @@ -29,7 +17,6 @@ __all__ = [ "CustomCommandAlias", "DadJoke", "NightlyData", - "Task", "UforaAnnouncement", "UforaCourse", "UforaCourseAlias", @@ -67,7 +54,7 @@ class Birthday(Base): birthday_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - birthday: date = Column(Date, nullable=False) + birthday: datetime = Column(DateTime, nullable=False) user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") @@ -116,22 +103,12 @@ class NightlyData(Base): nightly_id: int = Column(Integer, primary_key=True) user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - last_nightly: Optional[date] = Column(Date, nullable=True) + last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True) count: int = Column(Integer, server_default="0", nullable=False) user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") -class Task(Base): - """A Didier task""" - - __tablename__ = "tasks" - - task_id: int = Column(Integer, primary_key=True) - task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True) - previous_run: datetime = Column(DateTime(timezone=True), nullable=True) - - class UforaCourse(Base): """A course on Ufora""" @@ -170,7 +147,7 @@ class UforaAnnouncement(Base): announcement_id: int = Column(Integer, primary_key=True) course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) - publication_date: date = Column(Date) + publication_date: datetime = Column(DateTime(timezone=True)) course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") diff --git a/database/utils/datetime.py b/database/utils/datetime.py deleted file mode 100644 index 8450e84..0000000 --- a/database/utils/datetime.py +++ /dev/null @@ -1,5 +0,0 @@ -import zoneinfo - -__all__ = ["LOCAL_TIMEZONE"] - -LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index a73ead3..1152922 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -35,13 +35,7 @@ class Discord(commands.Cog): async def birthday_set(self, ctx: commands.Context, date_str: str): """Command to set your birthday""" try: - default_year = 2001 - date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"]) - - # If no year was passed, make it 2001 by default - if date_str.count("/") == 1: - date.replace(year=default_year) - + date = str_to_date(date_str) except ValueError: return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 9a93fdf..f43df42 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -8,7 +8,7 @@ from database.crud import custom_commands from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier -from didier.utils.discord.flags.owner import EditCustomFlags +from didier.data.flags.owner import EditCustomFlags from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 88bcb6e..b8488ae 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,90 +1,27 @@ -import datetime import traceback from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error import settings -from database import enums -from database.crud.birthdays import get_birthdays_on_day from database.crud.ufora_announcements import remove_old_announcements from didier import Didier from didier.data.embeds.ufora.announcements import fetch_ufora_announcements -from didier.decorators.tasks import timed_task -from didier.utils.discord.checks import is_owner -from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now - -# datetime.time()-instances for when every task should run -DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) -SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) class Tasks(commands.Cog): - """Task loops that run periodically - - Preferably these would use the new `time`-kwarg, but these don't run - on startup, which is not ideal. This means we still have to run them every hour - in order to never miss anything if Didier goes offline by coincidence - """ + """Task loops that run periodically""" client: Didier - _tasks: dict[str, tasks.Loop] def __init__(self, client: Didier): self.client = client - # Only check birthdays if there's a channel to send it to - if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None: - self.check_birthdays.start() - # Only pull announcements if a token was provided if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None: self.pull_ufora_announcements.start() self.remove_old_ufora_announcements.start() - self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements} - - @commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True) - @commands.check(is_owner) - async def tasks_group(self, ctx: commands.Context): - """Command group for Task-related commands - - Invoking the group itself shows the time until the next iteration - """ - raise NotImplementedError() - - @tasks_group.command(name="Force", case_insensitive=True) - async def force_task(self, ctx: commands.Context, name: str): - """Command to force-run a task without waiting for the run time""" - name = name.lower() - if name not in self._tasks: - return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False) - - task = self._tasks[name] - await task() - - @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) - @timed_task(enums.TaskType.BIRTHDAYS) - async def check_birthdays(self): - """Check if it's currently anyone's birthday""" - now = tz_aware_now().date() - async with self.client.db_session as session: - birthdays = await get_birthdays_on_day(session, now) - - channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) - if channel is None: - return await self.client.log_error("Unable to find channel for birthday announcements") - - for birthday in birthdays: - user = self.client.get_user(birthday.user_id) - # TODO more messages? - await channel.send(f"Gelukkig verjaardag {user.mention}!") - - @check_birthdays.before_loop - async def _before_check_birthdays(self): - await self.client.wait_until_ready() - @tasks.loop(minutes=10) - @timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS) async def pull_ufora_announcements(self): """Task that checks for new Ufora announcements & logs them in a channel""" # In theory this shouldn't happen but just to please Mypy @@ -100,20 +37,23 @@ class Tasks(commands.Cog): @pull_ufora_announcements.before_loop async def _before_ufora_announcements(self): + """Don't try to get announcements if the bot isn't ready yet""" await self.client.wait_until_ready() + @pull_ufora_announcements.error + async def _on_announcements_error(self, error: BaseException): + """Error handler for the Ufora Announcements task""" + print("".join(traceback.format_exception(type(error), error, error.__traceback__))) + @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" async with self.client.db_session as session: await remove_old_announcements(session) - @check_birthdays.error - @pull_ufora_announcements.error - @remove_old_ufora_announcements.error - async def _on_tasks_error(self, error: BaseException): - """Error handler for all tasks""" - print("".join(traceback.format_exception(type(error), error, error.__traceback__))) + @remove_old_ufora_announcements.before_loop + async def _before_remove_old_ufora_announcements(self): + await self.client.wait_until_ready() async def setup(client: Didier): diff --git a/didier/utils/discord/flags/__init__.py b/didier/data/flags/__init__.py similarity index 100% rename from didier/utils/discord/flags/__init__.py rename to didier/data/flags/__init__.py diff --git a/didier/utils/discord/flags/owner.py b/didier/data/flags/owner.py similarity index 80% rename from didier/utils/discord/flags/owner.py rename to didier/data/flags/owner.py index 282957c..5ff75a9 100644 --- a/didier/utils/discord/flags/owner.py +++ b/didier/data/flags/owner.py @@ -1,6 +1,6 @@ from typing import Optional -from didier.utils.discord.flags import PosixFlags +from didier.data.flags import PosixFlags __all__ = ["EditCustomFlags"] diff --git a/didier/utils/discord/flags/posix.py b/didier/data/flags/posix.py similarity index 100% rename from didier/utils/discord/flags/posix.py rename to didier/data/flags/posix.py diff --git a/didier/decorators/__init__.py b/didier/decorators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/didier/decorators/tasks.py b/didier/decorators/tasks.py deleted file mode 100644 index 36b4f89..0000000 --- a/didier/decorators/tasks.py +++ /dev/null @@ -1,28 +0,0 @@ -from __future__ import annotations - -import functools -from typing import TYPE_CHECKING - -if TYPE_CHECKING: - from didier.cogs.tasks import Tasks - -from database import enums -from database.crud.tasks import set_last_task_execution_time - -__all__ = ["timed_task"] - - -def timed_task(task: enums.TaskType): - """Decorator to log the last execution time of a task""" - - def _decorator(func): - @functools.wraps(func) - async def _wrapper(tasks_cog: Tasks, *args, **kwargs): - await func(tasks_cog, *args, **kwargs) - - async with tasks_cog.client.db_session as session: - await set_last_task_execution_time(session, task) - - return _wrapper - - return _decorator diff --git a/didier/didier.py b/didier/didier.py index 2f07372..9fef227 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -1,4 +1,3 @@ -import logging import os import discord @@ -15,14 +14,10 @@ from didier.utils.discord.prefix import get_prefix __all__ = ["Didier"] -logger = logging.getLogger(__name__) - - class Didier(commands.Bot): """DIDIER <3""" database_caches: CacheManager - error_channel: discord.abc.Messageable initial_extensions: tuple[str, ...] = () http_session: ClientSession @@ -65,12 +60,6 @@ class Didier(commands.Bot): # Create aiohttp session self.http_session = ClientSession() - # Configure channel to send errors to - if settings.ERRORS_CHANNEL is not None: - self.error_channel = self.get_channel(settings.ERRORS_CHANNEL) - else: - self.error_channel = self.get_user(self.owner_id) - async def _load_initial_extensions(self): """Load all extensions that should be loaded before the others""" for extension in self.initial_extensions: @@ -112,13 +101,6 @@ class Didier(commands.Bot): """Add an X to a message""" await message.add_reaction("❌") - async def log_error(self, message: str, log_to_discord: bool = True): - """Send an error message to the logs, and optionally the configured channel""" - logger.error(message) - if log_to_discord: - # TODO pretty embed - await self.error_channel.send(message) - async def on_ready(self): """Event triggered when the bot is ready""" print(settings.DISCORD_READY_MESSAGE) diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index 42f58a9..91c3583 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -1,11 +1,6 @@ import datetime -import zoneinfo -__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"] - -from typing import Union - -LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") +__all__ = ["int_to_weekday", "str_to_date"] def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this @@ -13,21 +8,6 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] -def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date: +def str_to_date(date_str: str) -> datetime.date: """Turn a string into a DD/MM/YYYY date""" - # Allow passing multiple formats in a list - if isinstance(formats, str): - formats = [formats] - - for format_str in formats: - try: - return datetime.datetime.strptime(date_str, format_str).date() - except ValueError: - continue - - raise ValueError - - -def tz_aware_now() -> datetime.datetime: - """Get the current date & time, but timezone-aware""" - return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE) + return datetime.datetime.strptime(date_str, "%d/%m/%Y").date() diff --git a/requirements-dev.txt b/requirements-dev.txt index 64b8467..d82fde3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,5 @@ black==22.3.0 coverage[toml]==6.4.1 -freezegun==1.2.1 mypy==0.961 pre-commit==2.20.0 pytest==7.1.2 diff --git a/settings.py b/settings.py index d03b48b..bb2296f 100644 --- a/settings.py +++ b/settings.py @@ -46,8 +46,6 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int) DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>") DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?") -BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None) -ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None) UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) """API Keys""" diff --git a/tests/conftest.py b/tests/conftest.py index 2e425ef..8530de5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ def event_loop() -> Generator: @pytest.fixture(scope="session") async def tables(): - """Fixture to initialize a database before the tests, and then tear it down again + """Initialize a database before the tests, and then tear it down again Checks that the migrations were successful by asserting that we are currently on the latest migration diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 544e5b0..96b924c 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -1,10 +1,8 @@ from datetime import datetime, timedelta -from freezegun import freeze_time from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud -from database.crud import users from database.models import User @@ -16,7 +14,7 @@ async def test_add_birthday_not_present(database_session: AsyncSession, user: Us await crud.add_birthday(database_session, user.user_id, bd_date) await database_session.refresh(user) assert user.birthday is not None - assert user.birthday.birthday == bd_date + assert user.birthday.birthday.date() == bd_date async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): @@ -29,7 +27,7 @@ async def test_add_birthday_overwrite(database_session: AsyncSession, user: User new_bd_date = bd_date + timedelta(weeks=1) await crud.add_birthday(database_session, user.user_id, new_bd_date) await database_session.refresh(user) - assert user.birthday.birthday == new_bd_date + assert user.birthday.birthday.date() == new_bd_date async def test_get_birthday_exists(database_session: AsyncSession, user: User): @@ -40,35 +38,10 @@ async def test_get_birthday_exists(database_session: AsyncSession, user: User): bd = await crud.get_birthday_for_user(database_session, user.user_id) assert bd is not None - assert bd.birthday == bd_date + assert bd.birthday.date() == bd_date async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): """Test getting a user's birthday when it doesn't exist""" bd = await crud.get_birthday_for_user(database_session, user.user_id) assert bd is None - - -@freeze_time("2022/07/23") -async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): - """Test getting all birthdays on a given day""" - await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001)) - - user_2 = await users.get_or_add(database_session, user.user_id + 1) - await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) - assert len(birthdays) == 1 - assert birthdays[0].user_id == user.user_id - - -@freeze_time("2022/07/23") -async def test_get_birthdays_none_present(database_session: AsyncSession): - """Test getting all birthdays when there are none""" - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) - assert len(birthdays) == 0 - - # Add a random birthday that is not today - await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1)) - - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) - assert len(birthdays) == 0 diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index a2eeec8..1f0a163 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -1,7 +1,4 @@ -import datetime - import pytest -from freezegun import freeze_time from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud @@ -17,18 +14,13 @@ async def test_add_dinks(database_session: AsyncSession, bank: Bank): assert bank.dinks == 10 -@freeze_time("2022/07/23") async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): """Test claiming nightlies when it hasn't been done yet""" await crud.claim_nightly(database_session, bank.user_id) await database_session.refresh(bank) assert bank.dinks == crud.NIGHTLY_AMOUNT - nightly_data = await crud.get_nightly_data(database_session, bank.user_id) - assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23) - -@freeze_time("2022/07/23") async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): """Test claiming nightlies twice in a day""" await crud.claim_nightly(database_session, bank.user_id) diff --git a/tests/test_didier/test_utils/test_types/test_datetime.py b/tests/test_didier/test_utils/test_types/test_datetime.py deleted file mode 100644 index ecb1973..0000000 --- a/tests/test_didier/test_utils/test_types/test_datetime.py +++ /dev/null @@ -1,34 +0,0 @@ -import datetime - -import pytest - -from didier.utils.types.datetime import str_to_date - - -def test_str_to_date_single_valid(): - """Test parsing a string for a single possibility (default)""" - result = str_to_date("23/11/2001") - assert result == datetime.date(year=2001, month=11, day=23) - - -def test_str_to_date_single_invalid(): - """Test parsing a string for an invalid string""" - # Invalid format - with pytest.raises(ValueError): - str_to_date("23/11/01") - - # Invalid date - with pytest.raises(ValueError): - str_to_date("69/42/0") - - -def test_str_to_date_multiple_valid(): - """Test parsing a string for multiple possibilities""" - result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"]) - assert result == datetime.date(year=2001, month=11, day=23) - - -def test_str_to_date_multiple_invalid(): - """Test parsing a string for multiple possibilities when none are valid""" - with pytest.raises(ValueError): - str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])