From 9401111bee4899b4467cddb4a4fd462677f54ca9 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 19 Jul 2022 18:49:22 +0200 Subject: [PATCH] Try to use migrations in tests --- alembic/env.py | 58 ++++++------------- .../0d03c226d881_initial_currency_models.py | 57 +++++++++--------- database/engine.py | 18 +++--- database/migrations.py | 40 +++++++++++-- tests/conftest.py | 16 +++-- 5 files changed, 105 insertions(+), 84 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index f28faf3..3cca2cf 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -1,8 +1,9 @@ import asyncio from logging.config import fileConfig -from alembic import context +from sqlalchemy.ext.asyncio import AsyncEngine +from alembic import context from database.engine import engine from database.models import Base @@ -18,31 +19,6 @@ if config.config_file_name is not None: target_metadata = Base.metadata -def run_migrations_offline() -> None: - """Run migrations in 'offline' mode. - - This configures the context with just a URL - and not an Engine, though an Engine is acceptable - here as well. By skipping the Engine creation - we don't even need a DBAPI to be available. - - Calls to context.execute() here emit the given string to the - script output. - - """ - url = config.get_main_option("sqlalchemy.url") - context.configure( - url=url, - target_metadata=target_metadata, - literal_binds=True, - dialect_opts={"paramstyle": "named"}, - render_as_batch=True, - ) - - with context.begin_transaction(): - context.run_migrations() - - def do_run_migrations(connection): context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True) @@ -50,22 +26,26 @@ def do_run_migrations(connection): context.run_migrations() -async def run_migrations_online() -> None: - """Run migrations in 'online' mode. - - In this scenario we need to create an Engine - and associate a connection with the context. - - """ - connectable = engine - +async def run_async_migrations(connectable: AsyncEngine): async with connectable.connect() as connection: await connection.run_sync(do_run_migrations) await connectable.dispose() -if context.is_offline_mode(): - run_migrations_offline() -else: - asyncio.run(run_migrations_online()) +def run_migrations_online() -> None: + """Run migrations in 'online' mode. + + In this scenario we need to create an Engine + and associate a connection with the context. + + """ + connectable = context.config.attributes.get("connection", None) or engine + + if isinstance(connectable, AsyncEngine): + asyncio.run(run_async_migrations(connectable)) + else: + do_run_migrations(connectable) + + +run_migrations_online() diff --git a/alembic/versions/0d03c226d881_initial_currency_models.py b/alembic/versions/0d03c226d881_initial_currency_models.py index 45a5e26..7478410 100644 --- a/alembic/versions/0d03c226d881_initial_currency_models.py +++ b/alembic/versions/0d03c226d881_initial_currency_models.py @@ -5,47 +5,52 @@ Revises: b2d511552a1f Create Date: 2022-06-30 20:02:27.284759 """ -from alembic import op import sqlalchemy as sa +from alembic import op # revision identifiers, used by Alembic. -revision = '0d03c226d881' -down_revision = 'b2d511552a1f' +revision = "0d03c226d881" +down_revision = "b2d511552a1f" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.create_table('users', - sa.Column('user_id', sa.BigInteger(), nullable=False), - sa.PrimaryKeyConstraint('user_id') + op.create_table("users", sa.Column("user_id", sa.BigInteger(), nullable=False), sa.PrimaryKeyConstraint("user_id")) + op.create_table( + "bank", + sa.Column("bank_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("dinks", sa.BigInteger(), server_default="0", nullable=False), + sa.Column("interest_level", sa.Integer(), server_default="1", nullable=False), + sa.Column("capacity_level", sa.Integer(), server_default="1", nullable=False), + sa.Column("rob_level", sa.Integer(), server_default="1", nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("bank_id"), ) - op.create_table('bank', - sa.Column('bank_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.BigInteger(), nullable=True), - sa.Column('dinks', sa.BigInteger(), nullable=False), - sa.Column('interest_level', sa.Integer(), nullable=False), - sa.Column('capacity_level', sa.Integer(), nullable=False), - sa.Column('rob_level', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), - sa.PrimaryKeyConstraint('bank_id') - ) - op.create_table('nightly_data', - sa.Column('nightly_id', sa.Integer(), nullable=False), - sa.Column('user_id', sa.BigInteger(), nullable=True), - sa.Column('last_nightly', sa.DateTime(timezone=True), nullable=True), - sa.Column('count', sa.Integer(), nullable=False), - sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), - sa.PrimaryKeyConstraint('nightly_id') + op.create_table( + "nightly_data", + sa.Column("nightly_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True), + sa.Column("count", sa.Integer(), server_default="0", nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("nightly_id"), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### - op.drop_table('nightly_data') - op.drop_table('bank') - op.drop_table('users') + op.drop_table("nightly_data") + op.drop_table("bank") + op.drop_table("users") # ### end Alembic commands ### diff --git a/database/engine.py b/database/engine.py index 44b5e92..4330629 100644 --- a/database/engine.py +++ b/database/engine.py @@ -7,15 +7,17 @@ from sqlalchemy.orm import sessionmaker import settings encoded_password = quote_plus(settings.DB_PASSWORD) +url = URL.create( + drivername="postgresql+asyncpg", + username=settings.DB_USERNAME, + password=encoded_password, + host=settings.DB_HOST, + port=settings.DB_PORT, + database=settings.DB_NAME, +) + engine = create_async_engine( - URL.create( - drivername="postgresql+asyncpg", - username=settings.DB_USERNAME, - password=encoded_password, - host=settings.DB_HOST, - port=settings.DB_PORT, - database=settings.DB_NAME, - ), + url, pool_pre_ping=True, future=True, ) diff --git a/database/migrations.py b/database/migrations.py index 1effdff..f81ec2a 100644 --- a/database/migrations.py +++ b/database/migrations.py @@ -1,16 +1,27 @@ import logging -from alembic import config, script -from alembic.runtime import migration -from database.engine import engine +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import Session -__all__ = ["ensure_latest_migration"] +from alembic import command, config, script +from alembic.config import Config +from alembic.runtime import migration +from database.engine import engine, url + +__config_path__ = "alembic.ini" +__migrations_path__ = "alembic/" + + +cfg = Config(__config_path__) +cfg.set_main_option("script_location", __migrations_path__) + + +__all__ = ["ensure_latest_migration", "migrate"] async def ensure_latest_migration(): """Make sure we are currently on the latest revision, otherwise raise an exception""" - alembic_config = config.Config("alembic.ini") - alembic_script = script.ScriptDirectory.from_config(alembic_config) + alembic_script = script.ScriptDirectory.from_config(cfg) async with engine.begin() as connection: current_revision = await connection.run_sync( @@ -25,3 +36,20 @@ async def ensure_latest_migration(): ) logging.error(error_message) raise RuntimeError(error_message) + + +def __execute_upgrade(connection: Session): + cfg.attributes["connection"] = connection + command.upgrade(cfg, "head") + + +def __execute_downgrade(connection: Session): + cfg.attributes["connection"] = connection + command.downgrade(cfg, "base") + + +async def migrate(up: bool): + """Migrate the database upwards or downwards""" + async_engine = create_async_engine(url, echo=True) + async with async_engine.begin() as connection: + await connection.run_sync(__execute_upgrade if up else __execute_downgrade) diff --git a/tests/conftest.py b/tests/conftest.py index b2a1e04..95b44db 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,11 +6,11 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.engine import engine -from database.models import Base +from database.migrations import ensure_latest_migration, migrate from didier import Didier -@pytest.fixture(scope="session") +@pytest.fixture(scope="session", autouse=True) def event_loop() -> Generator: loop = asyncio.get_event_loop_policy().new_event_loop() yield loop @@ -19,9 +19,15 @@ def event_loop() -> Generator: @pytest.fixture(scope="session") async def tables(): - """Initialize a database before the tests, and then tear it down again""" - async with engine.begin() as connection: - await connection.run_sync(Base.metadata.create_all) + """Initialize a database before the tests, and then tear it down again + + Checks that the migrations were successful by asserting that we are currently + on the latest migration + """ + await migrate(up=True) + await ensure_latest_migration() + yield + await migrate(up=False) @pytest.fixture