diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 23384df..7ef8dab 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -38,6 +38,19 @@ jobs: POSTGRES_DB: didier_pytest POSTGRES_USER: pytest POSTGRES_PASSWORD: pytest + mongo: + image: mongo:5.0 + options: >- + --health-cmd mongo + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 27018:27017 + env: + MONGO_DB: didier_pytest + MONGO_USER: pytest + MONGO_PASSWORD: pytest steps: - uses: actions/checkout@v3 - name: Setup Python diff --git a/alembic/env.py b/alembic/env.py index 3cca2cf..beaa206 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,8 +4,8 @@ from logging.config import fileConfig from sqlalchemy.ext.asyncio import AsyncEngine from alembic import context -from database.engine import engine -from database.models import Base +from database.engine import postgres_engine +from database.schemas.relational import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. @@ -40,7 +40,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = context.config.attributes.get("connection", None) or engine + connectable = context.config.attributes.get("connection", None) or postgres_engine if isinstance(connectable, AsyncEngine): asyncio.run(run_async_migrations(connectable)) diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index df59dfc..f078488 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from database.crud import users -from database.models import Birthday, User +from database.schemas.relational import Birthday, User __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] diff --git a/database/crud/currency.py b/database/crud/currency.py index 1bb2d11..382801d 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.exceptions import currency as exceptions -from database.models import Bank, NightlyData +from database.schemas.relational import Bank, NightlyData from database.utils.math.currency import ( capacity_upgrade_price, interest_upgrade_price, diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index 85ecf56..d0e86a1 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.models import CustomCommand, CustomCommandAlias +from database.schemas.relational import CustomCommand, CustomCommandAlias __all__ = [ "clean_name", diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py index 871c34d..c481ec3 100644 --- a/database/crud/dad_jokes.py +++ b/database/crud/dad_jokes.py @@ -2,7 +2,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.not_found import NoResultFoundException -from database.models import DadJoke +from database.schemas.relational import DadJoke __all__ = ["add_dad_joke", "get_random_dad_joke"] diff --git a/database/crud/tasks.py b/database/crud/tasks.py index dd1a607..a3b6f38 100644 --- a/database/crud/tasks.py +++ b/database/crud/tasks.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.enums import TaskType -from database.models import Task +from database.schemas.relational import Task from database.utils.datetime import LOCAL_TIMEZONE __all__ = ["get_task_by_enum", "set_last_task_execution_time"] diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 48a06ae..e2dbd16 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import UforaAnnouncement, UforaCourse +from database.schemas.relational import UforaAnnouncement, UforaCourse __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"] diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index d41846c..f6dd853 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import UforaCourse, UforaCourseAlias +from database.schemas.relational import UforaCourse, UforaCourseAlias __all__ = ["get_all_courses", "get_course_by_name"] diff --git a/database/crud/users.py b/database/crud/users.py index ba3011d..3024f26 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import Bank, NightlyData, User +from database.schemas.relational import Bank, NightlyData, User __all__ = [ "get_or_add", diff --git a/database/engine.py b/database/engine.py index 9b603cd..e98c8df 100644 --- a/database/engine.py +++ b/database/engine.py @@ -1,24 +1,34 @@ from urllib.parse import quote_plus +import motor.motor_asyncio from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker import settings -encoded_password = quote_plus(settings.DB_PASSWORD) +encoded_postgres_password = quote_plus(settings.POSTGRES_PASS) -engine = create_async_engine( +# PostgreSQL engine +postgres_engine = create_async_engine( URL.create( drivername="postgresql+asyncpg", - username=settings.DB_USERNAME, - password=encoded_password, - host=settings.DB_HOST, - port=settings.DB_PORT, - database=settings.DB_NAME, + username=settings.POSTGRES_USER, + password=encoded_postgres_password, + host=settings.POSTGRES_HOST, + port=settings.POSTGRES_PORT, + database=settings.POSTGRES_DB, ), pool_pre_ping=True, future=True, ) -DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession, expire_on_commit=False) +DBSession = sessionmaker( + autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False +) + +# MongoDB client +encoded_mongo_username = quote_plus(settings.MONGO_USER) +encoded_mongo_password = quote_plus(settings.MONGO_PASS) +mongo_url = f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/" +mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url) diff --git a/database/migrations.py b/database/migrations.py index e842410..20c43c8 100644 --- a/database/migrations.py +++ b/database/migrations.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from alembic import command, script from alembic.config import Config from alembic.runtime import migration -from database.engine import engine +from database.engine import postgres_engine __config_path__ = "alembic.ini" __migrations_path__ = "alembic/" @@ -22,7 +22,7 @@ async def ensure_latest_migration(): """Make sure we are currently on the latest revision, otherwise raise an exception""" alembic_script = script.ScriptDirectory.from_config(cfg) - async with engine.begin() as connection: + async with postgres_engine.begin() as connection: current_revision = await connection.run_sync( lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision() ) @@ -49,5 +49,5 @@ def __execute_downgrade(connection: Session): async def migrate(up: bool): """Migrate the database upwards or downwards""" - async with engine.begin() as connection: + async with postgres_engine.begin() as connection: await connection.run_sync(__execute_upgrade if up else __execute_downgrade) diff --git a/database/schemas/__init__.py b/database/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/schemas/mongo.py b/database/schemas/mongo.py new file mode 100644 index 0000000..95d2a2a --- /dev/null +++ b/database/schemas/mongo.py @@ -0,0 +1,38 @@ +from bson import ObjectId +from pydantic import BaseModel, Field + +__all__ = ["MongoBase"] + + +class PyObjectId(str): + """Custom type for bson ObjectIds""" + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: str): + """Check that a string is a valid bson ObjectId""" + if not ObjectId.is_valid(value): + raise ValueError(f"Invalid ObjectId: '{value}'") + + return ObjectId(value) + + @classmethod + def __modify_schema__(cls, field_schema: dict): + field_schema.update(type="string") + + +class MongoBase(BaseModel): + """Base model that properly sets the _id field, and adds one by default""" + + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + + class Config: + """Configuration for encoding and construction""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + json_encoders = {ObjectId: str, PyObjectId: str} + use_enum_values = True diff --git a/database/models.py b/database/schemas/relational.py similarity index 100% rename from database/models.py rename to database/schemas/relational.py diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 4c76e58..1b3dea5 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -31,7 +31,7 @@ class Currency(commands.Cog): """Award a user a given amount of Didier Dinks""" amount = typing.cast(int, amount) - async with self.client.db_session as session: + async with self.client.postgres_session as session: await crud.add_dinks(session, user.id, amount) plural = pluralize("Didier Dink", amount) await ctx.reply( @@ -42,7 +42,7 @@ class Currency(commands.Cog): @commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True) async def bank(self, ctx: commands.Context): """Show your Didier Bank information""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(colour=discord.Colour.blue()) @@ -58,7 +58,7 @@ class Currency(commands.Cog): @bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True) async def bank_upgrades(self, ctx: commands.Context): """List the upgrades you can buy & their prices""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(colour=discord.Colour.blue()) @@ -79,7 +79,7 @@ class Currency(commands.Cog): @bank_upgrades.command(name="Capacity", aliases=["C"]) async def bank_upgrade_capacity(self, ctx: commands.Context): """Upgrade the capacity level of your bank""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.upgrade_capacity(session, ctx.author.id) await ctx.message.add_reaction("⏫") @@ -90,7 +90,7 @@ class Currency(commands.Cog): @bank_upgrades.command(name="Interest", aliases=["I"]) async def bank_upgrade_interest(self, ctx: commands.Context): """Upgrade the interest level of your bank""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.upgrade_interest(session, ctx.author.id) await ctx.message.add_reaction("⏫") @@ -101,7 +101,7 @@ class Currency(commands.Cog): @bank_upgrades.command(name="Rob", aliases=["R"]) async def bank_upgrade_rob(self, ctx: commands.Context): """Upgrade the rob level of your bank""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.upgrade_rob(session, ctx.author.id) await ctx.message.add_reaction("⏫") @@ -112,7 +112,7 @@ class Currency(commands.Cog): @commands.hybrid_command(name="dinks") async def dinks(self, ctx: commands.Context): """Check your Didier Dinks""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) @@ -122,7 +122,7 @@ class Currency(commands.Cog): """Invest a given amount of Didier Dinks""" amount = typing.cast(typing.Union[str, int], amount) - async with self.client.db_session as session: + async with self.client.postgres_session as session: invested = await crud.invest(session, ctx.author.id, amount) plural = pluralize("Didier Dink", invested) @@ -136,7 +136,7 @@ class Currency(commands.Cog): @commands.hybrid_command(name="nightly") async def nightly(self, ctx: commands.Context): """Claim nightly Didier Dinks""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.claim_nightly(session, ctx.author.id) await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index db9ae7d..9dbcc9d 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -19,7 +19,7 @@ class Discord(commands.Cog): async def birthday(self, ctx: commands.Context, user: discord.User = None): """Command to check the birthday of a user""" user_id = (user and user.id) or ctx.author.id - async with self.client.db_session as session: + async with self.client.postgres_session as session: birthday = await birthdays.get_birthday_for_user(session, user_id) name = "Jouw" if user is None else f"{user.display_name}'s" @@ -45,7 +45,7 @@ class Discord(commands.Cog): except ValueError: return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) - async with self.client.db_session as session: + async with self.client.postgres_session as session: await birthdays.add_birthday(session, ctx.author.id, date) await self.client.confirm_message(ctx.message) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index ddc119b..0aade01 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -19,7 +19,7 @@ class Fun(commands.Cog): ) async def dad_joke(self, ctx: commands.Context): """Get a random dad joke""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: joke = await get_random_dad_joke(session) return await ctx.reply(joke.joke, mention_author=False) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 30090df..7d623b0 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -83,7 +83,7 @@ class Owner(commands.Cog): @add_msg.command(name="Custom") async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): """Add a new custom command""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await custom_commands.create_command(session, name, response) await self.client.confirm_message(ctx.message) @@ -94,7 +94,7 @@ class Owner(commands.Cog): @add_msg.command(name="Alias") async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await custom_commands.create_alias(session, command, alias) await self.client.confirm_message(ctx.message) @@ -130,7 +130,7 @@ class Owner(commands.Cog): @edit_msg.command(name="Custom") async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): """Edit an existing custom command""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await custom_commands.edit_command(session, command, flags.name, flags.response) return await self.client.confirm_message(ctx.message) @@ -147,7 +147,7 @@ class Owner(commands.Cog): "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True ) - async with self.client.db_session as session: + async with self.client.postgres_session as session: _command = await custom_commands.get_command(session, command) if _command is None: return await interaction.response.send_message( diff --git a/didier/cogs/school.py b/didier/cogs/school.py index ee13a0c..32eb8b9 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -68,7 +68,7 @@ class School(commands.Cog): @app_commands.describe(course="vak") async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): """Create links to study guides""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: ufora_course = await ufora_courses.get_course_by_name(session, course) if ufora_course is None: diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 3764331..6a1e5c6 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -72,7 +72,7 @@ class Tasks(commands.Cog): async def check_birthdays(self): """Check if it's currently anyone's birthday""" now = tz_aware_now().date() - async with self.client.db_session as session: + async with self.client.postgres_session as session: birthdays = await get_birthdays_on_day(session, now) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) @@ -96,7 +96,7 @@ class Tasks(commands.Cog): if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: return - async with self.client.db_session as db_session: + async with self.client.postgres_session as db_session: announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) announcements = await fetch_ufora_announcements(self.client.http_session, db_session) @@ -110,7 +110,7 @@ class Tasks(commands.Cog): @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: await remove_old_announcements(session) @check_birthdays.error diff --git a/didier/data/embeds/ufora/announcements.py b/didier/data/embeds/ufora/announcements.py index f4a8bdd..d906ea0 100644 --- a/didier/data/embeds/ufora/announcements.py +++ b/didier/data/embeds/ufora/announcements.py @@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import ufora_announcements as crud -from database.models import UforaCourse +from database.schemas.relational import UforaCourse from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import int_to_weekday from didier.utils.types.string import leading diff --git a/didier/decorators/tasks.py b/didier/decorators/tasks.py index 36b4f89..67e3e19 100644 --- a/didier/decorators/tasks.py +++ b/didier/decorators/tasks.py @@ -20,7 +20,7 @@ def timed_task(task: enums.TaskType): async def _wrapper(tasks_cog: Tasks, *args, **kwargs): await func(tasks_cog, *args, **kwargs) - async with tasks_cog.client.db_session as session: + async with tasks_cog.client.postgres_session as session: await set_last_task_execution_time(session, task) return _wrapper diff --git a/didier/didier.py b/didier/didier.py index d5745c2..ce81307 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -2,13 +2,14 @@ import logging import os import discord +import motor.motor_asyncio from aiohttp import ClientSession from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import custom_commands -from database.engine import DBSession +from database.engine import DBSession, mongo_client from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.utils.discord.prefix import get_prefix @@ -45,10 +46,15 @@ class Didier(commands.Bot): ) @property - def db_session(self) -> AsyncSession: - """Obtain a database session""" + def postgres_session(self) -> AsyncSession: + """Obtain a session for the PostgreSQL database""" return DBSession() + @property + def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: + """Obtain a reference to the MongoDB database""" + return mongo_client[settings.MONGO_DB] + async def setup_hook(self) -> None: """Do some initial setup @@ -60,7 +66,7 @@ class Didier(commands.Bot): # Initialize caches self.database_caches = CacheManager() - async with self.db_session as session: + async with self.postgres_session as session: await self.database_caches.initialize_caches(session) # Create aiohttp session @@ -153,7 +159,7 @@ class Didier(commands.Bot): if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): return False - async with self.db_session as session: + async with self.postgres_session as session: # Remove the prefix content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :] command = await custom_commands.get_command(session, content) diff --git a/didier/views/modals/custom_commands.py b/didier/views/modals/custom_commands.py index 35ac158..2116bd9 100644 --- a/didier/views/modals/custom_commands.py +++ b/didier/views/modals/custom_commands.py @@ -27,7 +27,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"): @overrides async def on_submit(self, interaction: discord.Interaction): - async with self.client.db_session as session: + async with self.client.postgres_session as session: command = await create_command(session, str(self.name.value), str(self.response.value)) await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) @@ -68,7 +68,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): name_field = typing.cast(discord.ui.TextInput, self.children[0]) response_field = typing.cast(discord.ui.TextInput, self.children[1]) - async with self.client.db_session as session: + async with self.client.postgres_session as session: await edit_command(session, self.original_name, name_field.value, response_field.value) await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) diff --git a/didier/views/modals/dad_jokes.py b/didier/views/modals/dad_jokes.py index 9632197..f52b051 100644 --- a/didier/views/modals/dad_jokes.py +++ b/didier/views/modals/dad_jokes.py @@ -26,7 +26,7 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"): @overrides async def on_submit(self, interaction: discord.Interaction): - async with self.client.db_session as session: + async with self.client.postgres_session as session: joke = await add_dad_joke(session, str(self.name.value)) await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) diff --git a/docker-compose.test.yml b/docker-compose.test.yml new file mode 100644 index 0000000..a446613 --- /dev/null +++ b/docker-compose.test.yml @@ -0,0 +1,21 @@ +version: '3.9' +services: + postgres-pytest: + image: postgres:14 + container_name: didier-pytest + restart: always + environment: + - POSTGRES_DB=didier_pytest + - POSTGRES_USER=pytest + - POSTGRES_PASSWORD=pytest + ports: + - "5433:5432" + mongo-pytest: + image: mongo:5.0 + restart: always + environment: + - MONGO_INITDB_ROOT_USERNAME=pytest + - MONGO_INITDB_ROOT_PASSWORD=pytest + - MONGO_INITDB_DATABASE=didier_pytest + ports: + - "27018:27017" diff --git a/docker-compose.yml b/docker-compose.yml index 25cd531..4af0a0e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,26 +1,29 @@ version: '3.9' services: - db: + postgres: image: postgres:14 container_name: didier restart: always environment: - - POSTGRES_DB=${DB_NAME:-didier_dev} - - POSTGRES_USER=${DB_USERNAME:-postgres} - - POSTGRES_PASSWORD=${DB_PASSWORD:-postgres} + - POSTGRES_DB=${POSTGRES_DB:-didier_dev} + - POSTGRES_USER=${POSTGRES_USER:-postgres} + - POSTGRES_PASSWORD=${POSTGRES_PASS:-postgres} ports: - - "${DB_PORT:-5432}:${DB_PORT:-5432}" + - "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}" volumes: - - db:/var/lib/postgresql/data - db-pytest: - image: postgres:14 - container_name: didier-pytest + - postgres:/var/lib/postgresql/data + mongo: + image: mongo:5.0 restart: always environment: - - POSTGRES_DB=didier_pytest - - POSTGRES_USER=pytest - - POSTGRES_PASSWORD=pytest + - MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root} + - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root} + - MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev} + command: [--auth] ports: - - "5433:5432" + - "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}" + volumes: + - mongo:/data/db volumes: - db: + postgres: + mongo: diff --git a/pyproject.toml b/pyproject.toml index f59d25e..60927e6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,16 +36,21 @@ plugins = [ "sqlalchemy.ext.mypy.plugin" ] [[tool.mypy.overrides]] -module = ["discord.*", "feedparser.*", "markdownify.*"] +module = ["discord.*", "feedparser.*", "markdownify.*", "motor.*"] ignore_missing_imports = true [tool.pytest.ini_options] asyncio_mode = "auto" env = [ - "DB_NAME = didier_pytest", - "DB_USERNAME = pytest", - "DB_PASSWORD = pytest", - "DB_HOST = localhost", - "DB_PORT = 5433", + "MONGO_DB = didier_pytest", + "MONGO_USER = pytest", + "MONGO_PASS = pytest", + "MONGO_HOST = localhost", + "MONGO_PORT = 27018", + "POSTGRES_DB = didier_pytest", + "POSTGRES_USER = pytest", + "POSTGRES_PASS = pytest", + "POSTGRES_HOST = localhost", + "POSTGRES_PORT = 5433", "DISCORD_TOKEN = token" ] diff --git a/requirements.txt b/requirements.txt index 285b936..950679b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ git+https://github.com/Rapptz/discord.py environs==9.5.0 feedparser==6.0.10 markdownify==0.11.2 +motor==3.0.0 overrides==6.1.0 pydantic==1.9.1 python-dateutil==2.8.2 diff --git a/settings.py b/settings.py index cf20b95..71f41ce 100644 --- a/settings.py +++ b/settings.py @@ -9,11 +9,11 @@ env.read_env() __all__ = [ "SANDBOX", "LOGFILE", - "DB_NAME", - "DB_USERNAME", - "DB_PASSWORD", - "DB_HOST", - "DB_PORT", + "POSTGRES_DB", + "POSTGRES_USER", + "POSTGRES_PASS", + "POSTGRES_HOST", + "POSTGRES_PORT", "DISCORD_TOKEN", "DISCORD_READY_MESSAGE", "DISCORD_STATUS_MESSAGE", @@ -33,11 +33,19 @@ SEMESTER: int = env.int("SEMESTER", 2) YEAR: int = env.int("YEAR", 3) """Database""" -DB_NAME: str = env.str("DB_NAME", "didier") -DB_USERNAME: str = env.str("DB_USERNAME", "postgres") -DB_PASSWORD: str = env.str("DB_PASSWORD", "") -DB_HOST: str = env.str("DB_HOST", "localhost") -DB_PORT: int = env.int("DB_PORT", "5432") +# MongoDB +MONGO_DB: str = env.str("MONGO_DB", "didier") +MONGO_USER: str = env.str("MONGO_USER", "root") +MONGO_PASS: str = env.str("MONGO_PASS", "root") +MONGO_HOST: str = env.str("MONGO_HOST", "localhost") +MONGO_PORT: int = env.int("MONGO_PORT", "27017") + +# PostgreSQL +POSTGRES_DB: str = env.str("POSTGRES_DB", "didier") +POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres") +POSTGRES_PASS: str = env.str("POSTGRES_PASS", "") +POSTGRES_HOST: str = env.str("POSTGRES_HOST", "localhost") +POSTGRES_PORT: int = env.int("POSTGRES_PORT", "5432") """Discord""" DISCORD_TOKEN: str = env.str("DISCORD_TOKEN") diff --git a/tests/conftest.py b/tests/conftest.py index 2e425ef..55919fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,12 @@ import asyncio from typing import AsyncGenerator, Generator from unittest.mock import MagicMock +import motor.motor_asyncio import pytest from sqlalchemy.ext.asyncio import AsyncSession -from database.engine import engine +import settings +from database.engine import mongo_client, postgres_engine from database.migrations import ensure_latest_migration, migrate from didier import Didier @@ -35,12 +37,12 @@ async def tables(): @pytest.fixture -async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: +async def postgres(tables) -> AsyncGenerator[AsyncSession, None]: """Fixture to create a session for every test Rollbacks the transaction afterwards so that the future tests start with a clean database """ - connection = await engine.connect() + connection = await postgres_engine.connect() transaction = await connection.begin() session = AsyncSession(bind=connection, expire_on_commit=False) @@ -54,6 +56,14 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: await connection.close() +@pytest.fixture +async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase: + """Fixture to get a MongoDB connection""" + database = mongo_client[settings.MONGO_DB] + yield database + mongo_client.drop_database(settings.MONGO_DB) + + @pytest.fixture def mock_client() -> Didier: """Fixture to get a mock Didier instance diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index 8bc765c..b2556c4 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -1,10 +1,15 @@ import datetime import pytest -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users -from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User +from database.schemas.relational import ( + Bank, + UforaAnnouncement, + UforaCourse, + UforaCourseAlias, + User, +) @pytest.fixture(scope="session") @@ -17,44 +22,44 @@ def test_user_id() -> int: @pytest.fixture -async def user(database_session: AsyncSession, test_user_id) -> User: +async def user(postgres, test_user_id) -> User: """Fixture to create a user""" - _user = await users.get_or_add(database_session, test_user_id) - await database_session.refresh(_user) + _user = await users.get_or_add(postgres, test_user_id) + await postgres.refresh(_user) return _user @pytest.fixture -async def bank(database_session: AsyncSession, user: User) -> Bank: +async def bank(postgres, user: User) -> Bank: """Fixture to fetch the test user's bank""" _bank = user.bank - await database_session.refresh(_bank) + await postgres.refresh(_bank) return _bank @pytest.fixture -async def ufora_course(database_session: AsyncSession) -> UforaCourse: +async def ufora_course(postgres) -> UforaCourse: """Fixture to create a course""" course = UforaCourse(name="test", code="code", year=1, log_announcements=True) - database_session.add(course) - await database_session.commit() + postgres.add(course) + await postgres.commit() return course @pytest.fixture -async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: +async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse: """Fixture to create a course with an alias""" alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") - database_session.add(alias) - await database_session.commit() - await database_session.refresh(ufora_course) + postgres.add(alias) + await postgres.commit() + await postgres.refresh(ufora_course) return ufora_course @pytest.fixture -async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: +async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement: """Fixture to create an announcement""" announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) - database_session.add(announcement) - await database_session.commit() + postgres.add(announcement) + await postgres.commit() return announcement diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 544e5b0..21639b1 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -1,74 +1,73 @@ from datetime import datetime, timedelta from freezegun import freeze_time -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud from database.crud import users -from database.models import User +from database.schemas.relational import User -async def test_add_birthday_not_present(database_session: AsyncSession, user: User): +async def test_add_birthday_not_present(postgres, user: User): """Test setting a user's birthday when it doesn't exist yet""" assert user.birthday is None bd_date = datetime.today().date() - await crud.add_birthday(database_session, user.user_id, bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, bd_date) + await postgres.refresh(user) assert user.birthday is not None assert user.birthday.birthday == bd_date -async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): +async def test_add_birthday_overwrite(postgres, user: User): """Test that setting a user's birthday when it already exists overwrites it""" bd_date = datetime.today().date() - await crud.add_birthday(database_session, user.user_id, bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, bd_date) + await postgres.refresh(user) assert user.birthday is not None new_bd_date = bd_date + timedelta(weeks=1) - await crud.add_birthday(database_session, user.user_id, new_bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, new_bd_date) + await postgres.refresh(user) assert user.birthday.birthday == new_bd_date -async def test_get_birthday_exists(database_session: AsyncSession, user: User): +async def test_get_birthday_exists(postgres, user: User): """Test getting a user's birthday when it exists""" bd_date = datetime.today().date() - await crud.add_birthday(database_session, user.user_id, bd_date) - await database_session.refresh(user) + await crud.add_birthday(postgres, user.user_id, bd_date) + await postgres.refresh(user) - bd = await crud.get_birthday_for_user(database_session, user.user_id) + bd = await crud.get_birthday_for_user(postgres, user.user_id) assert bd is not None assert bd.birthday == bd_date -async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): +async def test_get_birthday_not_exists(postgres, user: User): """Test getting a user's birthday when it doesn't exist""" - bd = await crud.get_birthday_for_user(database_session, user.user_id) + bd = await crud.get_birthday_for_user(postgres, user.user_id) assert bd is None @freeze_time("2022/07/23") -async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): +async def test_get_birthdays_on_day(postgres, user: User): """Test getting all birthdays on a given day""" - await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001)) + await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001)) - user_2 = await users.get_or_add(database_session, user.user_id + 1) - await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + user_2 = await users.get_or_add(postgres, user.user_id + 1) + await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1)) + birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 1 assert birthdays[0].user_id == user.user_id @freeze_time("2022/07/23") -async def test_get_birthdays_none_present(database_session: AsyncSession): +async def test_get_birthdays_none_present(postgres): """Test getting all birthdays when there are none""" - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 0 # Add a random birthday that is not today - await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1)) + await crud.add_birthday(postgres, 1, datetime.today() + timedelta(days=1)) - birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) + birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 0 diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index a2eeec8..e5cdc0c 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -2,78 +2,77 @@ import datetime import pytest from freezegun import freeze_time -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud from database.exceptions import currency as exceptions -from database.models import Bank +from database.schemas.relational import Bank -async def test_add_dinks(database_session: AsyncSession, bank: Bank): +async def test_add_dinks(postgres, bank: Bank): """Test adding dinks to an account""" assert bank.dinks == 0 - await crud.add_dinks(database_session, bank.user_id, 10) - await database_session.refresh(bank) + await crud.add_dinks(postgres, bank.user_id, 10) + await postgres.refresh(bank) assert bank.dinks == 10 @freeze_time("2022/07/23") -async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): +async def test_claim_nightly_available(postgres, bank: Bank): """Test claiming nightlies when it hasn't been done yet""" - await crud.claim_nightly(database_session, bank.user_id) - await database_session.refresh(bank) + await crud.claim_nightly(postgres, bank.user_id) + await postgres.refresh(bank) assert bank.dinks == crud.NIGHTLY_AMOUNT - nightly_data = await crud.get_nightly_data(database_session, bank.user_id) + nightly_data = await crud.get_nightly_data(postgres, bank.user_id) assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23) @freeze_time("2022/07/23") -async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): +async def test_claim_nightly_unavailable(postgres, bank: Bank): """Test claiming nightlies twice in a day""" - await crud.claim_nightly(database_session, bank.user_id) + await crud.claim_nightly(postgres, bank.user_id) with pytest.raises(exceptions.DoubleNightly): - await crud.claim_nightly(database_session, bank.user_id) + await crud.claim_nightly(postgres, bank.user_id) - await database_session.refresh(bank) + await postgres.refresh(bank) assert bank.dinks == crud.NIGHTLY_AMOUNT -async def test_invest(database_session: AsyncSession, bank: Bank): +async def test_invest(postgres, bank: Bank): """Test investing some Dinks""" bank.dinks = 100 - database_session.add(bank) - await database_session.commit() + postgres.add(bank) + await postgres.commit() - await crud.invest(database_session, bank.user_id, 20) - await database_session.refresh(bank) + await crud.invest(postgres, bank.user_id, 20) + await postgres.refresh(bank) assert bank.dinks == 80 assert bank.invested == 20 -async def test_invest_all(database_session: AsyncSession, bank: Bank): +async def test_invest_all(postgres, bank: Bank): """Test investing all dinks""" bank.dinks = 100 - database_session.add(bank) - await database_session.commit() + postgres.add(bank) + await postgres.commit() - await crud.invest(database_session, bank.user_id, "all") - await database_session.refresh(bank) + await crud.invest(postgres, bank.user_id, "all") + await postgres.refresh(bank) assert bank.dinks == 0 assert bank.invested == 100 -async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank): +async def test_invest_more_than_owned(postgres, bank: Bank): """Test investing more Dinks than you own""" bank.dinks = 100 - database_session.add(bank) - await database_session.commit() + postgres.add(bank) + await postgres.commit() - await crud.invest(database_session, bank.user_id, 200) - await database_session.refresh(bank) + await crud.invest(postgres, bank.user_id, 200) + await postgres.refresh(bank) assert bank.dinks == 0 assert bank.invested == 100 diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index a5c4092..88810d4 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -1,119 +1,118 @@ import pytest from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.models import CustomCommand +from database.schemas.relational import CustomCommand -async def test_create_command_non_existing(database_session: AsyncSession): +async def test_create_command_non_existing(postgres): """Test creating a new command when it doesn't exist yet""" - await crud.create_command(database_session, "name", "response") + await crud.create_command(postgres, "name", "response") - commands = (await database_session.execute(select(CustomCommand))).scalars().all() + commands = (await postgres.execute(select(CustomCommand))).scalars().all() assert len(commands) == 1 assert commands[0].name == "name" -async def test_create_command_duplicate_name(database_session: AsyncSession): +async def test_create_command_duplicate_name(postgres): """Test creating a command when the name already exists""" - await crud.create_command(database_session, "name", "response") + await crud.create_command(postgres, "name", "response") with pytest.raises(DuplicateInsertException): - await crud.create_command(database_session, "name", "other response") + await crud.create_command(postgres, "name", "other response") -async def test_create_command_name_is_alias(database_session: AsyncSession): +async def test_create_command_name_is_alias(postgres): """Test creating a command when the name is taken by an alias""" - await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, "name", "n") + await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, "name", "n") with pytest.raises(DuplicateInsertException): - await crud.create_command(database_session, "n", "other response") + await crud.create_command(postgres, "n", "other response") -async def test_create_alias(database_session: AsyncSession): +async def test_create_alias(postgres): """Test creating an alias when the name is still free""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "n") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "n") - await database_session.refresh(command) + await postgres.refresh(command) assert len(command.aliases) == 1 assert command.aliases[0].alias == "n" -async def test_create_alias_non_existing(database_session: AsyncSession): +async def test_create_alias_non_existing(postgres): """Test creating an alias when the command doesn't exist""" with pytest.raises(NoResultFoundException): - await crud.create_alias(database_session, "name", "alias") + await crud.create_alias(postgres, "name", "alias") -async def test_create_alias_duplicate(database_session: AsyncSession): +async def test_create_alias_duplicate(postgres): """Test creating an alias when another alias already has this name""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "n") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "n") with pytest.raises(DuplicateInsertException): - await crud.create_alias(database_session, command.name, "n") + await crud.create_alias(postgres, command.name, "n") -async def test_create_alias_is_command(database_session: AsyncSession): +async def test_create_alias_is_command(postgres): """Test creating an alias when the name is taken by a command""" - await crud.create_command(database_session, "n", "response") - command = await crud.create_command(database_session, "name", "response") + await crud.create_command(postgres, "n", "response") + command = await crud.create_command(postgres, "name", "response") with pytest.raises(DuplicateInsertException): - await crud.create_alias(database_session, command.name, "n") + await crud.create_alias(postgres, command.name, "n") -async def test_create_alias_match_by_alias(database_session: AsyncSession): +async def test_create_alias_match_by_alias(postgres): """Test creating an alias for a command when matching the name to another alias""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "a1") - alias = await crud.create_alias(database_session, "a1", "a2") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "a1") + alias = await crud.create_alias(postgres, "a1", "a2") assert alias.command == command -async def test_get_command_by_name_exists(database_session: AsyncSession): +async def test_get_command_by_name_exists(postgres): """Test getting a command by name""" - await crud.create_command(database_session, "name", "response") - command = await crud.get_command(database_session, "name") + await crud.create_command(postgres, "name", "response") + command = await crud.get_command(postgres, "name") assert command is not None -async def test_get_command_by_cleaned_name(database_session: AsyncSession): +async def test_get_command_by_cleaned_name(postgres): """Test getting a command by the cleaned version of the name""" - command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response") - found = await crud.get_command(database_session, "capitalizednamewithspaces") + command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response") + found = await crud.get_command(postgres, "capitalizednamewithspaces") assert command == found -async def test_get_command_by_alias(database_session: AsyncSession): +async def test_get_command_by_alias(postgres): """Test getting a command by an alias""" - command = await crud.create_command(database_session, "name", "response") - await crud.create_alias(database_session, command.name, "a1") - await crud.create_alias(database_session, command.name, "a2") + command = await crud.create_command(postgres, "name", "response") + await crud.create_alias(postgres, command.name, "a1") + await crud.create_alias(postgres, command.name, "a2") - found = await crud.get_command(database_session, "a1") + found = await crud.get_command(postgres, "a1") assert command == found -async def test_get_command_non_existing(database_session: AsyncSession): +async def test_get_command_non_existing(postgres): """Test getting a command when it doesn't exist""" - assert await crud.get_command(database_session, "name") is None + assert await crud.get_command(postgres, "name") is None -async def test_edit_command(database_session: AsyncSession): +async def test_edit_command(postgres): """Test editing an existing command""" - command = await crud.create_command(database_session, "name", "response") - await crud.edit_command(database_session, command.name, "new name", "new response") + command = await crud.create_command(postgres, "name", "response") + await crud.edit_command(postgres, command.name, "new name", "new response") assert command.name == "new name" assert command.response == "new response" -async def test_edit_command_non_existing(database_session: AsyncSession): +async def test_edit_command_non_existing(postgres): """Test editing a command that doesn't exist""" with pytest.raises(NoResultFoundException): - await crud.edit_command(database_session, "name", "n", "r") + await crud.edit_command(postgres, "name", "n", "r") diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py index 0c499c8..22c28c2 100644 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -1,16 +1,15 @@ from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import dad_jokes as crud -from database.models import DadJoke +from database.schemas.relational import DadJoke -async def test_add_dad_joke(database_session: AsyncSession): +async def test_add_dad_joke(postgres): """Test creating a new joke""" statement = select(DadJoke) - result = (await database_session.execute(statement)).scalars().all() + result = (await postgres.execute(statement)).scalars().all() assert len(result) == 0 - await crud.add_dad_joke(database_session, "joke") - result = (await database_session.execute(statement)).scalars().all() + await crud.add_dad_joke(postgres, "joke") + result = (await postgres.execute(statement)).scalars().all() assert len(result) == 1 diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py index e1e4f97..c4c7ba0 100644 --- a/tests/test_database/test_crud/test_tasks.py +++ b/tests/test_database/test_crud/test_tasks.py @@ -3,11 +3,10 @@ import datetime import pytest from freezegun import freeze_time from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import tasks as crud from database.enums import TaskType -from database.models import Task +from database.schemas.relational import Task @pytest.fixture @@ -17,47 +16,47 @@ def task_type() -> TaskType: @pytest.fixture -async def task(database_session: AsyncSession, task_type: TaskType) -> Task: +async def task(postgres, task_type: TaskType) -> Task: """Fixture to create a task""" task = Task(task=task_type) - database_session.add(task) - await database_session.commit() + postgres.add(task) + await postgres.commit() return task -async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType): +async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType): """Test getting a task by its enum type when it exists""" - result = await crud.get_task_by_enum(database_session, task_type) + result = await crud.get_task_by_enum(postgres, task_type) assert result is not None assert result == task -async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType): +async def test_get_task_by_enum_not_present(postgres, task_type: TaskType): """Test getting a task by its enum type when it doesn't exist""" - result = await crud.get_task_by_enum(database_session, task_type) + result = await crud.get_task_by_enum(postgres, task_type) assert result is None @freeze_time("2022/07/24") -async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType): +async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType): """Test setting the execution time of an existing task""" - await database_session.refresh(task) + await postgres.refresh(task) assert task.previous_run is None - await crud.set_last_task_execution_time(database_session, task_type) - await database_session.refresh(task) + await crud.set_last_task_execution_time(postgres, task_type) + await postgres.refresh(task) assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) @freeze_time("2022/07/24") -async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType): +async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType): """Test setting the execution time of a non-existing task""" statement = select(Task).where(Task.task == task_type) - results = list((await database_session.execute(statement)).scalars().all()) + results = list((await postgres.execute(statement)).scalars().all()) assert len(results) == 0 - await crud.set_last_task_execution_time(database_session, task_type) - results = list((await database_session.execute(statement)).scalars().all()) + await crud.set_last_task_execution_time(postgres, task_type) + results = list((await postgres.execute(statement)).scalars().all()) assert len(results) == 1 task = results[0] assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index 4e6fc47..1aa45ee 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -1,50 +1,46 @@ import datetime -from sqlalchemy.ext.asyncio import AsyncSession - from database.crud import ufora_announcements as crud -from database.models import UforaAnnouncement, UforaCourse +from database.schemas.relational import UforaAnnouncement, UforaCourse -async def test_get_courses_with_announcements_none(database_session: AsyncSession): +async def test_get_courses_with_announcements_none(postgres): """Test getting all courses with announcements when there are none""" - results = await crud.get_courses_with_announcements(database_session) + results = await crud.get_courses_with_announcements(postgres) assert len(results) == 0 -async def test_get_courses_with_announcements(database_session: AsyncSession): +async def test_get_courses_with_announcements(postgres): """Test getting all courses with announcements""" course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True) course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False) - database_session.add_all([course_1, course_2]) - await database_session.commit() + postgres.add_all([course_1, course_2]) + await postgres.commit() - results = await crud.get_courses_with_announcements(database_session) + results = await crud.get_courses_with_announcements(postgres) assert len(results) == 1 assert results[0] == course_1 -async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession): +async def test_create_new_announcement(ufora_course: UforaCourse, postgres): """Test creating a new announcement""" - await crud.create_new_announcement( - database_session, 1, course=ufora_course, publication_date=datetime.datetime.now() - ) - await database_session.refresh(ufora_course) + await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now()) + await postgres.refresh(ufora_course) assert len(ufora_course.announcements) == 1 -async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession): +async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres): """Test removing all stale announcements""" course = ufora_announcement.course ufora_announcement.publication_date -= datetime.timedelta(weeks=2) announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now()) - database_session.add_all([ufora_announcement, announcement_2]) - await database_session.commit() - await database_session.refresh(course) + postgres.add_all([ufora_announcement, announcement_2]) + await postgres.commit() + await postgres.refresh(course) assert len(course.announcements) == 2 - await crud.remove_old_announcements(database_session) + await crud.remove_old_announcements(postgres) - await database_session.refresh(course) + await postgres.refresh(course) assert len(course.announcements) == 1 assert announcement_2.course.announcements[0] == announcement_2 diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index d2d5e1b..34748c0 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -1,22 +1,20 @@ -from sqlalchemy.ext.asyncio import AsyncSession - from database.crud import ufora_courses as crud -from database.models import UforaCourse +from database.schemas.relational import UforaCourse -async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse): +async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse): """Test getting a course by its name when the query is an exact match""" - match = await crud.get_course_by_name(database_session, "Test") + match = await crud.get_course_by_name(postgres, "Test") assert match == ufora_course -async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse): +async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse): """Test getting a course by its name when the query is a substring""" - match = await crud.get_course_by_name(database_session, "es") + match = await crud.get_course_by_name(postgres, "es") assert match == ufora_course -async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): +async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse): """Test getting a course by its name when the name doesn't match, but the alias does""" - match = await crud.get_course_by_name(database_session, "ali") + match = await crud.get_course_by_name(postgres, "ali") assert match == ufora_course_with_alias diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index 08b4c81..e852298 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -1,25 +1,24 @@ from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users as crud -from database.models import User +from database.schemas.relational import User -async def test_get_or_add_non_existing(database_session: AsyncSession): +async def test_get_or_add_non_existing(postgres): """Test get_or_add for a user that doesn't exist""" - await crud.get_or_add(database_session, 1) + await crud.get_or_add(postgres, 1) statement = select(User) - res = (await database_session.execute(statement)).scalars().all() + res = (await postgres.execute(statement)).scalars().all() assert len(res) == 1 assert res[0].bank is not None assert res[0].nightly_data is not None -async def test_get_or_add_existing(database_session: AsyncSession): +async def test_get_or_add_existing(postgres): """Test get_or_add for a user that does exist""" - user = await crud.get_or_add(database_session, 1) + user = await crud.get_or_add(postgres, 1) bank = user.bank - assert await crud.get_or_add(database_session, 1) == user - assert (await crud.get_or_add(database_session, 1)).bank == bank + assert await crud.get_or_add(postgres, 1) == user + assert (await crud.get_or_add(postgres, 1)).bank == bank diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index 2e10664..69a6ff2 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -1,28 +1,24 @@ -from sqlalchemy.ext.asyncio import AsyncSession - -from database.models import UforaCourse +from database.schemas.relational import UforaCourse from database.utils.caches import UforaCourseCache -async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): +async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's empty""" cache = UforaCourseCache() - await cache.refresh(database_session) + await cache.refresh(postgres) assert len(cache.data) == 1 assert cache.data == ["test"] assert cache.aliases == {"alias": "test"} -async def test_ufora_course_cache_refresh_not_empty( - database_session: AsyncSession, ufora_course_with_alias: UforaCourse -): +async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's not empty anymore""" cache = UforaCourseCache() cache.data = ["Something"] cache.data_transformed = ["something"] - await cache.refresh(database_session) + await cache.refresh(postgres) assert len(cache.data) == 1 assert cache.data == ["test"]