Merge pull request #123 from stijndcl/add-mongo

Add MongoDB
pull/125/head
Stijn De Clercq 2022-07-25 21:32:32 +02:00 committed by GitHub
commit 2f4c2c347f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
42 changed files with 377 additions and 273 deletions

View File

@ -38,6 +38,19 @@ jobs:
POSTGRES_DB: didier_pytest POSTGRES_DB: didier_pytest
POSTGRES_USER: pytest POSTGRES_USER: pytest
POSTGRES_PASSWORD: pytest POSTGRES_PASSWORD: pytest
mongo:
image: mongo:5.0
options: >-
--health-cmd mongo
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 27018:27017
env:
MONGO_DB: didier_pytest
MONGO_USER: pytest
MONGO_PASSWORD: pytest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup Python - name: Setup Python

View File

@ -4,8 +4,8 @@ from logging.config import fileConfig
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context from alembic import context
from database.engine import engine from database.engine import postgres_engine
from database.models import Base from database.schemas.relational import Base
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.
@ -40,7 +40,7 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = context.config.attributes.get("connection", None) or engine connectable = context.config.attributes.get("connection", None) or postgres_engine
if isinstance(connectable, AsyncEngine): if isinstance(connectable, AsyncEngine):
asyncio.run(run_async_migrations(connectable)) asyncio.run(run_async_migrations(connectable))

View File

