Add simple caching implementation for database queries that will be used in command autocompletion

pull/119/head
stijndcl 2022-07-14 22:44:22 +02:00
parent f0a05c8b4d
commit 72c3acbcc2
9 changed files with 162 additions and 50 deletions

View File

@ -5,6 +5,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaCourse, UforaCourseAlias
__all__ = ["get_all_courses", "get_course_by_name"]
async def get_all_courses(session: AsyncSession) -> list[UforaCourse]:
"""Get a list of all courses in the database"""

View File

@ -0,0 +1,74 @@
from abc import ABC, abstractmethod
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_courses
__all__ = ["CacheManager"]
class DatabaseCache(ABC):
"""Base class for a simple cache-like structure
The goal of this class is to store data for Discord auto-completion results
that would otherwise potentially put heavy load on the database.
This only stores strings, to avoid having to constantly refresh these objects.
Once a choice has been made, it can just be pulled out of the database.
Considering the fact that a user isn't obligated to choose something from the suggestions,
chances are high we have to go to the database for the final action either way.
Also stores the data in lowercase to allow fast searching
"""
data: list[str] = []
data_transformed: list[str] = []
def clear(self):
"""Remove everything"""
self.data.clear()
@abstractmethod
async def refresh(self, database_session: AsyncSession):
"""Refresh the data stored in this cache"""
async def invalidate(self, database_session: AsyncSession):
"""Invalidate the data stored in this cache"""
await self.refresh(database_session)
def get_autocomplete_suggestions(self, query: str):
"""Filter the cache to find everything that matches the search query"""
query = query.lower()
# Return the original (non-transformed) version of the data for pretty display in Discord
return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
class UforaCourseCache(DatabaseCache):
"""Cache to store the names of Ufora courses"""
async def refresh(self, database_session: AsyncSession):
self.clear()
courses = await ufora_courses.get_all_courses(database_session)
# Load the course names + all the aliases
for course in courses:
aliases = list(map(lambda x: x.alias, course.aliases))
self.data.extend([course.name, *aliases])
self.data.sort()
self.data_transformed = list(map(str.lower, self.data))
class CacheManager:
"""Class that keeps track of all caches"""
ufora_courses: UforaCourseCache
def __init__(self):
self.ufora_courses = UforaCourseCache()
async def initialize_caches(self, database_session: AsyncSession):
"""Initialize the contents of all caches"""
await self.ufora_courses.refresh(database_session)

View File

@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
import settings
from database.crud import custom_commands
from database.engine import DBSession
from database.utils.caches import CacheManager
from didier.utils.discord.prefix import get_prefix
__all__ = ["Didier"]
@ -16,6 +17,7 @@ __all__ = ["Didier"]
class Didier(commands.Bot):
"""DIDIER <3"""
database_caches: CacheManager
initial_extensions: tuple[str, ...] = ()
http_session: ClientSession
@ -50,6 +52,11 @@ class Didier(commands.Bot):
await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs")
# Initialize caches
self.database_caches = CacheManager()
async with self.db_session as session:
await self.database_caches.initialize_caches(session)
# Create aiohttp session
self.http_session = ClientSession()

View File

@ -1,4 +1,5 @@
import asyncio
import datetime
from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock
@ -6,9 +7,11 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine
from database.models import Base
from database.models import Base, UforaAnnouncement, UforaCourse, UforaCourseAlias
from didier import Didier
"""General fixtures"""
@pytest.fixture(scope="session")
def event_loop() -> Generator:
@ -54,3 +57,34 @@ def mock_client() -> Didier:
mock_client.user = mock_user
return mock_client
"""Fixtures to put fake data in the database"""
@pytest.fixture
async def ufora_course(database_session: AsyncSession) -> UforaCourse:
"""Fixture to create a course"""
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
database_session.add(course)
await database_session.commit()
return course
@pytest.fixture
async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse:
"""Fixture to create a course with an alias"""
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
database_session.add(alias)
await database_session.commit()
await database_session.refresh(ufora_course)
return ufora_course
@pytest.fixture
async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement:
"""Fixture to create an announcement"""
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
database_session.add(announcement)
await database_session.commit()
return announcement

