diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index 08dd04f..d41846c 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -5,6 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.models import UforaCourse, UforaCourseAlias +__all__ = ["get_all_courses", "get_course_by_name"] + async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: """Get a list of all courses in the database""" diff --git a/database/utils/caches.py b/database/utils/caches.py new file mode 100644 index 0000000..5e5be4f --- /dev/null +++ b/database/utils/caches.py @@ -0,0 +1,74 @@ +from abc import ABC, abstractmethod + +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import ufora_courses + +__all__ = ["CacheManager"] + + +class DatabaseCache(ABC): + """Base class for a simple cache-like structure + + The goal of this class is to store data for Discord auto-completion results + that would otherwise potentially put heavy load on the database. + + This only stores strings, to avoid having to constantly refresh these objects. + Once a choice has been made, it can just be pulled out of the database. + + Considering the fact that a user isn't obligated to choose something from the suggestions, + chances are high we have to go to the database for the final action either way. + + Also stores the data in lowercase to allow fast searching + """ + + data: list[str] = [] + data_transformed: list[str] = [] + + def clear(self): + """Remove everything""" + self.data.clear() + + @abstractmethod + async def refresh(self, database_session: AsyncSession): + """Refresh the data stored in this cache""" + + async def invalidate(self, database_session: AsyncSession): + """Invalidate the data stored in this cache""" + await self.refresh(database_session) + + def get_autocomplete_suggestions(self, query: str): + """Filter the cache to find everything that matches the search query""" + query = query.lower() + # Return the original (non-transformed) version of the data for pretty display in Discord + return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] + + +class UforaCourseCache(DatabaseCache): + """Cache to store the names of Ufora courses""" + + async def refresh(self, database_session: AsyncSession): + self.clear() + + courses = await ufora_courses.get_all_courses(database_session) + + # Load the course names + all the aliases + for course in courses: + aliases = list(map(lambda x: x.alias, course.aliases)) + self.data.extend([course.name, *aliases]) + + self.data.sort() + self.data_transformed = list(map(str.lower, self.data)) + + +class CacheManager: + """Class that keeps track of all caches""" + + ufora_courses: UforaCourseCache + + def __init__(self): + self.ufora_courses = UforaCourseCache() + + async def initialize_caches(self, database_session: AsyncSession): + """Initialize the contents of all caches""" + await self.ufora_courses.refresh(database_session) diff --git a/didier/didier.py b/didier/didier.py index fbe855b..9fef227 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import custom_commands from database.engine import DBSession +from database.utils.caches import CacheManager from didier.utils.discord.prefix import get_prefix __all__ = ["Didier"] @@ -16,6 +17,7 @@ __all__ = ["Didier"] class Didier(commands.Bot): """DIDIER <3""" + database_caches: CacheManager initial_extensions: tuple[str, ...] = () http_session: ClientSession @@ -50,6 +52,11 @@ class Didier(commands.Bot): await self._load_initial_extensions() await self._load_directory_extensions("didier/cogs") + # Initialize caches + self.database_caches = CacheManager() + async with self.db_session as session: + await self.database_caches.initialize_caches(session) + # Create aiohttp session self.http_session = ClientSession() diff --git a/tests/conftest.py b/tests/conftest.py index b2a1e04..c8ab65f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import asyncio +import datetime from typing import AsyncGenerator, Generator from unittest.mock import MagicMock @@ -6,9 +7,11 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.engine import engine -from database.models import Base +from database.models import Base, UforaAnnouncement, UforaCourse, UforaCourseAlias from didier import Didier +"""General fixtures""" + @pytest.fixture(scope="session") def event_loop() -> Generator: @@ -54,3 +57,34 @@ def mock_client() -> Didier: mock_client.user = mock_user return mock_client + + +"""Fixtures to put fake data in the database""" + + +@pytest.fixture +async def ufora_course(database_session: AsyncSession) -> UforaCourse: + """Fixture to create a course""" + course = UforaCourse(name="test", code="code", year=1, log_announcements=True) + database_session.add(course) + await database_session.commit() + return course + + +@pytest.fixture +async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: + """Fixture to create a course with an alias""" + alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") + database_session.add(alias) + await database_session.commit() + await database_session.refresh(ufora_course) + return ufora_course + + +@pytest.fixture +async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: + """Fixture to create an announcement""" + announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) + database_session.add(announcement) + await database_session.commit() + return announcement diff --git a/tests/test_database/test_crud/conftest.py b/tests/test_database/test_crud/conftest.py deleted file mode 100644 index be5f889..0000000 --- a/tests/test_database/test_crud/conftest.py +++ /dev/null @@ -1,34 +0,0 @@ -import datetime - -import pytest -from sqlalchemy.ext.asyncio import AsyncSession - -from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias - - -@pytest.fixture -async def course(database_session: AsyncSession) -> UforaCourse: - """Fixture to create a course""" - course = UforaCourse(name="test", code="code", year=1, log_announcements=True) - database_session.add(course) - await database_session.commit() - return course - - -@pytest.fixture -async def course_with_alias(database_session: AsyncSession, course: UforaCourse) -> UforaCourse: - """Fixture to create a course with an alias""" - alias = UforaCourseAlias(course_id=course.course_id, alias="alias") - database_session.add(alias) - await database_session.commit() - await database_session.refresh(course) - return course - - -@pytest.fixture -async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: - """Fixture to create an announcement""" - announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now()) - database_session.add(announcement) - await database_session.commit() - return announcement diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index b2303a6..b2385a2 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -25,19 +25,21 @@ async def test_get_courses_with_announcements(database_session: AsyncSession): assert results[0] == course_1 -async def test_create_new_announcement(course: UforaCourse, database_session: AsyncSession): +async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession): """Test creating a new announcement""" - await crud.create_new_announcement(database_session, 1, course=course, publication_date=datetime.datetime.now()) - await database_session.refresh(course) - assert len(course.announcements) == 1 + await crud.create_new_announcement( + database_session, 1, course=ufora_course, publication_date=datetime.datetime.now() + ) + await database_session.refresh(ufora_course) + assert len(ufora_course.announcements) == 1 -async def test_remove_old_announcements(announcement: UforaAnnouncement, database_session: AsyncSession): +async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession): """Test removing all stale announcements""" - course = announcement.course - announcement.publication_date -= datetime.timedelta(weeks=2) - announcement_2 = UforaAnnouncement(course_id=announcement.course_id, publication_date=datetime.datetime.now()) - database_session.add_all([announcement, announcement_2]) + course = ufora_announcement.course + ufora_announcement.publication_date -= datetime.timedelta(weeks=2) + announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now()) + database_session.add_all([ufora_announcement, announcement_2]) await database_session.commit() await database_session.refresh(course) assert len(course.announcements) == 2 diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index efe8fe6..d2d5e1b 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -4,19 +4,19 @@ from database.crud import ufora_courses as crud from database.models import UforaCourse -async def test_get_course_by_name_exact(database_session: AsyncSession, course: UforaCourse): +async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse): """Test getting a course by its name when the query is an exact match""" match = await crud.get_course_by_name(database_session, "Test") - assert match == course + assert match == ufora_course -async def test_get_course_by_name_substring(database_session: AsyncSession, course: UforaCourse): +async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse): """Test getting a course by its name when the query is a substring""" match = await crud.get_course_by_name(database_session, "es") - assert match == course + assert match == ufora_course -async def test_get_course_by_name_alias(database_session: AsyncSession, course_with_alias: UforaCourse): +async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): """Test getting a course by its name when the name doesn't match, but the alias does""" match = await crud.get_course_by_name(database_session, "ali") - assert match == course_with_alias + assert match == ufora_course_with_alias diff --git a/tests/test_database/test_utils/__init__.py b/tests/test_database/test_utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py new file mode 100644 index 0000000..09583d3 --- /dev/null +++ b/tests/test_database/test_utils/test_caches.py @@ -0,0 +1,27 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import UforaCourse +from database.utils.caches import UforaCourseCache + + +async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): + """Test loading the data for the Ufora Course cache when it's empty""" + cache = UforaCourseCache() + await cache.refresh(database_session) + + assert len(cache.data) == 2 + assert cache.data == ["alias", "test"] + + +async def test_ufora_course_cache_refresh_not_empty( + database_session: AsyncSession, ufora_course_with_alias: UforaCourse +): + """Test loading the data for the Ufora Course cache when it's not empty anymore""" + cache = UforaCourseCache() + cache.data = ["Something"] + cache.data_transformed = ["something"] + + await cache.refresh(database_session) + + assert len(cache.data) == 2 + assert cache.data == ["alias", "test"]