@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from database.crud import users from database.crud import users
from database.models import Birthday, User from database.schemas.relational import Birthday, User
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users from database.crud import users
from database.exceptions import currency as exceptions from database.exceptions import currency as exceptions
from database.models import Bank, NightlyData from database.schemas.relational import Bank, NightlyData
from database.utils.math.currency import ( from database.utils.math.currency import (
capacity_upgrade_price, capacity_upgrade_price,
interest_upgrade_price, interest_upgrade_price,

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.models import CustomCommand, CustomCommandAlias from database.schemas.relational import CustomCommand, CustomCommandAlias
__all__ = [ __all__ = [
"clean_name", "clean_name",

View File

@ -2,7 +2,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.models import DadJoke from database.schemas.relational import DadJoke
__all__ = ["add_dad_joke", "get_random_dad_joke"] __all__ = ["add_dad_joke", "get_random_dad_joke"]

View File

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.enums import TaskType from database.enums import TaskType
from database.models import Task from database.schemas.relational import Task
from database.utils.datetime import LOCAL_TIMEZONE from database.utils.datetime import LOCAL_TIMEZONE
__all__ = ["get_task_by_enum", "set_last_task_execution_time"] __all__ = ["get_task_by_enum", "set_last_task_execution_time"]

View File

@ -3,7 +3,7 @@ import datetime
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaAnnouncement, UforaCourse from database.schemas.relational import UforaAnnouncement, UforaCourse
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"] __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]

View File

@ -3,7 +3,7 @@ from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.models import UforaCourse, UforaCourseAlias from database.schemas.relational import UforaCourse, UforaCourseAlias
__all__ = ["get_all_courses", "get_course_by_name"] __all__ = ["get_all_courses", "get_course_by_name"]

View File

@ -3,7 +3,7 @@ from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.models import Bank, NightlyData, User from database.schemas.relational import Bank, NightlyData, User
__all__ = [ __all__ = [
"get_or_add", "get_or_add",

View File

@ -1,24 +1,34 @@
from urllib.parse import quote_plus from urllib.parse import quote_plus
import motor.motor_asyncio
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
import settings import settings
encoded_password = quote_plus(settings.DB_PASSWORD) encoded_postgres_password = quote_plus(settings.POSTGRES_PASS)
engine = create_async_engine( # PostgreSQL engine
postgres_engine = create_async_engine(
URL.create( URL.create(
drivername="postgresql+asyncpg", drivername="postgresql+asyncpg",
username=settings.DB_USERNAME, username=settings.POSTGRES_USER,
password=encoded_password, password=encoded_postgres_password,
host=settings.DB_HOST, host=settings.POSTGRES_HOST,
port=settings.DB_PORT, port=settings.POSTGRES_PORT,
database=settings.DB_NAME, database=settings.POSTGRES_DB,
), ),
pool_pre_ping=True, pool_pre_ping=True,
future=True, future=True,
) )
DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession, expire_on_commit=False) DBSession = sessionmaker(
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
)
# MongoDB client
encoded_mongo_username = quote_plus(settings.MONGO_USER)
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
mongo_url = f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)

View File

@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from alembic import command, script from alembic import command, script
from alembic.config import Config from alembic.config import Config
from alembic.runtime import migration from alembic.runtime import migration
from database.engine import engine from database.engine import postgres_engine
__config_path__ = "alembic.ini" __config_path__ = "alembic.ini"
__migrations_path__ = "alembic/" __migrations_path__ = "alembic/"
@ -22,7 +22,7 @@ async def ensure_latest_migration():
"""Make sure we are currently on the latest revision, otherwise raise an exception""" """Make sure we are currently on the latest revision, otherwise raise an exception"""
alembic_script = script.ScriptDirectory.from_config(cfg) alembic_script = script.ScriptDirectory.from_config(cfg)
async with engine.begin() as connection: async with postgres_engine.begin() as connection:
current_revision = await connection.run_sync( current_revision = await connection.run_sync(
lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision() lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision()
) )
@ -49,5 +49,5 @@ def __execute_downgrade(connection: Session):
async def migrate(up: bool): async def migrate(up: bool):
"""Migrate the database upwards or downwards""" """Migrate the database upwards or downwards"""
async with engine.begin() as connection: async with postgres_engine.begin() as connection:
await connection.run_sync(__execute_upgrade if up else __execute_downgrade) await connection.run_sync(__execute_upgrade if up else __execute_downgrade)

View File

View File

@ -0,0 +1,38 @@
from bson import ObjectId
from pydantic import BaseModel, Field
__all__ = ["MongoBase"]
class PyObjectId(str):
"""Custom type for bson ObjectIds"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value: str):
"""Check that a string is a valid bson ObjectId"""
if not ObjectId.is_valid(value):
raise ValueError(f"Invalid ObjectId: '{value}'")
return ObjectId(value)
@classmethod
def __modify_schema__(cls, field_schema: dict):
field_schema.update(type="string")
class MongoBase(BaseModel):
"""Base model that properly sets the _id field, and adds one by default"""
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
class Config:
"""Configuration for encoding and construction"""
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str, PyObjectId: str}
use_enum_values = True

View File

@ -31,7 +31,7 @@ class Currency(commands.Cog):
"""Award a user a given amount of Didier Dinks""" """Award a user a given amount of Didier Dinks"""
amount = typing.cast(int, amount) amount = typing.cast(int, amount)
async with self.client.db_session as session: async with self.client.postgres_session as session:
await crud.add_dinks(session, user.id, amount) await crud.add_dinks(session, user.id, amount)
plural = pluralize("Didier Dink", amount) plural = pluralize("Didier Dink", amount)
await ctx.reply( await ctx.reply(
@ -42,7 +42,7 @@ class Currency(commands.Cog):
@commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True) @commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True)
async def bank(self, ctx: commands.Context): async def bank(self, ctx: commands.Context):
"""Show your Didier Bank information""" """Show your Didier Bank information"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
embed = discord.Embed(colour=discord.Colour.blue()) embed = discord.Embed(colour=discord.Colour.blue())
@ -58,7 +58,7 @@ class Currency(commands.Cog):
@bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True) @bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True)
async def bank_upgrades(self, ctx: commands.Context): async def bank_upgrades(self, ctx: commands.Context):
"""List the upgrades you can buy & their prices""" """List the upgrades you can buy & their prices"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
embed = discord.Embed(colour=discord.Colour.blue()) embed = discord.Embed(colour=discord.Colour.blue())
@ -79,7 +79,7 @@ class Currency(commands.Cog):
@bank_upgrades.command(name="Capacity", aliases=["C"]) @bank_upgrades.command(name="Capacity", aliases=["C"])
async def bank_upgrade_capacity(self, ctx: commands.Context): async def bank_upgrade_capacity(self, ctx: commands.Context):
"""Upgrade the capacity level of your bank""" """Upgrade the capacity level of your bank"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.upgrade_capacity(session, ctx.author.id) await crud.upgrade_capacity(session, ctx.author.id)
await ctx.message.add_reaction("") await ctx.message.add_reaction("")
@ -90,7 +90,7 @@ class Currency(commands.Cog):
@bank_upgrades.command(name="Interest", aliases=["I"]) @bank_upgrades.command(name="Interest", aliases=["I"])
async def bank_upgrade_interest(self, ctx: commands.Context): async def bank_upgrade_interest(self, ctx: commands.Context):
"""Upgrade the interest level of your bank""" """Upgrade the interest level of your bank"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.upgrade_interest(session, ctx.author.id) await crud.upgrade_interest(session, ctx.author.id)
await ctx.message.add_reaction("") await ctx.message.add_reaction("")
@ -101,7 +101,7 @@ class Currency(commands.Cog):
@bank_upgrades.command(name="Rob", aliases=["R"]) @bank_upgrades.command(name="Rob", aliases=["R"])
async def bank_upgrade_rob(self, ctx: commands.Context): async def bank_upgrade_rob(self, ctx: commands.Context):
"""Upgrade the rob level of your bank""" """Upgrade the rob level of your bank"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.upgrade_rob(session, ctx.author.id) await crud.upgrade_rob(session, ctx.author.id)
await ctx.message.add_reaction("") await ctx.message.add_reaction("")
@ -112,7 +112,7 @@ class Currency(commands.Cog):
@commands.hybrid_command(name="dinks") @commands.hybrid_command(name="dinks")
async def dinks(self, ctx: commands.Context): async def dinks(self, ctx: commands.Context):
"""Check your Didier Dinks""" """Check your Didier Dinks"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
plural = pluralize("Didier Dink", bank.dinks) plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
@ -122,7 +122,7 @@ class Currency(commands.Cog):
"""Invest a given amount of Didier Dinks""" """Invest a given amount of Didier Dinks"""
amount = typing.cast(typing.Union[str, int], amount) amount = typing.cast(typing.Union[str, int], amount)
async with self.client.db_session as session: async with self.client.postgres_session as session:
invested = await crud.invest(session, ctx.author.id, amount) invested = await crud.invest(session, ctx.author.id, amount)
plural = pluralize("Didier Dink", invested) plural = pluralize("Didier Dink", invested)
@ -136,7 +136,7 @@ class Currency(commands.Cog):
@commands.hybrid_command(name="nightly") @commands.hybrid_command(name="nightly")
async def nightly(self, ctx: commands.Context): async def nightly(self, ctx: commands.Context):
"""Claim nightly Didier Dinks""" """Claim nightly Didier Dinks"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.claim_nightly(session, ctx.author.id) await crud.claim_nightly(session, ctx.author.id)
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")

View File

@ -19,7 +19,7 @@ class Discord(commands.Cog):
async def birthday(self, ctx: commands.Context, user: discord.User = None): async def birthday(self, ctx: commands.Context, user: discord.User = None):
"""Command to check the birthday of a user""" """Command to check the birthday of a user"""
user_id = (user and user.id) or ctx.author.id user_id = (user and user.id) or ctx.author.id
async with self.client.db_session as session: async with self.client.postgres_session as session:
birthday = await birthdays.get_birthday_for_user(session, user_id) birthday = await birthdays.get_birthday_for_user(session, user_id)
name = "Jouw" if user is None else f"{user.display_name}'s" name = "Jouw" if user is None else f"{user.display_name}'s"
@ -45,7 +45,7 @@ class Discord(commands.Cog):
except ValueError: except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
async with self.client.db_session as session: async with self.client.postgres_session as session:
await birthdays.add_birthday(session, ctx.author.id, date) await birthdays.add_birthday(session, ctx.author.id, date)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)

View File

@ -19,7 +19,7 @@ class Fun(commands.Cog):
) )
async def dad_joke(self, ctx: commands.Context): async def dad_joke(self, ctx: commands.Context):
"""Get a random dad joke""" """Get a random dad joke"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
joke = await get_random_dad_joke(session) joke = await get_random_dad_joke(session)
return await ctx.reply(joke.joke, mention_author=False) return await ctx.reply(joke.joke, mention_author=False)

View File

@ -83,7 +83,7 @@ class Owner(commands.Cog):
@add_msg.command(name="Custom") @add_msg.command(name="Custom")
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command""" """Add a new custom command"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await custom_commands.create_command(session, name, response) await custom_commands.create_command(session, name, response)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
@ -94,7 +94,7 @@ class Owner(commands.Cog):
@add_msg.command(name="Alias") @add_msg.command(name="Alias")
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command""" """Add a new alias for a custom command"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await custom_commands.create_alias(session, command, alias) await custom_commands.create_alias(session, command, alias)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
@ -130,7 +130,7 @@ class Owner(commands.Cog):
@edit_msg.command(name="Custom") @edit_msg.command(name="Custom")
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
"""Edit an existing custom command""" """Edit an existing custom command"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await custom_commands.edit_command(session, command, flags.name, flags.response) await custom_commands.edit_command(session, command, flags.name, flags.response)
return await self.client.confirm_message(ctx.message) return await self.client.confirm_message(ctx.message)
@ -147,7 +147,7 @@ class Owner(commands.Cog):
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
) )
async with self.client.db_session as session: async with self.client.postgres_session as session:
_command = await custom_commands.get_command(session, command) _command = await custom_commands.get_command(session, command)
if _command is None: if _command is None:
return await interaction.response.send_message( return await interaction.response.send_message(

View File

@ -68,7 +68,7 @@ class School(commands.Cog):
@app_commands.describe(course="vak") @app_commands.describe(course="vak")
async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
"""Create links to study guides""" """Create links to study guides"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
ufora_course = await ufora_courses.get_course_by_name(session, course) ufora_course = await ufora_courses.get_course_by_name(session, course)
if ufora_course is None: if ufora_course is None:

View File

@ -72,7 +72,7 @@ class Tasks(commands.Cog):
async def check_birthdays(self): async def check_birthdays(self):
"""Check if it's currently anyone's birthday""" """Check if it's currently anyone's birthday"""
now = tz_aware_now().date() now = tz_aware_now().date()
async with self.client.db_session as session: async with self.client.postgres_session as session:
birthdays = await get_birthdays_on_day(session, now) birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
@ -96,7 +96,7 @@ class Tasks(commands.Cog):
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
return return
async with self.client.db_session as db_session: async with self.client.postgres_session as db_session:
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
announcements = await fetch_ufora_announcements(self.client.http_session, db_session) announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
@ -110,7 +110,7 @@ class Tasks(commands.Cog):
@tasks.loop(hours=24) @tasks.loop(hours=24)
async def remove_old_ufora_announcements(self): async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day""" """Remove all announcements that are over 1 week old, once per day"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
await remove_old_announcements(session) await remove_old_announcements(session)
@check_birthdays.error @check_birthdays.error

View File

@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud
from database.models import UforaCourse from database.schemas.relational import UforaCourse
from didier.data.embeds.base import EmbedBaseModel from didier.data.embeds.base import EmbedBaseModel
from didier.utils.types.datetime import int_to_weekday from didier.utils.types.datetime import int_to_weekday
from didier.utils.types.string import leading from didier.utils.types.string import leading

View File

@ -20,7 +20,7 @@ def timed_task(task: enums.TaskType):
async def _wrapper(tasks_cog: Tasks, *args, **kwargs): async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
await func(tasks_cog, *args, **kwargs) await func(tasks_cog, *args, **kwargs)
async with tasks_cog.client.db_session as session: async with tasks_cog.client.postgres_session as session:
await set_last_task_execution_time(session, task) await set_last_task_execution_time(session, task)
return _wrapper return _wrapper

View File

@ -2,13 +2,14 @@ import logging
import os import os
import discord import discord
import motor.motor_asyncio
from aiohttp import ClientSession from aiohttp import ClientSession
from discord.ext import commands from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import custom_commands from database.crud import custom_commands
from database.engine import DBSession from database.engine import DBSession, mongo_client
from database.utils.caches import CacheManager from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.error_embed import create_error_embed
from didier.utils.discord.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
@ -45,10 +46,15 @@ class Didier(commands.Bot):
) )
@property @property
def db_session(self) -> AsyncSession: def postgres_session(self) -> AsyncSession:
"""Obtain a database session""" """Obtain a session for the PostgreSQL database"""
return DBSession() return DBSession()
@property
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
"""Obtain a reference to the MongoDB database"""
return mongo_client[settings.MONGO_DB]
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""Do some initial setup """Do some initial setup
@ -60,7 +66,7 @@ class Didier(commands.Bot):
# Initialize caches # Initialize caches
self.database_caches = CacheManager() self.database_caches = CacheManager()
async with self.db_session as session: async with self.postgres_session as session:
await self.database_caches.initialize_caches(session) await self.database_caches.initialize_caches(session)
# Create aiohttp session # Create aiohttp session
@ -153,7 +159,7 @@ class Didier(commands.Bot):
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False return False
async with self.db_session as session: async with self.postgres_session as session:
# Remove the prefix # Remove the prefix
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :] content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
command = await custom_commands.get_command(session, content) command = await custom_commands.get_command(session, content)

View File

@ -27,7 +27,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction): async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session: async with self.client.postgres_session as session:
command = await create_command(session, str(self.name.value), str(self.response.value)) command = await create_command(session, str(self.name.value), str(self.response.value))
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
@ -68,7 +68,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
name_field = typing.cast(discord.ui.TextInput, self.children[0]) name_field = typing.cast(discord.ui.TextInput, self.children[0])
response_field = typing.cast(discord.ui.TextInput, self.children[1]) response_field = typing.cast(discord.ui.TextInput, self.children[1])
async with self.client.db_session as session: async with self.client.postgres_session as session:
await edit_command(session, self.original_name, name_field.value, response_field.value) await edit_command(session, self.original_name, name_field.value, response_field.value)
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)

