Compare commits

...

4 Commits

Author SHA1 Message Date
stijndcl 8227190a8d Fix mypy & tests 2022-07-16 00:19:05 +02:00
stijndcl 3debd18d82 Add dad jokes 2022-07-16 00:14:02 +02:00
stijndcl 3d0f771f94 Load year from settings 2022-07-15 23:14:56 +02:00
stijndcl 5b47397f29 Add study guide commands, get auto-completion for full course names based on aliases 2022-07-15 23:06:40 +02:00
20 changed files with 278 additions and 57 deletions

View File

@ -26,7 +26,7 @@ extend-ignore =
E203, E203,
# Don't require docstrings when overriding a method, # Don't require docstrings when overriding a method,
# the base method should have a docstring but the rest not # the base method should have a docstring but the rest not
ignore-decorator=overrides ignore-decorators=overrides
max-line-length = 120 max-line-length = 120
# Disable some rules for entire files # Disable some rules for entire files
per-file-ignores = per-file-ignores =

View File

@ -0,0 +1,33 @@
"""Add dad jokes
Revision ID: 581ae6511b98
Revises: 632b69cdadde
Create Date: 2022-07-15 23:37:08.147611
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "581ae6511b98"
down_revision = "632b69cdadde"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"dad_jokes",
sa.Column("dad_joke_id", sa.Integer(), nullable=False),
sa.Column("joke", sa.Text(), nullable=False),
sa.PrimaryKeyConstraint("dad_joke_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("dad_jokes")
# ### end Alembic commands ###

View File

@ -33,6 +33,7 @@ async def create_command(session: AsyncSession, name: str, response: str) -> Cus
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response) command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
session.add(command) session.add(command)
await session.commit() await session.commit()
return command return command

View File

@ -0,0 +1,42 @@
from typing import Optional
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.not_found import NoResultFoundException
from database.models import DadJoke
__all__ = ["add_dad_joke", "edit_dad_joke", "get_random_dad_joke"]
async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
"""Add a new dad joke to the database"""
dad_joke = DadJoke(joke=joke)
session.add(dad_joke)
await session.commit()
return dad_joke
async def edit_dad_joke(session: AsyncSession, joke_id: int, new_joke: str) -> DadJoke:
"""Edit an existing dad joke"""
statement = select(DadJoke).where(DadJoke.dad_joke_id == joke_id)
dad_joke: Optional[DadJoke] = (await session.execute(statement)).scalar_one_or_none()
if dad_joke is None:
raise NoResultFoundException
dad_joke.joke = new_joke
session.add(dad_joke)
await session.commit()
return dad_joke
async def get_random_dad_joke(session: AsyncSession) -> DadJoke:
"""Return a random database entry"""
statement = select(DadJoke).order_by(func.random())
row = (await session.execute(statement)).first()
if row is None:
raise NoResultFoundException
return row[0]

View File

@ -14,6 +14,7 @@ __all__ = [
"Bank", "Bank",
"CustomCommand", "CustomCommand",
"CustomCommandAlias", "CustomCommandAlias",
"DadJoke",
"NightlyData", "NightlyData",
"UforaAnnouncement", "UforaAnnouncement",
"UforaCourse", "UforaCourse",
@ -73,6 +74,15 @@ class CustomCommandAlias(Base):
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
class DadJoke(Base):
"""When I finally understood asymptotic notation, it was a big "oh" moment"""
__tablename__ = "dad_jokes"
dad_joke_id: int = Column(Integer, primary_key=True)
joke: str = Column(Text, nullable=False)
class NightlyData(Base): class NightlyData(Base):
"""Data for a user's Nightly stats""" """Data for a user's Nightly stats"""

View File

@ -1,5 +1,6 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from overrides import overrides
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_courses from database.crud import ufora_courses
@ -47,19 +48,47 @@ class DatabaseCache(ABC):
class UforaCourseCache(DatabaseCache): class UforaCourseCache(DatabaseCache):
"""Cache to store the names of Ufora courses""" """Cache to store the names of Ufora courses"""
# Also store the aliases to add additional support
aliases: dict[str, str] = {}
@overrides
def clear(self):
self.aliases.clear()
super().clear()
@overrides
async def refresh(self, database_session: AsyncSession): async def refresh(self, database_session: AsyncSession):
self.clear() self.clear()
courses = await ufora_courses.get_all_courses(database_session) courses = await ufora_courses.get_all_courses(database_session)
# Load the course names + all the aliases self.data = list(map(lambda c: c.name, courses))
# Load the aliases
for course in courses: for course in courses:
aliases = list(map(lambda x: x.alias, course.aliases)) for alias in course.aliases:
self.data.extend([course.name, *aliases]) # Store aliases in lowercase
self.aliases[alias.alias.lower()] = course.name
self.data.sort() self.data.sort()
self.data_transformed = list(map(str.lower, self.data)) self.data_transformed = list(map(str.lower, self.data))
@overrides
def get_autocomplete_suggestions(self, query: str):
query = query.lower()
results = set()
# Return the original (not-lowercase) version
for index, course in enumerate(self.data_transformed):
if query in course:
results.add(self.data[index])
for alias, course in self.aliases.items():
if query in alias:
results.add(course)
return sorted(list(results))
class CacheManager: class CacheManager:
"""Class that keeps track of all caches""" """Class that keeps track of all caches"""

29
didier/cogs/fun.py 100644
View File

@ -0,0 +1,29 @@
from discord.ext import commands
from database.crud.dad_jokes import get_random_dad_joke
from didier import Didier
class Fun(commands.Cog):
"""Cog with lots of random fun stuff"""
client: Didier
def __init__(self, client: Didier):
self.client = client
@commands.hybrid_command(
name="dadjoke",
aliases=["Dad", "Dj"],
description="Why does Yoda's code always crash? Because there is no try.",
)
async def dad_joke(self, ctx: commands.Context):
"""Get a random dad joke"""
async with self.client.db_session as session:
joke = await get_random_dad_joke(session)
return await ctx.reply(joke.joke, mention_author=False)
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(Fun(client))

View File

@ -9,7 +9,7 @@ from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from didier import Didier from didier import Didier
from didier.data.flags.owner import EditCustomFlags from didier.data.flags.owner import EditCustomFlags
from didier.data.modals.custom_commands import CreateCustomCommand, EditCustomCommand from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand
class Owner(commands.Cog): class Owner(commands.Cog):
@ -29,7 +29,6 @@ class Owner(commands.Cog):
This means that we don't have to add is_owner() to every single command separately This means that we don't have to add is_owner() to every single command separately
""" """
# pylint: disable=W0236 # Pylint thinks this can't be async, but it can
return await self.client.is_owner(ctx.author) return await self.client.is_owner(ctx.author)
@commands.command(name="Error") @commands.command(name="Error")
@ -48,7 +47,7 @@ class Owner(commands.Cog):
await ctx.message.add_reaction("🔄") await ctx.message.add_reaction("🔄")
@commands.group(name="Add", case_insensitive=True, invoke_without_command=False) @commands.group(name="Add", aliases=["Create"], case_insensitive=True, invoke_without_command=False)
async def add_msg(self, ctx: commands.Context): async def add_msg(self, ctx: commands.Context):
"""Command group for [add X] message commands""" """Command group for [add X] message commands"""
@ -88,6 +87,17 @@ class Owner(commands.Cog):
modal = CreateCustomCommand(self.client) modal = CreateCustomCommand(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@add_slash.command(name="dadjoke", description="Add a dad joke")
async def add_dad_joke_slash(self, interaction: discord.Interaction):
"""Slash command to add a dad joke"""
if not await self.client.is_owner(interaction.user):
return interaction.response.send_message(
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
)
modal = AddDadJoke(self.client)
await interaction.response.send_modal(modal)
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
async def edit_msg(self, ctx: commands.Context): async def edit_msg(self, ctx: commands.Context):
"""Command group for [edit X] commands""" """Command group for [edit X] commands"""

View File

@ -4,7 +4,9 @@ import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.crud import ufora_courses
from didier import Didier from didier import Didier
from didier.data import constants
class School(commands.Cog): class School(commands.Cog):
@ -57,6 +59,31 @@ class School(commands.Cog):
await message.add_reaction("📌") await message.add_reaction("📌")
return await interaction.response.send_message("📌", ephemeral=True) return await interaction.response.send_message("📌", ephemeral=True)
@commands.hybrid_command(
name="fiche", description="Stuurt de link naar de studiefiche voor [Vak]", aliases=["guide", "studiefiche"]
)
@app_commands.describe(course="vak")
async def study_guide(self, ctx: commands.Context, course: str):
"""Create links to study guides"""
async with self.client.db_session as session:
ufora_course = await ufora_courses.get_course_by_name(session, course)
if ufora_course is None:
return await ctx.reply(f"Geen vak gevonden voor ``{course}``", ephemeral=True)
return await ctx.reply(
f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{constants.CURRENT_YEAR}",
mention_author=False,
)
@study_guide.autocomplete("course")
async def study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
"""Autocompletion for the 'course'-parameter"""
return [
app_commands.Choice(name=course, value=course)
for course in self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
]
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""

View File

@ -14,7 +14,6 @@ class Tasks(commands.Cog):
client: Didier client: Didier
def __init__(self, client: Didier): def __init__(self, client: Didier):
# pylint: disable=no-member
self.client = client self.client = client
# Only pull announcements if a token was provided # Only pull announcements if a token was provided

View File

@ -1 +1,10 @@
# The year in which we were in 1Ba
import settings
FIRST_YEAR = 2019
# Year to use when adding the current year of our education
# to find the academic year
OFFSET_FIRST_YEAR = FIRST_YEAR - 1
# The current academic year
CURRENT_YEAR = OFFSET_FIRST_YEAR + settings.YEAR
PREFIXES = ["didier", "big d"] PREFIXES = ["didier", "big d"]

View File

@ -1,12 +0,0 @@
__all__ = ["MissingEnvironmentVariable"]
class MissingEnvironmentVariable(RuntimeError):
"""Exception raised when an environment variable is missing
These are not necessarily checked on startup, because they may be unused
during a given test run, and random unrelated crashes would be annoying
"""
def __init__(self, variable: str):
super().__init__(f"Missing environment variable: {variable}")

View File

@ -0,0 +1,4 @@
from .custom_commands import CreateCustomCommand, EditCustomCommand
from .dad_jokes import AddDadJoke
__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand"]

View File

@ -0,0 +1,37 @@
import traceback
import discord
from overrides import overrides
from database.crud.dad_jokes import add_dad_joke
from didier import Didier
__all__ = ["AddDadJoke"]
class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
"""Modal to add a new dad joke"""
name: discord.ui.TextInput = discord.ui.TextInput(
label="Joke",
placeholder="I sold our vacuum cleaner, it was just gathering dust.",
style=discord.TextStyle.long,
)
client: Didier
def __init__(self, client: Didier, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = client
@overrides
async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session:
joke = await add_dad_joke(session, str(self.name.value))
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)
@overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__)

View File

@ -15,7 +15,8 @@ omit = [
"./didier/cogs/*", "./didier/cogs/*",
"./didier/didier.py", "./didier/didier.py",
"./didier/data/*", "./didier/data/*",
"./didier/utils/discord/colours.py" "./didier/utils/discord/colours.py",
"./didier/utils/discord/constants.py"
] ]
[tool.isort] [tool.isort]
@ -38,5 +39,5 @@ env = [
"DB_PASSWORD = pytest", "DB_PASSWORD = pytest",
"DB_HOST = localhost", "DB_HOST = localhost",
"DB_PORT = 5433", "DB_PORT = 5433",
"DISC_TOKEN = token" "DISCORD_TOKEN = token"
] ]

View File

@ -29,6 +29,8 @@ __all__ = [
"""General config""" """General config"""
SANDBOX: bool = env.bool("SANDBOX", True) SANDBOX: bool = env.bool("SANDBOX", True)
LOGFILE: str = env.str("LOGFILE", "didier.log") LOGFILE: str = env.str("LOGFILE", "didier.log")
SEMESTER: int = env.int("SEMESTER", 2)
YEAR: int = env.int("YEAR", 3)
"""Database""" """Database"""
DB_NAME: str = env.str("DB_NAME", "didier") DB_NAME: str = env.str("DB_NAME", "didier")

View File

@ -1,5 +1,4 @@
import asyncio import asyncio
import datetime
from typing import AsyncGenerator, Generator from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
@ -7,11 +6,9 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine from database.engine import engine
from database.models import Base, UforaAnnouncement, UforaCourse, UforaCourseAlias from database.models import Base
from didier import Didier from didier import Didier
"""General fixtures"""
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def event_loop() -> Generator: def event_loop() -> Generator:
@ -57,34 +54,3 @@ def mock_client() -> Didier:
mock_client.user = mock_user mock_client.user = mock_user
return mock_client return mock_client
"""Fixtures to put fake data in the database"""
@pytest.fixture
async def ufora_course(database_session: AsyncSession) -> UforaCourse:
"""Fixture to create a course"""
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
database_session.add(course)
await database_session.commit()
return course
@pytest.fixture
async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse:
"""Fixture to create a course with an alias"""
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
database_session.add(alias)
await database_session.commit()
await database_session.refresh(ufora_course)
return ufora_course
@pytest.fixture
async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement:
"""Fixture to create an announcement"""
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
database_session.add(announcement)
await database_session.commit()
return announcement

View File

@ -0,0 +1,34 @@
import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias
@pytest.fixture
async def ufora_course(database_session: AsyncSession) -> UforaCourse:
"""Fixture to create a course"""
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
database_session.add(course)
await database_session.commit()
return course
@pytest.fixture
async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse:
"""Fixture to create a course with an alias"""
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
database_session.add(alias)
await database_session.commit()
await database_session.refresh(ufora_course)
return ufora_course
@pytest.fixture
async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement:
"""Fixture to create an announcement"""
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
database_session.add(announcement)
await database_session.commit()
return announcement