diff --git a/.flake8 b/.flake8 index 259dc4e..cab8ba8 100644 --- a/.flake8 +++ b/.flake8 @@ -26,7 +26,7 @@ extend-ignore = E203, # Don't require docstrings when overriding a method, # the base method should have a docstring but the rest not -ignore-decorator=overrides +ignore-decorators=overrides max-line-length = 120 # Disable some rules for entire files per-file-ignores = diff --git a/alembic/versions/581ae6511b98_add_dad_jokes.py b/alembic/versions/581ae6511b98_add_dad_jokes.py new file mode 100644 index 0000000..b3bed89 --- /dev/null +++ b/alembic/versions/581ae6511b98_add_dad_jokes.py @@ -0,0 +1,33 @@ +"""Add dad jokes + +Revision ID: 581ae6511b98 +Revises: 632b69cdadde +Create Date: 2022-07-15 23:37:08.147611 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "581ae6511b98" +down_revision = "632b69cdadde" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "dad_jokes", + sa.Column("dad_joke_id", sa.Integer(), nullable=False), + sa.Column("joke", sa.Text(), nullable=False), + sa.PrimaryKeyConstraint("dad_joke_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("dad_jokes") + # ### end Alembic commands ### diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index c6377c6..85ecf56 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -33,6 +33,7 @@ async def create_command(session: AsyncSession, name: str, response: str) -> Cus command = CustomCommand(name=name, indexed_name=clean_name(name), response=response) session.add(command) await session.commit() + return command diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py new file mode 100644 index 0000000..30aa010 --- /dev/null +++ b/database/crud/dad_jokes.py @@ -0,0 +1,42 @@ +from typing import Optional + +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.exceptions.not_found import NoResultFoundException +from database.models import DadJoke + +__all__ = ["add_dad_joke", "edit_dad_joke", "get_random_dad_joke"] + + +async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke: + """Add a new dad joke to the database""" + dad_joke = DadJoke(joke=joke) + session.add(dad_joke) + await session.commit() + + return dad_joke + + +async def edit_dad_joke(session: AsyncSession, joke_id: int, new_joke: str) -> DadJoke: + """Edit an existing dad joke""" + statement = select(DadJoke).where(DadJoke.dad_joke_id == joke_id) + dad_joke: Optional[DadJoke] = (await session.execute(statement)).scalar_one_or_none() + if dad_joke is None: + raise NoResultFoundException + + dad_joke.joke = new_joke + session.add(dad_joke) + await session.commit() + + return dad_joke + + +async def get_random_dad_joke(session: AsyncSession) -> DadJoke: + """Return a random database entry""" + statement = select(DadJoke).order_by(func.random()) + row = (await session.execute(statement)).first() + if row is None: + raise NoResultFoundException + + return row[0] diff --git a/database/models.py b/database/models.py index 71d4f3c..d663204 100644 --- a/database/models.py +++ b/database/models.py @@ -14,6 +14,7 @@ __all__ = [ "Bank", "CustomCommand", "CustomCommandAlias", + "DadJoke", "NightlyData", "UforaAnnouncement", "UforaCourse", @@ -73,6 +74,15 @@ class CustomCommandAlias(Base): command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") +class DadJoke(Base): + """When I finally understood asymptotic notation, it was a big "oh" moment""" + + __tablename__ = "dad_jokes" + + dad_joke_id: int = Column(Integer, primary_key=True) + joke: str = Column(Text, nullable=False) + + class NightlyData(Base): """Data for a user's Nightly stats""" diff --git a/database/utils/caches.py b/database/utils/caches.py index 5e5be4f..edc3a5e 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_courses @@ -47,19 +48,47 @@ class DatabaseCache(ABC): class UforaCourseCache(DatabaseCache): """Cache to store the names of Ufora courses""" + # Also store the aliases to add additional support + aliases: dict[str, str] = {} + + @overrides + def clear(self): + self.aliases.clear() + super().clear() + + @overrides async def refresh(self, database_session: AsyncSession): self.clear() courses = await ufora_courses.get_all_courses(database_session) - # Load the course names + all the aliases + self.data = list(map(lambda c: c.name, courses)) + + # Load the aliases for course in courses: - aliases = list(map(lambda x: x.alias, course.aliases)) - self.data.extend([course.name, *aliases]) + for alias in course.aliases: + # Store aliases in lowercase + self.aliases[alias.alias.lower()] = course.name self.data.sort() self.data_transformed = list(map(str.lower, self.data)) + @overrides + def get_autocomplete_suggestions(self, query: str): + query = query.lower() + results = set() + + # Return the original (not-lowercase) version + for index, course in enumerate(self.data_transformed): + if query in course: + results.add(self.data[index]) + + for alias, course in self.aliases.items(): + if query in alias: + results.add(course) + + return sorted(list(results)) + class CacheManager: """Class that keeps track of all caches""" diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py new file mode 100644 index 0000000..ddc119b --- /dev/null +++ b/didier/cogs/fun.py @@ -0,0 +1,29 @@ +from discord.ext import commands + +from database.crud.dad_jokes import get_random_dad_joke +from didier import Didier + + +class Fun(commands.Cog): + """Cog with lots of random fun stuff""" + + client: Didier + + def __init__(self, client: Didier): + self.client = client + + @commands.hybrid_command( + name="dadjoke", + aliases=["Dad", "Dj"], + description="Why does Yoda's code always crash? Because there is no try.", + ) + async def dad_joke(self, ctx: commands.Context): + """Get a random dad joke""" + async with self.client.db_session as session: + joke = await get_random_dad_joke(session) + return await ctx.reply(joke.joke, mention_author=False) + + +async def setup(client: Didier): + """Load the cog""" + await client.add_cog(Fun(client)) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index ca72cd6..f43df42 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -9,7 +9,7 @@ from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException from didier import Didier from didier.data.flags.owner import EditCustomFlags -from didier.data.modals.custom_commands import CreateCustomCommand, EditCustomCommand +from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand class Owner(commands.Cog): @@ -29,7 +29,6 @@ class Owner(commands.Cog): This means that we don't have to add is_owner() to every single command separately """ - # pylint: disable=W0236 # Pylint thinks this can't be async, but it can return await self.client.is_owner(ctx.author) @commands.command(name="Error") @@ -48,7 +47,7 @@ class Owner(commands.Cog): await ctx.message.add_reaction("🔄") - @commands.group(name="Add", case_insensitive=True, invoke_without_command=False) + @commands.group(name="Add", aliases=["Create"], case_insensitive=True, invoke_without_command=False) async def add_msg(self, ctx: commands.Context): """Command group for [add X] message commands""" @@ -88,6 +87,17 @@ class Owner(commands.Cog): modal = CreateCustomCommand(self.client) await interaction.response.send_modal(modal) + @add_slash.command(name="dadjoke", description="Add a dad joke") + async def add_dad_joke_slash(self, interaction: discord.Interaction): + """Slash command to add a dad joke""" + if not await self.client.is_owner(interaction.user): + return interaction.response.send_message( + "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True + ) + + modal = AddDadJoke(self.client) + await interaction.response.send_modal(modal) + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 9f9cadf..716c59d 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -4,7 +4,9 @@ import discord from discord import app_commands from discord.ext import commands +from database.crud import ufora_courses from didier import Didier +from didier.data import constants class School(commands.Cog): @@ -57,6 +59,31 @@ class School(commands.Cog): await message.add_reaction("📌") return await interaction.response.send_message("📌", ephemeral=True) + @commands.hybrid_command( + name="fiche", description="Stuurt de link naar de studiefiche voor [Vak]", aliases=["guide", "studiefiche"] + ) + @app_commands.describe(course="vak") + async def study_guide(self, ctx: commands.Context, course: str): + """Create links to study guides""" + async with self.client.db_session as session: + ufora_course = await ufora_courses.get_course_by_name(session, course) + + if ufora_course is None: + return await ctx.reply(f"Geen vak gevonden voor ``{course}``", ephemeral=True) + + return await ctx.reply( + f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{constants.CURRENT_YEAR}", + mention_author=False, + ) + + @study_guide.autocomplete("course") + async def study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: + """Autocompletion for the 'course'-parameter""" + return [ + app_commands.Choice(name=course, value=course) + for course in self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) + ] + async def setup(client: Didier): """Load the cog""" diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 285945c..b8488ae 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -14,7 +14,6 @@ class Tasks(commands.Cog): client: Didier def __init__(self, client: Didier): - # pylint: disable=no-member self.client = client # Only pull announcements if a token was provided diff --git a/didier/data/constants.py b/didier/data/constants.py index d5c1021..cea951c 100644 --- a/didier/data/constants.py +++ b/didier/data/constants.py @@ -1 +1,10 @@ +# The year in which we were in 1Ba +import settings + +FIRST_YEAR = 2019 +# Year to use when adding the current year of our education +# to find the academic year +OFFSET_FIRST_YEAR = FIRST_YEAR - 1 +# The current academic year +CURRENT_YEAR = OFFSET_FIRST_YEAR + settings.YEAR PREFIXES = ["didier", "big d"] diff --git a/didier/exceptions/config.py b/didier/exceptions/config.py deleted file mode 100644 index df73f7e..0000000 --- a/didier/exceptions/config.py +++ /dev/null @@ -1,12 +0,0 @@ -__all__ = ["MissingEnvironmentVariable"] - - -class MissingEnvironmentVariable(RuntimeError): - """Exception raised when an environment variable is missing - - These are not necessarily checked on startup, because they may be unused - during a given test run, and random unrelated crashes would be annoying - """ - - def __init__(self, variable: str): - super().__init__(f"Missing environment variable: {variable}") diff --git a/didier/data/modals/__init__.py b/didier/views/__init__.py similarity index 100% rename from didier/data/modals/__init__.py rename to didier/views/__init__.py diff --git a/didier/views/modals/__init__.py b/didier/views/modals/__init__.py new file mode 100644 index 0000000..b28a4de --- /dev/null +++ b/didier/views/modals/__init__.py @@ -0,0 +1,4 @@ +from .custom_commands import CreateCustomCommand, EditCustomCommand +from .dad_jokes import AddDadJoke + +__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand"] diff --git a/didier/data/modals/custom_commands.py b/didier/views/modals/custom_commands.py similarity index 100% rename from didier/data/modals/custom_commands.py rename to didier/views/modals/custom_commands.py diff --git a/didier/views/modals/dad_jokes.py b/didier/views/modals/dad_jokes.py new file mode 100644 index 0000000..9632197 --- /dev/null +++ b/didier/views/modals/dad_jokes.py @@ -0,0 +1,37 @@ +import traceback + +import discord +from overrides import overrides + +from database.crud.dad_jokes import add_dad_joke +from didier import Didier + +__all__ = ["AddDadJoke"] + + +class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"): + """Modal to add a new dad joke""" + + name: discord.ui.TextInput = discord.ui.TextInput( + label="Joke", + placeholder="I sold our vacuum cleaner, it was just gathering dust.", + style=discord.TextStyle.long, + ) + + client: Didier + + def __init__(self, client: Didier, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = client + + @overrides + async def on_submit(self, interaction: discord.Interaction): + async with self.client.db_session as session: + joke = await add_dad_joke(session, str(self.name.value)) + + await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) + + @overrides + async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore + await interaction.response.send_message("Something went wrong.", ephemeral=True) + traceback.print_tb(error.__traceback__) diff --git a/pyproject.toml b/pyproject.toml index 2a28e94..75d3054 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,8 @@ omit = [ "./didier/cogs/*", "./didier/didier.py", "./didier/data/*", - "./didier/utils/discord/colours.py" + "./didier/utils/discord/colours.py", + "./didier/utils/discord/constants.py" ] [tool.isort] @@ -38,5 +39,5 @@ env = [ "DB_PASSWORD = pytest", "DB_HOST = localhost", "DB_PORT = 5433", - "DISC_TOKEN = token" + "DISCORD_TOKEN = token" ] diff --git a/settings.py b/settings.py index bc35ab4..bb2296f 100644 --- a/settings.py +++ b/settings.py @@ -29,6 +29,8 @@ __all__ = [ """General config""" SANDBOX: bool = env.bool("SANDBOX", True) LOGFILE: str = env.str("LOGFILE", "didier.log") +SEMESTER: int = env.int("SEMESTER", 2) +YEAR: int = env.int("YEAR", 3) """Database""" DB_NAME: str = env.str("DB_NAME", "didier") diff --git a/tests/conftest.py b/tests/conftest.py index c8ab65f..b2a1e04 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,5 +1,4 @@ import asyncio -import datetime from typing import AsyncGenerator, Generator from unittest.mock import MagicMock @@ -7,11 +6,9 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.engine import engine -from database.models import Base, UforaAnnouncement, UforaCourse, UforaCourseAlias +from database.models import Base from didier import Didier -"""General fixtures""" - @pytest.fixture(scope="session") def event_loop() -> Generator: @@ -57,34 +54,3 @@ def mock_client() -> Didier: mock_client.user = mock_user return mock_client - - -"""Fixtures to put fake data in the database""" - - -@pytest.fixture -async def ufora_course(database_session: AsyncSession) -> UforaCourse: - """Fixture to create a course""" - course = UforaCourse(name="test", code="code", year=1, log_announcements=True) - database_session.add(course) - await database_session.commit() - return course - - -@pytest.fixture -async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: - """Fixture to create a course with an alias""" - alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") - database_session.add(alias) - await database_session.commit() - await database_session.refresh(ufora_course) - return ufora_course - - -@pytest.fixture -async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: - """Fixture to create an announcement""" - announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) - database_session.add(announcement) - await database_session.commit() - return announcement diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py new file mode 100644 index 0000000..de1e939 --- /dev/null +++ b/tests/test_database/conftest.py @@ -0,0 +1,34 @@ +import datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias + + +@pytest.fixture +async def ufora_course(database_session: AsyncSession) -> UforaCourse: + """Fixture to create a course""" + course = UforaCourse(name="test", code="code", year=1, log_announcements=True) + database_session.add(course) + await database_session.commit() + return course + + +@pytest.fixture +async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: + """Fixture to create a course with an alias""" + alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") + database_session.add(alias) + await database_session.commit() + await database_session.refresh(ufora_course) + return ufora_course + + +@pytest.fixture +async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: + """Fixture to create an announcement""" + announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) + database_session.add(announcement) + await database_session.commit() + return announcement