View File

@ -26,7 +26,7 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction): async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session: async with self.client.postgres_session as session:
joke = await add_dad_joke(session, str(self.name.value)) joke = await add_dad_joke(session, str(self.name.value))
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)

View File

@ -0,0 +1,21 @@
version: '3.9'
services:
postgres-pytest:
image: postgres:14
container_name: didier-pytest
restart: always
environment:
- POSTGRES_DB=didier_pytest
- POSTGRES_USER=pytest
- POSTGRES_PASSWORD=pytest
ports:
- "5433:5432"
mongo-pytest:
image: mongo:5.0
restart: always
environment:
- MONGO_INITDB_ROOT_USERNAME=pytest
- MONGO_INITDB_ROOT_PASSWORD=pytest
- MONGO_INITDB_DATABASE=didier_pytest
ports:
- "27018:27017"

View File

@ -1,26 +1,29 @@
version: '3.9' version: '3.9'
services: services:
db: postgres:
image: postgres:14 image: postgres:14
container_name: didier container_name: didier
restart: always restart: always
environment: environment:
- POSTGRES_DB=${DB_NAME:-didier_dev} - POSTGRES_DB=${POSTGRES_DB:-didier_dev}
- POSTGRES_USER=${DB_USERNAME:-postgres} - POSTGRES_USER=${POSTGRES_USER:-postgres}
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres} - POSTGRES_PASSWORD=${POSTGRES_PASS:-postgres}
ports: ports:
- "${DB_PORT:-5432}:${DB_PORT:-5432}" - "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
volumes: volumes:
- db:/var/lib/postgresql/data - postgres:/var/lib/postgresql/data
db-pytest: mongo:
image: postgres:14 image: mongo:5.0
container_name: didier-pytest
restart: always restart: always
environment: environment:
- POSTGRES_DB=didier_pytest - MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
- POSTGRES_USER=pytest - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
- POSTGRES_PASSWORD=pytest - MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
command: [--auth]
ports: ports:
- "5433:5432" - "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
volumes:
- mongo:/data/db
volumes: volumes:
db: postgres:
mongo:

