mirror of https://github.com/stijndcl/didier
				
				
				
			Add simple caching implementation for database queries that will be used in command autocompletion
							parent
							
								
									f0a05c8b4d
								
							
						
					
					
						commit
						72c3acbcc2
					
				|  | @ -5,6 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession | ||||||
| 
 | 
 | ||||||
| from database.models import UforaCourse, UforaCourseAlias | from database.models import UforaCourse, UforaCourseAlias | ||||||
| 
 | 
 | ||||||
|  | __all__ = ["get_all_courses", "get_course_by_name"] | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: | async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: | ||||||
|     """Get a list of all courses in the database""" |     """Get a list of all courses in the database""" | ||||||
|  |  | ||||||
|  | @ -0,0 +1,74 @@ | ||||||
|  | from abc import ABC, abstractmethod | ||||||
|  | 
 | ||||||
|  | from sqlalchemy.ext.asyncio import AsyncSession | ||||||
|  | 
 | ||||||
|  | from database.crud import ufora_courses | ||||||
|  | 
 | ||||||
|  | __all__ = ["CacheManager"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class DatabaseCache(ABC): | ||||||
|  |     """Base class for a simple cache-like structure | ||||||
|  | 
 | ||||||
|  |     The goal of this class is to store data for Discord auto-completion results | ||||||
|  |     that would otherwise potentially put heavy load on the database. | ||||||
|  | 
 | ||||||
|  |     This only stores strings, to avoid having to constantly refresh these objects. | ||||||
|  |     Once a choice has been made, it can just be pulled out of the database. | ||||||
|  | 
 | ||||||
|  |     Considering the fact that a user isn't obligated to choose something from the suggestions, | ||||||
|  |     chances are high we have to go to the database for the final action either way. | ||||||
|  | 
 | ||||||
|  |     Also stores the data in lowercase to allow fast searching | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     data: list[str] = [] | ||||||
|  |     data_transformed: list[str] = [] | ||||||
|  | 
 | ||||||
|  |     def clear(self): | ||||||
|  |         """Remove everything""" | ||||||
|  |         self.data.clear() | ||||||
|  | 
 | ||||||
|  |     @abstractmethod | ||||||
|  |     async def refresh(self, database_session: AsyncSession): | ||||||
|  |         """Refresh the data stored in this cache""" | ||||||
|  | 
 | ||||||
|  |     async def invalidate(self, database_session: AsyncSession): | ||||||
|  |         """Invalidate the data stored in this cache""" | ||||||
|  |         await self.refresh(database_session) | ||||||
|  | 
 | ||||||
|  |     def get_autocomplete_suggestions(self, query: str): | ||||||
|  |         """Filter the cache to find everything that matches the search query""" | ||||||
|  |         query = query.lower() | ||||||
|  |         # Return the original (non-transformed) version of the data for pretty display in Discord | ||||||
|  |         return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class UforaCourseCache(DatabaseCache): | ||||||
|  |     """Cache to store the names of Ufora courses""" | ||||||
|  | 
 | ||||||
|  |     async def refresh(self, database_session: AsyncSession): | ||||||
|  |         self.clear() | ||||||
|  | 
 | ||||||
|  |         courses = await ufora_courses.get_all_courses(database_session) | ||||||
|  | 
 | ||||||
|  |         # Load the course names + all the aliases | ||||||
|  |         for course in courses: | ||||||
|  |             aliases = list(map(lambda x: x.alias, course.aliases)) | ||||||
|  |             self.data.extend([course.name, *aliases]) | ||||||
|  | 
 | ||||||
|  |         self.data.sort() | ||||||
|  |         self.data_transformed = list(map(str.lower, self.data)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class CacheManager: | ||||||
|  |     """Class that keeps track of all caches""" | ||||||
|  | 
 | ||||||
|  |     ufora_courses: UforaCourseCache | ||||||
|  | 
 | ||||||
|  |     def __init__(self): | ||||||
|  |         self.ufora_courses = UforaCourseCache() | ||||||
|  | 
 | ||||||
|  |     async def initialize_caches(self, database_session: AsyncSession): | ||||||
|  |         """Initialize the contents of all caches""" | ||||||
|  |         await self.ufora_courses.refresh(database_session) | ||||||
|  | @ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession | ||||||
| import settings | import settings | ||||||
| from database.crud import custom_commands | from database.crud import custom_commands | ||||||
| from database.engine import DBSession | from database.engine import DBSession | ||||||
|  | from database.utils.caches import CacheManager | ||||||
| from didier.utils.discord.prefix import get_prefix | from didier.utils.discord.prefix import get_prefix | ||||||
| 
 | 
 | ||||||
| __all__ = ["Didier"] | __all__ = ["Didier"] | ||||||
|  | @ -16,6 +17,7 @@ __all__ = ["Didier"] | ||||||
| class Didier(commands.Bot): | class Didier(commands.Bot): | ||||||
|     """DIDIER <3""" |     """DIDIER <3""" | ||||||
| 
 | 
 | ||||||
|  |     database_caches: CacheManager | ||||||
|     initial_extensions: tuple[str, ...] = () |     initial_extensions: tuple[str, ...] = () | ||||||
|     http_session: ClientSession |     http_session: ClientSession | ||||||
| 
 | 
 | ||||||
|  | @ -50,6 +52,11 @@ class Didier(commands.Bot): | ||||||
|         await self._load_initial_extensions() |         await self._load_initial_extensions() | ||||||
|         await self._load_directory_extensions("didier/cogs") |         await self._load_directory_extensions("didier/cogs") | ||||||
| 
 | 
 | ||||||
|  |         # Initialize caches | ||||||
|  |         self.database_caches = CacheManager() | ||||||
|  |         async with self.db_session as session: | ||||||
|  |             await self.database_caches.initialize_caches(session) | ||||||
|  | 
 | ||||||
|         # Create aiohttp session |         # Create aiohttp session | ||||||
|         self.http_session = ClientSession() |         self.http_session = ClientSession() | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -1,4 +1,5 @@ | ||||||
| import asyncio | import asyncio | ||||||
|  | import datetime | ||||||
| from typing import AsyncGenerator, Generator | from typing import AsyncGenerator, Generator | ||||||
| from unittest.mock import MagicMock | from unittest.mock import MagicMock | ||||||
| 
 | 
 | ||||||
|  | @ -6,9 +7,11 @@ import pytest | ||||||
| from sqlalchemy.ext.asyncio import AsyncSession | from sqlalchemy.ext.asyncio import AsyncSession | ||||||
| 
 | 
 | ||||||
| from database.engine import engine | from database.engine import engine | ||||||
| from database.models import Base | from database.models import Base, UforaAnnouncement, UforaCourse, UforaCourseAlias | ||||||
| from didier import Didier | from didier import Didier | ||||||
| 
 | 
 | ||||||
|  | """General fixtures""" | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture(scope="session") | @pytest.fixture(scope="session") | ||||||
| def event_loop() -> Generator: | def event_loop() -> Generator: | ||||||
|  | @ -54,3 +57,34 @@ def mock_client() -> Didier: | ||||||
|     mock_client.user = mock_user |     mock_client.user = mock_user | ||||||
| 
 | 
 | ||||||
|     return mock_client |     return mock_client | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | """Fixtures to put fake data in the database""" | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | async def ufora_course(database_session: AsyncSession) -> UforaCourse: | ||||||
|  |     """Fixture to create a course""" | ||||||
|  |     course = UforaCourse(name="test", code="code", year=1, log_announcements=True) | ||||||
|  |     database_session.add(course) | ||||||
|  |     await database_session.commit() | ||||||
|  |     return course | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: | ||||||
|  |     """Fixture to create a course with an alias""" | ||||||
|  |     alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") | ||||||
|  |     database_session.add(alias) | ||||||
|  |     await database_session.commit() | ||||||
|  |     await database_session.refresh(ufora_course) | ||||||
|  |     return ufora_course | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.fixture | ||||||
|  | async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: | ||||||
|  |     """Fixture to create an announcement""" | ||||||
|  |     announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) | ||||||
|  |     database_session.add(announcement) | ||||||
|  |     await database_session.commit() | ||||||
|  |     return announcement | ||||||
|  |  | ||||||
|  | @ -1,34 +0,0 @@ | ||||||
| import datetime |  | ||||||
| 
 |  | ||||||
| import pytest |  | ||||||
| from sqlalchemy.ext.asyncio import AsyncSession |  | ||||||
| 
 |  | ||||||
| from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @pytest.fixture |  | ||||||
| async def course(database_session: AsyncSession) -> UforaCourse: |  | ||||||
|     """Fixture to create a course""" |  | ||||||
|     course = UforaCourse(name="test", code="code", year=1, log_announcements=True) |  | ||||||
|     database_session.add(course) |  | ||||||
|     await database_session.commit() |  | ||||||
|     return course |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @pytest.fixture |  | ||||||
| async def course_with_alias(database_session: AsyncSession, course: UforaCourse) -> UforaCourse: |  | ||||||
|     """Fixture to create a course with an alias""" |  | ||||||
|     alias = UforaCourseAlias(course_id=course.course_id, alias="alias") |  | ||||||
|     database_session.add(alias) |  | ||||||
|     await database_session.commit() |  | ||||||
|     await database_session.refresh(course) |  | ||||||
|     return course |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| @pytest.fixture |  | ||||||
| async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: |  | ||||||
|     """Fixture to create an announcement""" |  | ||||||
|     announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now()) |  | ||||||
|     database_session.add(announcement) |  | ||||||
|     await database_session.commit() |  | ||||||
|     return announcement |  | ||||||
|  | @ -25,19 +25,21 @@ async def test_get_courses_with_announcements(database_session: AsyncSession): | ||||||
|     assert results[0] == course_1 |     assert results[0] == course_1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_create_new_announcement(course: UforaCourse, database_session: AsyncSession): | async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession): | ||||||
|     """Test creating a new announcement""" |     """Test creating a new announcement""" | ||||||
|     await crud.create_new_announcement(database_session, 1, course=course, publication_date=datetime.datetime.now()) |     await crud.create_new_announcement( | ||||||
|     await database_session.refresh(course) |         database_session, 1, course=ufora_course, publication_date=datetime.datetime.now() | ||||||
|     assert len(course.announcements) == 1 |     ) | ||||||
|  |     await database_session.refresh(ufora_course) | ||||||
|  |     assert len(ufora_course.announcements) == 1 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_remove_old_announcements(announcement: UforaAnnouncement, database_session: AsyncSession): | async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession): | ||||||
|     """Test removing all stale announcements""" |     """Test removing all stale announcements""" | ||||||
|     course = announcement.course |     course = ufora_announcement.course | ||||||
|     announcement.publication_date -= datetime.timedelta(weeks=2) |     ufora_announcement.publication_date -= datetime.timedelta(weeks=2) | ||||||
|     announcement_2 = UforaAnnouncement(course_id=announcement.course_id, publication_date=datetime.datetime.now()) |     announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now()) | ||||||
|     database_session.add_all([announcement, announcement_2]) |     database_session.add_all([ufora_announcement, announcement_2]) | ||||||
|     await database_session.commit() |     await database_session.commit() | ||||||
|     await database_session.refresh(course) |     await database_session.refresh(course) | ||||||
|     assert len(course.announcements) == 2 |     assert len(course.announcements) == 2 | ||||||
|  |  | ||||||
|  | @ -4,19 +4,19 @@ from database.crud import ufora_courses as crud | ||||||
| from database.models import UforaCourse | from database.models import UforaCourse | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_get_course_by_name_exact(database_session: AsyncSession, course: UforaCourse): | async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse): | ||||||
|     """Test getting a course by its name when the query is an exact match""" |     """Test getting a course by its name when the query is an exact match""" | ||||||
|     match = await crud.get_course_by_name(database_session, "Test") |     match = await crud.get_course_by_name(database_session, "Test") | ||||||
|     assert match == course |     assert match == ufora_course | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_get_course_by_name_substring(database_session: AsyncSession, course: UforaCourse): | async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse): | ||||||
|     """Test getting a course by its name when the query is a substring""" |     """Test getting a course by its name when the query is a substring""" | ||||||
|     match = await crud.get_course_by_name(database_session, "es") |     match = await crud.get_course_by_name(database_session, "es") | ||||||
|     assert match == course |     assert match == ufora_course | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| async def test_get_course_by_name_alias(database_session: AsyncSession, course_with_alias: UforaCourse): | async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): | ||||||
|     """Test getting a course by its name when the name doesn't match, but the alias does""" |     """Test getting a course by its name when the name doesn't match, but the alias does""" | ||||||
|     match = await crud.get_course_by_name(database_session, "ali") |     match = await crud.get_course_by_name(database_session, "ali") | ||||||
|     assert match == course_with_alias |     assert match == ufora_course_with_alias | ||||||
|  |  | ||||||
|  | @ -0,0 +1,27 @@ | ||||||
|  | from sqlalchemy.ext.asyncio import AsyncSession | ||||||
|  | 
 | ||||||
|  | from database.models import UforaCourse | ||||||
|  | from database.utils.caches import UforaCourseCache | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): | ||||||
|  |     """Test loading the data for the Ufora Course cache when it's empty""" | ||||||
|  |     cache = UforaCourseCache() | ||||||
|  |     await cache.refresh(database_session) | ||||||
|  | 
 | ||||||
|  |     assert len(cache.data) == 2 | ||||||
|  |     assert cache.data == ["alias", "test"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def test_ufora_course_cache_refresh_not_empty( | ||||||
|  |     database_session: AsyncSession, ufora_course_with_alias: UforaCourse | ||||||
|  | ): | ||||||
|  |     """Test loading the data for the Ufora Course cache when it's not empty anymore""" | ||||||
|  |     cache = UforaCourseCache() | ||||||
|  |     cache.data = ["Something"] | ||||||
|  |     cache.data_transformed = ["something"] | ||||||
|  | 
 | ||||||
|  |     await cache.refresh(database_session) | ||||||
|  | 
 | ||||||
|  |     assert len(cache.data) == 2 | ||||||
|  |     assert cache.data == ["alias", "test"] | ||||||
		Loading…
	
		Reference in New Issue