Jail timer

feature/currency-improvements
stijndcl 2024-03-04 00:48:17 +01:00
parent a051423203
commit 3509073d7f
7 changed files with 154 additions and 38 deletions

View File

@ -34,14 +34,14 @@ def upgrade() -> None:
) )
op.create_table( op.create_table(
"jail", "jail",
sa.Column("jail_entry_i", sa.Integer(), nullable=False), sa.Column("jail_entry_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=False),
sa.Column("until", sa.DateTime(timezone=True), nullable=False), sa.Column("until", sa.DateTime(timezone=True), nullable=False),
sa.ForeignKeyConstraint( sa.ForeignKeyConstraint(
["user_id"], ["user_id"],
["users.user_id"], ["users.user_id"],
), ),
sa.PrimaryKeyConstraint("jail_entry_i"), sa.PrimaryKeyConstraint("jail_entry_id"),
) )
op.create_table( op.create_table(
"savings", "savings",

View File

@ -1,12 +1,19 @@
from datetime import datetime from datetime import datetime
from typing import Optional from typing import Optional
from sqlalchemy import select from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas import Jail from database.schemas import Jail
__all__ = ["get_jail", "get_user_jail", "imprison"] __all__ = [
"get_jail",
"get_jail_entry_by_id",
"get_next_jail_release",
"get_user_jail",
"imprison",
"delete_prisoner_by_id",
]
async def get_jail(session: AsyncSession) -> list[Jail]: async def get_jail(session: AsyncSession) -> list[Jail]:
@ -15,14 +22,36 @@ async def get_jail(session: AsyncSession) -> list[Jail]:
return list((await session.execute(statement)).scalars().all()) return list((await session.execute(statement)).scalars().all())
async def get_jail_entry_by_id(session: AsyncSession, jail_id: int) -> Optional[Jail]:
"""Get a jail entry by its id"""
statement = select(Jail).where(Jail.jail_entry_id == jail_id)
return (await session.execute(statement)).scalar_one_or_none()
async def get_next_jail_release(session: AsyncSession) -> Optional[Jail]:
"""Get the next person being released from jail"""
statement = select(Jail).order_by(Jail.until)
return (await session.execute(statement)).scalars().first()
async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]: async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]:
"""Check how long a given user is still in jail for""" """Check how long a given user is still in jail for"""
statement = select(Jail).where(Jail.user_id == user_id) statement = select(Jail).where(Jail.user_id == user_id)
return (await session.execute(statement)).scalar_one_or_none() return (await session.execute(statement)).scalar_one_or_none()
async def imprison(session: AsyncSession, user_id: int, until: datetime): async def imprison(session: AsyncSession, user_id: int, until: datetime) -> Jail:
"""Put a user in Didier Jail""" """Put a user in Didier Jail"""
jail = Jail(user_id=user_id, until=until) jail = Jail(user_id=user_id, until=until)
session.add(jail) session.add(jail)
await session.commit() await session.commit()
await session.refresh(jail)
return jail
async def delete_prisoner_by_id(session: AsyncSession, jail_id: int):
"""Release a user from jail using their jail entry id"""
statement = delete(Jail).where(Jail.jail_entry_id == jail_id)
await session.execute(statement)
await session.commit()

View File

@ -1,4 +1,4 @@
__all__ = ["DoubleNightly", "NotEnoughDinks"] __all__ = ["DoubleNightly", "NotEnoughDinks", "SavingsCapExceeded"]
class DoubleNightly(Exception): class DoubleNightly(Exception):

View File

@ -235,7 +235,7 @@ class Jail(Base):
__tablename__ = "jail" __tablename__ = "jail"
jail_entry_i: Mapped[int] = mapped_column(primary_key=True) jail_entry_id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
until: Mapped[datetime] = mapped_column(nullable=False) until: Mapped[datetime] = mapped_column(nullable=False)

View File

@ -11,7 +11,12 @@ from discord.ext import commands
import settings import settings
from database.crud import currency as crud from database.crud import currency as crud
from database.crud import users from database.crud import users
from database.crud.jail import get_user_jail, imprison from database.crud.jail import (
delete_prisoner_by_id,
get_jail_entry_by_id,
get_user_jail,
imprison,
)
from database.exceptions.currency import ( from database.exceptions.currency import (
DoubleNightly, DoubleNightly,
NotEnoughDinks, NotEnoughDinks,
@ -32,6 +37,7 @@ from didier import Didier
from didier.utils.discord import colours from didier.utils.discord import colours
from didier.utils.discord.checks import is_owner from didier.utils.discord.checks import is_owner
from didier.utils.discord.converters import abbreviated_number from didier.utils.discord.converters import abbreviated_number
from didier.utils.timer import JailTimer
from didier.utils.types.datetime import tz_aware_now from didier.utils.types.datetime import tz_aware_now
from didier.utils.types.string import pluralize from didier.utils.types.string import pluralize
@ -40,11 +46,14 @@ class Currency(commands.Cog):
"""Everything Dinks-related.""" """Everything Dinks-related."""
client: Didier client: Didier
_jail_timer: JailTimer
_rob_lock: asyncio.Lock _rob_lock: asyncio.Lock
def __init__(self, client: Didier): def __init__(self, client: Didier):
super().__init__() super().__init__()
self.client = client self.client = client
self._jail_timer = JailTimer(client)
self._rob_lock = asyncio.Lock() self._rob_lock = asyncio.Lock()
@commands.command(name="award") # type: ignore[arg-type] @commands.command(name="award") # type: ignore[arg-type]
@ -276,7 +285,8 @@ class Currency(commands.Cog):
if to_jail: if to_jail:
jail_t = jail_time(robber.rob_level) * punishment_factor jail_t = jail_time(robber.rob_level) * punishment_factor
until = tz_aware_now() + timedelta(hours=jail_t) until = tz_aware_now() + timedelta(hours=jail_t)
await imprison(session, ctx.author.id, until) jail = await imprison(session, ctx.author.id, until)
self._jail_timer.maybe_replace_task(jail)
return await ctx.reply( return await ctx.reply(
f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, " f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, "
@ -307,6 +317,19 @@ class Currency(commands.Cog):
return await ctx.reply(embed=embed, mention_author=False) return await ctx.reply(embed=embed, mention_author=False)
@commands.Cog.listener()
async def on_jail_release(self, jail_id: int):
"""Custom listener called when a jail timer ends"""
async with self.client.postgres_session as session:
entry = await get_jail_entry_by_id(session, jail_id)
if entry is None:
return await self.client.log_error(f"Unable to find jail entry with id {jail_id}.", log_to_discord=True)
await delete_prisoner_by_id(session, jail_id)
await self._jail_timer.update()
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""

View File

@ -19,7 +19,7 @@ from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
from didier.utils.discord.constants import Limits from didier.utils.discord.constants import Limits
from didier.utils.timer import Timer from didier.utils.timer import EventTimer
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
from didier.utils.types.string import abbreviate, leading from didier.utils.types.string import abbreviate, leading
from didier.views.modals import CreateBookmark from didier.views.modals import CreateBookmark
@ -29,7 +29,7 @@ class Discord(commands.Cog):
"""Commands related to Discord itself, which work with resources like servers and members.""" """Commands related to Discord itself, which work with resources like servers and members."""
client: Didier client: Didier
timer: Timer _event_timer: EventTimer
# Context-menu references # Context-menu references
_bookmark_ctx_menu: app_commands.ContextMenu _bookmark_ctx_menu: app_commands.ContextMenu
@ -42,20 +42,25 @@ class Discord(commands.Cog):
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx) self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
self.client.tree.add_command(self._bookmark_ctx_menu) self.client.tree.add_command(self._bookmark_ctx_menu)
self.client.tree.add_command(self._pin_ctx_menu) self.client.tree.add_command(self._pin_ctx_menu)
self.timer = Timer(self.client) self._event_timer = EventTimer(self.client)
async def cog_load(self) -> None:
"""Start any stored timers when the cog is loaded"""
await self._event_timer.update()
async def cog_unload(self) -> None: async def cog_unload(self) -> None:
"""Remove the commands when the cog is unloaded""" """Remove the commands and timers when the cog is unloaded"""
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type) self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type) self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
self._event_timer.cancel()
@commands.Cog.listener() @commands.Cog.listener()
async def on_event_create(self, event: Event): async def on_event_create(self, event: Event):
"""Custom listener called when an event is created""" """Custom listener called when an event is created"""
self.timer.maybe_replace_task(event) self._event_timer.maybe_replace_task(event)
@commands.Cog.listener() @commands.Cog.listener()
async def on_timer_end(self, event_id: int): async def on_event_reminder(self, event_id: int):
"""Custom listener called when an event timer ends""" """Custom listener called when an event timer ends"""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
event = await events.get_event_by_id(session, event_id) event = await events.get_event_by_id(session, event_id)
@ -89,7 +94,7 @@ class Discord(commands.Cog):
await events.delete_event_by_id(session, event.event_id) await events.delete_event_by_id(session, event.event_id)
# Set the next timer # Set the next timer
self.client.loop.create_task(self.timer.update()) await self._event_timer.update()
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None): async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):