View File

@ -36,16 +36,21 @@ plugins = [
"sqlalchemy.ext.mypy.plugin" "sqlalchemy.ext.mypy.plugin"
] ]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = ["discord.*", "feedparser.*", "markdownify.*"] module = ["discord.*", "feedparser.*", "markdownify.*", "motor.*"]
ignore_missing_imports = true ignore_missing_imports = true
[tool.pytest.ini_options] [tool.pytest.ini_options]
asyncio_mode = "auto" asyncio_mode = "auto"
env = [ env = [
"DB_NAME = didier_pytest", "MONGO_DB = didier_pytest",
"DB_USERNAME = pytest", "MONGO_USER = pytest",
"DB_PASSWORD = pytest", "MONGO_PASS = pytest",
"DB_HOST = localhost", "MONGO_HOST = localhost",
"DB_PORT = 5433", "MONGO_PORT = 27018",
"POSTGRES_DB = didier_pytest",
"POSTGRES_USER = pytest",
"POSTGRES_PASS = pytest",
"POSTGRES_HOST = localhost",
"POSTGRES_PORT = 5433",
"DISCORD_TOKEN = token" "DISCORD_TOKEN = token"
] ]

View File

@ -7,6 +7,7 @@ git+https://github.com/Rapptz/discord.py
environs==9.5.0 environs==9.5.0
feedparser==6.0.10 feedparser==6.0.10
markdownify==0.11.2 markdownify==0.11.2
motor==3.0.0
overrides==6.1.0 overrides==6.1.0
pydantic==1.9.1 pydantic==1.9.1
python-dateutil==2.8.2 python-dateutil==2.8.2

View File

