diff --git a/tests/conftest.py b/tests/conftest.py index 2e425ef..219568c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,7 +35,7 @@ async def tables(): @pytest.fixture -async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: +async def postgres(tables) -> AsyncGenerator[AsyncSession, None]: """Fixture to create a session for every test Rollbacks the transaction afterwards so that the future tests start with a clean database diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index 8bc765c..a99c770 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -1,7 +1,6 @@ import datetime import pytest -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User @@ -17,44 +16,44 @@ def test_user_id() -> int: @pytest.fixture -async def user(database_session: AsyncSession, test_user_id) -> User: +async def user(postgres, test_user_id) -> User: """Fixture to create a user""" - _user = await users.get_or_add(database_session, test_user_id) - await database_session.refresh(_user) + _user = await users.get_or_add(postgres, test_user_id) + await postgres.refresh(_user) return _user @pytest.fixture -async def bank(database_session: AsyncSession, user: User) -> Bank: +async def bank(postgres, user: User) -> Bank: """Fixture to fetch the test user's bank""" _bank = user.bank - await database_session.refresh(_bank) + await postgres.refresh(_bank) return _bank @pytest.fixture -async def ufora_course(database_session: AsyncSession) -> UforaCourse: +async def ufora_course(postgres) -> UforaCourse: """Fixture to create a course""" course = UforaCourse(name="test", code="code", year=1, log_announcements=True) - database_session.add(course) - await database_session.commit() + postgres.add(course) + await postgres.commit() return course @pytest.fixture -async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: +async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse: """Fixture to create a course with an alias""" alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") - database_session.add(alias) - await database_session.commit() - await database_session.refresh(ufora_course) + postgres.add(alias) + await postgres.commit() + await postgres.refresh(ufora_course) return ufora_course @pytest.fixture -async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: +async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement: """Fixture to create an announcement""" announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) - database_session.add(announcement) - await database_session.commit() + postgres.add(announcement) + await postgres.commit() return announcement diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 544e5b0..7433573 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -1,74 +1,73 @@ from datetime import datetime, timedelta from freezegun import freeze_time -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud from database.crud import users from database.models import User -async def test_add_birthday_not_present(database_session: AsyncSession, user: User): +async def test_add_birthday_not_present(postgres, user: User): """Test setting a user's birthday when it doesn't exist yet""" assert user.birthday is None bd_date = datetime.today().date() - await crud.add_birthday(database_session, user.user_id, bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, bd_date) + await postgres.refresh(user) assert user.birthday is not None assert user.birthday.birthday == bd_date -async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): +async def test_add_birthday_overwrite(postgres, user: User): """Test that setting a user's birthday when it already exists overwrites it""" bd_date = datetime.today().date() - await crud.add_birthday(database_session, user.user_id, bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, bd_date) + await postgres.refresh(user) assert user.birthday is not None new_bd_date = bd_date + timedelta(weeks=1) - await crud.add_birthday(database_session, user.user_id, new_bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, new_bd_date) + await postgres.refresh(user) assert user.birthday.birthday == new_bd_date -async def test_get_birthday_exists(database_session: AsyncSession, user: User): +async def test_get_birthday_exists(postgres, user: User): """Test getting a user's birthday when it exists""" bd_date = datetime.today().date() - await crud.add_birthday(database_session, user.user_id, bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, bd_date) + await postgres.refresh(user) - bd = await crud.get_birthday_for_user(database_session, user.user_id) + bd = await crud.get_birthday_for_user(postgres, user.user_id) assert bd is not None assert bd.birthday == bd_date -async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): +async def test_get_birthday_not_exists(postgres, user: User): """Test getting a user's birthday when it doesn't exist""" - bd = await crud.get_birthday_for_user(database_session, user.user_id) + bd = await crud.get_birthday_for_user(postgres, user.user_id) assert bd is None @freeze_time("2022/07/23") -async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): +async def test_get_birthdays_on_day(postgres, user: User): """Test getting all birthdays on a given day""" - await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001)) + await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001)) - user_2 = await users.get_or_add(database_session, user.user_id + 1) - await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + user_2 = await users.get_or_add(postgres, user.user_id + 1) + await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1)) + birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 1 assert birthdays[0].user_id == user.user_id @freeze_time("2022/07/23") -async def test_get_birthdays_none_present(database_session: AsyncSession): +async def test_get_birthdays_none_present(postgres): """Test getting all birthdays when there are none""" - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 0 # Add a random birthday that is not today - await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1)) + await crud.add_birthday(postgres, 1, datetime.today() + timedelta(days=1)) - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 0 diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index a2eeec8..b1e5192 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -2,78 +2,77 @@ import datetime import pytest from freezegun import freeze_time -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud from database.exceptions import currency as exceptions from database.models import Bank -async def test_add_dinks(database_session: AsyncSession, bank: Bank): +async def test_add_dinks(postgres, bank: Bank): """Test adding dinks to an account""" assert bank.dinks == 0 - await crud.add_dinks(database_session, bank.user_id, 10) - await database_session.refresh(bank) + await crud.add_dinks(postgres, bank.user_id, 10) + await postgres.refresh(bank) assert bank.dinks == 10 @freeze_time("2022/07/23") -async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): +async def test_claim_nightly_available(postgres, bank: Bank): """Test claiming nightlies when it hasn't been done yet""" - await crud.claim_nightly(database_session, bank.user_id) - await database_session.refresh(bank) + await crud.claim_nightly(postgres, bank.user_id) + await postgres.refresh(bank) assert bank.dinks == crud.NIGHTLY_AMOUNT - nightly_data = await crud.get_nightly_data(database_session, bank.user_id) + nightly_data = await crud.get_nightly_data(postgres, bank.user_id) assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23) @freeze_time("2022/07/23") -async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): +async def test_claim_nightly_unavailable(postgres, bank: Bank): """Test claiming nightlies twice in a day""" - await crud.claim_nightly(database_session, bank.user_id) + await crud.claim_nightly(postgres, bank.user_id) with pytest.raises(exceptions.DoubleNightly): - await crud.claim_nightly(database_session, bank.user_id) + await crud.claim_nightly(postgres, bank.user_id) - await database_session.refresh(bank) + await postgres.refresh(bank) assert bank.dinks == crud.NIGHTLY_AMOUNT -async def test_invest(database_session: AsyncSession, bank: Bank): +async def test_invest(postgres, bank: Bank): """Test investing some Dinks""" bank.dinks = 100 - database_session.add(bank) - await database_session.commit() + postgres.add(bank) + await postgres.commit() - await crud.invest(database_session, bank.user_id, 20) - await database_session.refresh(bank) + await crud.invest(postgres, bank.user_id, 20) + await postgres.refresh(bank) assert bank.dinks == 80 assert bank.invested == 20 -async def test_invest_all(database_session: AsyncSession, bank: Bank): +async def test_invest_all(postgres, bank: Bank): """Test investing all dinks""" bank.dinks = 100 - database_session.add(bank) - await database_session.commit() + postgres.add(bank) + await postgres.commit() - await crud.invest(database_session, bank.user_id, "all") - await database_session.refresh(bank) + await crud.invest(postgres, bank.user_id, "all") + await postgres.refresh(bank) assert bank.dinks == 0 assert bank.invested == 100 -async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank): +async def test_invest_more_than_owned(postgres, bank: Bank): """Test investing more Dinks than you own""" bank.dinks = 100 - database_session.add(bank) - await database_session.commit() + postgres.add(bank) + await postgres.commit() - await crud.invest(database_session, bank.user_id, 200) - await database_session.refresh(bank) + await crud.invest(postgres, bank.user_id, 200) + await postgres.refresh(bank) assert bank.dinks == 0 assert bank.invested == 100 diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index a5c4092..ec25637 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -1,6 +1,5 @@ import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException @@ -8,112 +7,112 @@ from database.exceptions.not_found import NoResultFoundException from database.models import CustomCommand -async def test_create_command_non_existing(database_session: AsyncSession): +async def test_create_command_non_existing(postgres): """Test creating a new command when it doesn't exist yet""" - await crud.create_command(database_session, "name", "response") + await crud.create_command(postgres, "name", "response") - commands = (await database_session.execute(select(CustomCommand))).scalars().all() + commands = (await postgres.execute(select(CustomCommand))).scalars().all() assert len(commands) == 1 assert commands[0].name == "name" -async def test_create_command_duplicate_name(database_session: AsyncSession): +async def test_create_command_duplicate_name(postgres): """Test creating a command when the name already exists""" - await crud.create_command(database_session, "name", "response") + await crud.create_command(postgres, "name", "response") with pytest.raises(DuplicateInsertException): - await crud.create_command(database_session, "name", "other response") + await crud.create_command(postgres, "name", "other response") -async def test_create_command_name_is_alias(database_session: AsyncSession): +async def test_create_command_name_is_alias(postgres): """Test creating a command when the name is taken by an alias""" - await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, "name", "n") + await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, "name", "n") with pytest.raises(DuplicateInsertException): - await crud.create_command(database_session, "n", "other response") + await crud.create_command(postgres, "n", "other response") -async def test_create_alias(database_session: AsyncSession): +async def test_create_alias(postgres): """Test creating an alias when the name is still free""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "n") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "n") - await database_session.refresh(command) + await postgres.refresh(command) assert len(command.aliases) == 1 assert command.aliases[0].alias == "n" -async def test_create_alias_non_existing(database_session: AsyncSession): +async def test_create_alias_non_existing(postgres): """Test creating an alias when the command doesn't exist""" with pytest.raises(NoResultFoundException): - await crud.create_alias(database_session, "name", "alias") + await crud.create_alias(postgres, "name", "alias") -async def test_create_alias_duplicate(database_session: AsyncSession): +async def test_create_alias_duplicate(postgres): """Test creating an alias when another alias already has this name""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "n") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "n") with pytest.raises(DuplicateInsertException): - await crud.create_alias(database_session, command.name, "n") + await crud.create_alias(postgres, command.name, "n") -async def test_create_alias_is_command(database_session: AsyncSession): +async def test_create_alias_is_command(postgres): """Test creating an alias when the name is taken by a command""" - await crud.create_command(database_session, "n", "response") - command = await crud.create_command(database_session, "name", "response") + await crud.create_command(postgres, "n", "response") + command = await crud.create_command(postgres, "name", "response") with pytest.raises(DuplicateInsertException): - await crud.create_alias(database_session, command.name, "n") + await crud.create_alias(postgres, command.name, "n") -async def test_create_alias_match_by_alias(database_session: AsyncSession): +async def test_create_alias_match_by_alias(postgres): """Test creating an alias for a command when matching the name to another alias""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "a1") - alias = await crud.create_alias(database_session, "a1", "a2") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "a1") + alias = await crud.create_alias(postgres, "a1", "a2") assert alias.command == command -async def test_get_command_by_name_exists(database_session: AsyncSession): +async def test_get_command_by_name_exists(postgres): """Test getting a command by name""" - await crud.create_command(database_session, "name", "response") - command = await crud.get_command(database_session, "name") + await crud.create_command(postgres, "name", "response") + command = await crud.get_command(postgres, "name") assert command is not None -async def test_get_command_by_cleaned_name(database_session: AsyncSession): +async def test_get_command_by_cleaned_name(postgres): """Test getting a command by the cleaned version of the name""" - command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response") - found = await crud.get_command(database_session, "capitalizednamewithspaces") + command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response") + found = await crud.get_command(postgres, "capitalizednamewithspaces") assert command == found -async def test_get_command_by_alias(database_session: AsyncSession): +async def test_get_command_by_alias(postgres): """Test getting a command by an alias""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "a1") - await crud.create_alias(database_session, command.name, "a2") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "a1") + await crud.create_alias(postgres, command.name, "a2") - found = await crud.get_command(database_session, "a1") + found = await crud.get_command(postgres, "a1") assert command == found -async def test_get_command_non_existing(database_session: AsyncSession): +async def test_get_command_non_existing(postgres): """Test getting a command when it doesn't exist""" - assert await crud.get_command(database_session, "name") is None + assert await crud.get_command(postgres, "name") is None -async def test_edit_command(database_session: AsyncSession): +async def test_edit_command(postgres): """Test editing an existing command""" - command = await crud.create_command(database_session, "name", "response") - await crud.edit_command(database_session, command.name, "new name", "new response") + command = await crud.create_command(postgres, "name", "response") + await crud.edit_command(postgres, command.name, "new name", "new response") assert command.name == "new name" assert command.response == "new response" -async def test_edit_command_non_existing(database_session: AsyncSession): +async def test_edit_command_non_existing(postgres): """Test editing a command that doesn't exist""" with pytest.raises(NoResultFoundException): - await crud.edit_command(database_session, "name", "n", "r") + await crud.edit_command(postgres, "name", "n", "r") diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py index 0c499c8..8138495 100644 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -1,16 +1,15 @@ from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import dad_jokes as crud from database.models import DadJoke -async def test_add_dad_joke(database_session: AsyncSession): +async def test_add_dad_joke(postgres): """Test creating a new joke""" statement = select(DadJoke) - result = (await database_session.execute(statement)).scalars().all() + result = (await postgres.execute(statement)).scalars().all() assert len(result) == 0 - await crud.add_dad_joke(database_session, "joke") - result = (await database_session.execute(statement)).scalars().all() + await crud.add_dad_joke(postgres, "joke") + result = (await postgres.execute(statement)).scalars().all() assert len(result) == 1 diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py index e1e4f97..4831e03 100644 --- a/tests/test_database/test_crud/test_tasks.py +++ b/tests/test_database/test_crud/test_tasks.py @@ -3,7 +3,6 @@ import datetime import pytest from freezegun import freeze_time from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import tasks as crud from database.enums import TaskType @@ -17,47 +16,47 @@ def task_type() -> TaskType: @pytest.fixture -async def task(database_session: AsyncSession, task_type: TaskType) -> Task: +async def task(postgres, task_type: TaskType) -> Task: """Fixture to create a task""" task = Task(task=task_type) - database_session.add(task) - await database_session.commit() + postgres.add(task) + await postgres.commit() return task -async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType): +async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType): """Test getting a task by its enum type when it exists""" - result = await crud.get_task_by_enum(database_session, task_type) + result = await crud.get_task_by_enum(postgres, task_type) assert result is not None assert result == task -async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType): +async def test_get_task_by_enum_not_present(postgres, task_type: TaskType): """Test getting a task by its enum type when it doesn't exist""" - result = await crud.get_task_by_enum(database_session, task_type) + result = await crud.get_task_by_enum(postgres, task_type) assert result is None @freeze_time("2022/07/24") -async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType): +async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType): """Test setting the execution time of an existing task""" - await database_session.refresh(task) + await postgres.refresh(task) assert task.previous_run is None - await crud.set_last_task_execution_time(database_session, task_type) - await database_session.refresh(task) + await crud.set_last_task_execution_time(postgres, task_type) + await postgres.refresh(task) assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) @freeze_time("2022/07/24") -async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType): +async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType): """Test setting the execution time of a non-existing task""" statement = select(Task).where(Task.task == task_type) - results = list((await database_session.execute(statement)).scalars().all()) + results = list((await postgres.execute(statement)).scalars().all()) assert len(results) == 0 - await crud.set_last_task_execution_time(database_session, task_type) - results = list((await database_session.execute(statement)).scalars().all()) + await crud.set_last_task_execution_time(postgres, task_type) + results = list((await postgres.execute(statement)).scalars().all()) assert len(results) == 1 task = results[0] assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index 4e6fc47..c6054ff 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -1,50 +1,46 @@ import datetime -from sqlalchemy.ext.asyncio import AsyncSession - from database.crud import ufora_announcements as crud from database.models import UforaAnnouncement, UforaCourse -async def test_get_courses_with_announcements_none(database_session: AsyncSession): +async def test_get_courses_with_announcements_none(postgres): """Test getting all courses with announcements when there are none""" - results = await crud.get_courses_with_announcements(database_session) + results = await crud.get_courses_with_announcements(postgres) assert len(results) == 0 -async def test_get_courses_with_announcements(database_session: AsyncSession): +async def test_get_courses_with_announcements(postgres): """Test getting all courses with announcements""" course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True) course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False) - database_session.add_all([course_1, course_2]) - await database_session.commit() + postgres.add_all([course_1, course_2]) + await postgres.commit() - results = await crud.get_courses_with_announcements(database_session) + results = await crud.get_courses_with_announcements(postgres) assert len(results) == 1 assert results[0] == course_1 -async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession): +async def test_create_new_announcement(ufora_course: UforaCourse, postgres): """Test creating a new announcement""" - await crud.create_new_announcement( - database_session, 1, course=ufora_course, publication_date=datetime.datetime.now() - ) - await database_session.refresh(ufora_course) + await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now()) + await postgres.refresh(ufora_course) assert len(ufora_course.announcements) == 1 -async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession): +async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres): """Test removing all stale announcements""" course = ufora_announcement.course ufora_announcement.publication_date -= datetime.timedelta(weeks=2) announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now()) - database_session.add_all([ufora_announcement, announcement_2]) - await database_session.commit() - await database_session.refresh(course) + postgres.add_all([ufora_announcement, announcement_2]) + await postgres.commit() + await postgres.refresh(course) assert len(course.announcements) == 2 - await crud.remove_old_announcements(database_session) + await crud.remove_old_announcements(postgres) - await database_session.refresh(course) + await postgres.refresh(course) assert len(course.announcements) == 1 assert announcement_2.course.announcements[0] == announcement_2 diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index d2d5e1b..5935fd9 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -1,22 +1,20 @@ -from sqlalchemy.ext.asyncio import AsyncSession - from database.crud import ufora_courses as crud from database.models import UforaCourse -async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse): +async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse): """Test getting a course by its name when the query is an exact match""" - match = await crud.get_course_by_name(database_session, "Test") + match = await crud.get_course_by_name(postgres, "Test") assert match == ufora_course -async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse): +async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse): """Test getting a course by its name when the query is a substring""" - match = await crud.get_course_by_name(database_session, "es") + match = await crud.get_course_by_name(postgres, "es") assert match == ufora_course -async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): +async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse): """Test getting a course by its name when the name doesn't match, but the alias does""" - match = await crud.get_course_by_name(database_session, "ali") + match = await crud.get_course_by_name(postgres, "ali") assert match == ufora_course_with_alias diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index 08b4c81..d6584de 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -1,25 +1,24 @@ from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users as crud from database.models import User -async def test_get_or_add_non_existing(database_session: AsyncSession): +async def test_get_or_add_non_existing(postgres): """Test get_or_add for a user that doesn't exist""" - await crud.get_or_add(database_session, 1) + await crud.get_or_add(postgres, 1) statement = select(User) - res = (await database_session.execute(statement)).scalars().all() + res = (await postgres.execute(statement)).scalars().all() assert len(res) == 1 assert res[0].bank is not None assert res[0].nightly_data is not None -async def test_get_or_add_existing(database_session: AsyncSession): +async def test_get_or_add_existing(postgres): """Test get_or_add for a user that does exist""" - user = await crud.get_or_add(database_session, 1) + user = await crud.get_or_add(postgres, 1) bank = user.bank - assert await crud.get_or_add(database_session, 1) == user - assert (await crud.get_or_add(database_session, 1)).bank == bank + assert await crud.get_or_add(postgres, 1) == user + assert (await crud.get_or_add(postgres, 1)).bank == bank diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index 2e10664..0a19e98 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -1,28 +1,24 @@ -from sqlalchemy.ext.asyncio import AsyncSession - from database.models import UforaCourse from database.utils.caches import UforaCourseCache -async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): +async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's empty""" cache = UforaCourseCache() - await cache.refresh(database_session) + await cache.refresh(postgres) assert len(cache.data) == 1 assert cache.data == ["test"] assert cache.aliases == {"alias": "test"} -async def test_ufora_course_cache_refresh_not_empty( - database_session: AsyncSession, ufora_course_with_alias: UforaCourse -): +async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's not empty anymore""" cache = UforaCourseCache() cache.data = ["Something"] cache.data_transformed = ["something"] - await cache.refresh(database_session) + await cache.refresh(postgres) assert len(cache.data) == 1 assert cache.data == ["test"]