Add extra tests

pull/119/head
stijndcl 2022-07-18 22:00:39 +02:00
parent 8227190a8d
commit 1aeaa71ef8
2 changed files with 17 additions and 17 deletions

View File

@ -1,12 +1,10 @@
from typing import Optional
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.models import DadJoke from database.models import DadJoke
__all__ = ["add_dad_joke", "edit_dad_joke", "get_random_dad_joke"] __all__ = ["add_dad_joke", "get_random_dad_joke"]
async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke: async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
@ -18,20 +16,6 @@ async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
return dad_joke return dad_joke
async def edit_dad_joke(session: AsyncSession, joke_id: int, new_joke: str) -> DadJoke:
"""Edit an existing dad joke"""
statement = select(DadJoke).where(DadJoke.dad_joke_id == joke_id)
dad_joke: Optional[DadJoke] = (await session.execute(statement)).scalar_one_or_none()
if dad_joke is None:
raise NoResultFoundException
dad_joke.joke = new_joke
session.add(dad_joke)
await session.commit()
return dad_joke
async def get_random_dad_joke(session: AsyncSession) -> DadJoke: async def get_random_dad_joke(session: AsyncSession) -> DadJoke:
"""Return a random database entry""" """Return a random database entry"""
statement = select(DadJoke).order_by(func.random()) statement = select(DadJoke).order_by(func.random())

View File

@ -0,0 +1,16 @@
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import dad_jokes as crud
from database.models import DadJoke
async def test_add_dad_joke(database_session: AsyncSession):
"""Test creating a new joke"""
statement = select(DadJoke)
result = (await database_session.execute(statement)).scalars().all()
assert len(result) == 0
await crud.add_dad_joke(database_session, "joke")
result = (await database_session.execute(statement)).scalars().all()
assert len(result) == 1