@ -9,11 +9,11 @@ env.read_env()
__all__ = [ __all__ = [
"SANDBOX", "SANDBOX",
"LOGFILE", "LOGFILE",
"DB_NAME", "POSTGRES_DB",
"DB_USERNAME", "POSTGRES_USER",
"DB_PASSWORD", "POSTGRES_PASS",
"DB_HOST", "POSTGRES_HOST",
"DB_PORT", "POSTGRES_PORT",
"DISCORD_TOKEN", "DISCORD_TOKEN",
"DISCORD_READY_MESSAGE", "DISCORD_READY_MESSAGE",
"DISCORD_STATUS_MESSAGE", "DISCORD_STATUS_MESSAGE",
@ -33,11 +33,19 @@ SEMESTER: int = env.int("SEMESTER", 2)
YEAR: int = env.int("YEAR", 3) YEAR: int = env.int("YEAR", 3)
"""Database""" """Database"""
DB_NAME: str = env.str("DB_NAME", "didier") # MongoDB
DB_USERNAME: str = env.str("DB_USERNAME", "postgres") MONGO_DB: str = env.str("MONGO_DB", "didier")
DB_PASSWORD: str = env.str("DB_PASSWORD", "") MONGO_USER: str = env.str("MONGO_USER", "root")
DB_HOST: str = env.str("DB_HOST", "localhost") MONGO_PASS: str = env.str("MONGO_PASS", "root")
DB_PORT: int = env.int("DB_PORT", "5432") MONGO_HOST: str = env.str("MONGO_HOST", "localhost")
MONGO_PORT: int = env.int("MONGO_PORT", "27017")
# PostgreSQL
POSTGRES_DB: str = env.str("POSTGRES_DB", "didier")
POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres")
POSTGRES_PASS: str = env.str("POSTGRES_PASS", "")
POSTGRES_HOST: str = env.str("POSTGRES_HOST", "localhost")
POSTGRES_PORT: int = env.int("POSTGRES_PORT", "5432")
"""Discord""" """Discord"""
DISCORD_TOKEN: str = env.str("DISCORD_TOKEN") DISCORD_TOKEN: str = env.str("DISCORD_TOKEN")

View File

@ -2,10 +2,12 @@ import asyncio
from typing import AsyncGenerator, Generator from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
import motor.motor_asyncio
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine import settings
from database.engine import mongo_client, postgres_engine
from database.migrations import ensure_latest_migration, migrate from database.migrations import ensure_latest_migration, migrate
from didier import Didier from didier import Didier
@ -35,12 +37,12 @@ async def tables():
@pytest.fixture @pytest.fixture
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]: async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
"""Fixture to create a session for every test """Fixture to create a session for every test
Rollbacks the transaction afterwards so that the future tests start with a clean database Rollbacks the transaction afterwards so that the future tests start with a clean database
""" """
connection = await engine.connect() connection = await postgres_engine.connect()
transaction = await connection.begin() transaction = await connection.begin()
session = AsyncSession(bind=connection, expire_on_commit=False) session = AsyncSession(bind=connection, expire_on_commit=False)
@ -54,6 +56,14 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
await connection.close() await connection.close()
@pytest.fixture
async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase:
"""Fixture to get a MongoDB connection"""
database = mongo_client[settings.MONGO_DB]
yield database
mongo_client.drop_database(settings.MONGO_DB)
@pytest.fixture @pytest.fixture
def mock_client() -> Didier: def mock_client() -> Didier:
"""Fixture to get a mock Didier instance """Fixture to get a mock Didier instance

View File

@ -1,10 +1,15 @@
import datetime import datetime
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users from database.crud import users
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User from database.schemas.relational import (
Bank,
UforaAnnouncement,
UforaCourse,
UforaCourseAlias,
User,
)
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -17,44 +22,44 @@ def test_user_id() -> int:
@pytest.fixture @pytest.fixture
async def user(database_session: AsyncSession, test_user_id) -> User: async def user(postgres, test_user_id) -> User:
"""Fixture to create a user""" """Fixture to create a user"""
_user = await users.get_or_add(database_session, test_user_id) _user = await users.get_or_add(postgres, test_user_id)
await database_session.refresh(_user) await postgres.refresh(_user)
return _user return _user
@pytest.fixture @pytest.fixture
async def bank(database_session: AsyncSession, user: User) -> Bank: async def bank(postgres, user: User) -> Bank:
"""Fixture to fetch the test user's bank""" """Fixture to fetch the test user's bank"""
_bank = user.bank _bank = user.bank
await database_session.refresh(_bank) await postgres.refresh(_bank)
return _bank return _bank
@pytest.fixture @pytest.fixture
async def ufora_course(database_session: AsyncSession) -> UforaCourse: async def ufora_course(postgres) -> UforaCourse:
"""Fixture to create a course""" """Fixture to create a course"""
course = UforaCourse(name="test", code="code", year=1, log_announcements=True) course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
database_session.add(course) postgres.add(course)
await database_session.commit() await postgres.commit()
return course return course
@pytest.fixture @pytest.fixture
async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse:
"""Fixture to create a course with an alias""" """Fixture to create a course with an alias"""
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
database_session.add(alias) postgres.add(alias)
await database_session.commit() await postgres.commit()
await database_session.refresh(ufora_course) await postgres.refresh(ufora_course)
return ufora_course return ufora_course
@pytest.fixture @pytest.fixture
async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement: async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement:
"""Fixture to create an announcement""" """Fixture to create an announcement"""
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
database_session.add(announcement) postgres.add(announcement)
await database_session.commit() await postgres.commit()
return announcement return announcement

View File

@ -1,74 +1,73 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from freezegun import freeze_time from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud from database.crud import birthdays as crud
from database.crud import users from database.crud import users
from database.models import User from database.schemas.relational import User
async def test_add_birthday_not_present(database_session: AsyncSession, user: User): async def test_add_birthday_not_present(postgres, user: User):
"""Test setting a user's birthday when it doesn't exist yet""" """Test setting a user's birthday when it doesn't exist yet"""
assert user.birthday is None assert user.birthday is None
bd_date = datetime.today().date() bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date) await crud.add_birthday(postgres, user.user_id, bd_date)
await database_session.refresh(user) await postgres.refresh(user)
assert user.birthday is not None assert user.birthday is not None
assert user.birthday.birthday == bd_date assert user.birthday.birthday == bd_date
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User): async def test_add_birthday_overwrite(postgres, user: User):
"""Test that setting a user's birthday when it already exists overwrites it""" """Test that setting a user's birthday when it already exists overwrites it"""
bd_date = datetime.today().date() bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date) await crud.add_birthday(postgres, user.user_id, bd_date)
await database_session.refresh(user) await postgres.refresh(user)
assert user.birthday is not None assert user.birthday is not None
new_bd_date = bd_date + timedelta(weeks=1) new_bd_date = bd_date + timedelta(weeks=1)
await crud.add_birthday(database_session, user.user_id, new_bd_date) await crud.add_birthday(postgres, user.user_id, new_bd_date)
await database_session.refresh(user) await postgres.refresh(user)
assert user.birthday.birthday == new_bd_date assert user.birthday.birthday == new_bd_date
async def test_get_birthday_exists(database_session: AsyncSession, user: User): async def test_get_birthday_exists(postgres, user: User):
"""Test getting a user's birthday when it exists""" """Test getting a user's birthday when it exists"""
bd_date = datetime.today().date() bd_date = datetime.today().date()
await crud.add_birthday(database_session, user.user_id, bd_date) await crud.add_birthday(postgres, user.user_id, bd_date)
await database_session.refresh(user) await postgres.refresh(user)
bd = await crud.get_birthday_for_user(database_session, user.user_id) bd = await crud.get_birthday_for_user(postgres, user.user_id)
assert bd is not None assert bd is not None
assert bd.birthday == bd_date assert bd.birthday == bd_date
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User): async def test_get_birthday_not_exists(postgres, user: User):
"""Test getting a user's birthday when it doesn't exist""" """Test getting a user's birthday when it doesn't exist"""
bd = await crud.get_birthday_for_user(database_session, user.user_id) bd = await crud.get_birthday_for_user(postgres, user.user_id)
assert bd is None assert bd is None
@freeze_time("2022/07/23") @freeze_time("2022/07/23")
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User): async def test_get_birthdays_on_day(postgres, user: User):
"""Test getting all birthdays on a given day""" """Test getting all birthdays on a given day"""
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001)) await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
user_2 = await users.get_or_add(database_session, user.user_id + 1) user_2 = await users.get_or_add(postgres, user.user_id + 1)
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1)) await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1))
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
assert len(birthdays) == 1 assert len(birthdays) == 1
assert birthdays[0].user_id == user.user_id assert birthdays[0].user_id == user.user_id
@freeze_time("2022/07/23") @freeze_time("2022/07/23")
async def test_get_birthdays_none_present(database_session: AsyncSession): async def test_get_birthdays_none_present(postgres):
"""Test getting all birthdays when there are none""" """Test getting all birthdays when there are none"""
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
assert len(birthdays) == 0 assert len(birthdays) == 0
# Add a random birthday that is not today # Add a random birthday that is not today
await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1)) await crud.add_birthday(postgres, 1, datetime.today() + timedelta(days=1))
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today()) birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
assert len(birthdays) == 0 assert len(birthdays) == 0

View File

