mirror of https://github.com/stijndcl/didier
Create connection fixture
parent
53a3e0e75a
commit
eb182b71f4
|
@ -24,4 +24,4 @@ env = [
|
||||||
"DB_USERNAME = postgres",
|
"DB_USERNAME = postgres",
|
||||||
"DB_HOST = localhost",
|
"DB_HOST = localhost",
|
||||||
"DISC_TOKEN = token"
|
"DISC_TOKEN = token"
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,6 +1,12 @@
|
||||||
|
import os
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from alembic import command, config
|
from alembic import command, config
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.engine import engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
|
@ -9,7 +15,27 @@ def tables():
|
||||||
Starts from an empty database and runs through all the migrations to check those as well
|
Starts from an empty database and runs through all the migrations to check those as well
|
||||||
while we're at it
|
while we're at it
|
||||||
"""
|
"""
|
||||||
|
print("CWD: ", os.getcwd())
|
||||||
alembic_config = config.Config("alembic.ini")
|
alembic_config = config.Config("alembic.ini")
|
||||||
command.upgrade(alembic_config, "head")
|
command.upgrade(alembic_config, "head")
|
||||||
yield
|
yield
|
||||||
command.downgrade(alembic_config, "base")
|
command.downgrade(alembic_config, "base")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Fixture to create a session for every test
|
||||||
|
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
||||||
|
"""
|
||||||
|
connection = await engine.connect()
|
||||||
|
transaction = await connection.begin()
|
||||||
|
session = AsyncSession(bind=connection, expire_on_commit=False)
|
||||||
|
|
||||||
|
yield session
|
||||||
|
|
||||||
|
# Clean up session & rollback transactions
|
||||||
|
await session.close()
|
||||||
|
if transaction.is_valid:
|
||||||
|
await transaction.rollback()
|
||||||
|
|
||||||
|
await connection.close()
|
||||||
|
|
Loading…
Reference in New Issue