View File

@ -1,69 +1,128 @@
import abc
import asyncio import asyncio
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Optional from typing import Generic, Optional, TypeVar
import discord.utils import discord.utils
from overrides import overrides
import settings import settings
from database.crud.events import get_next_event from database.crud.events import get_next_event
from database.schemas import Event from database.crud.jail import get_next_jail_release
from database.schemas import Event, Jail
from didier import Didier from didier import Didier
from didier.utils.types.datetime import tz_aware_now from didier.utils.types.datetime import tz_aware_now
__all__ = ["Timer"] __all__ = ["JailTimer", "EventTimer"]
REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE) REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE)
# TODO make this generic T = TypeVar("T")
# TODO add timer for jail freeing
class Timer:
"""Class for scheduled timers""" class ABCTimer(abc.ABC, Generic[T]):
"""Base class for scheduled timers"""
client: Didier client: Didier
upcoming_timer: Optional[datetime] upcoming_timer: Optional[datetime]
upcoming_event_id: Optional[int] upcoming_event_id: Optional[int]
_task: Optional[asyncio.Task] _task: Optional[asyncio.Task]
def __init__(self, client: Didier): _delta: Optional[timedelta]
_event: str
def __init__(self, client: Didier, *, event: str, delta: Optional[timedelta] = None):
self.client = client self.client = client
self.upcoming_timer = None self.upcoming_timer = None
self.upcoming_event_id = None self.upcoming_event_id = None
self._task = None self._task = None
self.client.loop.create_task(self.update()) self._delta = delta
self._event = event
@abc.abstractmethod
async def dissect_item(self, item: T) -> tuple[datetime, int]:
"""Method that takes an item and returns the corresponding timestamp and id"""
@abc.abstractmethod
async def get_next(self) -> Optional[T]:
"""Method that fetches the next item from the database"""
async def update(self): async def update(self):
"""Get & schedule the closest reminder""" """Get & schedule the closest item"""
async with self.client.postgres_session as session: next_item = await self.get_next()
event = await get_next_event(session, now=tz_aware_now())
# No upcoming events # No upcoming items
if event is None: if next_item is None:
return return
self.maybe_replace_task(event) self.maybe_replace_task(next_item)
def maybe_replace_task(self, event: Event): def cancel(self):
"""Cancel the running task"""
if self._task is not None:
self._task.cancel()
self._task = None
def maybe_replace_task(self, item: T):
"""Replace the current task if necessary""" """Replace the current task if necessary"""
timestamp, item_id = self.dissect_item(item)
# If there is a current (pending) task, and the new timer is sooner than the # If there is a current (pending) task, and the new timer is sooner than the
# pending one, cancel it # pending one, cancel it
if self._task is not None and not self._task.done(): if self._task is not None and not self._task.done():
# The upcoming timer will never be None at this point, but Mypy is mad # The upcoming timer will never be None at this point, but Mypy is mad
if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp: if self.upcoming_timer is not None and self.upcoming_timer > timestamp:
self._task.cancel() self._task.cancel()
else: else:
# The new task happens after the existing task, it has to wait for its turn # The new task happens after the existing task, it has to wait for its turn
return return
self.upcoming_timer = event.timestamp self.upcoming_timer = timestamp
self.upcoming_event_id = event.event_id self.upcoming_event_id = item_id
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id)) self._task = self.client.loop.create_task(self.end_timer(endtime=timestamp, event_id=item_id))
async def end_timer(self, *, endtime: datetime, event_id: int): async def end_timer(self, *, endtime: datetime, event_id: int):
"""Wait until a timer runs out, and then trigger an event to send the message""" """Wait until a timer runs out, and then trigger an event to send the message"""
await discord.utils.sleep_until(endtime - REMINDER_PREDELAY) until = endtime
if self._delta is not None:
until -= self._delta
await discord.utils.sleep_until(until)
self.upcoming_timer = None self.upcoming_timer = None
self.upcoming_event_id = None self.upcoming_event_id = None
self.client.dispatch("timer_end", event_id) self.client.dispatch(self._event, event_id)
class EventTimer(ABCTimer[Event]):
"""Timer for upcoming IRL events"""
def __init__(self, client: Didier):
super().__init__(client, event="event_reminder", delta=REMINDER_PREDELAY)
@overrides
async def dissect_item(self, item: Event) -> tuple[datetime, int]:
return item.timestamp, item.event_id
@overrides
async def get_next(self) -> Optional[Event]:
async with self.client.postgres_session as session:
return await get_next_event(session, now=tz_aware_now())
class JailTimer(ABCTimer[Jail]):
"""Timer for people spending time in Didier Jail"""
def __init__(self, client: Didier):
super().__init__(client, event="jail_release")
@overrides
async def dissect_item(self, item: Jail) -> tuple[datetime, int]:
return item.until, item.jail_entry_id
@overrides
async def get_next(self) -> Optional[Jail]:
async with self.client.postgres_session as session:
return await get_next_jail_release(session)