mirror of https://github.com/stijndcl/didier
commit
2f4c2c347f
|
@ -38,6 +38,19 @@ jobs:
|
|||
POSTGRES_DB: didier_pytest
|
||||
POSTGRES_USER: pytest
|
||||
POSTGRES_PASSWORD: pytest
|
||||
mongo:
|
||||
image: mongo:5.0
|
||||
options: >-
|
||||
--health-cmd mongo
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 27018:27017
|
||||
env:
|
||||
MONGO_DB: didier_pytest
|
||||
MONGO_USER: pytest
|
||||
MONGO_PASSWORD: pytest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Setup Python
|
||||
|
|
|
@ -4,8 +4,8 @@ from logging.config import fileConfig
|
|||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
|
||||
from alembic import context
|
||||
from database.engine import engine
|
||||
from database.models import Base
|
||||
from database.engine import postgres_engine
|
||||
from database.schemas.relational import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
@ -40,7 +40,7 @@ def run_migrations_online() -> None:
|
|||
and associate a connection with the context.
|
||||
|
||||
"""
|
||||
connectable = context.config.attributes.get("connection", None) or engine
|
||||
connectable = context.config.attributes.get("connection", None) or postgres_engine
|
||||
|
||||
if isinstance(connectable, AsyncEngine):
|
||||
asyncio.run(run_async_migrations(connectable))
|
||||
|
|
|
@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.crud import users
|
||||
from database.models import Birthday, User
|
||||
from database.schemas.relational import Birthday, User
|
||||
|
||||
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.crud import users
|
||||
from database.exceptions import currency as exceptions
|
||||
from database.models import Bank, NightlyData
|
||||
from database.schemas.relational import Bank, NightlyData
|
||||
from database.utils.math.currency import (
|
||||
capacity_upgrade_price,
|
||||
interest_upgrade_price,
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.models import CustomCommand, CustomCommandAlias
|
||||
from database.schemas.relational import CustomCommand, CustomCommandAlias
|
||||
|
||||
__all__ = [
|
||||
"clean_name",
|
||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.models import DadJoke
|
||||
from database.schemas.relational import DadJoke
|
||||
|
||||
__all__ = ["add_dad_joke", "get_random_dad_joke"]
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.enums import TaskType
|
||||
from database.models import Task
|
||||
from database.schemas.relational import Task
|
||||
from database.utils.datetime import LOCAL_TIMEZONE
|
||||
|
||||
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
|
||||
|
|
|
@ -3,7 +3,7 @@ import datetime
|
|||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.models import UforaAnnouncement, UforaCourse
|
||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||
|
||||
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.models import UforaCourse, UforaCourseAlias
|
||||
from database.schemas.relational import UforaCourse, UforaCourseAlias
|
||||
|
||||
__all__ = ["get_all_courses", "get_course_by_name"]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.models import Bank, NightlyData, User
|
||||
from database.schemas.relational import Bank, NightlyData, User
|
||||
|
||||
__all__ = [
|
||||
"get_or_add",
|
||||
|
|
|
@ -1,24 +1,34 @@
|
|||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.motor_asyncio
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
import settings
|
||||
|
||||
encoded_password = quote_plus(settings.DB_PASSWORD)
|
||||
encoded_postgres_password = quote_plus(settings.POSTGRES_PASS)
|
||||
|
||||
engine = create_async_engine(
|
||||
# PostgreSQL engine
|
||||
postgres_engine = create_async_engine(
|
||||
URL.create(
|
||||
drivername="postgresql+asyncpg",
|
||||
username=settings.DB_USERNAME,
|
||||
password=encoded_password,
|
||||
host=settings.DB_HOST,
|
||||
port=settings.DB_PORT,
|
||||
database=settings.DB_NAME,
|
||||
username=settings.POSTGRES_USER,
|
||||
password=encoded_postgres_password,
|
||||
host=settings.POSTGRES_HOST,
|
||||
port=settings.POSTGRES_PORT,
|
||||
database=settings.POSTGRES_DB,
|
||||
),
|
||||
pool_pre_ping=True,
|
||||
future=True,
|
||||
)
|
||||
|
||||
DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession, expire_on_commit=False)
|
||||
DBSession = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
# MongoDB client
|
||||
encoded_mongo_username = quote_plus(settings.MONGO_USER)
|
||||
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
|
||||
mongo_url = f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
|||
from alembic import command, script
|
||||
from alembic.config import Config
|
||||
from alembic.runtime import migration
|
||||
from database.engine import engine
|
||||
from database.engine import postgres_engine
|
||||
|
||||
__config_path__ = "alembic.ini"
|
||||
__migrations_path__ = "alembic/"
|
||||
|
@ -22,7 +22,7 @@ async def ensure_latest_migration():
|
|||
"""Make sure we are currently on the latest revision, otherwise raise an exception"""
|
||||
alembic_script = script.ScriptDirectory.from_config(cfg)
|
||||
|
||||
async with engine.begin() as connection:
|
||||
async with postgres_engine.begin() as connection:
|
||||
current_revision = await connection.run_sync(
|
||||
lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision()
|
||||
)
|
||||
|
@ -49,5 +49,5 @@ def __execute_downgrade(connection: Session):
|
|||
|
||||
async def migrate(up: bool):
|
||||
"""Migrate the database upwards or downwards"""
|
||||
async with engine.begin() as connection:
|
||||
async with postgres_engine.begin() as connection:
|
||||
await connection.run_sync(__execute_upgrade if up else __execute_downgrade)
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ["MongoBase"]
|
||||
|
||||
|
||||
class PyObjectId(str):
|
||||
"""Custom type for bson ObjectIds"""
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: str):
|
||||
"""Check that a string is a valid bson ObjectId"""
|
||||
if not ObjectId.is_valid(value):
|
||||
raise ValueError(f"Invalid ObjectId: '{value}'")
|
||||
|
||||
return ObjectId(value)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: dict):
|
||||
field_schema.update(type="string")
|
||||
|
||||
|
||||
class MongoBase(BaseModel):
|
||||
"""Base model that properly sets the _id field, and adds one by default"""
|
||||
|
||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||
|
||||
class Config:
|
||||
"""Configuration for encoding and construction"""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str, PyObjectId: str}
|
||||
use_enum_values = True
|
|
@ -31,7 +31,7 @@ class Currency(commands.Cog):
|
|||
"""Award a user a given amount of Didier Dinks"""
|
||||
amount = typing.cast(int, amount)
|
||||
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
await crud.add_dinks(session, user.id, amount)
|
||||
plural = pluralize("Didier Dink", amount)
|
||||
await ctx.reply(
|
||||
|
@ -42,7 +42,7 @@ class Currency(commands.Cog):
|
|||
@commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True)
|
||||
async def bank(self, ctx: commands.Context):
|
||||
"""Show your Didier Bank information"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
bank = await crud.get_bank(session, ctx.author.id)
|
||||
|
||||
embed = discord.Embed(colour=discord.Colour.blue())
|
||||
|
@ -58,7 +58,7 @@ class Currency(commands.Cog):
|
|||
@bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True)
|
||||
async def bank_upgrades(self, ctx: commands.Context):
|
||||
"""List the upgrades you can buy & their prices"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
bank = await crud.get_bank(session, ctx.author.id)
|
||||
|
||||
embed = discord.Embed(colour=discord.Colour.blue())
|
||||
|
@ -79,7 +79,7 @@ class Currency(commands.Cog):
|
|||
@bank_upgrades.command(name="Capacity", aliases=["C"])
|
||||
async def bank_upgrade_capacity(self, ctx: commands.Context):
|
||||
"""Upgrade the capacity level of your bank"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await crud.upgrade_capacity(session, ctx.author.id)
|
||||
await ctx.message.add_reaction("⏫")
|
||||
|
@ -90,7 +90,7 @@ class Currency(commands.Cog):
|
|||
@bank_upgrades.command(name="Interest", aliases=["I"])
|
||||
async def bank_upgrade_interest(self, ctx: commands.Context):
|
||||
"""Upgrade the interest level of your bank"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await crud.upgrade_interest(session, ctx.author.id)
|
||||
await ctx.message.add_reaction("⏫")
|
||||
|
@ -101,7 +101,7 @@ class Currency(commands.Cog):
|
|||
@bank_upgrades.command(name="Rob", aliases=["R"])
|
||||
async def bank_upgrade_rob(self, ctx: commands.Context):
|
||||
"""Upgrade the rob level of your bank"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await crud.upgrade_rob(session, ctx.author.id)
|
||||
await ctx.message.add_reaction("⏫")
|
||||
|
@ -112,7 +112,7 @@ class Currency(commands.Cog):
|
|||
@commands.hybrid_command(name="dinks")
|
||||
async def dinks(self, ctx: commands.Context):
|
||||
"""Check your Didier Dinks"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
bank = await crud.get_bank(session, ctx.author.id)
|
||||
plural = pluralize("Didier Dink", bank.dinks)
|
||||
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
|
||||
|
@ -122,7 +122,7 @@ class Currency(commands.Cog):
|
|||
"""Invest a given amount of Didier Dinks"""
|
||||
amount = typing.cast(typing.Union[str, int], amount)
|
||||
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
invested = await crud.invest(session, ctx.author.id, amount)
|
||||
plural = pluralize("Didier Dink", invested)
|
||||
|
||||
|
@ -136,7 +136,7 @@ class Currency(commands.Cog):
|
|||
@commands.hybrid_command(name="nightly")
|
||||
async def nightly(self, ctx: commands.Context):
|
||||
"""Claim nightly Didier Dinks"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await crud.claim_nightly(session, ctx.author.id)
|
||||
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")
|
||||
|
|
|
@ -19,7 +19,7 @@ class Discord(commands.Cog):
|
|||
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
||||
"""Command to check the birthday of a user"""
|
||||
user_id = (user and user.id) or ctx.author.id
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
birthday = await birthdays.get_birthday_for_user(session, user_id)
|
||||
|
||||
name = "Jouw" if user is None else f"{user.display_name}'s"
|
||||
|
@ -45,7 +45,7 @@ class Discord(commands.Cog):
|
|||
except ValueError:
|
||||
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
||||
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
await birthdays.add_birthday(session, ctx.author.id, date)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ class Fun(commands.Cog):
|
|||
)
|
||||
async def dad_joke(self, ctx: commands.Context):
|
||||
"""Get a random dad joke"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
joke = await get_random_dad_joke(session)
|
||||
return await ctx.reply(joke.joke, mention_author=False)
|
||||
|
||||
|
|
|
@ -83,7 +83,7 @@ class Owner(commands.Cog):
|
|||
@add_msg.command(name="Custom")
|
||||
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
|
||||
"""Add a new custom command"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await custom_commands.create_command(session, name, response)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
@ -94,7 +94,7 @@ class Owner(commands.Cog):
|
|||
@add_msg.command(name="Alias")
|
||||
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
|
||||
"""Add a new alias for a custom command"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await custom_commands.create_alias(session, command, alias)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
@ -130,7 +130,7 @@ class Owner(commands.Cog):
|
|||
@edit_msg.command(name="Custom")
|
||||
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
|
||||
"""Edit an existing custom command"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
try:
|
||||
await custom_commands.edit_command(session, command, flags.name, flags.response)
|
||||
return await self.client.confirm_message(ctx.message)
|
||||
|
@ -147,7 +147,7 @@ class Owner(commands.Cog):
|
|||
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
||||
)
|
||||
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
_command = await custom_commands.get_command(session, command)
|
||||
if _command is None:
|
||||
return await interaction.response.send_message(
|
||||
|
|
|
@ -68,7 +68,7 @@ class School(commands.Cog):
|
|||
@app_commands.describe(course="vak")
|
||||
async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
|
||||
"""Create links to study guides"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
ufora_course = await ufora_courses.get_course_by_name(session, course)
|
||||
|
||||
if ufora_course is None:
|
||||
|
|
|
@ -72,7 +72,7 @@ class Tasks(commands.Cog):
|
|||
async def check_birthdays(self):
|
||||
"""Check if it's currently anyone's birthday"""
|
||||
now = tz_aware_now().date()
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
birthdays = await get_birthdays_on_day(session, now)
|
||||
|
||||
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
||||
|
@ -96,7 +96,7 @@ class Tasks(commands.Cog):
|
|||
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
||||
return
|
||||
|
||||
async with self.client.db_session as db_session:
|
||||
async with self.client.postgres_session as db_session:
|
||||
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
|
||||
announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
|
||||
|
||||
|
@ -110,7 +110,7 @@ class Tasks(commands.Cog):
|
|||
@tasks.loop(hours=24)
|
||||
async def remove_old_ufora_announcements(self):
|
||||
"""Remove all announcements that are over 1 week old, once per day"""
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
await remove_old_announcements(session)
|
||||
|
||||
@check_birthdays.error
|
||||
|
|
|
@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
import settings
|
||||
from database.crud import ufora_announcements as crud
|
||||
from database.models import UforaCourse
|
||||
from database.schemas.relational import UforaCourse
|
||||
from didier.data.embeds.base import EmbedBaseModel
|
||||
from didier.utils.types.datetime import int_to_weekday
|
||||
from didier.utils.types.string import leading
|
||||
|
|
|
@ -20,7 +20,7 @@ def timed_task(task: enums.TaskType):
|
|||
async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
|
||||
await func(tasks_cog, *args, **kwargs)
|
||||
|
||||
async with tasks_cog.client.db_session as session:
|
||||
async with tasks_cog.client.postgres_session as session:
|
||||
await set_last_task_execution_time(session, task)
|
||||
|
||||
return _wrapper
|
||||
|
|
|
@ -2,13 +2,14 @@ import logging
|
|||
import os
|
||||
|
||||
import discord
|
||||
import motor.motor_asyncio
|
||||
from aiohttp import ClientSession
|
||||
from discord.ext import commands
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import settings
|
||||
from database.crud import custom_commands
|
||||
from database.engine import DBSession
|
||||
from database.engine import DBSession, mongo_client
|
||||
from database.utils.caches import CacheManager
|
||||
from didier.data.embeds.error_embed import create_error_embed
|
||||
from didier.utils.discord.prefix import get_prefix
|
||||
|
@ -45,10 +46,15 @@ class Didier(commands.Bot):
|
|||
)
|
||||
|
||||
@property
|
||||
def db_session(self) -> AsyncSession:
|
||||
"""Obtain a database session"""
|
||||
def postgres_session(self) -> AsyncSession:
|
||||
"""Obtain a session for the PostgreSQL database"""
|
||||
return DBSession()
|
||||
|
||||
@property
|
||||
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
||||
"""Obtain a reference to the MongoDB database"""
|
||||
return mongo_client[settings.MONGO_DB]
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Do some initial setup
|
||||
|
||||
|
@ -60,7 +66,7 @@ class Didier(commands.Bot):
|
|||
|
||||
# Initialize caches
|
||||
self.database_caches = CacheManager()
|
||||
async with self.db_session as session:
|
||||
async with self.postgres_session as session:
|
||||
await self.database_caches.initialize_caches(session)
|
||||
|
||||
# Create aiohttp session
|
||||
|
@ -153,7 +159,7 @@ class Didier(commands.Bot):
|
|||
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
|
||||
return False
|
||||
|
||||
async with self.db_session as session:
|
||||
async with self.postgres_session as session:
|
||||
# Remove the prefix
|
||||
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
|
||||
command = await custom_commands.get_command(session, content)
|
||||
|
|
|
@ -27,7 +27,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
|
|||
|
||||
@overrides
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
command = await create_command(session, str(self.name.value), str(self.response.value))
|
||||
|
||||
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
|
||||
|
@ -68,7 +68,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
|
|||
name_field = typing.cast(discord.ui.TextInput, self.children[0])
|
||||
response_field = typing.cast(discord.ui.TextInput, self.children[1])
|
||||
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
await edit_command(session, self.original_name, name_field.value, response_field.value)
|
||||
|
||||
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)
|
||||
|
|
|
@ -26,7 +26,7 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
|
|||
|
||||
@overrides
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
async with self.client.db_session as session:
|
||||
async with self.client.postgres_session as session:
|
||||
joke = await add_dad_joke(session, str(self.name.value))
|
||||
|
||||
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
version: '3.9'
|
||||
services:
|
||||
postgres-pytest:
|
||||
image: postgres:14
|
||||
container_name: didier-pytest
|
||||
restart: always
|
||||
environment:
|
||||
- POSTGRES_DB=didier_pytest
|
||||
- POSTGRES_USER=pytest
|
||||
- POSTGRES_PASSWORD=pytest
|
||||
ports:
|
||||
- "5433:5432"
|
||||
mongo-pytest:
|
||||
image: mongo:5.0
|
||||
restart: always
|
||||
environment:
|
||||
- MONGO_INITDB_ROOT_USERNAME=pytest
|
||||
- MONGO_INITDB_ROOT_PASSWORD=pytest
|
||||
- MONGO_INITDB_DATABASE=didier_pytest
|
||||
ports:
|
||||
- "27018:27017"
|
|
@ -1,26 +1,29 @@
|
|||
version: '3.9'
|
||||
services:
|
||||
db:
|
||||
postgres:
|
||||
image: postgres:14
|
||||
container_name: didier
|
||||
restart: always
|
||||
environment:
|
||||
- POSTGRES_DB=${DB_NAME:-didier_dev}
|
||||
- POSTGRES_USER=${DB_USERNAME:-postgres}
|
||||
- POSTGRES_PASSWORD=${DB_PASSWORD:-postgres}
|
||||
- POSTGRES_DB=${POSTGRES_DB:-didier_dev}
|
||||
- POSTGRES_USER=${POSTGRES_USER:-postgres}
|
||||
- POSTGRES_PASSWORD=${POSTGRES_PASS:-postgres}
|
||||
ports:
|
||||
- "${DB_PORT:-5432}:${DB_PORT:-5432}"
|
||||
- "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
|
||||
volumes:
|
||||
- db:/var/lib/postgresql/data
|
||||
db-pytest:
|
||||
image: postgres:14
|
||||
container_name: didier-pytest
|
||||
- postgres:/var/lib/postgresql/data
|
||||
mongo:
|
||||
image: mongo:5.0
|
||||
restart: always
|
||||
environment:
|
||||
- POSTGRES_DB=didier_pytest
|
||||
- POSTGRES_USER=pytest
|
||||
- POSTGRES_PASSWORD=pytest
|
||||
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
|
||||
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
|
||||
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
|
||||
command: [--auth]
|
||||
ports:
|
||||
- "5433:5432"
|
||||
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
|
||||
volumes:
|
||||
db:
|
||||
- mongo:/data/db
|
||||
volumes:
|
||||
postgres:
|
||||
mongo:
|
||||
|
|
|
@ -36,16 +36,21 @@ plugins = [
|
|||
"sqlalchemy.ext.mypy.plugin"
|
||||
]
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["discord.*", "feedparser.*", "markdownify.*"]
|
||||
module = ["discord.*", "feedparser.*", "markdownify.*", "motor.*"]
|
||||
ignore_missing_imports = true
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
env = [
|
||||
"DB_NAME = didier_pytest",
|
||||
"DB_USERNAME = pytest",
|
||||
"DB_PASSWORD = pytest",
|
||||
"DB_HOST = localhost",
|
||||
"DB_PORT = 5433",
|
||||
"MONGO_DB = didier_pytest",
|
||||
"MONGO_USER = pytest",
|
||||
"MONGO_PASS = pytest",
|
||||
"MONGO_HOST = localhost",
|
||||
"MONGO_PORT = 27018",
|
||||
"POSTGRES_DB = didier_pytest",
|
||||
"POSTGRES_USER = pytest",
|
||||
"POSTGRES_PASS = pytest",
|
||||
"POSTGRES_HOST = localhost",
|
||||
"POSTGRES_PORT = 5433",
|
||||
"DISCORD_TOKEN = token"
|
||||
]
|
||||
|
|
|
@ -7,6 +7,7 @@ git+https://github.com/Rapptz/discord.py
|
|||
environs==9.5.0
|
||||
feedparser==6.0.10
|
||||
markdownify==0.11.2
|
||||
motor==3.0.0
|
||||
overrides==6.1.0
|
||||
pydantic==1.9.1
|
||||
python-dateutil==2.8.2
|
||||
|
|
28
settings.py
28
settings.py
|
@ -9,11 +9,11 @@ env.read_env()
|
|||
__all__ = [
|
||||
"SANDBOX",
|
||||
"LOGFILE",
|
||||
"DB_NAME",
|
||||
"DB_USERNAME",
|
||||
"DB_PASSWORD",
|
||||
"DB_HOST",
|
||||
"DB_PORT",
|
||||
"POSTGRES_DB",
|
||||
"POSTGRES_USER",
|
||||
"POSTGRES_PASS",
|
||||
"POSTGRES_HOST",
|
||||
"POSTGRES_PORT",
|
||||
"DISCORD_TOKEN",
|
||||
"DISCORD_READY_MESSAGE",
|
||||
"DISCORD_STATUS_MESSAGE",
|
||||
|
@ -33,11 +33,19 @@ SEMESTER: int = env.int("SEMESTER", 2)
|
|||
YEAR: int = env.int("YEAR", 3)
|
||||
|
||||
"""Database"""
|
||||
DB_NAME: str = env.str("DB_NAME", "didier")
|
||||
DB_USERNAME: str = env.str("DB_USERNAME", "postgres")
|
||||
DB_PASSWORD: str = env.str("DB_PASSWORD", "")
|
||||
DB_HOST: str = env.str("DB_HOST", "localhost")
|
||||
DB_PORT: int = env.int("DB_PORT", "5432")
|
||||
# MongoDB
|
||||
MONGO_DB: str = env.str("MONGO_DB", "didier")
|
||||
MONGO_USER: str = env.str("MONGO_USER", "root")
|
||||
MONGO_PASS: str = env.str("MONGO_PASS", "root")
|
||||
MONGO_HOST: str = env.str("MONGO_HOST", "localhost")
|
||||
MONGO_PORT: int = env.int("MONGO_PORT", "27017")
|
||||
|
||||
# PostgreSQL
|
||||
POSTGRES_DB: str = env.str("POSTGRES_DB", "didier")
|
||||
POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres")
|
||||
POSTGRES_PASS: str = env.str("POSTGRES_PASS", "")
|
||||
POSTGRES_HOST: str = env.str("POSTGRES_HOST", "localhost")
|
||||
POSTGRES_PORT: int = env.int("POSTGRES_PORT", "5432")
|
||||
|
||||
"""Discord"""
|
||||
DISCORD_TOKEN: str = env.str("DISCORD_TOKEN")
|
||||
|
|
|
@ -2,10 +2,12 @@ import asyncio
|
|||
from typing import AsyncGenerator, Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import motor.motor_asyncio
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.engine import engine
|
||||
import settings
|
||||
from database.engine import mongo_client, postgres_engine
|
||||
from database.migrations import ensure_latest_migration, migrate
|
||||
from didier import Didier
|
||||
|
||||
|
@ -35,12 +37,12 @@ async def tables():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||
async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Fixture to create a session for every test
|
||||
|
||||
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
||||
"""
|
||||
connection = await engine.connect()
|
||||
connection = await postgres_engine.connect()
|
||||
transaction = await connection.begin()
|
||||
session = AsyncSession(bind=connection, expire_on_commit=False)
|
||||
|
||||
|
@ -54,6 +56,14 @@ async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
|||
await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
||||
"""Fixture to get a MongoDB connection"""
|
||||
database = mongo_client[settings.MONGO_DB]
|
||||
yield database
|
||||
mongo_client.drop_database(settings.MONGO_DB)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client() -> Didier:
|
||||
"""Fixture to get a mock Didier instance
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users
|
||||
from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User
|
||||
from database.schemas.relational import (
|
||||
Bank,
|
||||
UforaAnnouncement,
|
||||
UforaCourse,
|
||||
UforaCourseAlias,
|
||||
User,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
|
@ -17,44 +22,44 @@ def test_user_id() -> int:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def user(database_session: AsyncSession, test_user_id) -> User:
|
||||
async def user(postgres, test_user_id) -> User:
|
||||
"""Fixture to create a user"""
|
||||
_user = await users.get_or_add(database_session, test_user_id)
|
||||
await database_session.refresh(_user)
|
||||
_user = await users.get_or_add(postgres, test_user_id)
|
||||
await postgres.refresh(_user)
|
||||
return _user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bank(database_session: AsyncSession, user: User) -> Bank:
|
||||
async def bank(postgres, user: User) -> Bank:
|
||||
"""Fixture to fetch the test user's bank"""
|
||||
_bank = user.bank
|
||||
await database_session.refresh(_bank)
|
||||
await postgres.refresh(_bank)
|
||||
return _bank
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ufora_course(database_session: AsyncSession) -> UforaCourse:
|
||||
async def ufora_course(postgres) -> UforaCourse:
|
||||
"""Fixture to create a course"""
|
||||
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
||||
database_session.add(course)
|
||||
await database_session.commit()
|
||||
postgres.add(course)
|
||||
await postgres.commit()
|
||||
return course
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ufora_course_with_alias(database_session: AsyncSession, ufora_course: UforaCourse) -> UforaCourse:
|
||||
async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse:
|
||||
"""Fixture to create a course with an alias"""
|
||||
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
|
||||
database_session.add(alias)
|
||||
await database_session.commit()
|
||||
await database_session.refresh(ufora_course)
|
||||
postgres.add(alias)
|
||||
await postgres.commit()
|
||||
await postgres.refresh(ufora_course)
|
||||
return ufora_course
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ufora_announcement(ufora_course: UforaCourse, database_session: AsyncSession) -> UforaAnnouncement:
|
||||
async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement:
|
||||
"""Fixture to create an announcement"""
|
||||
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
|
||||
database_session.add(announcement)
|
||||
await database_session.commit()
|
||||
postgres.add(announcement)
|
||||
await postgres.commit()
|
||||
return announcement
|
||||
|
|
|
@ -1,74 +1,73 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import birthdays as crud
|
||||
from database.crud import users
|
||||
from database.models import User
|
||||
from database.schemas.relational import User
|
||||
|
||||
|
||||
async def test_add_birthday_not_present(database_session: AsyncSession, user: User):
|
||||
async def test_add_birthday_not_present(postgres, user: User):
|
||||
"""Test setting a user's birthday when it doesn't exist yet"""
|
||||
assert user.birthday is None
|
||||
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||
await postgres.refresh(user)
|
||||
assert user.birthday is not None
|
||||
assert user.birthday.birthday == bd_date
|
||||
|
||||
|
||||
async def test_add_birthday_overwrite(database_session: AsyncSession, user: User):
|
||||
async def test_add_birthday_overwrite(postgres, user: User):
|
||||
"""Test that setting a user's birthday when it already exists overwrites it"""
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||
await postgres.refresh(user)
|
||||
assert user.birthday is not None
|
||||
|
||||
new_bd_date = bd_date + timedelta(weeks=1)
|
||||
await crud.add_birthday(database_session, user.user_id, new_bd_date)
|
||||
await database_session.refresh(user)
|
||||
await crud.add_birthday(postgres, user.user_id, new_bd_date)
|
||||
await postgres.refresh(user)
|
||||
assert user.birthday.birthday == new_bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_exists(database_session: AsyncSession, user: User):
|
||||
async def test_get_birthday_exists(postgres, user: User):
|
||||
"""Test getting a user's birthday when it exists"""
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(database_session, user.user_id, bd_date)
|
||||
await database_session.refresh(user)
|
||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||
await postgres.refresh(user)
|
||||
|
||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||
bd = await crud.get_birthday_for_user(postgres, user.user_id)
|
||||
assert bd is not None
|
||||
assert bd.birthday == bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_not_exists(database_session: AsyncSession, user: User):
|
||||
async def test_get_birthday_not_exists(postgres, user: User):
|
||||
"""Test getting a user's birthday when it doesn't exist"""
|
||||
bd = await crud.get_birthday_for_user(database_session, user.user_id)
|
||||
bd = await crud.get_birthday_for_user(postgres, user.user_id)
|
||||
assert bd is None
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_on_day(database_session: AsyncSession, user: User):
|
||||
async def test_get_birthdays_on_day(postgres, user: User):
|
||||
"""Test getting all birthdays on a given day"""
|
||||
await crud.add_birthday(database_session, user.user_id, datetime.today().replace(year=2001))
|
||||
await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
|
||||
|
||||
user_2 = await users.get_or_add(database_session, user.user_id + 1)
|
||||
await crud.add_birthday(database_session, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||
user_2 = await users.get_or_add(postgres, user.user_id + 1)
|
||||
await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
||||
assert len(birthdays) == 1
|
||||
assert birthdays[0].user_id == user.user_id
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_none_present(database_session: AsyncSession):
|
||||
async def test_get_birthdays_none_present(postgres):
|
||||
"""Test getting all birthdays when there are none"""
|
||||
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
||||
assert len(birthdays) == 0
|
||||
|
||||
# Add a random birthday that is not today
|
||||
await crud.add_birthday(database_session, 1, datetime.today() + timedelta(days=1))
|
||||
await crud.add_birthday(postgres, 1, datetime.today() + timedelta(days=1))
|
||||
|
||||
birthdays = await crud.get_birthdays_on_day(database_session, datetime.today())
|
||||
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
||||
assert len(birthdays) == 0
|
||||
|
|
|
@ -2,78 +2,77 @@ import datetime
|
|||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import currency as crud
|
||||
from database.exceptions import currency as exceptions
|
||||
from database.models import Bank
|
||||
from database.schemas.relational import Bank
|
||||
|
||||
|
||||
async def test_add_dinks(database_session: AsyncSession, bank: Bank):
|
||||
async def test_add_dinks(postgres, bank: Bank):
|
||||
"""Test adding dinks to an account"""
|
||||
assert bank.dinks == 0
|
||||
await crud.add_dinks(database_session, bank.user_id, 10)
|
||||
await database_session.refresh(bank)
|
||||
await crud.add_dinks(postgres, bank.user_id, 10)
|
||||
await postgres.refresh(bank)
|
||||
assert bank.dinks == 10
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_claim_nightly_available(database_session: AsyncSession, bank: Bank):
|
||||
async def test_claim_nightly_available(postgres, bank: Bank):
|
||||
"""Test claiming nightlies when it hasn't been done yet"""
|
||||
await crud.claim_nightly(database_session, bank.user_id)
|
||||
await database_session.refresh(bank)
|
||||
await crud.claim_nightly(postgres, bank.user_id)
|
||||
await postgres.refresh(bank)
|
||||
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||
|
||||
nightly_data = await crud.get_nightly_data(database_session, bank.user_id)
|
||||
nightly_data = await crud.get_nightly_data(postgres, bank.user_id)
|
||||
assert nightly_data.last_nightly == datetime.date(year=2022, month=7, day=23)
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_claim_nightly_unavailable(database_session: AsyncSession, bank: Bank):
|
||||
async def test_claim_nightly_unavailable(postgres, bank: Bank):
|
||||
"""Test claiming nightlies twice in a day"""
|
||||
await crud.claim_nightly(database_session, bank.user_id)
|
||||
await crud.claim_nightly(postgres, bank.user_id)
|
||||
|
||||
with pytest.raises(exceptions.DoubleNightly):
|
||||
await crud.claim_nightly(database_session, bank.user_id)
|
||||
await crud.claim_nightly(postgres, bank.user_id)
|
||||
|
||||
await database_session.refresh(bank)
|
||||
await postgres.refresh(bank)
|
||||
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||
|
||||
|
||||
async def test_invest(database_session: AsyncSession, bank: Bank):
|
||||
async def test_invest(postgres, bank: Bank):
|
||||
"""Test investing some Dinks"""
|
||||
bank.dinks = 100
|
||||
database_session.add(bank)
|
||||
await database_session.commit()
|
||||
postgres.add(bank)
|
||||
await postgres.commit()
|
||||
|
||||
await crud.invest(database_session, bank.user_id, 20)
|
||||
await database_session.refresh(bank)
|
||||
await crud.invest(postgres, bank.user_id, 20)
|
||||
await postgres.refresh(bank)
|
||||
|
||||
assert bank.dinks == 80
|
||||
assert bank.invested == 20
|
||||
|
||||
|
||||
async def test_invest_all(database_session: AsyncSession, bank: Bank):
|
||||
async def test_invest_all(postgres, bank: Bank):
|
||||
"""Test investing all dinks"""
|
||||
bank.dinks = 100
|
||||
database_session.add(bank)
|
||||
await database_session.commit()
|
||||
postgres.add(bank)
|
||||
await postgres.commit()
|
||||
|
||||
await crud.invest(database_session, bank.user_id, "all")
|
||||
await database_session.refresh(bank)
|
||||
await crud.invest(postgres, bank.user_id, "all")
|
||||
await postgres.refresh(bank)
|
||||
|
||||
assert bank.dinks == 0
|
||||
assert bank.invested == 100
|
||||
|
||||
|
||||
async def test_invest_more_than_owned(database_session: AsyncSession, bank: Bank):
|
||||
async def test_invest_more_than_owned(postgres, bank: Bank):
|
||||
"""Test investing more Dinks than you own"""
|
||||
bank.dinks = 100
|
||||
database_session.add(bank)
|
||||
await database_session.commit()
|
||||
postgres.add(bank)
|
||||
await postgres.commit()
|
||||
|
||||
await crud.invest(database_session, bank.user_id, 200)
|
||||
await database_session.refresh(bank)
|
||||
await crud.invest(postgres, bank.user_id, 200)
|
||||
await postgres.refresh(bank)
|
||||
|
||||
assert bank.dinks == 0
|
||||
assert bank.invested == 100
|
||||
|
|
|
@ -1,119 +1,118 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import custom_commands as crud
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.models import CustomCommand
|
||||
from database.schemas.relational import CustomCommand
|
||||
|
||||
|
||||
async def test_create_command_non_existing(database_session: AsyncSession):
|
||||
async def test_create_command_non_existing(postgres):
|
||||
"""Test creating a new command when it doesn't exist yet"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
|
||||
commands = (await database_session.execute(select(CustomCommand))).scalars().all()
|
||||
commands = (await postgres.execute(select(CustomCommand))).scalars().all()
|
||||
assert len(commands) == 1
|
||||
assert commands[0].name == "name"
|
||||
|
||||
|
||||
async def test_create_command_duplicate_name(database_session: AsyncSession):
|
||||
async def test_create_command_duplicate_name(postgres):
|
||||
"""Test creating a command when the name already exists"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_command(database_session, "name", "other response")
|
||||
await crud.create_command(postgres, "name", "other response")
|
||||
|
||||
|
||||
async def test_create_command_name_is_alias(database_session: AsyncSession):
|
||||
async def test_create_command_name_is_alias(postgres):
|
||||
"""Test creating a command when the name is taken by an alias"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, "name", "n")
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, "name", "n")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_command(database_session, "n", "other response")
|
||||
await crud.create_command(postgres, "n", "other response")
|
||||
|
||||
|
||||
async def test_create_alias(database_session: AsyncSession):
|
||||
async def test_create_alias(postgres):
|
||||
"""Test creating an alias when the name is still free"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
||||
await database_session.refresh(command)
|
||||
await postgres.refresh(command)
|
||||
assert len(command.aliases) == 1
|
||||
assert command.aliases[0].alias == "n"
|
||||
|
||||
|
||||
async def test_create_alias_non_existing(database_session: AsyncSession):
|
||||
async def test_create_alias_non_existing(postgres):
|
||||
"""Test creating an alias when the command doesn't exist"""
|
||||
with pytest.raises(NoResultFoundException):
|
||||
await crud.create_alias(database_session, "name", "alias")
|
||||
await crud.create_alias(postgres, "name", "alias")
|
||||
|
||||
|
||||
async def test_create_alias_duplicate(database_session: AsyncSession):
|
||||
async def test_create_alias_duplicate(postgres):
|
||||
"""Test creating an alias when another alias already has this name"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
||||
|
||||
async def test_create_alias_is_command(database_session: AsyncSession):
|
||||
async def test_create_alias_is_command(postgres):
|
||||
"""Test creating an alias when the name is taken by a command"""
|
||||
await crud.create_command(database_session, "n", "response")
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_command(postgres, "n", "response")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
|
||||
with pytest.raises(DuplicateInsertException):
|
||||
await crud.create_alias(database_session, command.name, "n")
|
||||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
||||
|
||||
async def test_create_alias_match_by_alias(database_session: AsyncSession):
|
||||
async def test_create_alias_match_by_alias(postgres):
|
||||
"""Test creating an alias for a command when matching the name to another alias"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "a1")
|
||||
alias = await crud.create_alias(database_session, "a1", "a2")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "a1")
|
||||
alias = await crud.create_alias(postgres, "a1", "a2")
|
||||
assert alias.command == command
|
||||
|
||||
|
||||
async def test_get_command_by_name_exists(database_session: AsyncSession):
|
||||
async def test_get_command_by_name_exists(postgres):
|
||||
"""Test getting a command by name"""
|
||||
await crud.create_command(database_session, "name", "response")
|
||||
command = await crud.get_command(database_session, "name")
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
command = await crud.get_command(postgres, "name")
|
||||
assert command is not None
|
||||
|
||||
|
||||
async def test_get_command_by_cleaned_name(database_session: AsyncSession):
|
||||
async def test_get_command_by_cleaned_name(postgres):
|
||||
"""Test getting a command by the cleaned version of the name"""
|
||||
command = await crud.create_command(database_session, "CAPITALIZED NAME WITH SPACES", "response")
|
||||
found = await crud.get_command(database_session, "capitalizednamewithspaces")
|
||||
command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response")
|
||||
found = await crud.get_command(postgres, "capitalizednamewithspaces")
|
||||
assert command == found
|
||||
|
||||
|
||||
async def test_get_command_by_alias(database_session: AsyncSession):
|
||||
async def test_get_command_by_alias(postgres):
|
||||
"""Test getting a command by an alias"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.create_alias(database_session, command.name, "a1")
|
||||
await crud.create_alias(database_session, command.name, "a2")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "a1")
|
||||
await crud.create_alias(postgres, command.name, "a2")
|
||||
|
||||
found = await crud.get_command(database_session, "a1")
|
||||
found = await crud.get_command(postgres, "a1")
|
||||
assert command == found
|
||||
|
||||
|
||||
async def test_get_command_non_existing(database_session: AsyncSession):
|
||||
async def test_get_command_non_existing(postgres):
|
||||
"""Test getting a command when it doesn't exist"""
|
||||
assert await crud.get_command(database_session, "name") is None
|
||||
assert await crud.get_command(postgres, "name") is None
|
||||
|
||||
|
||||
async def test_edit_command(database_session: AsyncSession):
|
||||
async def test_edit_command(postgres):
|
||||
"""Test editing an existing command"""
|
||||
command = await crud.create_command(database_session, "name", "response")
|
||||
await crud.edit_command(database_session, command.name, "new name", "new response")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.edit_command(postgres, command.name, "new name", "new response")
|
||||
assert command.name == "new name"
|
||||
assert command.response == "new response"
|
||||
|
||||
|
||||
async def test_edit_command_non_existing(database_session: AsyncSession):
|
||||
async def test_edit_command_non_existing(postgres):
|
||||
"""Test editing a command that doesn't exist"""
|
||||
with pytest.raises(NoResultFoundException):
|
||||
await crud.edit_command(database_session, "name", "n", "r")
|
||||
await crud.edit_command(postgres, "name", "n", "r")
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import dad_jokes as crud
|
||||
from database.models import DadJoke
|
||||
from database.schemas.relational import DadJoke
|
||||
|
||||
|
||||
async def test_add_dad_joke(database_session: AsyncSession):
|
||||
async def test_add_dad_joke(postgres):
|
||||
"""Test creating a new joke"""
|
||||
statement = select(DadJoke)
|
||||
result = (await database_session.execute(statement)).scalars().all()
|
||||
result = (await postgres.execute(statement)).scalars().all()
|
||||
assert len(result) == 0
|
||||
|
||||
await crud.add_dad_joke(database_session, "joke")
|
||||
result = (await database_session.execute(statement)).scalars().all()
|
||||
await crud.add_dad_joke(postgres, "joke")
|
||||
result = (await postgres.execute(statement)).scalars().all()
|
||||
assert len(result) == 1
|
||||
|
|
|
@ -3,11 +3,10 @@ import datetime
|
|||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import tasks as crud
|
||||
from database.enums import TaskType
|
||||
from database.models import Task
|
||||
from database.schemas.relational import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -17,47 +16,47 @@ def task_type() -> TaskType:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def task(database_session: AsyncSession, task_type: TaskType) -> Task:
|
||||
async def task(postgres, task_type: TaskType) -> Task:
|
||||
"""Fixture to create a task"""
|
||||
task = Task(task=task_type)
|
||||
database_session.add(task)
|
||||
await database_session.commit()
|
||||
postgres.add(task)
|
||||
await postgres.commit()
|
||||
return task
|
||||
|
||||
|
||||
async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType):
|
||||
async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType):
|
||||
"""Test getting a task by its enum type when it exists"""
|
||||
result = await crud.get_task_by_enum(database_session, task_type)
|
||||
result = await crud.get_task_by_enum(postgres, task_type)
|
||||
assert result is not None
|
||||
assert result == task
|
||||
|
||||
|
||||
async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType):
|
||||
async def test_get_task_by_enum_not_present(postgres, task_type: TaskType):
|
||||
"""Test getting a task by its enum type when it doesn't exist"""
|
||||
result = await crud.get_task_by_enum(database_session, task_type)
|
||||
result = await crud.get_task_by_enum(postgres, task_type)
|
||||
assert result is None
|
||||
|
||||
|
||||
@freeze_time("2022/07/24")
|
||||
async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType):
|
||||
async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType):
|
||||
"""Test setting the execution time of an existing task"""
|
||||
await database_session.refresh(task)
|
||||
await postgres.refresh(task)
|
||||
assert task.previous_run is None
|
||||
|
||||
await crud.set_last_task_execution_time(database_session, task_type)
|
||||
await database_session.refresh(task)
|
||||
await crud.set_last_task_execution_time(postgres, task_type)
|
||||
await postgres.refresh(task)
|
||||
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)
|
||||
|
||||
|
||||
@freeze_time("2022/07/24")
|
||||
async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType):
|
||||
async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType):
|
||||
"""Test setting the execution time of a non-existing task"""
|
||||
statement = select(Task).where(Task.task == task_type)
|
||||
results = list((await database_session.execute(statement)).scalars().all())
|
||||
results = list((await postgres.execute(statement)).scalars().all())
|
||||
assert len(results) == 0
|
||||
|
||||
await crud.set_last_task_execution_time(database_session, task_type)
|
||||
results = list((await database_session.execute(statement)).scalars().all())
|
||||
await crud.set_last_task_execution_time(postgres, task_type)
|
||||
results = list((await postgres.execute(statement)).scalars().all())
|
||||
assert len(results) == 1
|
||||
task = results[0]
|
||||
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)
|
||||
|
|
|
@ -1,50 +1,46 @@
|
|||
import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_announcements as crud
|
||||
from database.models import UforaAnnouncement, UforaCourse
|
||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||
|
||||
|
||||
async def test_get_courses_with_announcements_none(database_session: AsyncSession):
|
||||
async def test_get_courses_with_announcements_none(postgres):
|
||||
"""Test getting all courses with announcements when there are none"""
|
||||
results = await crud.get_courses_with_announcements(database_session)
|
||||
results = await crud.get_courses_with_announcements(postgres)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
async def test_get_courses_with_announcements(database_session: AsyncSession):
|
||||
async def test_get_courses_with_announcements(postgres):
|
||||
"""Test getting all courses with announcements"""
|
||||
course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
||||
course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False)
|
||||
database_session.add_all([course_1, course_2])
|
||||
await database_session.commit()
|
||||
postgres.add_all([course_1, course_2])
|
||||
await postgres.commit()
|
||||
|
||||
results = await crud.get_courses_with_announcements(database_session)
|
||||
results = await crud.get_courses_with_announcements(postgres)
|
||||
assert len(results) == 1
|
||||
assert results[0] == course_1
|
||||
|
||||
|
||||
async def test_create_new_announcement(ufora_course: UforaCourse, database_session: AsyncSession):
|
||||
async def test_create_new_announcement(ufora_course: UforaCourse, postgres):
|
||||
"""Test creating a new announcement"""
|
||||
await crud.create_new_announcement(
|
||||
database_session, 1, course=ufora_course, publication_date=datetime.datetime.now()
|
||||
)
|
||||
await database_session.refresh(ufora_course)
|
||||
await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now())
|
||||
await postgres.refresh(ufora_course)
|
||||
assert len(ufora_course.announcements) == 1
|
||||
|
||||
|
||||
async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, database_session: AsyncSession):
|
||||
async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres):
|
||||
"""Test removing all stale announcements"""
|
||||
course = ufora_announcement.course
|
||||
ufora_announcement.publication_date -= datetime.timedelta(weeks=2)
|
||||
announcement_2 = UforaAnnouncement(course_id=ufora_announcement.course_id, publication_date=datetime.datetime.now())
|
||||
database_session.add_all([ufora_announcement, announcement_2])
|
||||
await database_session.commit()
|
||||
await database_session.refresh(course)
|
||||
postgres.add_all([ufora_announcement, announcement_2])
|
||||
await postgres.commit()
|
||||
await postgres.refresh(course)
|
||||
assert len(course.announcements) == 2
|
||||
|
||||
await crud.remove_old_announcements(database_session)
|
||||
await crud.remove_old_announcements(postgres)
|
||||
|
||||
await database_session.refresh(course)
|
||||
await postgres.refresh(course)
|
||||
assert len(course.announcements) == 1
|
||||
assert announcement_2.course.announcements[0] == announcement_2
|
||||
|
|
|
@ -1,22 +1,20 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_courses as crud
|
||||
from database.models import UforaCourse
|
||||
from database.schemas.relational import UforaCourse
|
||||
|
||||
|
||||
async def test_get_course_by_name_exact(database_session: AsyncSession, ufora_course: UforaCourse):
|
||||
async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse):
|
||||
"""Test getting a course by its name when the query is an exact match"""
|
||||
match = await crud.get_course_by_name(database_session, "Test")
|
||||
match = await crud.get_course_by_name(postgres, "Test")
|
||||
assert match == ufora_course
|
||||
|
||||
|
||||
async def test_get_course_by_name_substring(database_session: AsyncSession, ufora_course: UforaCourse):
|
||||
async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse):
|
||||
"""Test getting a course by its name when the query is a substring"""
|
||||
match = await crud.get_course_by_name(database_session, "es")
|
||||
match = await crud.get_course_by_name(postgres, "es")
|
||||
assert match == ufora_course
|
||||
|
||||
|
||||
async def test_get_course_by_name_alias(database_session: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||
async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse):
|
||||
"""Test getting a course by its name when the name doesn't match, but the alias does"""
|
||||
match = await crud.get_course_by_name(database_session, "ali")
|
||||
match = await crud.get_course_by_name(postgres, "ali")
|
||||
assert match == ufora_course_with_alias
|
||||
|
|
|
@ -1,25 +1,24 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users as crud
|
||||
from database.models import User
|
||||
from database.schemas.relational import User
|
||||
|
||||
|
||||
async def test_get_or_add_non_existing(database_session: AsyncSession):
|
||||
async def test_get_or_add_non_existing(postgres):
|
||||
"""Test get_or_add for a user that doesn't exist"""
|
||||
await crud.get_or_add(database_session, 1)
|
||||
await crud.get_or_add(postgres, 1)
|
||||
statement = select(User)
|
||||
res = (await database_session.execute(statement)).scalars().all()
|
||||
res = (await postgres.execute(statement)).scalars().all()
|
||||
|
||||
assert len(res) == 1
|
||||
assert res[0].bank is not None
|
||||
assert res[0].nightly_data is not None
|
||||
|
||||
|
||||
async def test_get_or_add_existing(database_session: AsyncSession):
|
||||
async def test_get_or_add_existing(postgres):
|
||||
"""Test get_or_add for a user that does exist"""
|
||||
user = await crud.get_or_add(database_session, 1)
|
||||
user = await crud.get_or_add(postgres, 1)
|
||||
bank = user.bank
|
||||
|
||||
assert await crud.get_or_add(database_session, 1) == user
|
||||
assert (await crud.get_or_add(database_session, 1)).bank == bank
|
||||
assert await crud.get_or_add(postgres, 1) == user
|
||||
assert (await crud.get_or_add(postgres, 1)).bank == bank
|
||||
|
|
|
@ -1,28 +1,24 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.models import UforaCourse
|
||||
from database.schemas.relational import UforaCourse
|
||||
from database.utils.caches import UforaCourseCache
|
||||
|
||||
|
||||
async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||
async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse):
|
||||
"""Test loading the data for the Ufora Course cache when it's empty"""
|
||||
cache = UforaCourseCache()
|
||||
await cache.refresh(database_session)
|
||||
await cache.refresh(postgres)
|
||||
|
||||
assert len(cache.data) == 1
|
||||
assert cache.data == ["test"]
|
||||
assert cache.aliases == {"alias": "test"}
|
||||
|
||||
|
||||
async def test_ufora_course_cache_refresh_not_empty(
|
||||
database_session: AsyncSession, ufora_course_with_alias: UforaCourse
|
||||
):
|
||||
async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse):
|
||||
"""Test loading the data for the Ufora Course cache when it's not empty anymore"""
|
||||
cache = UforaCourseCache()
|
||||
cache.data = ["Something"]
|
||||
cache.data_transformed = ["something"]
|
||||
|
||||
await cache.refresh(database_session)
|
||||
await cache.refresh(postgres)
|
||||
|
||||
assert len(cache.data) == 1
|
||||
assert cache.data == ["test"]
|
||||
|
|
Loading…
Reference in New Issue