diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py new file mode 100644 index 0000000..152d687 --- /dev/null +++ b/database/crud/ufora_courses.py @@ -0,0 +1,30 @@ +from typing import Optional + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import UforaCourse, UforaCourseAlias + + +async def get_all_courses(session: AsyncSession) -> list[UforaCourse]: + """Get a list of all courses in the database""" + statement = select(UforaCourse) + return (await session.execute(statement)).scalars().all() + + +async def get_course_by_name(session: AsyncSession, query: str) -> Optional[UforaCourse]: + """Try to find a course by its name + + This checks for regular name first, and then aliases + """ + # Search case-insensitively + query = query.lower() + + statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) + result = (await session.execute(statement)).scalars().first() + if result: + return result + + statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) + result = (await session.execute(statement)).scalars().first() + return result.course if result else None diff --git a/tests/test_database/test_crud/conftest.py b/tests/test_database/test_crud/conftest.py new file mode 100644 index 0000000..be5f889 --- /dev/null +++ b/tests/test_database/test_crud/conftest.py @@ -0,0 +1,34 @@ +import datetime + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from database.models import UforaAnnouncement, UforaCourse, UforaCourseAlias + + +@pytest.fixture +async def course(database_session: AsyncSession) -> UforaCourse: + """Fixture to create a course""" + course = UforaCourse(name="test", code="code", year=1, log_announcements=True) + database_session.add(course) + await database_session.commit() + return course + + +@pytest.fixture +async def course_with_alias(database_session: AsyncSession, course: UforaCourse) -> UforaCourse: + """Fixture to create a course with an alias""" + alias = UforaCourseAlias(course_id=course.course_id, alias="alias") + database_session.add(alias) + await database_session.commit() + await database_session.refresh(course) + return course + + +@pytest.fixture +async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: + """Fixture to create an announcement""" + announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now()) + database_session.add(announcement) + await database_session.commit() + return announcement diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index ba6564a..b2303a6 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -7,24 +7,6 @@ from database.crud import ufora_announcements as crud from database.models import UforaAnnouncement, UforaCourse -@pytest.fixture -async def course(database_session: AsyncSession) -> UforaCourse: - """Fixture to create a course""" - course = UforaCourse(name="test", code="code", year=1, log_announcements=True) - database_session.add(course) - await database_session.commit() - return course - - -@pytest.fixture -async def announcement(course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: - """Fixture to create an announcement""" - announcement = UforaAnnouncement(course_id=course.course_id, publication_date=datetime.datetime.now()) - database_session.add(announcement) - await database_session.commit() - return announcement - - async def test_get_courses_with_announcements_none(database_session: AsyncSession): """Test getting all courses with announcements when there are none""" results = await crud.get_courses_with_announcements(database_session) diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py new file mode 100644 index 0000000..efe8fe6 --- /dev/null +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -0,0 +1,22 @@ +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import ufora_courses as crud +from database.models import UforaCourse + + +async def test_get_course_by_name_exact(database_session: AsyncSession, course: UforaCourse): + """Test getting a course by its name when the query is an exact match""" + match = await crud.get_course_by_name(database_session, "Test") + assert match == course + + +async def test_get_course_by_name_substring(database_session: AsyncSession, course: UforaCourse): + """Test getting a course by its name when the query is a substring""" + match = await crud.get_course_by_name(database_session, "es") + assert match == course + + +async def test_get_course_by_name_alias(database_session: AsyncSession, course_with_alias: UforaCourse): + """Test getting a course by its name when the name doesn't match, but the alias does""" + match = await crud.get_course_by_name(database_session, "ali") + assert match == course_with_alias