From e6b4c3fd7671c26b60a58c4fd6609fe2402395ff Mon Sep 17 00:00:00 2001 From: stijndcl Date: Mon, 25 Jul 2022 21:20:09 +0200 Subject: [PATCH] Create base model for Mongo schemas --- alembic/env.py | 2 +- database/crud/birthdays.py | 2 +- database/crud/currency.py | 2 +- database/crud/custom_commands.py | 2 +- database/crud/dad_jokes.py | 2 +- database/crud/tasks.py | 2 +- database/crud/ufora_announcements.py | 2 +- database/crud/ufora_courses.py | 2 +- database/crud/users.py | 2 +- database/schemas/__init__.py | 0 database/schemas/mongo.py | 38 +++++++++++++++++++ database/{models.py => schemas/relational.py} | 0 didier/data/embeds/ufora/announcements.py | 2 +- tests/test_database/conftest.py | 8 +++- .../test_database/test_crud/test_birthdays.py | 2 +- .../test_database/test_crud/test_currency.py | 2 +- .../test_crud/test_custom_commands.py | 2 +- .../test_database/test_crud/test_dad_jokes.py | 2 +- tests/test_database/test_crud/test_tasks.py | 2 +- .../test_crud/test_ufora_announcements.py | 2 +- .../test_crud/test_ufora_courses.py | 2 +- tests/test_database/test_crud/test_users.py | 2 +- tests/test_database/test_utils/test_caches.py | 2 +- 23 files changed, 64 insertions(+), 20 deletions(-) create mode 100644 database/schemas/__init__.py create mode 100644 database/schemas/mongo.py rename database/{models.py => schemas/relational.py} (100%) diff --git a/alembic/env.py b/alembic/env.py index 72d5170..beaa206 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine from alembic import context from database.engine import postgres_engine -from database.models import Base +from database.schemas.relational import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index df59dfc..f078488 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from database.crud import users -from database.models import Birthday, User +from database.schemas.relational import Birthday, User __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] diff --git a/database/crud/currency.py b/database/crud/currency.py index 1bb2d11..382801d 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.exceptions import currency as exceptions -from database.models import Bank, NightlyData +from database.schemas.relational import Bank, NightlyData from database.utils.math.currency import ( capacity_upgrade_price, interest_upgrade_price, diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index 85ecf56..d0e86a1 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.models import CustomCommand, CustomCommandAlias +from database.schemas.relational import CustomCommand, CustomCommandAlias __all__ = [ "clean_name", diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py index 871c34d..c481ec3 100644 --- a/database/crud/dad_jokes.py +++ b/database/crud/dad_jokes.py @@ -2,7 +2,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.not_found import NoResultFoundException -from database.models import DadJoke +from database.schemas.relational import DadJoke __all__ = ["add_dad_joke", "get_random_dad_joke"] diff --git a/database/crud/tasks.py b/database/crud/tasks.py index dd1a607..a3b6f38 100644 --- a/database/crud/tasks.py +++ b/database/crud/tasks.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.enums import TaskType -from database.models import Task +from database.schemas.relational import Task from database.utils.datetime import LOCAL_TIMEZONE __all__ = ["get_task_by_enum", "set_last_task_execution_time"] diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 48a06ae..e2dbd16 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import UforaAnnouncement, UforaCourse +from database.schemas.relational import UforaAnnouncement, UforaCourse __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"] diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index d41846c..f6dd853 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import UforaCourse, UforaCourseAlias +from database.schemas.relational import UforaCourse, UforaCourseAlias __all__ = ["get_all_courses", "get_course_by_name"] diff --git a/database/crud/users.py b/database/crud/users.py index ba3011d..3024f26 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.models import Bank, NightlyData, User +from database.schemas.relational import Bank, NightlyData, User __all__ = [ "get_or_add", diff --git a/database/schemas/__init__.py b/database/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/schemas/mongo.py b/database/schemas/mongo.py new file mode 100644 index 0000000..6917466 --- /dev/null +++ b/database/schemas/mongo.py @@ -0,0 +1,38 @@ +from bson import ObjectId +from pydantic import BaseModel, Field + +__all__ = [] + + +class PyObjectId(str): + """Custom type for bson ObjectIds""" + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: str): + """Check that a string is a valid bson ObjectId""" + if not ObjectId.is_valid(value): + raise ValueError(f"Invalid ObjectId: '{value}'") + + return ObjectId(value) + + @classmethod + def __modify_schema__(cls, field_schema: dict): + field_schema.update(type="string") + + +class MongoBase(BaseModel): + """Base model that properly sets the _id field, and adds one by default""" + + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + + class Config: + """Configuration for encoding and construction""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + json_encoders = {ObjectId: str, PyObjectId: str} + use_enum_values = True diff --git a/database/models.py b/database/schemas/relational.py similarity index 100% rename from database/models.py rename to database/schemas/relational.py diff --git a/didier/data/embeds/ufora/announcements.py b/didier/data/embeds/ufora/announcements.py index f4a8bdd..d906ea0 100644 --- a/didier/data/embeds/ufora/announcements.py +++ b/didier/data/embeds/ufora/announcements.py @@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import ufora_announcements as crud -from database.models import UforaCourse +from database.schemas.relational import UforaCourse from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import int_to_weekday from didier.utils.types.string import leading diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index a99c770..b2556c4 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -3,7 +3,13 @@ import datetime import pytest from database.crud import users -from database.models import Bank, UforaAnnouncement, UforaCourse, UforaCourseAlias, User +from database.schemas.relational import ( + Bank, + UforaAnnouncement, + UforaCourse, + UforaCourseAlias, + User, +) @pytest.fixture(scope="session") diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 7433573..21639b1 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -4,7 +4,7 @@ from freezegun import freeze_time from database.crud import birthdays as crud from database.crud import users -from database.models import User +from database.schemas.relational import User async def test_add_birthday_not_present(postgres, user: User): diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index b1e5192..e5cdc0c 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -5,7 +5,7 @@ from freezegun import freeze_time from database.crud import currency as crud from database.exceptions import currency as exceptions -from database.models import Bank +from database.schemas.relational import Bank async def test_add_dinks(postgres, bank: Bank): diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index ec25637..88810d4 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -4,7 +4,7 @@ from sqlalchemy import select from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.models import CustomCommand +from database.schemas.relational import CustomCommand async def test_create_command_non_existing(postgres): diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py index 8138495..22c28c2 100644 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -1,7 +1,7 @@ from sqlalchemy import select from database.crud import dad_jokes as crud -from database.models import DadJoke +from database.schemas.relational import DadJoke async def test_add_dad_joke(postgres): diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py index 4831e03..c4c7ba0 100644 --- a/tests/test_database/test_crud/test_tasks.py +++ b/tests/test_database/test_crud/test_tasks.py @@ -6,7 +6,7 @@ from sqlalchemy import select from database.crud import tasks as crud from database.enums import TaskType -from database.models import Task +from database.schemas.relational import Task @pytest.fixture diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index c6054ff..1aa45ee 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -1,7 +1,7 @@ import datetime from database.crud import ufora_announcements as crud -from database.models import UforaAnnouncement, UforaCourse +from database.schemas.relational import UforaAnnouncement, UforaCourse async def test_get_courses_with_announcements_none(postgres): diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index 5935fd9..34748c0 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -1,5 +1,5 @@ from database.crud import ufora_courses as crud -from database.models import UforaCourse +from database.schemas.relational import UforaCourse async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse): diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index d6584de..e852298 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -1,7 +1,7 @@ from sqlalchemy import select from database.crud import users as crud -from database.models import User +from database.schemas.relational import User async def test_get_or_add_non_existing(postgres): diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index 0a19e98..69a6ff2 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -1,4 +1,4 @@ -from database.models import UforaCourse +from database.schemas.relational import UforaCourse from database.utils.caches import UforaCourseCache