mirror of https://github.com/stijndcl/didier
				
				
				
			Add simple caching implementation for database queries that will be used in command autocompletion
							parent
							
								
									f0a05c8b4d
								
							
						
					
					
						commit
						72c3acbcc2
					
				|  | @ -5,6 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession | |||
| 
 | ||||
| from database.models import UforaCourse, UforaCourseAlias | ||||
| 
 | ||||
| __all__ = ["get_all_courses", "get_course_by_name"] | ||||
| 
 | ||||
| 
 | ||||
| async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: | ||||
|     """Get a list of all courses in the database""" | ||||
|  |  | |||
|  | @ -0,0 +1,74 @@ | |||
| from abc import ABC, abstractmethod | ||||
| 
 | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.crud import ufora_courses | ||||
| 
 | ||||
| __all__ = ["CacheManager"] | ||||
| 
 | ||||
| 
 | ||||
| class DatabaseCache(ABC): | ||||
|     """Base class for a simple cache-like structure | ||||
| 
 | ||||
|     The goal of this class is to store data for Discord auto-completion results | ||||
|     that would otherwise potentially put heavy load on the database. | ||||
| 
 | ||||
|     This only stores strings, to avoid having to constantly refresh these objects. | ||||
|     Once a choice has been made, it can just be pulled out of the database. | ||||
| 
 | ||||
|     Considering the fact that a user isn't obligated to choose something from the suggestions, | ||||
|     chances are high we have to go to the database for the final action either way. | ||||
| 
 | ||||
|     Also stores the data in lowercase to allow fast searching | ||||
|     """ | ||||
| 
 | ||||
|     data: list[str] = [] | ||||
|     data_transformed: list[str] = [] | ||||
| 
 | ||||
|     def clear(self): | ||||
|         """Remove everything""" | ||||
|         self.data.clear() | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     async def refresh(self, database_session: AsyncSession): | ||||
|         """Refresh the data stored in this cache""" | ||||
| 
 | ||||
|     async def invalidate(self, database_session: AsyncSession): | ||||
|         """Invalidate the data stored in this cache""" | ||||
|         await self.refresh(database_session) | ||||
| 
 | ||||
|     def get_autocomplete_suggestions(self, query: str): | ||||
|         """Filter the cache to find everything that matches the search query""" | ||||
|         query = query.lower() | ||||
|         # Return the original (non-transformed) version of the data for pretty display in Discord | ||||
|         return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] | ||||
| 
 | ||||
| 
 | ||||
| class UforaCourseCache(DatabaseCache): | ||||
|     """Cache to store the names of Ufora courses""" | ||||
| 
 | ||||
|     async def refresh(self, database_session: AsyncSession): | ||||
|         self.clear() | ||||
| 
 | ||||
|         courses = await ufora_courses.get_all_courses(database_session) | ||||
| 
 | ||||
|         # Load the course names + all the aliases | ||||
|         for course in courses: | ||||
|             aliases = list(map(lambda x: x.alias, course.aliases)) | ||||
|             self.data.extend([course.name, *aliases]) | ||||
| 
 | ||||
|         self.data.sort() | ||||
|         self.data_transformed = list(map(str.lower, self.data)) | ||||
| 
 | ||||
| 
 | ||||
| class CacheManager: | ||||
|     """Class that keeps track of all caches""" | ||||
| 
 | ||||
|     ufora_courses: UforaCourseCache | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.ufora_courses = UforaCourseCache() | ||||
| 
 | ||||
|     async def initialize_caches(self, database_session: AsyncSession): | ||||
|         """Initialize the contents of all caches""" | ||||
|         await self.ufora_courses.refresh(database_session) | ||||
|  | @ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession | |||
| import settings | ||||
| from database.crud import custom_commands | ||||
| from database.engine import DBSession | ||||
| from database.utils.caches import CacheManager | ||||
| from didier.utils.discord.prefix import get_prefix | ||||
| 
 | ||||