@ -2,78 +2,77 @@ import datetime
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import currency as crud from database.crud import currency as crud
from database.exceptions import currency as exceptions from database.exceptions import currency as exceptions
from database.models import Bank from database.schemas.relational import Bank
async def test_add_dinks(database_session: AsyncSession, bank: Bank): async def test_add_dinks(postgres, bank: Bank):
"""Test adding dinks to an account""" """Test adding dinks to an account"""
assert bank.dinks == 0 assert bank.dinks == 0
await crud.add_dinks(database_session, bank.user_id, 10) await crud.add_dinks(postgres, bank.user_id, 10)
await database_session.refresh(bank) await postgres.refresh(bank)
assert bank.dinks == 10 assert bank.dinks == 10
@freeze_time("2022/07/23") @freeze_time("2022/07/23")
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank): async def test_claim_nightly_available(postgres, bank: Bank):
"""Test claiming nightlies when it hasn't been done yet""" """Test claiming nightlies when it hasn't been done yet"""
await crud.claim_nightly(database_session, bank.user_id) await crud.claim_nightly(postgres, bank.user_id)
await database_session.refresh(bank) await postgres.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT assert bank.dinks == crud.NIGHTLY_AMOUNT
nightly_data = await crud.get_nightly_data(database_session, bank.user_id) nightly_data = await crud.get_nightly_data(postgres, bank.user_id)
assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23) assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23)
@freeze_time("2022/07/23") @freeze_time("2022/07/23")
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank): async def test_claim_nightly_unavailable(postgres, bank: Bank):
"""Test claiming nightlies twice in a day""" """Test claiming nightlies twice in a day"""
await crud.claim_nightly(database_session, bank.user_id) await crud.claim_nightly(postgres, bank.user_id)
with pytest.raises(exceptions.DoubleNightly): with pytest.raises(exceptions.DoubleNightly):
await crud.claim_nightly(database_session, bank.user_id) await crud.claim_nightly(postgres, bank.user_id)
await database_session.refresh(bank) await postgres.refresh(bank)
assert bank.dinks == crud.NIGHTLY_AMOUNT assert bank.dinks == crud.NIGHTLY_AMOUNT
async def test_invest(database_session: AsyncSession, bank: Bank): async def test_invest(postgres, bank: Bank):
"""Test investing some Dinks""" """Test investing some Dinks"""
bank.dinks = 100 bank.dinks = 100
database_session.add(bank) postgres.add(bank)
await database_session.commit() await postgres.commit()
await crud.invest(database_session, bank.user_id, 20) await crud.invest(postgres, bank.user_id, 20)
await database_session.refresh(bank) await postgres.refresh(bank)
assert bank.dinks == 80 assert bank.dinks == 80
assert bank.invested == 20 assert bank.invested == 20
async def test_invest_all(database_session: AsyncSession, bank: Bank): async def test_invest_all(postgres, bank: Bank):
"""Test investing all dinks""" """Test investing all dinks"""
bank.dinks = 100 bank.dinks = 100
database_session.add(bank) postgres.add(bank)
await database_session.commit() await postgres.commit()
await crud.invest(database_session, bank.user_id, "all") await crud.invest(postgres, bank.user_id, "all")
await database_session.refresh(bank) await postgres.refresh(bank)
assert bank.dinks == 0 assert bank.dinks == 0
assert bank.invested == 100 assert bank.invested == 100
async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank): async def test_invest_more_than_owned(postgres, bank: Bank):
"""Test investing more Dinks than you own""" """Test investing more Dinks than you own"""
bank.dinks = 100 bank.dinks = 100
database_session.add(bank) postgres.add(bank)
await database_session.commit() await postgres.commit()
await crud.invest(database_session, bank.user_id, 200) await crud.invest(postgres, bank.user_id, 200)
await database_session.refresh(bank) await postgres.refresh(bank)
assert bank.dinks == 0 assert bank.dinks == 0
assert bank.invested == 100 assert bank.invested == 100

View File

@ -1,119 +1,118 @@
import pytest import pytest
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import custom_commands as crud from database.crud import custom_commands as crud
from database.exceptions.constraints import DuplicateInsertException from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.models import CustomCommand from database.schemas.relational import CustomCommand
async def test_create_command_non_existing(database_session: AsyncSession): async def test_create_command_non_existing(postgres):
"""Test creating a new command when it doesn't exist yet""" """Test creating a new command when it doesn't exist yet"""
await crud.create_command(database_session, "name", "response") await crud.create_command(postgres, "name", "response")
commands = (await database_session.execute(select(CustomCommand))).scalars().all() commands = (await postgres.execute(select(CustomCommand))).scalars().all()
assert len(commands) == 1 assert len(commands) == 1
assert commands[0].name == "name" assert commands[0].name == "name"
async def test_create_command_duplicate_name(database_session: AsyncSession): async def test_create_command_duplicate_name(postgres):
"""Test creating a command when the name already exists""" """Test creating a command when the name already exists"""
await crud.create_command(database_session, "name", "response") await crud.create_command(postgres, "name", "response")
with pytest.raises(DuplicateInsertException): with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "name", "other response") await crud.create_command(postgres, "name", "other response")
async def test_create_command_name_is_alias(database_session: AsyncSession): async def test_create_command_name_is_alias(postgres):
"""Test creating a command when the name is taken by an alias""" """Test creating a command when the name is taken by an alias"""
await crud.create_command(database_session, "name", "response") await crud.create_command(postgres, "name", "response")
await crud.create_alias(database_session, "name", "n") await crud.create_alias(postgres, "name", "n")
with pytest.raises(DuplicateInsertException): with pytest.raises(DuplicateInsertException):
await crud.create_command(database_session, "n", "other response") await crud.create_command(postgres, "n", "other response")
async def test_create_alias(database_session: AsyncSession): async def test_create_alias(postgres):
"""Test creating an alias when the name is still free""" """Test creating an alias when the name is still free"""
command = await crud.create_command(database_session, "name", "response") command = await crud.create_command(postgres, "name", "response")
await crud.create_alias(database_session, command.name, "n") await crud.create_alias(postgres, command.name, "n")
await database_session.refresh(command) await postgres.refresh(command)
assert len(command.aliases) == 1 assert len(command.aliases) == 1
assert command.aliases[0].alias == "n" assert command.aliases[0].alias == "n"
async def test_create_alias_non_existing(database_session: AsyncSession): async def test_create_alias_non_existing(postgres):
"""Test creating an alias when the command doesn't exist""" """Test creating an alias when the command doesn't exist"""
with pytest.raises(NoResultFoundException): with pytest.raises(NoResultFoundException):
await crud.create_alias(database_session, "name", "alias") await crud.create_alias(postgres, "name", "alias")
async def test_create_alias_duplicate(database_session: AsyncSession): async def test_create_alias_duplicate(postgres):
"""Test creating an alias when another alias already has this name""" """Test creating an alias when another alias already has this name"""
command = await crud.create_command(database_session, "name", "response") command = await crud.create_command(postgres, "name", "response")
await crud.create_alias(database_session, command.name, "n") await crud.create_alias(postgres, command.name, "n")
with pytest.raises(DuplicateInsertException): with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n") await crud.create_alias(postgres, command.name, "n")
async def test_create_alias_is_command(database_session: AsyncSession): async def test_create_alias_is_command(postgres):
"""Test creating an alias when the name is taken by a command""" """Test creating an alias when the name is taken by a command"""
await crud.create_command(database_session, "n", "response") await crud.create_command(postgres, "n", "response")
command = await crud.create_command(database_session, "name", "response") command = await crud.create_command(postgres, "name", "response")
with pytest.raises(DuplicateInsertException): with pytest.raises(DuplicateInsertException):
await crud.create_alias(database_session, command.name, "n") await crud.create_alias(postgres, command.name, "n")
async def test_create_alias_match_by_alias(database_session: AsyncSession): async def test_create_alias_match_by_alias(postgres):
"""Test creating an alias for a command when matching the name to another alias""" """Test creating an alias for a command when matching the name to another alias"""
command = await crud.create_command(database_session, "name", "response") command = await crud.create_command(postgres, "name", "response")
await crud.create_alias(database_session, command.name, "a1") await crud.create_alias(postgres, command.name, "a1")
alias = await crud.create_alias(database_session, "a1", "a2") alias = await crud.create_alias(postgres, "a1", "a2")
assert alias.command == command assert alias.command == command
async def test_get_command_by_name_exists(database_session: AsyncSession): async def test_get_command_by_name_exists(postgres):
"""Test getting a command by name""" """Test getting a command by name"""
await crud.create_command(database_session, "name", "response") await crud.create_command(postgres, "name", "response")
command = await crud.get_command(database_session, "name") command = await crud.get_command(postgres, "name")
assert command is not None assert command is not None
async def test_get_command_by_cleaned_name(database_session: AsyncSession): async def test_get_command_by_cleaned_name(postgres):
"""Test getting a command by the cleaned version of the name""" """Test getting a command by the cleaned version of the name"""
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response") command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response")
found = await crud.get_command(database_session, "capitalizednamewithspaces") found = await crud.get_command(postgres, "capitalizednamewithspaces")
assert command == found assert command == found
async def test_get_command_by_alias(database_session: AsyncSession): async def test_get_command_by_alias(postgres):
"""Test getting a command by an alias""" """Test getting a command by an alias"""
command = await crud.create_command(database_session, "name", "response") command = await crud.create_command(postgres, "name", "response")
await crud.create_alias(database_session, command.name, "a1") await crud.create_alias(postgres, command.name, "a1")
await crud.create_alias(database_session, command.name, "a2") await crud.create_alias(postgres, command.name, "a2")
found = await crud.get_command(database_session, "a1") found = await crud.get_command(postgres, "a1")
assert command == found assert command == found
async def test_get_command_non_existing(database_session: AsyncSession): async def test_get_command_non_existing(postgres):
"""Test getting a command when it doesn't exist""" """Test getting a command when it doesn't exist"""
assert await crud.get_command(database_session, "name") is None assert await crud.get_command(postgres, "name") is None
async def test_edit_command(database_session: AsyncSession): async def test_edit_command(postgres):
"""Test editing an existing command""" """Test editing an existing command"""
command = await crud.create_command(database_session, "name", "response") command = await crud.create_command(postgres, "name", "response")
await crud.edit_command(database_session, command.name, "new name", "new response") await crud.edit_command(postgres, command.name, "new name", "new response")
assert command.name == "new name" assert command.name == "new name"
assert command.response == "new response" assert command.response == "new response"
async def test_edit_command_non_existing(database_session: AsyncSession): async def test_edit_command_non_existing(postgres):
"""Test editing a command that doesn't exist""" """Test editing a command that doesn't exist"""
with pytest.raises(NoResultFoundException): with pytest.raises(NoResultFoundException):
await crud.edit_command(database_session, "name", "n", "r") await crud.edit_command(postgres, "name", "n", "r")

