diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py index 30aa010..871c34d 100644 --- a/database/crud/dad_jokes.py +++ b/database/crud/dad_jokes.py @@ -1,12 +1,10 @@ -from typing import Optional - from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.not_found import NoResultFoundException from database.models import DadJoke -__all__ = ["add_dad_joke", "edit_dad_joke", "get_random_dad_joke"] +__all__ = ["add_dad_joke", "get_random_dad_joke"] async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke: @@ -18,20 +16,6 @@ async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke: return dad_joke -async def edit_dad_joke(session: AsyncSession, joke_id: int, new_joke: str) -> DadJoke: - """Edit an existing dad joke""" - statement = select(DadJoke).where(DadJoke.dad_joke_id == joke_id) - dad_joke: Optional[DadJoke] = (await session.execute(statement)).scalar_one_or_none() - if dad_joke is None: - raise NoResultFoundException - - dad_joke.joke = new_joke - session.add(dad_joke) - await session.commit() - - return dad_joke - - async def get_random_dad_joke(session: AsyncSession) -> DadJoke: """Return a random database entry""" statement = select(DadJoke).order_by(func.random()) diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py new file mode 100644 index 0000000..0c499c8 --- /dev/null +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -0,0 +1,16 @@ +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import dad_jokes as crud +from database.models import DadJoke + + +async def test_add_dad_joke(database_session: AsyncSession): + """Test creating a new joke""" + statement = select(DadJoke) + result = (await database_session.execute(statement)).scalars().all() + assert len(result) == 0 + + await crud.add_dad_joke(database_session, "joke") + result = (await database_session.execute(statement)).scalars().all() + assert len(result) == 1