View File

@ -1,34 +0,0 @@
import datetime
import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias
@pytest.fixture
async def course(database_session: AsyncSession) -> UforaCourse:
"""Fixture to create a course"""
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
database_session.add(course)
await database_session.commit()
return course
@pytest.fixture
async def course_with_alias(database_session: AsyncSession, course: UforaCourse) -> UforaCourse:
"""Fixture to create a course with an alias"""
alias = UforaCourseAlias(course_id=course.course_id, alias="alias")
database_session.add(alias)
await database_session.commit()
await database_session.refresh(course)
return course
@pytest.fixture
async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement:
"""Fixture to create an announcement"""
announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now())
database_session.add(announcement)
await database_session.commit()
return announcement

View File

@ -25,19 +25,21 @@ async def test_get_courses_with_announcements(database_session: AsyncSession):
assert results[0] == course_1
async def test_create_new_announcement(course: UforaCourse, database_session: AsyncSession):
async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession):
"""Test creating a new announcement"""
await crud.create_new_announcement(database_session, 1, course=course, publication_date=datetime.datetime.now())
await database_session.refresh(course)
assert len(course.announcements) == 1
await crud.create_new_announcement(
database_session, 1, course=ufora_course, publication_date=datetime.datetime.now()
)
await database_session.refresh(ufora_course)
assert len(ufora_course.announcements) == 1
async def test_remove_old_announcements(announcement: UforaAnnouncement, database_session: AsyncSession):
async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession):
"""Test removing all stale announcements"""
course = announcement.course
announcement.publication_date -= datetime.timedelta(weeks=2)
announcement_2 = UforaAnnouncement(course_id=announcement.course_id, publication_date=datetime.datetime.now())
database_session.add_all([announcement, announcement_2])
course = ufora_announcement.course
ufora_announcement.publication_date -= datetime.timedelta(weeks=2)
announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now())
database_session.add_all([ufora_announcement, announcement_2])
await database_session.commit()
await database_session.refresh(course)
assert len(course.announcements) == 2

View File

@ -4,19 +4,19 @@ from database.crud import ufora_courses as crud
from database.models import UforaCourse
async def test_get_course_by_name_exact(database_session: AsyncSession, course: UforaCourse):
async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse):
"""Test getting a course by its name when the query is an exact match"""
match = await crud.get_course_by_name(database_session, "Test")
assert match == course
assert match == ufora_course
async def test_get_course_by_name_substring(database_session: AsyncSession, course: UforaCourse):
async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse):
"""Test getting a course by its name when the query is a substring"""
match = await crud.get_course_by_name(database_session, "es")
assert match == course
assert match == ufora_course
async def test_get_course_by_name_alias(database_session: AsyncSession, course_with_alias: UforaCourse):
async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse):
"""Test getting a course by its name when the name doesn't match, but the alias does"""
match = await crud.get_course_by_name(database_session, "ali")
assert match == course_with_alias
assert match == ufora_course_with_alias

View File

@ -0,0 +1,27 @@
from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaCourse
from database.utils.caches import UforaCourseCache
async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse):
"""Test loading the data for the Ufora Course cache when it's empty"""
cache = UforaCourseCache()
await cache.refresh(database_session)
assert len(cache.data) == 2
assert cache.data == ["alias", "test"]
async def test_ufora_course_cache_refresh_not_empty(
database_session: AsyncSession, ufora_course_with_alias: UforaCourse
):
"""Test loading the data for the Ufora Course cache when it's not empty anymore"""
cache = UforaCourseCache()
cache.data = ["Something"]
cache.data_transformed = ["something"]
await cache.refresh(database_session)
assert len(cache.data) == 2
assert cache.data == ["alias", "test"]