View File

@ -1,16 +1,15 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import dad_jokes as crud from database.crud import dad_jokes as crud
from database.models import DadJoke from database.schemas.relational import DadJoke
async def test_add_dad_joke(database_session: AsyncSession): async def test_add_dad_joke(postgres):
"""Test creating a new joke""" """Test creating a new joke"""
statement = select(DadJoke) statement = select(DadJoke)
result = (await database_session.execute(statement)).scalars().all() result = (await postgres.execute(statement)).scalars().all()
assert len(result) == 0 assert len(result) == 0
await crud.add_dad_joke(database_session, "joke") await crud.add_dad_joke(postgres, "joke")
result = (await database_session.execute(statement)).scalars().all() result = (await postgres.execute(statement)).scalars().all()
assert len(result) == 1 assert len(result) == 1

View File

@ -3,11 +3,10 @@ import datetime
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import tasks as crud from database.crud import tasks as crud
from database.enums import TaskType from database.enums import TaskType
from database.models import Task from database.schemas.relational import Task
@pytest.fixture @pytest.fixture
@ -17,47 +16,47 @@ def task_type() -> TaskType:
@pytest.fixture @pytest.fixture
async def task(database_session: AsyncSession, task_type: TaskType) -> Task: async def task(postgres, task_type: TaskType) -> Task:
"""Fixture to create a task""" """Fixture to create a task"""
task = Task(task=task_type) task = Task(task=task_type)
database_session.add(task) postgres.add(task)
await database_session.commit() await postgres.commit()
return task return task
async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType): async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType):
"""Test getting a task by its enum type when it exists""" """Test getting a task by its enum type when it exists"""
result = await crud.get_task_by_enum(database_session, task_type) result = await crud.get_task_by_enum(postgres, task_type)
assert result is not None assert result is not None
assert result == task assert result == task
async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType): async def test_get_task_by_enum_not_present(postgres, task_type: TaskType):
"""Test getting a task by its enum type when it doesn't exist""" """Test getting a task by its enum type when it doesn't exist"""
result = await crud.get_task_by_enum(database_session, task_type) result = await crud.get_task_by_enum(postgres, task_type)
assert result is None assert result is None
@freeze_time("2022/07/24") @freeze_time("2022/07/24")
async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType): async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType):
"""Test setting the execution time of an existing task""" """Test setting the execution time of an existing task"""
await database_session.refresh(task) await postgres.refresh(task)
assert task.previous_run is None assert task.previous_run is None
await crud.set_last_task_execution_time(database_session, task_type) await crud.set_last_task_execution_time(postgres, task_type)
await database_session.refresh(task) await postgres.refresh(task)
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)
@freeze_time("2022/07/24") @freeze_time("2022/07/24")
async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType): async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType):
"""Test setting the execution time of a non-existing task""" """Test setting the execution time of a non-existing task"""
statement = select(Task).where(Task.task == task_type) statement = select(Task).where(Task.task == task_type)
results = list((await database_session.execute(statement)).scalars().all()) results = list((await postgres.execute(statement)).scalars().all())
assert len(results) == 0 assert len(results) == 0
await crud.set_last_task_execution_time(database_session, task_type) await crud.set_last_task_execution_time(postgres, task_type)
results = list((await database_session.execute(statement)).scalars().all()) results = list((await postgres.execute(statement)).scalars().all())
assert len(results) == 1 assert len(results) == 1
task = results[0] task = results[0]
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24) assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)

