Jail timer

feature/currency-improvements
stijndcl 2024-03-04 00:48:17 +01:00
parent a051423203
commit 3509073d7f
7 changed files with 154 additions and 38 deletions

View File

@ -34,14 +34,14 @@ def upgrade() -> None:
)
op.create_table(
"jail",
sa.Column("jail_entry_i", sa.Integer(), nullable=False),
sa.Column("jail_entry_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("until", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("jail_entry_i"),
sa.PrimaryKeyConstraint("jail_entry_id"),
)
op.create_table(
"savings",

View File

@ -1,12 +1,19 @@
from datetime import datetime
from typing import Optional
from sqlalchemy import select
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas import Jail
__all__ = ["get_jail", "get_user_jail", "imprison"]
__all__ = [
"get_jail",
"get_jail_entry_by_id",
"get_next_jail_release",
"get_user_jail",
"imprison",
"delete_prisoner_by_id",
]
async def get_jail(session: AsyncSession) -> list[Jail]:
@ -15,14 +22,36 @@ async def get_jail(session: AsyncSession) -> list[Jail]:
return list((await session.execute(statement)).scalars().all())
async def get_jail_entry_by_id(session: AsyncSession, jail_id: int) -> Optional[Jail]:
"""Get a jail entry by its id"""
statement = select(Jail).where(Jail.jail_entry_id == jail_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_next_jail_release(session: AsyncSession) -> Optional[Jail]:
"""Get the next person being released from jail"""
statement = select(Jail).order_by(Jail.until)
return (await session.execute(statement)).scalars().first()
async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]:
"""Check how long a given user is still in jail for"""
statement = select(Jail).where(Jail.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none()
async def imprison(session: AsyncSession, user_id: int, until: datetime):
async def imprison(session: AsyncSession, user_id: int, until: datetime) -> Jail:
"""Put a user in Didier Jail"""
jail = Jail(user_id=user_id, until=until)
session.add(jail)
await session.commit()
await session.refresh(jail)
return jail
async def delete_prisoner_by_id(session: AsyncSession, jail_id: int):
"""Release a user from jail using their jail entry id"""
statement = delete(Jail).where(Jail.jail_entry_id == jail_id)
await session.execute(statement)
await session.commit()

View File

@ -1,4 +1,4 @@
__all__ = ["DoubleNightly", "NotEnoughDinks"]
__all__ = ["DoubleNightly", "NotEnoughDinks", "SavingsCapExceeded"]
class DoubleNightly(Exception):

View File

@ -235,7 +235,7 @@ class Jail(Base):
__tablename__ = "jail"
jail_entry_i: Mapped[int] = mapped_column(primary_key=True)
jail_entry_id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
until: Mapped[datetime] = mapped_column(nullable=False)

View File

@ -11,7 +11,12 @@ from discord.ext import commands
import settings
from database.crud import currency as crud
from database.crud import users
from database.crud.jail import get_user_jail, imprison
from database.crud.jail import (
delete_prisoner_by_id,
get_jail_entry_by_id,
get_user_jail,
imprison,
)
from database.exceptions.currency import (
DoubleNightly,
NotEnoughDinks,
@ -32,6 +37,7 @@ from didier import Didier
from didier.utils.discord import colours
from didier.utils.discord.checks import is_owner
from didier.utils.discord.converters import abbreviated_number
from didier.utils.timer import JailTimer
from didier.utils.types.datetime import tz_aware_now
from didier.utils.types.string import pluralize
@ -40,11 +46,14 @@ class Currency(commands.Cog):
"""Everything Dinks-related."""
client: Didier
_jail_timer: JailTimer
_rob_lock: asyncio.Lock
def __init__(self, client: Didier):
super().__init__()
self.client = client
self._jail_timer = JailTimer(client)
self._rob_lock = asyncio.Lock()
@commands.command(name="award") # type: ignore[arg-type]
@ -276,7 +285,8 @@ class Currency(commands.Cog):
if to_jail:
jail_t = jail_time(robber.rob_level) * punishment_factor
until = tz_aware_now() + timedelta(hours=jail_t)
await imprison(session, ctx.author.id, until)
jail = await imprison(session, ctx.author.id, until)
self._jail_timer.maybe_replace_task(jail)
return await ctx.reply(
f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, "
@ -307,6 +317,19 @@ class Currency(commands.Cog):
return await ctx.reply(embed=embed, mention_author=False)
@commands.Cog.listener()
async def on_jail_release(self, jail_id: int):
"""Custom listener called when a jail timer ends"""
async with self.client.postgres_session as session:
entry = await get_jail_entry_by_id(session, jail_id)
if entry is None:
return await self.client.log_error(f"Unable to find jail entry with id {jail_id}.", log_to_discord=True)
await delete_prisoner_by_id(session, jail_id)
await self._jail_timer.update()
async def setup(client: Didier):
"""Load the cog"""

View File

@ -19,7 +19,7 @@ from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
from didier.utils.discord.constants import Limits
from didier.utils.timer import Timer
from didier.utils.timer import EventTimer
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
from didier.utils.types.string import abbreviate, leading
from didier.views.modals import CreateBookmark
@ -29,7 +29,7 @@ class Discord(commands.Cog):
"""Commands related to Discord itself, which work with resources like servers and members."""
client: Didier
timer: Timer
_event_timer: EventTimer
# Context-menu references
_bookmark_ctx_menu: app_commands.ContextMenu
@ -42,20 +42,25 @@ class Discord(commands.Cog):
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
self.client.tree.add_command(self._bookmark_ctx_menu)
self.client.tree.add_command(self._pin_ctx_menu)
self.timer = Timer(self.client)
self._event_timer = EventTimer(self.client)
async def cog_load(self) -> None:
"""Start any stored timers when the cog is loaded"""
await self._event_timer.update()
async def cog_unload(self) -> None:
"""Remove the commands when the cog is unloaded"""
"""Remove the commands and timers when the cog is unloaded"""
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
self._event_timer.cancel()
@commands.Cog.listener()
async def on_event_create(self, event: Event):
"""Custom listener called when an event is created"""
self.timer.maybe_replace_task(event)
self._event_timer.maybe_replace_task(event)
@commands.Cog.listener()
async def on_timer_end(self, event_id: int):
async def on_event_reminder(self, event_id: int):
"""Custom listener called when an event timer ends"""
async with self.client.postgres_session as session:
event = await events.get_event_by_id(session, event_id)
@ -89,7 +94,7 @@ class Discord(commands.Cog):
await events.delete_event_by_id(session, event.event_id)
# Set the next timer
self.client.loop.create_task(self.timer.update())
await self._event_timer.update()
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):

View File

@ -1,69 +1,128 @@
import abc
import asyncio
from datetime import datetime, timedelta
from typing import Optional
from typing import Generic, Optional, TypeVar
import discord.utils
from overrides import overrides
import settings
from database.crud.events import get_next_event
from database.schemas import Event
from database.crud.jail import get_next_jail_release
from database.schemas import Event, Jail
from didier import Didier
from didier.utils.types.datetime import tz_aware_now
__all__ = ["Timer"]
__all__ = ["JailTimer", "EventTimer"]
REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE)
# TODO make this generic
# TODO add timer for jail freeing
class Timer:
"""Class for scheduled timers"""
T = TypeVar("T")
class ABCTimer(abc.ABC, Generic[T]):
"""Base class for scheduled timers"""
client: Didier
upcoming_timer: Optional[datetime]
upcoming_event_id: Optional[int]
_task: Optional[asyncio.Task]
def __init__(self, client: Didier):
_delta: Optional[timedelta]
_event: str
def __init__(self, client: Didier, *, event: str, delta: Optional[timedelta] = None):
self.client = client
self.upcoming_timer = None
self.upcoming_event_id = None
self._task = None
self.client.loop.create_task(self.update())
self._delta = delta
self._event = event
@abc.abstractmethod
async def dissect_item(self, item: T) -> tuple[datetime, int]:
"""Method that takes an item and returns the corresponding timestamp and id"""
@abc.abstractmethod
async def get_next(self) -> Optional[T]:
"""Method that fetches the next item from the database"""
async def update(self):
"""Get & schedule the closest reminder"""
async with self.client.postgres_session as session:
event = await get_next_event(session, now=tz_aware_now())
"""Get & schedule the closest item"""
next_item = await self.get_next()
# No upcoming events
if event is None:
# No upcoming items
if next_item is None:
return
self.maybe_replace_task(event)
self.maybe_replace_task(next_item)
def maybe_replace_task(self, event: Event):
def cancel(self):
"""Cancel the running task"""
if self._task is not None:
self._task.cancel()
self._task = None
def maybe_replace_task(self, item: T):
"""Replace the current task if necessary"""
timestamp, item_id = self.dissect_item(item)
# If there is a current (pending) task, and the new timer is sooner than the
# pending one, cancel it
if self._task is not None and not self._task.done():
# The upcoming timer will never be None at this point, but Mypy is mad
if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp:
if self.upcoming_timer is not None and self.upcoming_timer > timestamp:
self._task.cancel()
else:
# The new task happens after the existing task, it has to wait for its turn
return
self.upcoming_timer = event.timestamp
self.upcoming_event_id = event.event_id
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id))
self.upcoming_timer = timestamp
self.upcoming_event_id = item_id
self._task = self.client.loop.create_task(self.end_timer(endtime=timestamp, event_id=item_id))
async def end_timer(self, *, endtime: datetime, event_id: int):
"""Wait until a timer runs out, and then trigger an event to send the message"""
await discord.utils.sleep_until(endtime - REMINDER_PREDELAY)
until = endtime
if self._delta is not None:
until -= self._delta
await discord.utils.sleep_until(until)
self.upcoming_timer = None
self.upcoming_event_id = None
self.client.dispatch("timer_end", event_id)
self.client.dispatch(self._event, event_id)
class EventTimer(ABCTimer[Event]):
"""Timer for upcoming IRL events"""
def __init__(self, client: Didier):
super().__init__(client, event="event_reminder", delta=REMINDER_PREDELAY)
@overrides
async def dissect_item(self, item: Event) -> tuple[datetime, int]:
return item.timestamp, item.event_id
@overrides
async def get_next(self) -> Optional[Event]:
async with self.client.postgres_session as session:
return await get_next_event(session, now=tz_aware_now())
class JailTimer(ABCTimer[Jail]):
"""Timer for people spending time in Didier Jail"""
def __init__(self, client: Didier):
super().__init__(client, event="jail_release")
@overrides
async def dissect_item(self, item: Jail) -> tuple[datetime, int]:
return item.until, item.jail_entry_id
@overrides
async def get_next(self) -> Optional[Jail]:
async with self.client.postgres_session as session:
return await get_next_jail_release(session)