mirror of https://github.com/stijndcl/didier
Jail timer
parent
a051423203
commit
3509073d7f
|
@ -34,14 +34,14 @@ def upgrade() -> None:
|
|||
)
|
||||
op.create_table(
|
||||
"jail",
|
||||
sa.Column("jail_entry_i", sa.Integer(), nullable=False),
|
||||
sa.Column("jail_entry_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("until", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("jail_entry_i"),
|
||||
sa.PrimaryKeyConstraint("jail_entry_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"savings",
|
||||
|
|
|
@ -1,12 +1,19 @@
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas import Jail
|
||||
|
||||
__all__ = ["get_jail", "get_user_jail", "imprison"]
|
||||
__all__ = [
|
||||
"get_jail",
|
||||
"get_jail_entry_by_id",
|
||||
"get_next_jail_release",
|
||||
"get_user_jail",
|
||||
"imprison",
|
||||
"delete_prisoner_by_id",
|
||||
]
|
||||
|
||||
|
||||
async def get_jail(session: AsyncSession) -> list[Jail]:
|
||||
|
@ -15,14 +22,36 @@ async def get_jail(session: AsyncSession) -> list[Jail]:
|
|||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_jail_entry_by_id(session: AsyncSession, jail_id: int) -> Optional[Jail]:
|
||||
"""Get a jail entry by its id"""
|
||||
statement = select(Jail).where(Jail.jail_entry_id == jail_id)
|
||||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def get_next_jail_release(session: AsyncSession) -> Optional[Jail]:
|
||||
"""Get the next person being released from jail"""
|
||||
statement = select(Jail).order_by(Jail.until)
|
||||
return (await session.execute(statement)).scalars().first()
|
||||
|
||||
|
||||
async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]:
|
||||
"""Check how long a given user is still in jail for"""
|
||||
statement = select(Jail).where(Jail.user_id == user_id)
|
||||
return (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
|
||||
async def imprison(session: AsyncSession, user_id: int, until: datetime):
|
||||
async def imprison(session: AsyncSession, user_id: int, until: datetime) -> Jail:
|
||||
"""Put a user in Didier Jail"""
|
||||
jail = Jail(user_id=user_id, until=until)
|
||||
session.add(jail)
|
||||
await session.commit()
|
||||
await session.refresh(jail)
|
||||
|
||||
return jail
|
||||
|
||||
|
||||
async def delete_prisoner_by_id(session: AsyncSession, jail_id: int):
|
||||
"""Release a user from jail using their jail entry id"""
|
||||
statement = delete(Jail).where(Jail.jail_entry_id == jail_id)
|
||||
await session.execute(statement)
|
||||
await session.commit()
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
__all__ = ["DoubleNightly", "NotEnoughDinks"]
|
||||
__all__ = ["DoubleNightly", "NotEnoughDinks", "SavingsCapExceeded"]
|
||||
|
||||
|
||||
class DoubleNightly(Exception):
|
||||
|
|
|
@ -235,7 +235,7 @@ class Jail(Base):
|
|||
|
||||
__tablename__ = "jail"
|
||||
|
||||
jail_entry_i: Mapped[int] = mapped_column(primary_key=True)
|
||||
jail_entry_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
until: Mapped[datetime] = mapped_column(nullable=False)
|
||||
|
||||
|
|
|
@ -11,7 +11,12 @@ from discord.ext import commands
|
|||
import settings
|
||||
from database.crud import currency as crud
|
||||
from database.crud import users
|
||||
from database.crud.jail import get_user_jail, imprison
|
||||
from database.crud.jail import (
|
||||
delete_prisoner_by_id,
|
||||
get_jail_entry_by_id,
|
||||
get_user_jail,
|
||||
imprison,
|
||||
)
|
||||
from database.exceptions.currency import (
|
||||
DoubleNightly,
|
||||
NotEnoughDinks,
|
||||
|
@ -32,6 +37,7 @@ from didier import Didier
|
|||
from didier.utils.discord import colours
|
||||
from didier.utils.discord.checks import is_owner
|
||||
from didier.utils.discord.converters import abbreviated_number
|
||||
from didier.utils.timer import JailTimer
|
||||
from didier.utils.types.datetime import tz_aware_now
|
||||
from didier.utils.types.string import pluralize
|
||||
|
||||
|
@ -40,11 +46,14 @@ class Currency(commands.Cog):
|
|||
"""Everything Dinks-related."""
|
||||
|
||||
client: Didier
|
||||
|
||||
_jail_timer: JailTimer
|
||||
_rob_lock: asyncio.Lock
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
super().__init__()
|
||||
self.client = client
|
||||
self._jail_timer = JailTimer(client)
|
||||
self._rob_lock = asyncio.Lock()
|
||||
|
||||
@commands.command(name="award") # type: ignore[arg-type]
|
||||
|
@ -276,7 +285,8 @@ class Currency(commands.Cog):
|
|||
if to_jail:
|
||||
jail_t = jail_time(robber.rob_level) * punishment_factor
|
||||
until = tz_aware_now() + timedelta(hours=jail_t)
|
||||
await imprison(session, ctx.author.id, until)
|
||||
jail = await imprison(session, ctx.author.id, until)
|
||||
self._jail_timer.maybe_replace_task(jail)
|
||||
|
||||
return await ctx.reply(
|
||||
f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, "
|
||||
|
@ -307,6 +317,19 @@ class Currency(commands.Cog):
|
|||
|
||||
return await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_jail_release(self, jail_id: int):
|
||||
"""Custom listener called when a jail timer ends"""
|
||||
async with self.client.postgres_session as session:
|
||||
entry = await get_jail_entry_by_id(session, jail_id)
|
||||
|
||||
if entry is None:
|
||||
return await self.client.log_error(f"Unable to find jail entry with id {jail_id}.", log_to_discord=True)
|
||||
|
||||
await delete_prisoner_by_id(session, jail_id)
|
||||
|
||||
await self._jail_timer.update()
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
|
|
|
@ -19,7 +19,7 @@ from didier.utils.discord import colours
|
|||
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
||||
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
|
||||
from didier.utils.discord.constants import Limits
|
||||
from didier.utils.timer import Timer
|
||||
from didier.utils.timer import EventTimer
|
||||
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
|
||||
from didier.utils.types.string import abbreviate, leading
|
||||
from didier.views.modals import CreateBookmark
|
||||
|
@ -29,7 +29,7 @@ class Discord(commands.Cog):
|
|||
"""Commands related to Discord itself, which work with resources like servers and members."""
|
||||
|
||||
client: Didier
|
||||
timer: Timer
|
||||
_event_timer: EventTimer
|
||||
|
||||
# Context-menu references
|
||||
_bookmark_ctx_menu: app_commands.ContextMenu
|
||||
|
@ -42,20 +42,25 @@ class Discord(commands.Cog):
|
|||
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
|
||||
self.client.tree.add_command(self._bookmark_ctx_menu)
|
||||
self.client.tree.add_command(self._pin_ctx_menu)
|
||||
self.timer = Timer(self.client)
|
||||
self._event_timer = EventTimer(self.client)
|
||||
|
||||
async def cog_load(self) -> None:
|
||||
"""Start any stored timers when the cog is loaded"""
|
||||
await self._event_timer.update()
|
||||
|
||||
async def cog_unload(self) -> None:
|
||||
"""Remove the commands when the cog is unloaded"""
|
||||
"""Remove the commands and timers when the cog is unloaded"""
|
||||
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
|
||||
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
||||
self._event_timer.cancel()
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_event_create(self, event: Event):
|
||||
"""Custom listener called when an event is created"""
|
||||
self.timer.maybe_replace_task(event)
|
||||
self._event_timer.maybe_replace_task(event)
|
||||
|
||||
@commands.Cog.listener()
|
||||
async def on_timer_end(self, event_id: int):
|
||||
async def on_event_reminder(self, event_id: int):
|
||||
"""Custom listener called when an event timer ends"""
|
||||
async with self.client.postgres_session as session:
|
||||
event = await events.get_event_by_id(session, event_id)
|
||||
|
@ -89,7 +94,7 @@ class Discord(commands.Cog):
|
|||
await events.delete_event_by_id(session, event.event_id)
|
||||
|
||||
# Set the next timer
|
||||
self.client.loop.create_task(self.timer.update())
|
||||
await self._event_timer.update()
|
||||
|
||||
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
|
||||
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):
|
||||
|
|
|
@ -1,69 +1,128 @@
|
|||
import abc
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
from typing import Generic, Optional, TypeVar
|
||||
|
||||
import discord.utils
|
||||
from overrides import overrides
|
||||
|
||||
import settings
|
||||
from database.crud.events import get_next_event
|
||||
from database.schemas import Event
|
||||
from database.crud.jail import get_next_jail_release
|
||||
from database.schemas import Event, Jail
|
||||
from didier import Didier
|
||||
from didier.utils.types.datetime import tz_aware_now
|
||||
|
||||
__all__ = ["Timer"]
|
||||
__all__ = ["JailTimer", "EventTimer"]
|
||||
|
||||
REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE)
|
||||
|
||||
|
||||
# TODO make this generic
|
||||
# TODO add timer for jail freeing
|
||||
class Timer:
|
||||
"""Class for scheduled timers"""
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ABCTimer(abc.ABC, Generic[T]):
|
||||
"""Base class for scheduled timers"""
|
||||
|
||||
client: Didier
|
||||
upcoming_timer: Optional[datetime]
|
||||
upcoming_event_id: Optional[int]
|
||||
_task: Optional[asyncio.Task]
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
_delta: Optional[timedelta]
|
||||
_event: str
|
||||
|
||||
def __init__(self, client: Didier, *, event: str, delta: Optional[timedelta] = None):
|
||||
self.client = client
|
||||
|
||||
self.upcoming_timer = None
|
||||
self.upcoming_event_id = None
|
||||
self._task = None
|
||||
|
||||
self.client.loop.create_task(self.update())
|
||||
self._delta = delta
|
||||
self._event = event
|
||||
|
||||
@abc.abstractmethod
|
||||
async def dissect_item(self, item: T) -> tuple[datetime, int]:
|
||||
"""Method that takes an item and returns the corresponding timestamp and id"""
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get_next(self) -> Optional[T]:
|
||||
"""Method that fetches the next item from the database"""
|
||||
|
||||
async def update(self):
|
||||
"""Get & schedule the closest reminder"""
|
||||
async with self.client.postgres_session as session:
|
||||
event = await get_next_event(session, now=tz_aware_now())
|
||||
"""Get & schedule the closest item"""
|
||||
next_item = await self.get_next()
|
||||
|
||||
# No upcoming events
|
||||
if event is None:
|
||||
# No upcoming items
|
||||
if next_item is None:
|
||||
return
|
||||
|
||||
self.maybe_replace_task(event)
|
||||
self.maybe_replace_task(next_item)
|
||||
|
||||
def maybe_replace_task(self, event: Event):
|
||||
def cancel(self):
|
||||
"""Cancel the running task"""
|
||||
if self._task is not None:
|
||||
self._task.cancel()
|
||||
self._task = None
|
||||
|
||||
def maybe_replace_task(self, item: T):
|
||||
"""Replace the current task if necessary"""
|
||||
timestamp, item_id = self.dissect_item(item)
|
||||
|
||||
# If there is a current (pending) task, and the new timer is sooner than the
|
||||
# pending one, cancel it
|
||||
if self._task is not None and not self._task.done():
|
||||
# The upcoming timer will never be None at this point, but Mypy is mad
|
||||
if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp:
|
||||
if self.upcoming_timer is not None and self.upcoming_timer > timestamp:
|
||||
self._task.cancel()
|
||||
else:
|
||||
# The new task happens after the existing task, it has to wait for its turn
|
||||
return
|
||||
|
||||
self.upcoming_timer = event.timestamp
|
||||
self.upcoming_event_id = event.event_id
|
||||
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id))
|
||||
self.upcoming_timer = timestamp
|
||||
self.upcoming_event_id = item_id
|
||||
self._task = self.client.loop.create_task(self.end_timer(endtime=timestamp, event_id=item_id))
|
||||
|
||||
async def end_timer(self, *, endtime: datetime, event_id: int):
|
||||
"""Wait until a timer runs out, and then trigger an event to send the message"""
|
||||
await discord.utils.sleep_until(endtime - REMINDER_PREDELAY)
|
||||
until = endtime
|
||||
if self._delta is not None:
|
||||
until -= self._delta
|
||||
|
||||
await discord.utils.sleep_until(until)
|
||||
self.upcoming_timer = None
|
||||
self.upcoming_event_id = None
|
||||
self.client.dispatch("timer_end", event_id)
|
||||
self.client.dispatch(self._event, event_id)
|
||||
|
||||
|
||||
class EventTimer(ABCTimer[Event]):
|
||||
"""Timer for upcoming IRL events"""
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
super().__init__(client, event="event_reminder", delta=REMINDER_PREDELAY)
|
||||
|
||||
@overrides
|
||||
async def dissect_item(self, item: Event) -> tuple[datetime, int]:
|
||||
return item.timestamp, item.event_id
|
||||
|
||||
@overrides
|
||||
async def get_next(self) -> Optional[Event]:
|
||||
async with self.client.postgres_session as session:
|
||||
return await get_next_event(session, now=tz_aware_now())
|
||||
|
||||
|
||||
class JailTimer(ABCTimer[Jail]):
|
||||
"""Timer for people spending time in Didier Jail"""
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
super().__init__(client, event="jail_release")
|
||||
|
||||
@overrides
|
||||
async def dissect_item(self, item: Jail) -> tuple[datetime, int]:
|
||||
return item.until, item.jail_entry_id
|
||||
|
||||
@overrides
|
||||
async def get_next(self) -> Optional[Jail]:
|
||||
async with self.client.postgres_session as session:
|
||||
return await get_next_jail_release(session)
|
||||
|
|
Loading…
Reference in New Issue