| __all__ = ["Didier"] | ||||
|  | @ -16,6 +17,7 @@ __all__ = ["Didier"] | |||
| class Didier(commands.Bot): | ||||
|     """DIDIER <3""" | ||||
| 
 | ||||
|     database_caches: CacheManager | ||||
|     initial_extensions: tuple[str, ...] = () | ||||
|     http_session: ClientSession | ||||
| 
 | ||||
|  | @ -50,6 +52,11 @@ class Didier(commands.Bot): | |||
|         await self._load_initial_extensions() | ||||
|         await self._load_directory_extensions("didier/cogs") | ||||
| 
 | ||||
|         # Initialize caches | ||||
|         self.database_caches = CacheManager() | ||||
|         async with self.db_session as session: | ||||
|             await self.database_caches.initialize_caches(session) | ||||
| 
 | ||||
|         # Create aiohttp session | ||||
|         self.http_session = ClientSession() | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| import asyncio | ||||
| import datetime | ||||
| from typing import AsyncGenerator, Generator | ||||
| from unittest.mock import MagicMock | ||||
| 
 | ||||
|  | @ -6,9 +7,11 @@ import pytest | |||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.engine import engine | ||||
| from database.models import Base | ||||
| from database.models import Base, UforaAnnouncement, UforaCourse, UforaCourseAlias | ||||
| from didier import Didier | ||||
| 
 | ||||
| """General fixtures""" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(scope="session") | ||||
| def event_loop() -> Generator: | ||||
|  | @ -54,3 +57,34 @@ def mock_client() -> Didier: | |||
|     mock_client.user = mock_user | ||||
| 
 | ||||
|     return mock_client | ||||
| 
 | ||||
| 
 | ||||
| """Fixtures to put fake data in the database""" | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| async def ufora_course(database_session: AsyncSession) -> UforaCourse: | ||||
|     """Fixture to create a course""" | ||||
|     course = UforaCourse(name="test", code="code", year=1, log_announcements=True) | ||||
|     database_session.add(course) | ||||
|     await database_session.commit() | ||||
|     return course | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: | ||||
|     """Fixture to create a course with an alias""" | ||||
|     alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") | ||||
|     database_session.add(alias) | ||||
|     await database_session.commit() | ||||
|     await database_session.refresh(ufora_course) | ||||
|     return ufora_course | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: | ||||
|     """Fixture to create an announcement""" | ||||
|     announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) | ||||
|     database_session.add(announcement) | ||||
|     await database_session.commit() | ||||
|     return announcement | ||||
|  |  | |||
|  | @ -1,34 +0,0 @@ | |||
| import datetime | ||||
| 
 | ||||
| import pytest | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| async def course(database_session: AsyncSession) -> UforaCourse: | ||||
|     """Fixture to create a course""" | ||||
|     course = UforaCourse(name="test", code="code", year=1, log_announcements=True) | ||||
|     database_session.add(course) | ||||
|     await database_session.commit() | ||||
|     return course | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| async def course_with_alias(database_session: AsyncSession, course: UforaCourse) -> UforaCourse: | ||||
|     """Fixture to create a course with an alias""" | ||||
|     alias = UforaCourseAlias(course_id=course.course_id, alias="alias") | ||||
|     database_session.add(alias) | ||||
|     await database_session.commit() | ||||
|     await database_session.refresh(course) | ||||
|     return course | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: | ||||
|     """Fixture to create an announcement""" | ||||
|     announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now()) | ||||
|     database_session.add(announcement) | ||||
|     await database_session.commit() | ||||
|     return announcement | ||||
|  | @ -25,19 +25,21 @@ async def test_get_courses_with_announcements(database_session: AsyncSession): | |||
|     assert results[0] == course_1 | ||||
| 
 | ||||
| 
 | ||||
