From 181118aa1d19b2eb840da2e81ecf18b49887186d Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 20 Sep 2022 00:31:33 +0200 Subject: [PATCH] Easter eggs --- alembic/versions/b84bb10fb8de_easter_eggs.py | 36 +++++++++++++++ database/crud/easter_eggs.py | 12 +++++ database/schemas.py | 13 ++++++ database/utils/caches.py | 26 +++++++++-- didier/didier.py | 6 ++- didier/utils/discord/prefix.py | 23 +++++++--- didier/utils/easter_eggs.py | 48 ++++++++++++++++++++ 7 files changed, 151 insertions(+), 13 deletions(-) create mode 100644 alembic/versions/b84bb10fb8de_easter_eggs.py create mode 100644 database/crud/easter_eggs.py create mode 100644 didier/utils/easter_eggs.py diff --git a/alembic/versions/b84bb10fb8de_easter_eggs.py b/alembic/versions/b84bb10fb8de_easter_eggs.py new file mode 100644 index 0000000..dbadf49 --- /dev/null +++ b/alembic/versions/b84bb10fb8de_easter_eggs.py @@ -0,0 +1,36 @@ +"""Easter eggs + +Revision ID: b84bb10fb8de +Revises: 515dc3f52c6d +Create Date: 2022-09-20 00:23:53.160168 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b84bb10fb8de" +down_revision = "515dc3f52c6d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "easter_eggs", + sa.Column("easter_egg_id", sa.Integer(), nullable=False), + sa.Column("match", sa.Text(), nullable=False), + sa.Column("response", sa.Text(), nullable=False), + sa.Column("exact", sa.Boolean(), server_default="1", nullable=False), + sa.Column("startswith", sa.Boolean(), server_default="1", nullable=False), + sa.PrimaryKeyConstraint("easter_egg_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("easter_eggs") + # ### end Alembic commands ### diff --git a/database/crud/easter_eggs.py b/database/crud/easter_eggs.py new file mode 100644 index 0000000..d4c25d9 --- /dev/null +++ b/database/crud/easter_eggs.py @@ -0,0 +1,12 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas import EasterEgg + +__all__ = ["get_all_easter_eggs"] + + +async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: + """Return a list of all easter eggs""" + statement = select(EasterEgg) + return (await session.execute(statement)).scalars().all() diff --git a/database/schemas.py b/database/schemas.py index e18ffe7..8b952fd 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -31,6 +31,7 @@ __all__ = [ "CustomCommandAlias", "DadJoke", "Deadline", + "EasterEgg", "Link", "MemeTemplate", "NightlyData", @@ -144,6 +145,18 @@ class Deadline(Base): course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin") +class EasterEgg(Base): + """An easter egg response""" + + __tablename__ = "easter_eggs" + + easter_egg_id: int = Column(Integer, primary_key=True) + match: str = Column(Text, nullable=False) + response: str = Column(Text, nullable=False) + exact: bool = Column(Boolean, nullable=False, server_default="1") + startswith: bool = Column(Boolean, nullable=False, server_default="1") + + class Link(Base): """Useful links that go useful places""" diff --git a/database/utils/caches.py b/database/utils/caches.py index 25c9cb5..ac260b3 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -4,11 +4,10 @@ from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import links, memes, ufora_courses, wordle +from database.crud import easter_eggs, links, memes, ufora_courses, wordle +from database.schemas import EasterEgg, WordleWord -__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] - -from database.schemas import WordleWord +__all__ = ["CacheManager", "EasterEggCache", "LinkCache", "UforaCourseCache"] class DatabaseCache(ABC): @@ -46,6 +45,22 @@ class DatabaseCache(ABC): return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] +class EasterEggCache(DatabaseCache): + """Cache to store easter eggs invoked by messages""" + + easter_eggs: list[EasterEgg] = [] + + @overrides + async def clear(self): + self.easter_eggs.clear() + + @overrides + async def invalidate(self, database_session: AsyncSession): + """Invalidate the data stored in this cache""" + await self.clear() + self.easter_eggs = await easter_eggs.get_all_easter_eggs(database_session) + + class LinkCache(DatabaseCache): """Cache to store the names of links""" @@ -131,12 +146,14 @@ class WordleCache(DatabaseCache): class CacheManager: """Class that keeps track of all caches""" + easter_eggs: EasterEggCache links: LinkCache memes: MemeCache ufora_courses: UforaCourseCache wordle_word: WordleCache def __init__(self): + self.easter_eggs = EasterEggCache() self.links = LinkCache() self.memes = MemeCache() self.ufora_courses = UforaCourseCache() @@ -144,6 +161,7 @@ class CacheManager: async def initialize_caches(self, postgres_session: AsyncSession): """Initialize the contents of all caches""" + await self.easter_eggs.invalidate(postgres_session) await self.links.invalidate(postgres_session) await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) diff --git a/didier/didier.py b/didier/didier.py index f7bf388..a72196e 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -17,6 +17,7 @@ from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.schedules import Schedule, parse_schedule from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix +from didier.utils.easter_eggs import detect_easter_egg __all__ = ["Didier"] @@ -213,8 +214,9 @@ class Didier(commands.Bot): await self.process_commands(message) - # TODO easter eggs - # TODO stats + easter_egg = await detect_easter_egg(self, message, self.database_caches.easter_eggs) + if easter_egg is not None: + await message.reply(easter_egg, mention_author=False) async def _try_invoke_custom_command(self, message: discord.Message) -> bool: """Check if the message tries to invoke a custom command diff --git a/didier/utils/discord/prefix.py b/didier/utils/discord/prefix.py index df62ad4..f3fa7c4 100644 --- a/didier/utils/discord/prefix.py +++ b/didier/utils/discord/prefix.py @@ -1,18 +1,19 @@ import re +from typing import Optional from discord import Message from discord.ext import commands from didier.data import constants -__all__ = ["get_prefix"] +__all__ = ["get_prefix", "match_prefix"] -def get_prefix(client: commands.Bot, message: Message) -> str: - """Match a prefix against a message +def match_prefix(client: commands.Bot, message: Message) -> Optional[str]: + """Try to match a prefix against a message, returning None instead of a default value - This is done dynamically to allow variable amounts of whitespace, - and through regexes to allow case-insensitivity among other things. + This is done dynamically through regexes to allow case-insensitivity + and variable amounts of whitespace among other things. """ mention = f"<@!?{client.user.id}>" regex = r"^({})\s*" @@ -26,5 +27,13 @@ def get_prefix(client: commands.Bot, message: Message) -> str: # .group() is inconsistent with whitespace, so that can't be used return message.content[: match.end()] - # Matched nothing - return "didier" + return None + + +def get_prefix(client: commands.Bot, message: Message) -> str: + """Match a prefix against a message, with a fallback + + This is the main prefix function that is used by the bot + """ + # If nothing was matched, return "didier" as a fallback + return match_prefix(client, message) or "didier" diff --git a/didier/utils/easter_eggs.py b/didier/utils/easter_eggs.py new file mode 100644 index 0000000..5ed3a24 --- /dev/null +++ b/didier/utils/easter_eggs.py @@ -0,0 +1,48 @@ +import random +from typing import Optional + +import discord +from discord.ext import commands + +from database.utils.caches import EasterEggCache +from didier.utils.discord.prefix import match_prefix + +__all__ = ["detect_easter_egg"] + + +def _roll_easter_egg(response: str) -> Optional[str]: + """Roll a random chance for an easter egg to be responded with + + The chance for an easter egg to be used is 33% + """ + rolled = random.randint(0, 100) < 33 + return response if rolled else None + + +async def detect_easter_egg(client: commands.Bot, message: discord.Message, cache: EasterEggCache) -> Optional[str]: + """Try to detect an easter egg in a message""" + prefix = match_prefix(client, message) + + content = message.content.strip().lower() + + # Message calls Didier + if prefix is not None: + prefix = prefix.strip().lower() + + # Message is only "Didier" + if content == prefix: + return "Hmm?" + else: + # Message invokes a command: do nothing + return None + + for easter_egg in cache.easter_eggs: + # Exact matches + if easter_egg.exact and easter_egg.match == content: + return _roll_easter_egg(easter_egg.response) + + # Matches that start with a certain term + if easter_egg.startswith and content.startswith(easter_egg.match): + return _roll_easter_egg(easter_egg.response) + + return None