diff --git a/alembic/versions/8f7e8384cec3_currency_updates.py b/alembic/versions/8f7e8384cec3_currency_updates.py index f05451e..d174bfd 100644 --- a/alembic/versions/8f7e8384cec3_currency_updates.py +++ b/alembic/versions/8f7e8384cec3_currency_updates.py @@ -34,14 +34,14 @@ def upgrade() -> None: ) op.create_table( "jail", - sa.Column("jail_entry_i", sa.Integer(), nullable=False), + sa.Column("jail_entry_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=False), sa.Column("until", sa.DateTime(timezone=True), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["users.user_id"], ), - sa.PrimaryKeyConstraint("jail_entry_i"), + sa.PrimaryKeyConstraint("jail_entry_id"), ) op.create_table( "savings", diff --git a/database/crud/jail.py b/database/crud/jail.py index e2bdd80..c8c72c0 100644 --- a/database/crud/jail.py +++ b/database/crud/jail.py @@ -1,12 +1,19 @@ from datetime import datetime from typing import Optional -from sqlalchemy import select +from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession from database.schemas import Jail -__all__ = ["get_jail", "get_user_jail", "imprison"] +__all__ = [ + "get_jail", + "get_jail_entry_by_id", + "get_next_jail_release", + "get_user_jail", + "imprison", + "delete_prisoner_by_id", +] async def get_jail(session: AsyncSession) -> list[Jail]: @@ -15,14 +22,36 @@ async def get_jail(session: AsyncSession) -> list[Jail]: return list((await session.execute(statement)).scalars().all()) +async def get_jail_entry_by_id(session: AsyncSession, jail_id: int) -> Optional[Jail]: + """Get a jail entry by its id""" + statement = select(Jail).where(Jail.jail_entry_id == jail_id) + return (await session.execute(statement)).scalar_one_or_none() + + +async def get_next_jail_release(session: AsyncSession) -> Optional[Jail]: + """Get the next person being released from jail""" + statement = select(Jail).order_by(Jail.until) + return (await session.execute(statement)).scalars().first() + + async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]: """Check how long a given user is still in jail for""" statement = select(Jail).where(Jail.user_id == user_id) return (await session.execute(statement)).scalar_one_or_none() -async def imprison(session: AsyncSession, user_id: int, until: datetime): +async def imprison(session: AsyncSession, user_id: int, until: datetime) -> Jail: """Put a user in Didier Jail""" jail = Jail(user_id=user_id, until=until) session.add(jail) await session.commit() + await session.refresh(jail) + + return jail + + +async def delete_prisoner_by_id(session: AsyncSession, jail_id: int): + """Release a user from jail using their jail entry id""" + statement = delete(Jail).where(Jail.jail_entry_id == jail_id) + await session.execute(statement) + await session.commit() diff --git a/database/exceptions/currency.py b/database/exceptions/currency.py index 2cc9a1d..941faf1 100644 --- a/database/exceptions/currency.py +++ b/database/exceptions/currency.py @@ -1,4 +1,4 @@ -__all__ = ["DoubleNightly", "NotEnoughDinks"] +__all__ = ["DoubleNightly", "NotEnoughDinks", "SavingsCapExceeded"] class DoubleNightly(Exception): diff --git a/database/schemas.py b/database/schemas.py index 1d8ffcb..8543870 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -235,7 +235,7 @@ class Jail(Base): __tablename__ = "jail" - jail_entry_i: Mapped[int] = mapped_column(primary_key=True) + jail_entry_id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) until: Mapped[datetime] = mapped_column(nullable=False) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 69e6363..c5ea8db 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -11,7 +11,12 @@ from discord.ext import commands import settings from database.crud import currency as crud from database.crud import users -from database.crud.jail import get_user_jail, imprison +from database.crud.jail import ( + delete_prisoner_by_id, + get_jail_entry_by_id, + get_user_jail, + imprison, +) from database.exceptions.currency import ( DoubleNightly, NotEnoughDinks, @@ -32,6 +37,7 @@ from didier import Didier from didier.utils.discord import colours from didier.utils.discord.checks import is_owner from didier.utils.discord.converters import abbreviated_number +from didier.utils.timer import JailTimer from didier.utils.types.datetime import tz_aware_now from didier.utils.types.string import pluralize @@ -40,11 +46,14 @@ class Currency(commands.Cog): """Everything Dinks-related.""" client: Didier + + _jail_timer: JailTimer _rob_lock: asyncio.Lock def __init__(self, client: Didier): super().__init__() self.client = client + self._jail_timer = JailTimer(client) self._rob_lock = asyncio.Lock() @commands.command(name="award") # type: ignore[arg-type] @@ -276,7 +285,8 @@ class Currency(commands.Cog): if to_jail: jail_t = jail_time(robber.rob_level) * punishment_factor until = tz_aware_now() + timedelta(hours=jail_t) - await imprison(session, ctx.author.id, until) + jail = await imprison(session, ctx.author.id, until) + self._jail_timer.maybe_replace_task(jail) return await ctx.reply( f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, " @@ -307,6 +317,19 @@ class Currency(commands.Cog): return await ctx.reply(embed=embed, mention_author=False) + @commands.Cog.listener() + async def on_jail_release(self, jail_id: int): + """Custom listener called when a jail timer ends""" + async with self.client.postgres_session as session: + entry = await get_jail_entry_by_id(session, jail_id) + + if entry is None: + return await self.client.log_error(f"Unable to find jail entry with id {jail_id}.", log_to_discord=True) + + await delete_prisoner_by_id(session, jail_id) + + await self._jail_timer.update() + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index fdfa05d..0243d12 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -19,7 +19,7 @@ from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.constants import Limits -from didier.utils.timer import Timer +from didier.utils.timer import EventTimer from didier.utils.types.datetime import localize, str_to_date, tz_aware_now from didier.utils.types.string import abbreviate, leading from didier.views.modals import CreateBookmark @@ -29,7 +29,7 @@ class Discord(commands.Cog): """Commands related to Discord itself, which work with resources like servers and members.""" client: Didier - timer: Timer + _event_timer: EventTimer # Context-menu references _bookmark_ctx_menu: app_commands.ContextMenu @@ -42,20 +42,25 @@ class Discord(commands.Cog): self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx) self.client.tree.add_command(self._bookmark_ctx_menu) self.client.tree.add_command(self._pin_ctx_menu) - self.timer = Timer(self.client) + self._event_timer = EventTimer(self.client) + + async def cog_load(self) -> None: + """Start any stored timers when the cog is loaded""" + await self._event_timer.update() async def cog_unload(self) -> None: - """Remove the commands when the cog is unloaded""" + """Remove the commands and timers when the cog is unloaded""" self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) + self._event_timer.cancel() @commands.Cog.listener() async def on_event_create(self, event: Event): """Custom listener called when an event is created""" - self.timer.maybe_replace_task(event) + self._event_timer.maybe_replace_task(event) @commands.Cog.listener() - async def on_timer_end(self, event_id: int): + async def on_event_reminder(self, event_id: int): """Custom listener called when an event timer ends""" async with self.client.postgres_session as session: event = await events.get_event_by_id(session, event_id) @@ -89,7 +94,7 @@ class Discord(commands.Cog): await events.delete_event_by_id(session, event.event_id) # Set the next timer - self.client.loop.create_task(self.timer.update()) + await self._event_timer.update() @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None): diff --git a/didier/utils/timer.py b/didier/utils/timer.py index 08f674a..b39ccf5 100644 --- a/didier/utils/timer.py +++ b/didier/utils/timer.py @@ -1,69 +1,128 @@ +import abc import asyncio from datetime import datetime, timedelta -from typing import Optional +from typing import Generic, Optional, TypeVar import discord.utils +from overrides import overrides import settings from database.crud.events import get_next_event -from database.schemas import Event +from database.crud.jail import get_next_jail_release +from database.schemas import Event, Jail from didier import Didier from didier.utils.types.datetime import tz_aware_now -__all__ = ["Timer"] +__all__ = ["JailTimer", "EventTimer"] REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE) -# TODO make this generic -# TODO add timer for jail freeing -class Timer: - """Class for scheduled timers""" +T = TypeVar("T") + + +class ABCTimer(abc.ABC, Generic[T]): + """Base class for scheduled timers""" client: Didier upcoming_timer: Optional[datetime] upcoming_event_id: Optional[int] _task: Optional[asyncio.Task] - def __init__(self, client: Didier): + _delta: Optional[timedelta] + _event: str + + def __init__(self, client: Didier, *, event: str, delta: Optional[timedelta] = None): self.client = client self.upcoming_timer = None self.upcoming_event_id = None self._task = None - self.client.loop.create_task(self.update()) + self._delta = delta + self._event = event + + @abc.abstractmethod + async def dissect_item(self, item: T) -> tuple[datetime, int]: + """Method that takes an item and returns the corresponding timestamp and id""" + + @abc.abstractmethod + async def get_next(self) -> Optional[T]: + """Method that fetches the next item from the database""" async def update(self): - """Get & schedule the closest reminder""" - async with self.client.postgres_session as session: - event = await get_next_event(session, now=tz_aware_now()) + """Get & schedule the closest item""" + next_item = await self.get_next() - # No upcoming events - if event is None: + # No upcoming items + if next_item is None: return - self.maybe_replace_task(event) + self.maybe_replace_task(next_item) - def maybe_replace_task(self, event: Event): + def cancel(self): + """Cancel the running task""" + if self._task is not None: + self._task.cancel() + self._task = None + + def maybe_replace_task(self, item: T): """Replace the current task if necessary""" + timestamp, item_id = self.dissect_item(item) + # If there is a current (pending) task, and the new timer is sooner than the # pending one, cancel it if self._task is not None and not self._task.done(): # The upcoming timer will never be None at this point, but Mypy is mad - if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp: + if self.upcoming_timer is not None and self.upcoming_timer > timestamp: self._task.cancel() else: # The new task happens after the existing task, it has to wait for its turn return - self.upcoming_timer = event.timestamp - self.upcoming_event_id = event.event_id - self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) + self.upcoming_timer = timestamp + self.upcoming_event_id = item_id + self._task = self.client.loop.create_task(self.end_timer(endtime=timestamp, event_id=item_id)) async def end_timer(self, *, endtime: datetime, event_id: int): """Wait until a timer runs out, and then trigger an event to send the message""" - await discord.utils.sleep_until(endtime - REMINDER_PREDELAY) + until = endtime + if self._delta is not None: + until -= self._delta + + await discord.utils.sleep_until(until) self.upcoming_timer = None self.upcoming_event_id = None - self.client.dispatch("timer_end", event_id) + self.client.dispatch(self._event, event_id) + + +class EventTimer(ABCTimer[Event]): + """Timer for upcoming IRL events""" + + def __init__(self, client: Didier): + super().__init__(client, event="event_reminder", delta=REMINDER_PREDELAY) + + @overrides + async def dissect_item(self, item: Event) -> tuple[datetime, int]: + return item.timestamp, item.event_id + + @overrides + async def get_next(self) -> Optional[Event]: + async with self.client.postgres_session as session: + return await get_next_event(session, now=tz_aware_now()) + + +class JailTimer(ABCTimer[Jail]): + """Timer for people spending time in Didier Jail""" + + def __init__(self, client: Didier): + super().__init__(client, event="jail_release") + + @overrides + async def dissect_item(self, item: Jail) -> tuple[datetime, int]: + return item.until, item.jail_entry_id + + @overrides + async def get_next(self) -> Optional[Jail]: + async with self.client.postgres_session as session: + return await get_next_jail_release(session)