| async def test_create_new_announcement(course: UforaCourse, database_session: AsyncSession): | ||||
| async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession): | ||||
|     """Test creating a new announcement""" | ||||
|     await crud.create_new_announcement(database_session, 1, course=course, publication_date=datetime.datetime.now()) | ||||
|     await database_session.refresh(course) | ||||
|     assert len(course.announcements) == 1 | ||||
|     await crud.create_new_announcement( | ||||
|         database_session, 1, course=ufora_course, publication_date=datetime.datetime.now() | ||||
|     ) | ||||
|     await database_session.refresh(ufora_course) | ||||
|     assert len(ufora_course.announcements) == 1 | ||||
| 
 | ||||
| 
 | ||||
| async def test_remove_old_announcements(announcement: UforaAnnouncement, database_session: AsyncSession): | ||||
| async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession): | ||||
|     """Test removing all stale announcements""" | ||||
|     course = announcement.course | ||||
|     announcement.publication_date -= datetime.timedelta(weeks=2) | ||||
|     announcement_2 = UforaAnnouncement(course_id=announcement.course_id, publication_date=datetime.datetime.now()) | ||||
|     database_session.add_all([announcement, announcement_2]) | ||||
|     course = ufora_announcement.course | ||||
|     ufora_announcement.publication_date -= datetime.timedelta(weeks=2) | ||||
|     announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now()) | ||||
|     database_session.add_all([ufora_announcement, announcement_2]) | ||||
|     await database_session.commit() | ||||
|     await database_session.refresh(course) | ||||
|     assert len(course.announcements) == 2 | ||||
|  |  | |||
|  | @ -4,19 +4,19 @@ from database.crud import ufora_courses as crud | |||
| from database.models import UforaCourse | ||||
| 
 | ||||
| 
 | ||||
| async def test_get_course_by_name_exact(database_session: AsyncSession, course: UforaCourse): | ||||
| async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse): | ||||
|     """Test getting a course by its name when the query is an exact match""" | ||||
|     match = await crud.get_course_by_name(database_session, "Test") | ||||
|     assert match == course | ||||
|     assert match == ufora_course | ||||
| 
 | ||||
| 
 | ||||
| async def test_get_course_by_name_substring(database_session: AsyncSession, course: UforaCourse): | ||||
| async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse): | ||||
|     """Test getting a course by its name when the query is a substring""" | ||||
|     match = await crud.get_course_by_name(database_session, "es") | ||||
|     assert match == course | ||||
|     assert match == ufora_course | ||||
| 
 | ||||
| 
 | ||||
| async def test_get_course_by_name_alias(database_session: AsyncSession, course_with_alias: UforaCourse): | ||||
| async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): | ||||
|     """Test getting a course by its name when the name doesn't match, but the alias does""" | ||||
|     match = await crud.get_course_by_name(database_session, "ali") | ||||
|     assert match == course_with_alias | ||||
|     assert match == ufora_course_with_alias | ||||
|  |  | |||
|  | @ -0,0 +1,27 @@ | |||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.models import UforaCourse | ||||
| from database.utils.caches import UforaCourseCache | ||||
| 
 | ||||
| 
 | ||||
| async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): | ||||
|     """Test loading the data for the Ufora Course cache when it's empty""" | ||||
|     cache = UforaCourseCache() | ||||
|     await cache.refresh(database_session) | ||||
| 
 | ||||
|     assert len(cache.data) == 2 | ||||
|     assert cache.data == ["alias", "test"] | ||||
| 
 | ||||
| 
 | ||||
| async def test_ufora_course_cache_refresh_not_empty( | ||||
|     database_session: AsyncSession, ufora_course_with_alias: UforaCourse | ||||
| ): | ||||
|     """Test loading the data for the Ufora Course cache when it's not empty anymore""" | ||||
|     cache = UforaCourseCache() | ||||
|     cache.data = ["Something"] | ||||
|     cache.data_transformed = ["something"] | ||||
| 
 | ||||
|     await cache.refresh(database_session) | ||||
| 
 | ||||
|     assert len(cache.data) == 2 | ||||
|     assert cache.data == ["alias", "test"] | ||||
		Loading…
	
		Reference in New Issue