diff --git a/alembic/versions/ea9811f060aa_initial_migration.py b/alembic/versions/5bdb99885a5d_initial_migration.py similarity index 98% rename from alembic/versions/ea9811f060aa_initial_migration.py rename to alembic/versions/5bdb99885a5d_initial_migration.py index dbf5580..9c86c48 100644 --- a/alembic/versions/ea9811f060aa_initial_migration.py +++ b/alembic/versions/5bdb99885a5d_initial_migration.py @@ -1,8 +1,8 @@ """Initial migration -Revision ID: ea9811f060aa +Revision ID: 5bdb99885a5d Revises: -Create Date: 2022-09-17 17:31:20.593318 +Create Date: 2022-09-17 22:39:15.969694 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = "ea9811f060aa" +revision = "5bdb99885a5d" down_revision = None branch_labels = None depends_on = None @@ -70,6 +70,7 @@ def upgrade() -> None: sa.Column("year", sa.Integer(), nullable=False), sa.Column("compulsory", sa.Boolean(), server_default="1", nullable=False), sa.Column("role_id", sa.Integer(), nullable=True), + sa.Column("overarching_role_id", sa.Integer(), nullable=True), sa.Column("log_announcements", sa.Boolean(), server_default="0", nullable=False), sa.PrimaryKeyConstraint("course_id"), sa.UniqueConstraint("code"), diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index 19369c1..5374c07 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.schemas import UforaCourse, UforaCourseAlias -__all__ = ["get_all_courses", "get_course_by_name"] +__all__ = ["get_all_courses", "get_course_by_code", "get_course_by_name"] async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: @@ -14,6 +14,12 @@ async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: return list((await session.execute(statement)).scalars().all()) +async def get_course_by_code(session: AsyncSession, code: str) -> Optional[UforaCourse]: + """Try to find a course by its code""" + statement = select(UforaCourse).where(UforaCourse.code == code) + return (await session.execute(statement)).scalar_one_or_none() + + async def get_course_by_name(session: AsyncSession, query: str) -> Optional[UforaCourse]: """Try to find a course by its name diff --git a/database/schemas.py b/database/schemas.py index 945781d..2a2da34 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -199,6 +199,7 @@ class UforaCourse(Base): year: int = Column(Integer, nullable=False) compulsory: bool = Column(Boolean, server_default="1", nullable=False) role_id: Optional[int] = Column(Integer, nullable=True, unique=False) + overarching_role_id: Optional[int] = Column(Integer, nullable=True, unique=False) log_announcements: bool = Column(Boolean, server_default="0", nullable=False) announcements: list[UforaAnnouncement] = relationship( diff --git a/database/scripts/debug_add_courses.py b/database/scripts/debug_add_courses.py new file mode 100644 index 0000000..aeaa9b1 --- /dev/null +++ b/database/scripts/debug_add_courses.py @@ -0,0 +1,16 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from database.engine import DBSession +from database.schemas import UforaCourse + +__all__ = ["main"] + + +async def main(): + """Add debug Ufora courses""" + session: AsyncSession + async with DBSession() as session: + modsim = UforaCourse(course_id=439235, code="C003786", name="Modelleren en Simuleren", year=3, compulsory=False) + + session.add_all([modsim]) + await session.commit() diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 2fa4b4e..974896e 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -53,6 +53,15 @@ class Owner(commands.Cog): """Raise an exception for debugging purposes""" raise Exception(message) + @commands.command(name="Reload") + async def reload(self, ctx: commands.Context, *cogs: str): + """Reload the cogs passed as an argument""" + for cog in cogs: + await self.client.reload_extension(f"didier.cogs.{cog}") + + await self.client.confirm_message(ctx.message) + return await ctx.reply(f"Successfully reloaded {', '.join(cogs)}.", mention_author=False) + @commands.command(name="Sync") async def sync( self, diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 42a7d0b..36c2467 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -12,6 +12,7 @@ from database.crud.ufora_announcements import remove_old_announcements from database.crud.wordle import set_daily_word from didier import Didier from didier.data.embeds.ufora.announcements import fetch_ufora_announcements +from didier.data.schedules import parse_schedule_from_content from didier.decorators.tasks import timed_task from didier.utils.discord.checks import is_owner from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now @@ -121,22 +122,25 @@ class Tasks(commands.Cog): """ _ = kwargs - for data in settings.SCHEDULE_DATA: - if data.schedule_url is None: - return + async with self.client.postgres_session as session: + for data in settings.SCHEDULE_DATA: + if data.schedule_url is None: + return - async with self.client.http_session.get(data.schedule_url) as response: - # If a schedule couldn't be fetched, log it and move on - if response.status != 200: - await self.client.log_warning( - f"Unable to fetch schedule {data.name} (status {response.status}).", log_to_discord=False - ) - continue + async with self.client.http_session.get(data.schedule_url) as response: + # If a schedule couldn't be fetched, log it and move on + if response.status != 200: + await self.client.log_warning( + f"Unable to fetch schedule {data.name} (status {response.status}).", log_to_discord=False + ) + continue - # Write the content to a file - content = await response.text() - with open(f"files/schedules/{data.name}.ics", "w+") as fp: - fp.write(content) + # Write the content to a file + content = await response.text() + with open(f"files/schedules/{data.name}.ics", "w+") as fp: + fp.write(content) + + await parse_schedule_from_content(content, database_session=session) @tasks.loop(minutes=10) @timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS) @@ -194,4 +198,4 @@ async def setup(client: Didier): cog = Tasks(client) await client.add_cog(cog) await cog.reset_wordle_word() - await cog.pull_schedules() + # await cog.pull_schedules() diff --git a/didier/data/schedules.py b/didier/data/schedules.py new file mode 100644 index 0000000..132fe79 --- /dev/null +++ b/didier/data/schedules.py @@ -0,0 +1,116 @@ +from __future__ import annotations + +import pathlib +import re +from dataclasses import dataclass, field +from datetime import datetime +from typing import Optional + +from arrow import Arrow +from ics import Calendar +from overrides import overrides +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud.ufora_courses import get_course_by_code +from database.schemas import UforaCourse +from didier.utils.types.datetime import LOCAL_TIMEZONE +from settings import ScheduleType + +__all__ = ["Schedule", "parse_schedule_from_content", "parse_schedule"] + + +@dataclass +class Schedule: + """An entire schedule""" + + slots: set[ScheduleSlot] + + +@dataclass +class ScheduleSlot: + """A slot in the schedule""" + + course: UforaCourse + start_time: datetime + end_time: datetime + location: str + _hash: int = field(init=False) + + def __post_init__(self): + """Fix some properties to display more nicely""" + # Re-format the location data + room, building, campus = re.search(r"Leslokaal (.*)\. Gebouw (.*)\. Campus (.*)\. ", self.location).groups() + self.location = f"{campus} {building} {room}" + + self._hash = hash(f"{self.course.course_id} {str(self.start_time)}") + + @overrides + def __hash__(self) -> int: + return self._hash + + @overrides + def __eq__(self, other: ScheduleSlot): + return self._hash == other._hash + + +def parse_course_code(summary: str) -> str: + """Parse a course's code out of the summary""" + code = re.search(r"^([^ ]+)\. ", summary).groups()[0] + + # Strip off last character as it's not relevant + if code[-1].isalpha(): + return code[:-1] + + return code + + +def parse_time_string(string: str) -> datetime: + """Parse an ISO string to a timezone-aware datetime instance""" + return datetime.fromisoformat(string).astimezone(LOCAL_TIMEZONE) + + +async def parse_schedule_from_content(content: str, *, database_session: AsyncSession) -> Schedule: + """Parse a schedule file, taking the file content as an argument + + This can be used to avoid unnecessarily opening the file again if you already have its contents + """ + calendar = Calendar(content) + day = Arrow(year=2022, month=9, day=26) + events = list(calendar.timeline.on(day)) + course_codes: dict[str, UforaCourse] = {} + slots: set[ScheduleSlot] = set() + + for event in events: + code = parse_course_code(event.name) + + if code not in course_codes: + course = await get_course_by_code(database_session, code) + if course is None: + # raise ValueError(f"Unable to find course with code {code} (event {event.name})") + continue # TODO uncomment the line above + + course_codes[code] = course + + # Overwrite the name to be the sanitized value + event.name = code + + slot = ScheduleSlot( + course=course_codes[code], + start_time=parse_time_string(str(event.begin)), + end_time=parse_time_string(str(event.end)), + location=event.location, + ) + + slots.add(slot) + + return Schedule(slots=slots) + + +async def parse_schedule(name: ScheduleType, *, database_session: AsyncSession) -> Optional[Schedule]: + """Read and then parse a schedule file""" + schedule_path = pathlib.Path(f"files/schedules/{name}.ics") + if not schedule_path.exists(): + return None + + with open(schedule_path, "r", encoding="utf-8") as fp: + return await parse_schedule_from_content(fp.read(), database_session=database_session) diff --git a/didier/didier.py b/didier/didier.py index 4d4434e..d4337e5 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -13,6 +13,7 @@ from database.crud import custom_commands from database.engine import DBSession from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed +from didier.data.schedules import Schedule, parse_schedule from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix @@ -29,6 +30,7 @@ class Didier(commands.Bot): error_channel: discord.abc.Messageable initial_extensions: tuple[str, ...] = () http_session: ClientSession + schedules: dict[settings.ScheduleType, Schedule] = {} wordle_words: set[str] = set() def __init__(self): @@ -63,6 +65,9 @@ class Didier(commands.Bot): # Create directories that are ignored on GitHub self._create_ignored_directories() + # Load schedules + await self.load_schedules() + # Load the Wordle dictionary self._load_wordle_words() @@ -120,6 +125,18 @@ class Didier(commands.Bot): for line in fp: self.wordle_words.add(line.strip()) + async def load_schedules(self): + """Parse & load all schedules into memory""" + self.schedules = {} + + async with self.postgres_session as session: + for schedule_data in settings.SCHEDULE_DATA: + schedule = await parse_schedule(schedule_data.name, database_session=session) + if schedule is None: + continue + + self.schedules[schedule_data.name] = schedule + async def get_reply_target(self, ctx: commands.Context) -> discord.Message: """Get the target message that should be replied to diff --git a/run_db_scripts.py b/run_db_scripts.py new file mode 100644 index 0000000..12eb230 --- /dev/null +++ b/run_db_scripts.py @@ -0,0 +1,28 @@ +"""Script to run database-related scripts + +This is slightly ugly, but running the scripts directly isn't possible because of imports +This could be cleaned up a bit using importlib but this is safer +""" +import asyncio +import sys +from typing import Callable + +from database.scripts.debug_add_courses import main as debug_add_courses + +script_mapping: dict[str, Callable] = {"debug_add_courses.py": debug_add_courses} + + +if __name__ == "__main__": + scripts = sys.argv[1:] + if not scripts: + print("No scripts provided.", file=sys.stderr) + exit(1) + + for script in scripts: + script_main = script_mapping.get(script.removeprefix("database/scripts/"), None) + if script_main is None: + print(f'Script "{script}" not found.', file=sys.stderr) + exit(1) + + asyncio.run(script_main()) + print(f"Successfully ran {script}") diff --git a/settings.py b/settings.py index ad72779..ff8bb35 100644 --- a/settings.py +++ b/settings.py @@ -1,4 +1,5 @@ from dataclasses import dataclass +from enum import Enum from typing import Optional from environs import Env @@ -29,6 +30,8 @@ __all__ = [ "IMGFLIP_NAME", "IMGFLIP_PASSWORD", "BA3_SCHEDULE_URL", + "ScheduleType", + "ScheduleInfo", "SCHEDULE_DATA", ] @@ -53,6 +56,7 @@ POSTGRES_PORT: int = env.int("POSTGRES_PORT", "5432") DISCORD_TOKEN: str = env.str("DISCORD_TOKEN") DISCORD_READY_MESSAGE: str = env.str("DISCORD_READY_MESSAGE", "I'M READY I'M READY I'M READY") DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didier Dinks.") +DISCORD_MAIN_GUILD: Optional[int] = env.int("DISCORD_MAIN_GUILD", 626699611192688641) DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int) DISCORD_OWNER_GUILDS: Optional[list[int]] = env.list("DISCORD_OWNER_GUILDS", [], subcast=int) or None DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>") @@ -77,6 +81,12 @@ BA3_SCHEDULE_URL: Optional[str] = env.str("BA3_SCHEDULE_URL", None) """Computed properties""" +class ScheduleType(str, Enum): + """Enum to differentiate schedules""" + + BA3 = "ba3" + + @dataclass class ScheduleInfo: """Dataclass to hold and combine some information about schedule-related settings""" @@ -86,4 +96,4 @@ class ScheduleInfo: name: Optional[str] = None -SCHEDULE_DATA = [ScheduleInfo(name="ba3", role_id=BA3_ROLE, schedule_url=BA3_SCHEDULE_URL)] +SCHEDULE_DATA = [ScheduleInfo(name=ScheduleType.BA3, role_id=BA3_ROLE, schedule_url=BA3_SCHEDULE_URL)]