mirror of https://github.com/stijndcl/didier
Compare commits
4 Commits
adcf94c66e
...
da0e60ac4f
| Author | SHA1 | Date |
|---|---|---|
|
|
da0e60ac4f | |
|
|
393cc9c891 | |
|
|
66997b7556 | |
|
|
8bc0f1fa7a |
|
|
@ -37,7 +37,7 @@ def upgrade() -> None:
|
|||
"nightly_data",
|
||||
sa.Column("nightly_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("last_nightly", sa.Date, nullable=True),
|
||||
sa.Column("count", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ def upgrade() -> None:
|
|||
"birthdays",
|
||||
sa.Column("birthday_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("birthday", sa.DateTime(), nullable=False),
|
||||
sa.Column("birthday", sa.Date, nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
|
|
|
|||
|
|
@ -0,0 +1,36 @@
|
|||
"""Create tasks
|
||||
|
||||
Revision ID: 346b408c362a
|
||||
Revises: 1716bfecf684
|
||||
Create Date: 2022-07-23 19:41:07.029482
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "346b408c362a"
|
||||
down_revision = "1716bfecf684"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"tasks",
|
||||
sa.Column("task_id", sa.Integer(), nullable=False),
|
||||
sa.Column("task", sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype"), nullable=False),
|
||||
sa.Column("previous_run", sa.DateTime(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("task_id"),
|
||||
sa.UniqueConstraint("task"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("tasks")
|
||||
sa.Enum("BIRTHDAYS", "UFORA_ANNOUNCEMENTS", name="tasktype").drop(op.get_bind())
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -1,16 +1,16 @@
|
|||
"""Initial migration
|
||||
|
||||
Revision ID: 4ec79dd5b191
|
||||
Revises:
|
||||
Revises:
|
||||
Create Date: 2022-06-19 00:31:58.384360
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '4ec79dd5b191'
|
||||
revision = "4ec79dd5b191"
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
|
@ -18,37 +18,46 @@ depends_on = None
|
|||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('ufora_courses',
|
||||
sa.Column('course_id', sa.Integer(), nullable=False),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('code', sa.Text(), nullable=False),
|
||||
sa.Column('year', sa.Integer(), nullable=False),
|
||||
sa.Column('log_announcements', sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('course_id'),
|
||||
sa.UniqueConstraint('code'),
|
||||
sa.UniqueConstraint('name')
|
||||
op.create_table(
|
||||
"ufora_courses",
|
||||
sa.Column("course_id", sa.Integer(), nullable=False),
|
||||
sa.Column("name", sa.Text(), nullable=False),
|
||||
sa.Column("code", sa.Text(), nullable=False),
|
||||
sa.Column("year", sa.Integer(), nullable=False),
|
||||
sa.Column("log_announcements", sa.Boolean(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("course_id"),
|
||||
sa.UniqueConstraint("code"),
|
||||
sa.UniqueConstraint("name"),
|
||||
)
|
||||
op.create_table('ufora_announcements',
|
||||
sa.Column('announcement_id', sa.Integer(), nullable=False),
|
||||
sa.Column('course_id', sa.Integer(), nullable=True),
|
||||
sa.Column('publication_date', sa.DateTime(timezone=True), nullable=True),
|
||||
sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ),
|
||||
sa.PrimaryKeyConstraint('announcement_id')
|
||||
op.create_table(
|
||||
"ufora_announcements",
|
||||
sa.Column("announcement_id", sa.Integer(), nullable=False),
|
||||
sa.Column("course_id", sa.Integer(), nullable=True),
|
||||
sa.Column("publication_date", sa.Date, nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["course_id"],
|
||||
["ufora_courses.course_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("announcement_id"),
|
||||
)
|
||||
op.create_table('ufora_course_aliases',
|
||||
sa.Column('alias_id', sa.Integer(), nullable=False),
|
||||
sa.Column('alias', sa.Text(), nullable=False),
|
||||
sa.Column('course_id', sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['course_id'], ['ufora_courses.course_id'], ),
|
||||
sa.PrimaryKeyConstraint('alias_id'),
|
||||
sa.UniqueConstraint('alias')
|
||||
op.create_table(
|
||||
"ufora_course_aliases",
|
||||
sa.Column("alias_id", sa.Integer(), nullable=False),
|
||||
sa.Column("alias", sa.Text(), nullable=False),
|
||||
sa.Column("course_id", sa.Integer(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["course_id"],
|
||||
["ufora_courses.course_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("alias_id"),
|
||||
sa.UniqueConstraint("alias"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('ufora_course_aliases')
|
||||
op.drop_table('ufora_announcements')
|
||||
op.drop_table('ufora_courses')
|
||||
op.drop_table("ufora_course_aliases")
|
||||
op.drop_table("ufora_announcements")
|
||||
op.drop_table("ufora_courses")
|
||||
# ### end Alembic commands ###
|
||||
|
|
|
|||
|
|
@ -1,13 +1,15 @@
|
|||
import datetime
|
||||
from datetime import date
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import extract, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.crud import users
|
||||
from database.models import Birthday
|
||||
from database.models import Birthday, User
|
||||
|
||||
__all__ = ["add_birthday", "get_birthday_for_user"]
|
||||
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||
|
||||
|
||||
async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
||||
|
|
@ -15,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
|||
|
||||
If already present, overwrites the existing one
|
||||
"""
|
||||
user = await users.get_or_add(session, user_id)
|
||||
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
|
||||
|
||||
if user.birthday is not None:
|
||||
bd = user.birthday
|
||||
|
|
@ -32,3 +34,12 @@ async def get_birthday_for_user(session: AsyncSession, user_id: int) -> Optional
|
|||
"""Find a user's birthday"""
|
||||
statement = select(Birthday).where(Birthday.user_id == user_id)
|
||||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> list[Birthday]:
|
||||
"""Get all birthdays that happen on a given day"""
|
||||
days = extract("day", Birthday.birthday)
|
||||
months = extract("month", Birthday.birthday)
|
||||
|
||||
statement = select(Birthday).where((days == day.day) & (months == day.month))
|
||||
return list((await session.execute(statement)).scalars())
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from datetime import datetime
|
||||
from datetime import date
|
||||
from typing import Union
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
|
@ -69,9 +69,9 @@ async def claim_nightly(session: AsyncSession, user_id: int):
|
|||
"""Claim daily Dinks"""
|
||||
nightly_data = await get_nightly_data(session, user_id)
|
||||
|
||||
now = datetime.now()
|
||||
now = date.today()
|
||||
|
||||
if nightly_data.last_nightly is not None and nightly_data.last_nightly.date() == now.date():
|
||||
if nightly_data.last_nightly is not None and nightly_data.last_nightly == now:
|
||||
raise exceptions.DoubleNightly
|
||||
|
||||
bank = await get_bank(session, user_id)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,32 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.enums import TaskType
|
||||
from database.models import Task
|
||||
from database.utils.datetime import LOCAL_TIMEZONE
|
||||
|
||||
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
|
||||
|
||||
|
||||
async def get_task_by_enum(session: AsyncSession, task: TaskType) -> Optional[Task]:
|
||||
"""Get a task by its enum value, if it exists
|
||||
|
||||
Returns None if the task does not exist
|
||||
"""
|
||||
statement = select(Task).where(Task.task == task)
|
||||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def set_last_task_execution_time(session: AsyncSession, task: TaskType):
|
||||
"""Set the last time a specific task was executed"""
|
||||
_task = await get_task_by_enum(session, task)
|
||||
|
||||
if _task is None:
|
||||
_task = Task(task=task)
|
||||
|
||||
_task.previous_run = datetime.datetime.now(tz=LOCAL_TIMEZONE)
|
||||
session.add(_task)
|
||||
await session.commit()
|
||||
|
|
@ -10,12 +10,16 @@ __all__ = [
|
|||
]
|
||||
|
||||
|
||||
async def get_or_add(session: AsyncSession, user_id: int) -> User:
|
||||
async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
|
||||
"""Get a user's profile
|
||||
|
||||
If it doesn't exist yet, create it (along with all linked datastructures)
|
||||
"""
|
||||
statement = select(User).where(User.user_id == user_id)
|
||||
if options is None:
|
||||
options = []
|
||||
|
||||
statement = select(User).where(User.user_id == user_id).options(*options)
|
||||
|
||||
user: Optional[User] = (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
# User exists
|
||||
|
|
@ -38,5 +42,6 @@ async def get_or_add(session: AsyncSession, user_id: int) -> User:
|
|||
session.add(user)
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
|
||||
return user
|
||||
|
|
|
|||
|
|
@ -0,0 +1,13 @@
|
|||
import enum
|
||||
|
||||
__all__ = ["TaskType"]
|
||||
|
||||
|
||||
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
||||
# https://github.com/python/typeshed/issues/8286
|
||||
# noinspection PyArgumentList
|
||||
class TaskType(enum.IntEnum):
|
||||
"""Enum for the different types of tasks"""
|
||||
|
||||
BIRTHDAYS = enum.auto()
|
||||
UFORA_ANNOUNCEMENTS = enum.auto()
|
||||
|
|
@ -1,11 +1,23 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, Column, DateTime, ForeignKey, Integer, Text
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Text,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
|
||||
from database import enums
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
|
|
@ -17,6 +29,7 @@ __all__ = [
|
|||
"CustomCommandAlias",
|
||||
"DadJoke",
|
||||
"NightlyData",
|
||||
"Task",
|
||||
"UforaAnnouncement",
|
||||
"UforaCourse",
|
||||
"UforaCourseAlias",
|
||||
|
|
@ -54,7 +67,7 @@ class Birthday(Base):
|
|||
|
||||
birthday_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
birthday: datetime = Column(DateTime, nullable=False)
|
||||
birthday: date = Column(Date, nullable=False)
|
||||
|
||||
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
|
||||
|
||||
|
|
@ -103,12 +116,22 @@ class NightlyData(Base):
|
|||
|
||||
nightly_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
last_nightly: Optional[datetime] = Column(DateTime(timezone=True), nullable=True)
|
||||
last_nightly: Optional[date] = Column(Date, nullable=True)
|
||||
count: int = Column(Integer, server_default="0", nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Task(Base):
|
||||
"""A Didier task"""
|
||||
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id: int = Column(Integer, primary_key=True)
|
||||
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True)
|
||||
previous_run: datetime = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
|
||||
class UforaCourse(Base):
|
||||
"""A course on Ufora"""
|
||||
|
||||
|
|
@ -147,7 +170,7 @@ class UforaAnnouncement(Base):
|
|||
|
||||
announcement_id: int = Column(Integer, primary_key=True)
|
||||
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
|
||||
publication_date: datetime = Column(DateTime(timezone=True))
|
||||
publication_date: date = Column(Date)
|
||||
|
||||
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,5 @@
|
|||
import zoneinfo
|
||||
|
||||
__all__ = ["LOCAL_TIMEZONE"]
|
||||
|
||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||
|
|
@ -35,7 +35,13 @@ class Discord(commands.Cog):
|
|||
async def birthday_set(self, ctx: commands.Context, date_str: str):
|
||||
"""Command to set your birthday"""
|
||||
try:
|
||||
date = str_to_date(date_str)
|
||||
default_year = 2001
|
||||
date = str_to_date(date_str, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
|
||||
|
||||
# If no year was passed, make it 2001 by default
|
||||
if date_str.count("/") == 1:
|
||||
date.replace(year=default_year)
|
||||
|
||||
except ValueError:
|
||||
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from database.crud import custom_commands
|
|||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from didier import Didier
|
||||
from didier.data.flags.owner import EditCustomFlags
|
||||
from didier.utils.discord.flags.owner import EditCustomFlags
|
||||
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,27 +1,90 @@
|
|||
import datetime
|
||||
import traceback
|
||||
|
||||
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
|
||||
|
||||
import settings
|
||||
from database import enums
|
||||
from database.crud.birthdays import get_birthdays_on_day
|
||||
from database.crud.ufora_announcements import remove_old_announcements
|
||||
from didier import Didier
|
||||
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
||||
from didier.decorators.tasks import timed_task
|
||||
from didier.utils.discord.checks import is_owner
|
||||
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
|
||||
|
||||
# datetime.time()-instances for when every task should run
|
||||
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||
|
||||
|
||||
class Tasks(commands.Cog):
|
||||
"""Task loops that run periodically"""
|
||||
"""Task loops that run periodically
|
||||
|
||||
Preferably these would use the new `time`-kwarg, but these don't run
|
||||
on startup, which is not ideal. This means we still have to run them every hour
|
||||
in order to never miss anything if Didier goes offline by coincidence
|
||||
"""
|
||||
|
||||
client: Didier
|
||||
_tasks: dict[str, tasks.Loop]
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
# Only check birthdays if there's a channel to send it to
|
||||
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
|
||||
self.check_birthdays.start()
|
||||
|
||||
# Only pull announcements if a token was provided
|
||||
if settings.UFORA_RSS_TOKEN is not None and settings.UFORA_ANNOUNCEMENTS_CHANNEL is not None:
|
||||
self.pull_ufora_announcements.start()
|
||||
self.remove_old_ufora_announcements.start()
|
||||
|
||||
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
|
||||
|
||||
@commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
|
||||
@commands.check(is_owner)
|
||||
async def tasks_group(self, ctx: commands.Context):
|
||||
"""Command group for Task-related commands
|
||||
|
||||
Invoking the group itself shows the time until the next iteration
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
@tasks_group.command(name="Force", case_insensitive=True)
|
||||
async def force_task(self, ctx: commands.Context, name: str):
|
||||
"""Command to force-run a task without waiting for the run time"""
|
||||
name = name.lower()
|
||||
if name not in self._tasks:
|
||||
return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False)
|
||||
|
||||
task = self._tasks[name]
|
||||
await task()
|
||||
|
||||
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
|
||||
@timed_task(enums.TaskType.BIRTHDAYS)
|
||||
async def check_birthdays(self):
|
||||
"""Check if it's currently anyone's birthday"""
|
||||
now = tz_aware_now().date()
|
||||
async with self.client.db_session as session:
|
||||
birthdays = await get_birthdays_on_day(session, now)
|
||||
|
||||
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
||||
if channel is None:
|
||||
return await self.client.log_error("Unable to find channel for birthday announcements")
|
||||
|
||||
for birthday in birthdays:
|
||||
user = self.client.get_user(birthday.user_id)
|
||||
# TODO more messages?
|
||||
await channel.send(f"Gelukkig verjaardag {user.mention}!")
|
||||
|
||||
@check_birthdays.before_loop
|
||||
async def _before_check_birthdays(self):
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
@tasks.loop(minutes=10)
|
||||
@timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS)
|
||||
async def pull_ufora_announcements(self):
|
||||
"""Task that checks for new Ufora announcements & logs them in a channel"""
|
||||
# In theory this shouldn't happen but just to please Mypy
|
||||
|
|
@ -37,23 +100,20 @@ class Tasks(commands.Cog):
|
|||
|
||||
@pull_ufora_announcements.before_loop
|
||||
async def _before_ufora_announcements(self):
|
||||
"""Don't try to get announcements if the bot isn't ready yet"""
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
@pull_ufora_announcements.error
|
||||
async def _on_announcements_error(self, error: BaseException):
|
||||
"""Error handler for the Ufora Announcements task"""
|
||||
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
|
||||
|
||||
@tasks.loop(hours=24)
|
||||
async def remove_old_ufora_announcements(self):
|
||||
"""Remove all announcements that are over 1 week old, once per day"""
|
||||
async with self.client.db_session as session:
|
||||
await remove_old_announcements(session)
|
||||
|
||||
@remove_old_ufora_announcements.before_loop
|
||||
async def _before_remove_old_ufora_announcements(self):
|
||||
await self.client.wait_until_ready()
|
||||
@check_birthdays.error
|
||||
@pull_ufora_announcements.error
|
||||
@remove_old_ufora_announcements.error
|
||||
async def _on_tasks_error(self, error: BaseException):
|
||||
"""Error handler for all tasks"""
|
||||
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
|
|
|
|||
|
|
@ -0,0 +1,28 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from didier.cogs.tasks import Tasks
|
||||
|
||||
from database import enums
|
||||
from database.crud.tasks import set_last_task_execution_time
|
||||
|
||||
__all__ = ["timed_task"]
|
||||
|
||||
|
||||
def timed_task(task: enums.TaskType):
|
||||
"""Decorator to log the last execution time of a task"""
|
||||
|
||||
def _decorator(func):
|
||||
@functools.wraps(func)
|
||||
async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
|
||||
await func(tasks_cog, *args, **kwargs)
|
||||
|
||||
async with tasks_cog.client.db_session as session:
|
||||
await set_last_task_execution_time(session, task)
|
||||
|
||||
return _wrapper
|
||||
|
||||
return _decorator
|
||||
|
|
@ -1,3 +1,4 @@
|
|||
import logging
|
||||
import os
|
||||
|
||||
import discord
|
||||
|
|
@ -14,10 +15,14 @@ from didier.utils.discord.prefix import get_prefix
|
|||
__all__ = ["Didier"]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Didier(commands.Bot):
|
||||
"""DIDIER <3"""
|
||||
|
||||
database_caches: CacheManager
|
||||
error_channel: discord.abc.Messageable
|
||||
initial_extensions: tuple[str, ...] = ()
|
||||
http_session: ClientSession
|
||||
|
||||
|
|
@ -60,6 +65,12 @@ class Didier(commands.Bot):
|
|||
# Create aiohttp session
|
||||
self.http_session = ClientSession()
|
||||
|
||||
# Configure channel to send errors to
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
else:
|
||||
self.error_channel = self.get_user(self.owner_id)
|
||||
|
||||
async def _load_initial_extensions(self):
|
||||
"""Load all extensions that should be loaded before the others"""
|
||||
for extension in self.initial_extensions:
|
||||
|
|
@ -101,6 +112,13 @@ class Didier(commands.Bot):
|
|||
"""Add an X to a message"""
|
||||
await message.add_reaction("❌")
|
||||
|
||||
async def log_error(self, message: str, log_to_discord: bool = True):
|
||||
"""Send an error message to the logs, and optionally the configured channel"""
|
||||
logger.error(message)
|
||||
if log_to_discord:
|
||||
# TODO pretty embed
|
||||
await self.error_channel.send(message)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
from didier.data.flags import PosixFlags
|
||||
from didier.utils.discord.flags import PosixFlags
|
||||
|
||||
__all__ = ["EditCustomFlags"]
|
||||
|
||||
|
|
@ -1,6 +1,11 @@
|
|||
import datetime
|
||||
import zoneinfo
|
||||
|
||||
__all__ = ["int_to_weekday", "str_to_date"]
|
||||
__all__ = ["LOCAL_TIMEZONE", "int_to_weekday", "str_to_date", "tz_aware_now"]
|
||||
|
||||
from typing import Union
|
||||
|
||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||
|
||||
|
||||
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
|
||||
|
|
@ -8,6 +13,21 @@ def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to wr
|
|||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
||||
|
||||
|
||||
def str_to_date(date_str: str) -> datetime.date:
|
||||
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
|
||||
"""Turn a string into a DD/MM/YYYY date"""
|
||||
return datetime.datetime.strptime(date_str, "%d/%m/%Y").date()
|
||||
# Allow passing multiple formats in a list
|
||||
if isinstance(formats, str):
|
||||
formats = [formats]
|
||||
|
||||
for format_str in formats:
|
||||
try:
|
||||
return datetime.datetime.strptime(date_str, format_str).date()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
raise ValueError
|
||||
|
||||
|
||||
def tz_aware_now() -> datetime.datetime:
|
||||
"""Get the current date & time, but timezone-aware"""
|
||||
return datetime.datetime.utcnow().replace(tzinfo=datetime.timezone.utc).astimezone(LOCAL_TIMEZONE)
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
black==22.3.0
|
||||
coverage[toml]==6.4.1
|
||||
freezegun==1.2.1
|
||||
mypy==0.961
|
||||
pre-commit==2.20.0
|
||||
pytest==7.1.2
|
||||
|
|
|
|||
|
|
@ -46,6 +46,8 @@ DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didie
|
|||
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
|
||||
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
|
||||
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
|
||||
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
|
||||
ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None)
|
||||
UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None)
|
||||
|
||||
"""API Keys"""
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ def event_loop() -> Generator:
|
|||
|
||||
@pytest.fixture(scope="session")
|
||||
async def tables():
|
||||
"""Initialize a database before the tests, and then tear it down again
|
||||
"""Fixture to initialize a database before the tests, and then tear it down again
|
||||
|
||||
Checks that the migrations were successful by asserting that we are currently
|
||||
on the latest migration
|
||||
|
|
|
|||
|
|
@ -1,8 +1,10 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import birthdays as crud
|
||||
from database.crud import users
|
||||
from database.models import User
|
||||
|
||||
|
||||
|
|
@ -14,7 +16,7 @@ async def test_add_birthday_not_present(database_session: AsyncSession, user: Us
|
|||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
assert user.birthday is not None
|
||||
assert user.birthday.birthday.date() == bd_date
|
||||
assert user.birthday.birthday == bd_date
|
||||
|
||||
|
||||
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
|
||||
|
|
@ -27,7 +29,7 @@ async def test_add_birthday_overwrite(database_session: AsyncSession, user: User
|
|||
new_bd_date = bd_date + timedelta(weeks=1)
|
||||
await crud.add_birthday(database_session, user.user_id, new_bd_date)
|
||||
await database_session.refresh(user)
|
||||
assert user.birthday.birthday.date() == new_bd_date
|
||||
assert user.birthday.birthday == new_bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
|
||||
|
|
@ -38,10 +40,35 @@ async def test_get_birthday_exists(database_session: AsyncSession, user: User):
|
|||
|
||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||
assert bd is not None
|
||||
assert bd.birthday.date() == bd_date
|
||||
assert bd.birthday == bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
|
||||
"""Test getting a user's birthday when it doesn't exist"""
|
||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||
assert bd is None
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
|
||||
"""Test getting all birthdays on a given day"""
|
||||
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
|
||||
|
||||
user_2 = await users.get_or_add(database_session, user.user_id + 1)
|
||||
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||
assert len(birthdays) == 1
|
||||
assert birthdays[0].user_id == user.user_id
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_none_present(database_session: AsyncSession):
|
||||
"""Test getting all birthdays when there are none"""
|
||||
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||
assert len(birthdays) == 0
|
||||
|
||||
# Add a random birthday that is not today
|
||||
await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1))
|
||||
|
||||
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||
assert len(birthdays) == 0
|
||||
|
|
|
|||
|
|
@ -1,4 +1,7 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import currency as crud
|
||||
|
|
@ -14,13 +17,18 @@ async def test_add_dinks(database_session: AsyncSession, bank: Bank):
|
|||
assert bank.dinks == 10
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank):
|
||||
"""Test claiming nightlies when it hasn't been done yet"""
|
||||
await crud.claim_nightly(database_session, bank.user_id)
|
||||
await database_session.refresh(bank)
|
||||
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||
|
||||
nightly_data = await crud.get_nightly_data(database_session, bank.user_id)
|
||||
assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23)
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank):
|
||||
"""Test claiming nightlies twice in a day"""
|
||||
await crud.claim_nightly(database_session, bank.user_id)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,34 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from didier.utils.types.datetime import str_to_date
|
||||
|
||||
|
||||
def test_str_to_date_single_valid():
|
||||
"""Test parsing a string for a single possibility (default)"""
|
||||
result = str_to_date("23/11/2001")
|
||||
assert result == datetime.date(year=2001, month=11, day=23)
|
||||
|
||||
|
||||
def test_str_to_date_single_invalid():
|
||||
"""Test parsing a string for an invalid string"""
|
||||
# Invalid format
|
||||
with pytest.raises(ValueError):
|
||||
str_to_date("23/11/01")
|
||||
|
||||
# Invalid date
|
||||
with pytest.raises(ValueError):
|
||||
str_to_date("69/42/0")
|
||||
|
||||
|
||||
def test_str_to_date_multiple_valid():
|
||||
"""Test parsing a string for multiple possibilities"""
|
||||
result = str_to_date("23/11/01", formats=["%d/%m/%Y", "%d/%m/%y"])
|
||||
assert result == datetime.date(year=2001, month=11, day=23)
|
||||
|
||||
|
||||
def test_str_to_date_multiple_invalid():
|
||||
"""Test parsing a string for multiple possibilities when none are valid"""
|
||||
with pytest.raises(ValueError):
|
||||
str_to_date("2001/01/02", formats=["%d/%m/%Y", "%d/%m/%y"])
|
||||
Loading…
Reference in New Issue