View File

@ -1,50 +1,46 @@
import datetime import datetime
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud
from database.models import UforaAnnouncement, UforaCourse from database.schemas.relational import UforaAnnouncement, UforaCourse
async def test_get_courses_with_announcements_none(database_session: AsyncSession): async def test_get_courses_with_announcements_none(postgres):
"""Test getting all courses with announcements when there are none""" """Test getting all courses with announcements when there are none"""
results = await crud.get_courses_with_announcements(database_session) results = await crud.get_courses_with_announcements(postgres)
assert len(results) == 0 assert len(results) == 0
async def test_get_courses_with_announcements(database_session: AsyncSession): async def test_get_courses_with_announcements(postgres):
"""Test getting all courses with announcements""" """Test getting all courses with announcements"""
course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True) course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True)
course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False) course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False)
database_session.add_all([course_1, course_2]) postgres.add_all([course_1, course_2])
await database_session.commit() await postgres.commit()
results = await crud.get_courses_with_announcements(database_session) results = await crud.get_courses_with_announcements(postgres)
assert len(results) == 1 assert len(results) == 1
assert results[0] == course_1 assert results[0] == course_1
async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession): async def test_create_new_announcement(ufora_course: UforaCourse, postgres):
"""Test creating a new announcement""" """Test creating a new announcement"""
await crud.create_new_announcement( await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now())
database_session, 1, course=ufora_course, publication_date=datetime.datetime.now() await postgres.refresh(ufora_course)
)
await database_session.refresh(ufora_course)
assert len(ufora_course.announcements) == 1 assert len(ufora_course.announcements) == 1
async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession): async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres):
"""Test removing all stale announcements""" """Test removing all stale announcements"""
course = ufora_announcement.course course = ufora_announcement.course
ufora_announcement.publication_date -= datetime.timedelta(weeks=2) ufora_announcement.publication_date -= datetime.timedelta(weeks=2)
announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now()) announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now())
database_session.add_all([ufora_announcement, announcement_2]) postgres.add_all([ufora_announcement, announcement_2])
await database_session.commit() await postgres.commit()
await database_session.refresh(course) await postgres.refresh(course)
assert len(course.announcements) == 2 assert len(course.announcements) == 2
await crud.remove_old_announcements(database_session) await crud.remove_old_announcements(postgres)
await database_session.refresh(course) await postgres.refresh(course)
assert len(course.announcements) == 1 assert len(course.announcements) == 1
assert announcement_2.course.announcements[0] == announcement_2 assert announcement_2.course.announcements[0] == announcement_2

View File

@ -1,22 +1,20 @@
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_courses as crud from database.crud import ufora_courses as crud
from database.models import UforaCourse from database.schemas.relational import UforaCourse
async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse): async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse):
"""Test getting a course by its name when the query is an exact match""" """Test getting a course by its name when the query is an exact match"""
match = await crud.get_course_by_name(database_session, "Test") match = await crud.get_course_by_name(postgres, "Test")
assert match == ufora_course assert match == ufora_course
async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse): async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse):
"""Test getting a course by its name when the query is a substring""" """Test getting a course by its name when the query is a substring"""
match = await crud.get_course_by_name(database_session, "es") match = await crud.get_course_by_name(postgres, "es")
assert match == ufora_course assert match == ufora_course
async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse):
"""Test getting a course by its name when the name doesn't match, but the alias does""" """Test getting a course by its name when the name doesn't match, but the alias does"""
match = await crud.get_course_by_name(database_session, "ali") match = await crud.get_course_by_name(postgres, "ali")
assert match == ufora_course_with_alias assert match == ufora_course_with_alias

View File

@ -1,25 +1,24 @@
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users as crud from database.crud import users as crud
from database.models import User from database.schemas.relational import User
async def test_get_or_add_non_existing(database_session: AsyncSession): async def test_get_or_add_non_existing(postgres):
"""Test get_or_add for a user that doesn't exist""" """Test get_or_add for a user that doesn't exist"""
await crud.get_or_add(database_session, 1) await crud.get_or_add(postgres, 1)
statement = select(User) statement = select(User)
res = (await database_session.execute(statement)).scalars().all() res = (await postgres.execute(statement)).scalars().all()
assert len(res) == 1 assert len(res) == 1
assert res[0].bank is not None assert res[0].bank is not None
assert res[0].nightly_data is not None assert res[0].nightly_data is not None
async def test_get_or_add_existing(database_session: AsyncSession): async def test_get_or_add_existing(postgres):
"""Test get_or_add for a user that does exist""" """Test get_or_add for a user that does exist"""
user = await crud.get_or_add(database_session, 1) user = await crud.get_or_add(postgres, 1)
bank = user.bank bank = user.bank
assert await crud.get_or_add(database_session, 1) == user assert await crud.get_or_add(postgres, 1) == user
assert (await crud.get_or_add(database_session, 1)).bank == bank assert (await crud.get_or_add(postgres, 1)).bank == bank

View File

@ -1,28 +1,24 @@
from sqlalchemy.ext.asyncio import AsyncSession from database.schemas.relational import UforaCourse
from database.models import UforaCourse
from database.utils.caches import UforaCourseCache from database.utils.caches import UforaCourseCache
async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse): async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse):
"""Test loading the data for the Ufora Course cache when it's empty""" """Test loading the data for the Ufora Course cache when it's empty"""
cache = UforaCourseCache() cache = UforaCourseCache()
await cache.refresh(database_session) await cache.refresh(postgres)
assert len(cache.data) == 1 assert len(cache.data) == 1
assert cache.data == ["test"] assert cache.data == ["test"]
assert cache.aliases == {"alias": "test"} assert cache.aliases == {"alias": "test"}
async def test_ufora_course_cache_refresh_not_empty( async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse):
database_session: AsyncSession, ufora_course_with_alias: UforaCourse
):
"""Test loading the data for the Ufora Course cache when it's not empty anymore""" """Test loading the data for the Ufora Course cache when it's not empty anymore"""
cache = UforaCourseCache() cache = UforaCourseCache()
cache.data = ["Something"] cache.data = ["Something"]
cache.data_transformed = ["something"] cache.data_transformed = ["something"]
await cache.refresh(database_session) await cache.refresh(postgres)
assert len(cache.data) == 1 assert len(cache.data) == 1
assert cache.data == ["test"] assert cache.data == ["test"]