diff --git a/alembic/versions/b84bb10fb8de_easter_eggs.py b/alembic/versions/b84bb10fb8de_easter_eggs.py new file mode 100644 index 0000000..dbadf49 --- /dev/null +++ b/alembic/versions/b84bb10fb8de_easter_eggs.py @@ -0,0 +1,36 @@ +"""Easter eggs + +Revision ID: b84bb10fb8de +Revises: 515dc3f52c6d +Create Date: 2022-09-20 00:23:53.160168 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "b84bb10fb8de" +down_revision = "515dc3f52c6d" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "easter_eggs", + sa.Column("easter_egg_id", sa.Integer(), nullable=False), + sa.Column("match", sa.Text(), nullable=False), + sa.Column("response", sa.Text(), nullable=False), + sa.Column("exact", sa.Boolean(), server_default="1", nullable=False), + sa.Column("startswith", sa.Boolean(), server_default="1", nullable=False), + sa.PrimaryKeyConstraint("easter_egg_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("easter_eggs") + # ### end Alembic commands ### diff --git a/database/crud/easter_eggs.py b/database/crud/easter_eggs.py new file mode 100644 index 0000000..d4c25d9 --- /dev/null +++ b/database/crud/easter_eggs.py @@ -0,0 +1,12 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas import EasterEgg + +__all__ = ["get_all_easter_eggs"] + + +async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: + """Return a list of all easter eggs""" + statement = select(EasterEgg) + return (await session.execute(statement)).scalars().all() diff --git a/database/schemas.py b/database/schemas.py index e18ffe7..8b952fd 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -31,6 +31,7 @@ __all__ = [ "CustomCommandAlias", "DadJoke", "Deadline", + "EasterEgg", "Link", "MemeTemplate", "NightlyData", @@ -144,6 +145,18 @@ class Deadline(Base): course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin") +class EasterEgg(Base): + """An easter egg response""" + + __tablename__ = "easter_eggs" + + easter_egg_id: int = Column(Integer, primary_key=True) + match: str = Column(Text, nullable=False) + response: str = Column(Text, nullable=False) + exact: bool = Column(Boolean, nullable=False, server_default="1") + startswith: bool = Column(Boolean, nullable=False, server_default="1") + + class Link(Base): """Useful links that go useful places""" diff --git a/database/scripts/db01_initial_easter_eggs.py b/database/scripts/db01_initial_easter_eggs.py new file mode 100644 index 0000000..d202994 --- /dev/null +++ b/database/scripts/db01_initial_easter_eggs.py @@ -0,0 +1,69 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from database.engine import DBSession +from database.schemas import EasterEgg + +__all__ = ["main"] + + +async def main(): + """Add the initial easter egg responses""" + session: AsyncSession + async with DBSession() as session: + # https://www.youtube.com/watch?v=Vd6hVYkkq88 + do_not_cite_deep_magic = EasterEgg( + match=r"((don'?t)|(do not)) cite the deep magic to me,? witch", + response="_I was there when it was written_", + exact=True, + ) + + # https://www.youtube.com/watch?v=LrHTR22pIhw + dormammu = EasterEgg(match=r"dormammu", response="_I've come to bargain_", exact=True) + + # https://youtu.be/rEq1Z0bjdwc?t=7 + hello_there = EasterEgg(match=r"hello there", response="_General Kenobi_", exact=True) + + # https://www.youtube.com/watch?v=_WZCvQ5J3pk + hey = EasterEgg( + match=r"hey,? ?(?:you)?", + response="_You're finally awake!_", + exact=True, + ) + + # https://www.youtube.com/watch?v=2z5ZDC1eQEA + is_this_the_kk = EasterEgg( + match=r"is (this|dis) (.*)", response="No, this is Patrick.", exact=False, startswith=True + ) + + # https://youtu.be/d6uckPRKvSg?t=4 + its_over_anakin = EasterEgg( + match=r"it'?s over ", response="_I have the high ground_", exact=False, startswith=True + ) + + # https://www.youtube.com/watch?v=Vx5prDjKAcw + perfectly_balanced = EasterEgg(match=r"perfectly balanced", response="_As all things should be_", exact=True) + + # ( ͡◉ ͜ʖ ͡◉) + sixty_nine = EasterEgg(match=r"(^69$)|(^69 )|( 69 )|( 69$)", response="_Nice_", exact=False, startswith=False) + + # https://youtu.be/7mbLzkNFDs8?t=19 + what_did_it_cost = EasterEgg(match=r"what did it cost\??", response="_Everything_", exact=True) + + # https://youtu.be/EJfYh-JVbJA?t=10 + you_cant_defeat_me = EasterEgg(match=r"you can'?t defeat me", response="_I know, but he can_", exact=False) + + session.add_all( + [ + do_not_cite_deep_magic, + dormammu, + hello_there, + hey, + is_this_the_kk, + its_over_anakin, + perfectly_balanced, + sixty_nine, + what_did_it_cost, + you_cant_defeat_me, + ] + ) + await session.commit() diff --git a/database/utils/caches.py b/database/utils/caches.py index 25c9cb5..ac260b3 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -4,11 +4,10 @@ from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import links, memes, ufora_courses, wordle +from database.crud import easter_eggs, links, memes, ufora_courses, wordle +from database.schemas import EasterEgg, WordleWord -__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] - -from database.schemas import WordleWord +__all__ = ["CacheManager", "EasterEggCache", "LinkCache", "UforaCourseCache"] class DatabaseCache(ABC): @@ -46,6 +45,22 @@ class DatabaseCache(ABC): return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] +class EasterEggCache(DatabaseCache): + """Cache to store easter eggs invoked by messages""" + + easter_eggs: list[EasterEgg] = [] + + @overrides + async def clear(self): + self.easter_eggs.clear() + + @overrides + async def invalidate(self, database_session: AsyncSession): + """Invalidate the data stored in this cache""" + await self.clear() + self.easter_eggs = await easter_eggs.get_all_easter_eggs(database_session) + + class LinkCache(DatabaseCache): """Cache to store the names of links""" @@ -131,12 +146,14 @@ class WordleCache(DatabaseCache): class CacheManager: """Class that keeps track of all caches""" + easter_eggs: EasterEggCache links: LinkCache memes: MemeCache ufora_courses: UforaCourseCache wordle_word: WordleCache def __init__(self): + self.easter_eggs = EasterEggCache() self.links = LinkCache() self.memes = MemeCache() self.ufora_courses = UforaCourseCache() @@ -144,6 +161,7 @@ class CacheManager: async def initialize_caches(self, postgres_session: AsyncSession): """Initialize the contents of all caches""" + await self.easter_eggs.invalidate(postgres_session) await self.links.invalidate(postgres_session) await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) diff --git a/didier/cogs/other.py b/didier/cogs/other.py index 06ec2a4..20c4fd8 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -7,7 +7,7 @@ from discord.ext import commands from database.crud.links import get_link_by_name from database.schemas import Link from didier import Didier -from didier.data.apis import urban_dictionary +from didier.data.apis import inspirobot, urban_dictionary from didier.data.embeds.google import GoogleSearch from didier.data.scrapers import google @@ -48,6 +48,13 @@ class Other(commands.Cog): embed = GoogleSearch(results).to_embed() await ctx.reply(embed=embed, mention_author=False) + @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") + async def inspire(self, ctx: commands.Context): + """Generate an [InspiroBot](https://inspirobot.me/) quote.""" + async with ctx.typing(): + link = await inspirobot.get_inspirobot_quote(self.client.http_session) + await ctx.reply(link, mention_author=False, ephemeral=False) + async def _get_link(self, name: str) -> Optional[Link]: async with self.client.postgres_session as session: return await get_link_by_name(session, name.lower()) diff --git a/didier/data/apis/inspirobot.py b/didier/data/apis/inspirobot.py new file mode 100644 index 0000000..730738c --- /dev/null +++ b/didier/data/apis/inspirobot.py @@ -0,0 +1,16 @@ +from http import HTTPStatus + +from aiohttp import ClientSession + +from didier.exceptions import HTTPException + +__all__ = ["get_inspirobot_quote"] + + +async def get_inspirobot_quote(http_session: ClientSession) -> str: + """Get a new InspiroBot quote""" + async with http_session.get("https://inspirobot.me/api?generate=true") as response: + if response.status != HTTPStatus.OK: + raise HTTPException(response.status) + + return await response.text() diff --git a/didier/didier.py b/didier/didier.py index f7bf388..a72196e 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -17,6 +17,7 @@ from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.schedules import Schedule, parse_schedule from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix +from didier.utils.easter_eggs import detect_easter_egg __all__ = ["Didier"] @@ -213,8 +214,9 @@ class Didier(commands.Bot): await self.process_commands(message) - # TODO easter eggs - # TODO stats + easter_egg = await detect_easter_egg(self, message, self.database_caches.easter_eggs) + if easter_egg is not None: + await message.reply(easter_egg, mention_author=False) async def _try_invoke_custom_command(self, message: discord.Message) -> bool: """Check if the message tries to invoke a custom command diff --git a/didier/utils/discord/prefix.py b/didier/utils/discord/prefix.py index df62ad4..f3fa7c4 100644 --- a/didier/utils/discord/prefix.py +++ b/didier/utils/discord/prefix.py @@ -1,18 +1,19 @@ import re +from typing import Optional from discord import Message from discord.ext import commands from didier.data import constants -__all__ = ["get_prefix"] +__all__ = ["get_prefix", "match_prefix"] -def get_prefix(client: commands.Bot, message: Message) -> str: - """Match a prefix against a message +def match_prefix(client: commands.Bot, message: Message) -> Optional[str]: + """Try to match a prefix against a message, returning None instead of a default value - This is done dynamically to allow variable amounts of whitespace, - and through regexes to allow case-insensitivity among other things. + This is done dynamically through regexes to allow case-insensitivity + and variable amounts of whitespace among other things. """ mention = f"<@!?{client.user.id}>" regex = r"^({})\s*" @@ -26,5 +27,13 @@ def get_prefix(client: commands.Bot, message: Message) -> str: # .group() is inconsistent with whitespace, so that can't be used return message.content[: match.end()] - # Matched nothing - return "didier" + return None + + +def get_prefix(client: commands.Bot, message: Message) -> str: + """Match a prefix against a message, with a fallback + + This is the main prefix function that is used by the bot + """ + # If nothing was matched, return "didier" as a fallback + return match_prefix(client, message) or "didier" diff --git a/didier/utils/easter_eggs.py b/didier/utils/easter_eggs.py new file mode 100644 index 0000000..9720986 --- /dev/null +++ b/didier/utils/easter_eggs.py @@ -0,0 +1,53 @@ +import random +import re +from typing import Optional + +import discord +from discord.ext import commands + +import settings +from database.utils.caches import EasterEggCache +from didier.utils.discord.prefix import match_prefix + +__all__ = ["detect_easter_egg"] + + +def _roll_easter_egg(response: str) -> Optional[str]: + """Roll a random chance for an easter egg to be responded with""" + rolled = random.randint(0, 100) < settings.EASTER_EGG_CHANCE + return response if rolled else None + + +async def detect_easter_egg(client: commands.Bot, message: discord.Message, cache: EasterEggCache) -> Optional[str]: + """Try to detect an easter egg in a message""" + prefix = match_prefix(client, message) + + # Remove markdown and whitespace for better matches + content = message.content.strip().strip("_* \t\n").lower() + + # Message calls Didier + if prefix is not None: + prefix = prefix.strip().lower() + + # Message is only "Didier" + if content == prefix: + return "Hmm?" + else: + # Message invokes a command: do nothing + return None + + for easter_egg in cache.easter_eggs: + pattern = easter_egg.match + + # Use regular expressions to easily allow slight variations + if easter_egg.exact: + pattern = rf"^{pattern}$" + elif easter_egg.startswith: + pattern = rf"^{pattern}" + + matched = re.search(pattern, content) + + if matched is not None: + return _roll_easter_egg(easter_egg.response) + + return None diff --git a/run_db_scripts.py b/run_db_scripts.py index 91f4af2..dbf325f 100644 --- a/run_db_scripts.py +++ b/run_db_scripts.py @@ -4,14 +4,10 @@ This is slightly ugly, but running the scripts directly isn't possible because o This could be cleaned up a bit using importlib but this is safer """ import asyncio +import importlib import sys from typing import Callable -from database.scripts.db00_example import main as debug_add_courses - -script_mapping: dict[str, Callable] = {"debug_add_courses.py": debug_add_courses} - - if __name__ == "__main__": scripts = sys.argv[1:] if not scripts: @@ -19,10 +15,13 @@ if __name__ == "__main__": exit(1) for script in scripts: - script_main = script_mapping.get(script.removeprefix("database/scripts/"), None) - if script_main is None: + script = script.replace("/", ".").removesuffix(".py") + module = importlib.import_module(script) + + try: + script_main: Callable = module.main + asyncio.run(script_main()) + print(f"Successfully ran {script}") + except AttributeError: print(f'Script "{script}" not found.', file=sys.stderr) exit(1) - - asyncio.run(script_main()) - print(f"Successfully ran {script}") diff --git a/settings.py b/settings.py index 79efc78..1dca717 100644 --- a/settings.py +++ b/settings.py @@ -12,6 +12,10 @@ __all__ = [ "SANDBOX", "TESTING", "LOGFILE", + "SEMESTER", + "YEAR", + "MENU_TIMEOUT", + "EASTER_EGG_CHANCE", "POSTGRES_DB", "POSTGRES_USER", "POSTGRES_PASS", @@ -24,12 +28,10 @@ __all__ = [ "DISCORD_BOOS_REACT", "DISCORD_CUSTOM_COMMAND_PREFIX", "UFORA_ANNOUNCEMENTS_CHANNEL", - "BA3_ROLE", "UFORA_RSS_TOKEN", "URBAN_DICTIONARY_TOKEN", "IMGFLIP_NAME", "IMGFLIP_PASSWORD", - "BA3_SCHEDULE_URL", "ScheduleType", "ScheduleInfo", "SCHEDULE_DATA", @@ -43,6 +45,7 @@ LOGFILE: str = env.str("LOGFILE", "didier.log") SEMESTER: int = env.int("SEMESTER", 2) YEAR: int = env.int("YEAR", 3) MENU_TIMEOUT: int = env.int("MENU_TIMEOUT", 30) +EASTER_EGG_CHANCE: int = env.int("EASTER_EGG_CHANCE", 15) """Database""" # PostgreSQL