Merge pull request #119 from stijndcl/pytest-migrations

Use migrations in tests
pull/121/head
Stijn De Clercq 2022-07-19 18:53:07 +02:00 committed by GitHub
commit 6c225bacc1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 105 additions and 84 deletions

View File

@ -1,8 +1,9 @@
import asyncio import asyncio
from logging.config import fileConfig from logging.config import fileConfig
from alembic import context from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context
from database.engine import engine from database.engine import engine
from database.models import Base from database.models import Base
@ -18,31 +19,6 @@ if config.config_file_name is not None:
target_metadata = Base.metadata target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""Run migrations in 'offline' mode.
This configures the context with just a URL
and not an Engine, though an Engine is acceptable
here as well. By skipping the Engine creation
we don't even need a DBAPI to be available.
Calls to context.execute() here emit the given string to the
script output.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
render_as_batch=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection): def do_run_migrations(connection):
context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True) context.configure(connection=connection, target_metadata=target_metadata, render_as_batch=True)
@ -50,22 +26,26 @@ def do_run_migrations(connection):
context.run_migrations() context.run_migrations()
async def run_migrations_online() -> None: async def run_async_migrations(connectable: AsyncEngine):
"""Run migrations in 'online' mode.
In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = engine
async with connectable.connect() as connection: async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations) await connection.run_sync(do_run_migrations)
await connectable.dispose() await connectable.dispose()
if context.is_offline_mode(): def run_migrations_online() -> None:
run_migrations_offline() """Run migrations in 'online' mode.
else:
asyncio.run(run_migrations_online()) In this scenario we need to create an Engine
and associate a connection with the context.
"""
connectable = context.config.attributes.get("connection", None) or engine
if isinstance(connectable, AsyncEngine):
asyncio.run(run_async_migrations(connectable))
else:
do_run_migrations(connectable)
run_migrations_online()

View File

@ -5,47 +5,52 @@ Revises: b2d511552a1f
Create Date: 2022-06-30 20:02:27.284759 Create Date: 2022-06-30 20:02:27.284759
""" """
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = '0d03c226d881' revision = "0d03c226d881"
down_revision = 'b2d511552a1f' down_revision = "b2d511552a1f"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
def upgrade() -> None: def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.create_table('users', op.create_table("users", sa.Column("user_id", sa.BigInteger(), nullable=False), sa.PrimaryKeyConstraint("user_id"))
sa.Column('user_id', sa.BigInteger(), nullable=False), op.create_table(
sa.PrimaryKeyConstraint('user_id') "bank",
sa.Column("bank_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("dinks", sa.BigInteger(), server_default="0", nullable=False),
sa.Column("interest_level", sa.Integer(), server_default="1", nullable=False),
sa.Column("capacity_level", sa.Integer(), server_default="1", nullable=False),
sa.Column("rob_level", sa.Integer(), server_default="1", nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("bank_id"),
) )
op.create_table('bank', op.create_table(
sa.Column('bank_id', sa.Integer(), nullable=False), "nightly_data",
sa.Column('user_id', sa.BigInteger(), nullable=True), sa.Column("nightly_id", sa.Integer(), nullable=False),
sa.Column('dinks', sa.BigInteger(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column('interest_level', sa.Integer(), nullable=False), sa.Column("last_nightly", sa.DateTime(timezone=True), nullable=True),
sa.Column('capacity_level', sa.Integer(), nullable=False), sa.Column("count", sa.Integer(), server_default="0", nullable=False),
sa.Column('rob_level', sa.Integer(), nullable=False), sa.ForeignKeyConstraint(
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ), ["user_id"],
sa.PrimaryKeyConstraint('bank_id') ["users.user_id"],
) ),
op.create_table('nightly_data', sa.PrimaryKeyConstraint("nightly_id"),
sa.Column('nightly_id', sa.Integer(), nullable=False),
sa.Column('user_id', sa.BigInteger(), nullable=True),
sa.Column('last_nightly', sa.DateTime(timezone=True), nullable=True),
sa.Column('count', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.user_id'], ),
sa.PrimaryKeyConstraint('nightly_id')
) )
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table('nightly_data') op.drop_table("nightly_data")
op.drop_table('bank') op.drop_table("bank")
op.drop_table('users') op.drop_table("users")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -7,15 +7,17 @@ from sqlalchemy.orm import sessionmaker
import settings import settings
encoded_password = quote_plus(settings.DB_PASSWORD) encoded_password = quote_plus(settings.DB_PASSWORD)
url = URL.create(
drivername="postgresql+asyncpg",
username=settings.DB_USERNAME,
password=encoded_password,
host=settings.DB_HOST,
port=settings.DB_PORT,
database=settings.DB_NAME,
)
engine = create_async_engine( engine = create_async_engine(
URL.create( url,
drivername="postgresql+asyncpg",
username=settings.DB_USERNAME,
password=encoded_password,
host=settings.DB_HOST,
port=settings.DB_PORT,
database=settings.DB_NAME,
),
pool_pre_ping=True, pool_pre_ping=True,
future=True, future=True,
) )

View File

@ -1,16 +1,27 @@
import logging import logging
from alembic import config, script from sqlalchemy.ext.asyncio import create_async_engine
from alembic.runtime import migration from sqlalchemy.orm import Session
from database.engine import engine
__all__ = ["ensure_latest_migration"] from alembic import command, config, script
from alembic.config import Config
from alembic.runtime import migration
from database.engine import engine, url
__config_path__ = "alembic.ini"
__migrations_path__ = "alembic/"
cfg = Config(__config_path__)
cfg.set_main_option("script_location", __migrations_path__)
__all__ = ["ensure_latest_migration", "migrate"]
async def ensure_latest_migration(): async def ensure_latest_migration():
"""Make sure we are currently on the latest revision, otherwise raise an exception""" """Make sure we are currently on the latest revision, otherwise raise an exception"""
alembic_config = config.Config("alembic.ini") alembic_script = script.ScriptDirectory.from_config(cfg)
alembic_script = script.ScriptDirectory.from_config(alembic_config)
async with engine.begin() as connection: async with engine.begin() as connection:
current_revision = await connection.run_sync( current_revision = await connection.run_sync(
@ -25,3 +36,20 @@ async def ensure_latest_migration():
) )
logging.error(error_message) logging.error(error_message)
raise RuntimeError(error_message) raise RuntimeError(error_message)
def __execute_upgrade(connection: Session):
cfg.attributes["connection"] = connection
command.upgrade(cfg, "head")
def __execute_downgrade(connection: Session):
cfg.attributes["connection"] = connection
command.downgrade(cfg, "base")
async def migrate(up: bool):
"""Migrate the database upwards or downwards"""
async_engine = create_async_engine(url, echo=True)
async with async_engine.begin() as connection:
await connection.run_sync(__execute_upgrade if up else __execute_downgrade)

View File

@ -6,11 +6,11 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine from database.engine import engine
from database.models import Base from database.migrations import ensure_latest_migration, migrate
from didier import Didier from didier import Didier
@pytest.fixture(scope="session") @pytest.fixture(scope="session", autouse=True)
def event_loop() -> Generator: def event_loop() -> Generator:
loop = asyncio.get_event_loop_policy().new_event_loop() loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop yield loop
@ -19,9 +19,15 @@ def event_loop() -> Generator:
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
async def tables(): async def tables():
"""Initialize a database before the tests, and then tear it down again""" """Initialize a database before the tests, and then tear it down again
async with engine.begin() as connection:
await connection.run_sync(Base.metadata.create_all) Checks that the migrations were successful by asserting that we are currently
on the latest migration
"""
await migrate(up=True)
await ensure_latest_migration()
yield
await migrate(up=False)
@pytest.fixture @pytest.fixture