mirror of https://github.com/stijndcl/didier
Jail timer
parent
a051423203
commit
3509073d7f
|
@ -34,14 +34,14 @@ def upgrade() -> None:
|
||||||
)
|
)
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"jail",
|
"jail",
|
||||||
sa.Column("jail_entry_i", sa.Integer(), nullable=False),
|
sa.Column("jail_entry_id", sa.Integer(), nullable=False),
|
||||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||||
sa.Column("until", sa.DateTime(timezone=True), nullable=False),
|
sa.Column("until", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.ForeignKeyConstraint(
|
sa.ForeignKeyConstraint(
|
||||||
["user_id"],
|
["user_id"],
|
||||||
["users.user_id"],
|
["users.user_id"],
|
||||||
),
|
),
|
||||||
sa.PrimaryKeyConstraint("jail_entry_i"),
|
sa.PrimaryKeyConstraint("jail_entry_id"),
|
||||||
)
|
)
|
||||||
op.create_table(
|
op.create_table(
|
||||||
"savings",
|
"savings",
|
||||||
|
|
|
@ -1,12 +1,19 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas import Jail
|
from database.schemas import Jail
|
||||||
|
|
||||||
__all__ = ["get_jail", "get_user_jail", "imprison"]
|
__all__ = [
|
||||||
|
"get_jail",
|
||||||
|
"get_jail_entry_by_id",
|
||||||
|
"get_next_jail_release",
|
||||||
|
"get_user_jail",
|
||||||
|
"imprison",
|
||||||
|
"delete_prisoner_by_id",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def get_jail(session: AsyncSession) -> list[Jail]:
|
async def get_jail(session: AsyncSession) -> list[Jail]:
|
||||||
|
@ -15,14 +22,36 @@ async def get_jail(session: AsyncSession) -> list[Jail]:
|
||||||
return list((await session.execute(statement)).scalars().all())
|
return list((await session.execute(statement)).scalars().all())
|
||||||
|
|
||||||
|
|
||||||
|
async def get_jail_entry_by_id(session: AsyncSession, jail_id: int) -> Optional[Jail]:
|
||||||
|
"""Get a jail entry by its id"""
|
||||||
|
statement = select(Jail).where(Jail.jail_entry_id == jail_id)
|
||||||
|
return (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_next_jail_release(session: AsyncSession) -> Optional[Jail]:
|
||||||
|
"""Get the next person being released from jail"""
|
||||||
|
statement = select(Jail).order_by(Jail.until)
|
||||||
|
return (await session.execute(statement)).scalars().first()
|
||||||
|
|
||||||
|
|
||||||
async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]:
|
async def get_user_jail(session: AsyncSession, user_id: int) -> Optional[Jail]:
|
||||||
"""Check how long a given user is still in jail for"""
|
"""Check how long a given user is still in jail for"""
|
||||||
statement = select(Jail).where(Jail.user_id == user_id)
|
statement = select(Jail).where(Jail.user_id == user_id)
|
||||||
return (await session.execute(statement)).scalar_one_or_none()
|
return (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
|
|
||||||
async def imprison(session: AsyncSession, user_id: int, until: datetime):
|
async def imprison(session: AsyncSession, user_id: int, until: datetime) -> Jail:
|
||||||
"""Put a user in Didier Jail"""
|
"""Put a user in Didier Jail"""
|
||||||
jail = Jail(user_id=user_id, until=until)
|
jail = Jail(user_id=user_id, until=until)
|
||||||
session.add(jail)
|
session.add(jail)
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
await session.refresh(jail)
|
||||||
|
|
||||||
|
return jail
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_prisoner_by_id(session: AsyncSession, jail_id: int):
|
||||||
|
"""Release a user from jail using their jail entry id"""
|
||||||
|
statement = delete(Jail).where(Jail.jail_entry_id == jail_id)
|
||||||
|
await session.execute(statement)
|
||||||
|
await session.commit()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
__all__ = ["DoubleNightly", "NotEnoughDinks"]
|
__all__ = ["DoubleNightly", "NotEnoughDinks", "SavingsCapExceeded"]
|
||||||
|
|
||||||
|
|
||||||
class DoubleNightly(Exception):
|
class DoubleNightly(Exception):
|
||||||
|
|
|
@ -235,7 +235,7 @@ class Jail(Base):
|
||||||
|
|
||||||
__tablename__ = "jail"
|
__tablename__ = "jail"
|
||||||
|
|
||||||
jail_entry_i: Mapped[int] = mapped_column(primary_key=True)
|
jail_entry_id: Mapped[int] = mapped_column(primary_key=True)
|
||||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||||
until: Mapped[datetime] = mapped_column(nullable=False)
|
until: Mapped[datetime] = mapped_column(nullable=False)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,12 @@ from discord.ext import commands
|
||||||
import settings
|
import settings
|
||||||
from database.crud import currency as crud
|
from database.crud import currency as crud
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.crud.jail import get_user_jail, imprison
|
from database.crud.jail import (
|
||||||
|
delete_prisoner_by_id,
|
||||||
|
get_jail_entry_by_id,
|
||||||
|
get_user_jail,
|
||||||
|
imprison,
|
||||||
|
)
|
||||||
from database.exceptions.currency import (
|
from database.exceptions.currency import (
|
||||||
DoubleNightly,
|
DoubleNightly,
|
||||||
NotEnoughDinks,
|
NotEnoughDinks,
|
||||||
|
@ -32,6 +37,7 @@ from didier import Didier
|
||||||
from didier.utils.discord import colours
|
from didier.utils.discord import colours
|
||||||
from didier.utils.discord.checks import is_owner
|
from didier.utils.discord.checks import is_owner
|
||||||
from didier.utils.discord.converters import abbreviated_number
|
from didier.utils.discord.converters import abbreviated_number
|
||||||
|
from didier.utils.timer import JailTimer
|
||||||
from didier.utils.types.datetime import tz_aware_now
|
from didier.utils.types.datetime import tz_aware_now
|
||||||
from didier.utils.types.string import pluralize
|
from didier.utils.types.string import pluralize
|
||||||
|
|
||||||
|
@ -40,11 +46,14 @@ class Currency(commands.Cog):
|
||||||
"""Everything Dinks-related."""
|
"""Everything Dinks-related."""
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
|
|
||||||
|
_jail_timer: JailTimer
|
||||||
_rob_lock: asyncio.Lock
|
_rob_lock: asyncio.Lock
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.client = client
|
self.client = client
|
||||||
|
self._jail_timer = JailTimer(client)
|
||||||
self._rob_lock = asyncio.Lock()
|
self._rob_lock = asyncio.Lock()
|
||||||
|
|
||||||
@commands.command(name="award") # type: ignore[arg-type]
|
@commands.command(name="award") # type: ignore[arg-type]
|
||||||
|
@ -276,7 +285,8 @@ class Currency(commands.Cog):
|
||||||
if to_jail:
|
if to_jail:
|
||||||
jail_t = jail_time(robber.rob_level) * punishment_factor
|
jail_t = jail_time(robber.rob_level) * punishment_factor
|
||||||
until = tz_aware_now() + timedelta(hours=jail_t)
|
until = tz_aware_now() + timedelta(hours=jail_t)
|
||||||
await imprison(session, ctx.author.id, until)
|
jail = await imprison(session, ctx.author.id, until)
|
||||||
|
self._jail_timer.maybe_replace_task(jail)
|
||||||
|
|
||||||
return await ctx.reply(
|
return await ctx.reply(
|
||||||
f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, "
|
f"Robbery attempt failed! You've lost {lost_dinks} Didier Dinks, "
|
||||||
|
@ -307,6 +317,19 @@ class Currency(commands.Cog):
|
||||||
|
|
||||||
return await ctx.reply(embed=embed, mention_author=False)
|
return await ctx.reply(embed=embed, mention_author=False)
|
||||||
|
|
||||||
|
@commands.Cog.listener()
|
||||||
|
async def on_jail_release(self, jail_id: int):
|
||||||
|
"""Custom listener called when a jail timer ends"""
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
entry = await get_jail_entry_by_id(session, jail_id)
|
||||||
|
|
||||||
|
if entry is None:
|
||||||
|
return await self.client.log_error(f"Unable to find jail entry with id {jail_id}.", log_to_discord=True)
|
||||||
|
|
||||||
|
await delete_prisoner_by_id(session, jail_id)
|
||||||
|
|
||||||
|
await self._jail_timer.update()
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
"""Load the cog"""
|
"""Load the cog"""
|
||||||
|
|
|
@ -19,7 +19,7 @@ from didier.utils.discord import colours
|
||||||
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
||||||
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
|
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
|
||||||
from didier.utils.discord.constants import Limits
|
from didier.utils.discord.constants import Limits
|
||||||
from didier.utils.timer import Timer
|
from didier.utils.timer import EventTimer
|
||||||
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
|
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
|
||||||
from didier.utils.types.string import abbreviate, leading
|
from didier.utils.types.string import abbreviate, leading
|
||||||
from didier.views.modals import CreateBookmark
|
from didier.views.modals import CreateBookmark
|
||||||
|
@ -29,7 +29,7 @@ class Discord(commands.Cog):
|
||||||
"""Commands related to Discord itself, which work with resources like servers and members."""
|
"""Commands related to Discord itself, which work with resources like servers and members."""
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
timer: Timer
|
_event_timer: EventTimer
|
||||||
|
|
||||||
# Context-menu references
|
# Context-menu references
|
||||||
_bookmark_ctx_menu: app_commands.ContextMenu
|
_bookmark_ctx_menu: app_commands.ContextMenu
|
||||||
|
@ -42,20 +42,25 @@ class Discord(commands.Cog):
|
||||||
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
|
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self._pin_ctx)
|
||||||
self.client.tree.add_command(self._bookmark_ctx_menu)
|
self.client.tree.add_command(self._bookmark_ctx_menu)
|
||||||
self.client.tree.add_command(self._pin_ctx_menu)
|
self.client.tree.add_command(self._pin_ctx_menu)
|
||||||
self.timer = Timer(self.client)
|
self._event_timer = EventTimer(self.client)
|
||||||
|
|
||||||
|
async def cog_load(self) -> None:
|
||||||
|
"""Start any stored timers when the cog is loaded"""
|
||||||
|
await self._event_timer.update()
|
||||||
|
|
||||||
async def cog_unload(self) -> None:
|
async def cog_unload(self) -> None:
|
||||||
"""Remove the commands when the cog is unloaded"""
|
"""Remove the commands and timers when the cog is unloaded"""
|
||||||
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
|
self.client.tree.remove_command(self._bookmark_ctx_menu.name, type=self._bookmark_ctx_menu.type)
|
||||||
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
||||||
|
self._event_timer.cancel()
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_event_create(self, event: Event):
|
async def on_event_create(self, event: Event):
|
||||||
"""Custom listener called when an event is created"""
|
"""Custom listener called when an event is created"""
|
||||||
self.timer.maybe_replace_task(event)
|
self._event_timer.maybe_replace_task(event)
|
||||||
|
|
||||||
@commands.Cog.listener()
|
@commands.Cog.listener()
|
||||||
async def on_timer_end(self, event_id: int):
|
async def on_event_reminder(self, event_id: int):
|
||||||
"""Custom listener called when an event timer ends"""
|
"""Custom listener called when an event timer ends"""
|
||||||
async with self.client.postgres_session as session:
|
async with self.client.postgres_session as session:
|
||||||
event = await events.get_event_by_id(session, event_id)
|
event = await events.get_event_by_id(session, event_id)
|
||||||
|
@ -89,7 +94,7 @@ class Discord(commands.Cog):
|
||||||
await events.delete_event_by_id(session, event.event_id)
|
await events.delete_event_by_id(session, event.event_id)
|
||||||
|
|
||||||
# Set the next timer
|
# Set the next timer
|
||||||
self.client.loop.create_task(self.timer.update())
|
await self._event_timer.update()
|
||||||
|
|
||||||
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
|
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
|
||||||
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):
|
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):
|
||||||
|
|
|
@ -1,69 +1,128 @@
|
||||||
|
import abc
|
||||||
import asyncio
|
import asyncio
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Optional
|
from typing import Generic, Optional, TypeVar
|
||||||
|
|
||||||
import discord.utils
|
import discord.utils
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.crud.events import get_next_event
|
from database.crud.events import get_next_event
|
||||||
from database.schemas import Event
|
from database.crud.jail import get_next_jail_release
|
||||||
|
from database.schemas import Event, Jail
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.utils.types.datetime import tz_aware_now
|
from didier.utils.types.datetime import tz_aware_now
|
||||||
|
|
||||||
__all__ = ["Timer"]
|
__all__ = ["JailTimer", "EventTimer"]
|
||||||
|
|
||||||
REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE)
|
REMINDER_PREDELAY = timedelta(minutes=settings.REMINDER_PRE)
|
||||||
|
|
||||||
|
|
||||||
# TODO make this generic
|
T = TypeVar("T")
|
||||||
# TODO add timer for jail freeing
|
|
||||||
class Timer:
|
|
||||||
"""Class for scheduled timers"""
|
class ABCTimer(abc.ABC, Generic[T]):
|
||||||
|
"""Base class for scheduled timers"""
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
upcoming_timer: Optional[datetime]
|
upcoming_timer: Optional[datetime]
|
||||||
upcoming_event_id: Optional[int]
|
upcoming_event_id: Optional[int]
|
||||||
_task: Optional[asyncio.Task]
|
_task: Optional[asyncio.Task]
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
_delta: Optional[timedelta]
|
||||||
|
_event: str
|
||||||
|
|
||||||
|
def __init__(self, client: Didier, *, event: str, delta: Optional[timedelta] = None):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
self.upcoming_timer = None
|
self.upcoming_timer = None
|
||||||
self.upcoming_event_id = None
|
self.upcoming_event_id = None
|
||||||
self._task = None
|
self._task = None
|
||||||
|
|
||||||
self.client.loop.create_task(self.update())
|
self._delta = delta
|
||||||
|
self._event = event
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def dissect_item(self, item: T) -> tuple[datetime, int]:
|
||||||
|
"""Method that takes an item and returns the corresponding timestamp and id"""
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
async def get_next(self) -> Optional[T]:
|
||||||
|
"""Method that fetches the next item from the database"""
|
||||||
|
|
||||||
async def update(self):
|
async def update(self):
|
||||||
"""Get & schedule the closest reminder"""
|
"""Get & schedule the closest item"""
|
||||||
async with self.client.postgres_session as session:
|
next_item = await self.get_next()
|
||||||
event = await get_next_event(session, now=tz_aware_now())
|
|
||||||
|
|
||||||
# No upcoming events
|
# No upcoming items
|
||||||
if event is None:
|
if next_item is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
self.maybe_replace_task(event)
|
self.maybe_replace_task(next_item)
|
||||||
|
|
||||||
def maybe_replace_task(self, event: Event):
|
def cancel(self):
|
||||||
|
"""Cancel the running task"""
|
||||||
|
if self._task is not None:
|
||||||
|
self._task.cancel()
|
||||||
|
self._task = None
|
||||||
|
|
||||||
|
def maybe_replace_task(self, item: T):
|
||||||
"""Replace the current task if necessary"""
|
"""Replace the current task if necessary"""
|
||||||
|
timestamp, item_id = self.dissect_item(item)
|
||||||
|
|
||||||
# If there is a current (pending) task, and the new timer is sooner than the
|
# If there is a current (pending) task, and the new timer is sooner than the
|
||||||
# pending one, cancel it
|
# pending one, cancel it
|
||||||
if self._task is not None and not self._task.done():
|
if self._task is not None and not self._task.done():
|
||||||
# The upcoming timer will never be None at this point, but Mypy is mad
|
# The upcoming timer will never be None at this point, but Mypy is mad
|
||||||
if self.upcoming_timer is not None and self.upcoming_timer > event.timestamp:
|
if self.upcoming_timer is not None and self.upcoming_timer > timestamp:
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
else:
|
else:
|
||||||
# The new task happens after the existing task, it has to wait for its turn
|
# The new task happens after the existing task, it has to wait for its turn
|
||||||
return
|
return
|
||||||
|
|
||||||
self.upcoming_timer = event.timestamp
|
self.upcoming_timer = timestamp
|
||||||
self.upcoming_event_id = event.event_id
|
self.upcoming_event_id = item_id
|
||||||
self._task = self.client.loop.create_task(self.end_timer(endtime=event.timestamp, event_id=event.event_id))
|
self._task = self.client.loop.create_task(self.end_timer(endtime=timestamp, event_id=item_id))
|
||||||
|
|
||||||
async def end_timer(self, *, endtime: datetime, event_id: int):
|
async def end_timer(self, *, endtime: datetime, event_id: int):
|
||||||
"""Wait until a timer runs out, and then trigger an event to send the message"""
|
"""Wait until a timer runs out, and then trigger an event to send the message"""
|
||||||
await discord.utils.sleep_until(endtime - REMINDER_PREDELAY)
|
until = endtime
|
||||||
|
if self._delta is not None:
|
||||||
|
until -= self._delta
|
||||||
|
|
||||||
|
await discord.utils.sleep_until(until)
|
||||||
self.upcoming_timer = None
|
self.upcoming_timer = None
|
||||||
self.upcoming_event_id = None
|
self.upcoming_event_id = None
|
||||||
self.client.dispatch("timer_end", event_id)
|
self.client.dispatch(self._event, event_id)
|
||||||
|
|
||||||
|
|
||||||
|
class EventTimer(ABCTimer[Event]):
|
||||||
|
"""Timer for upcoming IRL events"""
|
||||||
|
|
||||||
|
def __init__(self, client: Didier):
|
||||||
|
super().__init__(client, event="event_reminder", delta=REMINDER_PREDELAY)
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def dissect_item(self, item: Event) -> tuple[datetime, int]:
|
||||||
|
return item.timestamp, item.event_id
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def get_next(self) -> Optional[Event]:
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
return await get_next_event(session, now=tz_aware_now())
|
||||||
|
|
||||||
|
|
||||||
|
class JailTimer(ABCTimer[Jail]):
|
||||||
|
"""Timer for people spending time in Didier Jail"""
|
||||||
|
|
||||||
|
def __init__(self, client: Didier):
|
||||||
|
super().__init__(client, event="jail_release")
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def dissect_item(self, item: Jail) -> tuple[datetime, int]:
|
||||||
|
return item.until, item.jail_entry_id
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
async def get_next(self) -> Optional[Jail]:
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
return await get_next_jail_release(session)
|
||||||
|
|
Loading…
Reference in New Issue