From 8a4baf6bb8f161d40a1b463a6cd42b279f71c155 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Mon, 29 Aug 2022 20:24:42 +0200 Subject: [PATCH] Remove mongo & fix tests --- .pre-commit-config.yaml | 6 - alembic/env.py | 2 +- alembic/versions/38b7c29f10ee_wordle.py | 63 +++++++ database/crud/birthdays.py | 2 +- database/crud/currency.py | 2 +- database/crud/custom_commands.py | 2 +- database/crud/dad_jokes.py | 2 +- database/crud/deadlines.py | 2 +- database/crud/game_stats.py | 59 ------- database/crud/links.py | 2 +- database/crud/memes.py | 2 +- database/crud/tasks.py | 2 +- database/crud/ufora_announcements.py | 2 +- database/crud/ufora_courses.py | 2 +- database/crud/users.py | 2 +- database/crud/wordle.py | 80 ++++----- database/crud/wordle_stats.py | 57 ++++++ database/engine.py | 14 -- database/enums.py | 9 +- database/mongo_types.py | 6 - .../{schemas/relational.py => schemas.py} | 47 +++++ database/schemas/__init__.py | 0 database/schemas/mongo/__init__.py | 0 database/schemas/mongo/common.py | 53 ------ database/schemas/mongo/game_stats.py | 40 ----- database/schemas/mongo/temporary_storage.py | 16 -- database/schemas/mongo/wordle.py | 44 ----- database/utils/caches.py | 24 ++- database/utils/datetime.py | 12 +- didier/cogs/{test_cog.py => debug_cog.py} | 8 +- didier/cogs/other.py | 2 +- didier/data/apis/imgflip.py | 2 +- didier/data/embeds/deadlines.py | 2 +- didier/data/embeds/ufora/announcements.py | 2 +- didier/didier.py | 10 +- didier/views/modals/deadlines.py | 2 +- didier/views/modals/memes.py | 2 +- docker-compose.test.yml | 7 - docker-compose.yml | 13 -- pyproject.toml | 4 - requirements.txt | 1 - settings.py | 7 - tests/conftest.py | 12 +- tests/test_database/conftest.py | 2 +- .../test_database/test_crud/test_birthdays.py | 2 +- .../test_database/test_crud/test_currency.py | 2 +- .../test_crud/test_custom_commands.py | 2 +- .../test_database/test_crud/test_dad_jokes.py | 2 +- .../test_crud/test_game_stats.py | 63 ------- tests/test_database/test_crud/test_tasks.py | 2 +- .../test_crud/test_ufora_announcements.py | 2 +- .../test_crud/test_ufora_courses.py | 2 +- tests/test_database/test_crud/test_users.py | 2 +- tests/test_database/test_crud/test_wordle.py | 162 +++++++++--------- .../test_crud/test_wordle_stats.py | 72 ++++++++ tests/test_database/test_utils/test_caches.py | 2 +- 56 files changed, 406 insertions(+), 539 deletions(-) create mode 100644 alembic/versions/38b7c29f10ee_wordle.py delete mode 100644 database/crud/game_stats.py create mode 100644 database/crud/wordle_stats.py delete mode 100644 database/mongo_types.py rename database/{schemas/relational.py => schemas.py} (80%) delete mode 100644 database/schemas/__init__.py delete mode 100644 database/schemas/mongo/__init__.py delete mode 100644 database/schemas/mongo/common.py delete mode 100644 database/schemas/mongo/game_stats.py delete mode 100644 database/schemas/mongo/temporary_storage.py delete mode 100644 database/schemas/mongo/wordle.py rename didier/cogs/{test_cog.py => debug_cog.py} (73%) delete mode 100644 tests/test_database/test_crud/test_game_stats.py create mode 100644 tests/test_database/test_crud/test_wordle_stats.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7b320a4..10eabb9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -44,9 +44,3 @@ repos: - "flake8-eradicate" - "flake8-isort" - "flake8-simplify" - -- repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.961 - hooks: - - id: mypy - args: [--config, pyproject.toml] diff --git a/alembic/env.py b/alembic/env.py index beaa206..18b7c21 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine from alembic import context from database.engine import postgres_engine -from database.schemas.relational import Base +from database.schemas import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/alembic/versions/38b7c29f10ee_wordle.py b/alembic/versions/38b7c29f10ee_wordle.py new file mode 100644 index 0000000..8fe53b2 --- /dev/null +++ b/alembic/versions/38b7c29f10ee_wordle.py @@ -0,0 +1,63 @@ +"""Wordle + +Revision ID: 38b7c29f10ee +Revises: 36300b558ef1 +Create Date: 2022-08-29 20:21:02.413631 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "38b7c29f10ee" +down_revision = "36300b558ef1" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "wordle_word", + sa.Column("word_id", sa.Integer(), nullable=False), + sa.Column("word", sa.Text(), nullable=False), + sa.Column("day", sa.Date(), nullable=False), + sa.PrimaryKeyConstraint("word_id"), + sa.UniqueConstraint("day"), + ) + op.create_table( + "wordle_guesses", + sa.Column("wordle_guess_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("guess", sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("wordle_guess_id"), + ) + op.create_table( + "wordle_stats", + sa.Column("wordle_stats_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.Column("last_win", sa.Date(), nullable=True), + sa.Column("games", sa.Integer(), server_default="0", nullable=False), + sa.Column("wins", sa.Integer(), server_default="0", nullable=False), + sa.Column("current_streak", sa.Integer(), server_default="0", nullable=False), + sa.Column("highest_streak", sa.Integer(), server_default="0", nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("wordle_stats_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("wordle_stats") + op.drop_table("wordle_guesses") + op.drop_table("wordle_word") + # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index f078488..229ef89 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from database.crud import users -from database.schemas.relational import Birthday, User +from database.schemas import Birthday, User __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] diff --git a/database/crud/currency.py b/database/crud/currency.py index 382801d..f720c69 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.exceptions import currency as exceptions -from database.schemas.relational import Bank, NightlyData +from database.schemas import Bank, NightlyData from database.utils.math.currency import ( capacity_upgrade_price, interest_upgrade_price, diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index d0e86a1..bb6ac0c 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.schemas.relational import CustomCommand, CustomCommandAlias +from database.schemas import CustomCommand, CustomCommandAlias __all__ = [ "clean_name", diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py index 3d673de..30697bd 100644 --- a/database/crud/dad_jokes.py +++ b/database/crud/dad_jokes.py @@ -2,7 +2,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.not_found import NoResultFoundException -from database.schemas.relational import DadJoke +from database.schemas import DadJoke __all__ = ["add_dad_joke", "get_random_dad_joke"] diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index c1b2885..ce7fba5 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from database.schemas.relational import Deadline, UforaCourse +from database.schemas import Deadline, UforaCourse __all__ = ["add_deadline", "get_deadlines"] diff --git a/database/crud/game_stats.py b/database/crud/game_stats.py deleted file mode 100644 index 5634b49..0000000 --- a/database/crud/game_stats.py +++ /dev/null @@ -1,59 +0,0 @@ -import datetime -from typing import Union - -from database.mongo_types import MongoDatabase -from database.schemas.mongo.game_stats import GameStats - -__all__ = ["get_game_stats", "complete_wordle_game"] - -from database.utils.datetime import today_only_date - - -async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats: - """Get a user's game stats - - If no entry is found, it is first created - """ - collection = database[GameStats.collection()] - stats = await collection.find_one({"user_id": user_id}) - if stats is not None: - return GameStats(**stats) - - stats = GameStats(user_id=user_id) - await collection.insert_one(stats.dict(by_alias=True)) - return stats - - -async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0): - """Update the user's Wordle stats""" - stats = await get_game_stats(database, user_id) - - update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}} - - if win: - update["$inc"]["wordle.wins"] = 1 - update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1 - - # Update streak - today = today_only_date() - last_win = stats.wordle.last_win - update["$set"]["wordle.last_win"] = today - - if last_win is None or (today - last_win).days > 1: - # Never won a game before or streak is over - update["$set"]["wordle.current_streak"] = 1 - stats.wordle.current_streak = 1 - else: - # On a streak: increase counter - update["$inc"]["wordle.current_streak"] = 1 - stats.wordle.current_streak += 1 - - # Update max streak if necessary - if stats.wordle.current_streak > stats.wordle.max_streak: - update["$set"]["wordle.max_streak"] = stats.wordle.current_streak - else: - # Streak is over - update["$set"]["wordle.current_streak"] = 0 - - collection = database[GameStats.collection()] - await collection.update_one({"_id": stats.id}, update) diff --git a/database/crud/links.py b/database/crud/links.py index e97c328..495e0f3 100644 --- a/database/crud/links.py +++ b/database/crud/links.py @@ -4,7 +4,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions import NoResultFoundException -from database.schemas.relational import Link +from database.schemas import Link __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] diff --git a/database/crud/memes.py b/database/crud/memes.py index f92f6ef..ab288aa 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -4,7 +4,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas.relational import MemeTemplate +from database.schemas import MemeTemplate __all__ = ["add_meme", "get_all_memes", "get_meme_by_name"] diff --git a/database/crud/tasks.py b/database/crud/tasks.py index a3b6f38..f66ffc5 100644 --- a/database/crud/tasks.py +++ b/database/crud/tasks.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.enums import TaskType -from database.schemas.relational import Task +from database.schemas import Task from database.utils.datetime import LOCAL_TIMEZONE __all__ = ["get_task_by_enum", "set_last_task_execution_time"] diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index e2dbd16..688bcc7 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas.relational import UforaAnnouncement, UforaCourse +from database.schemas import UforaAnnouncement, UforaCourse __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"] diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index f6dd853..19369c1 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas.relational import UforaCourse, UforaCourseAlias +from database.schemas import UforaCourse, UforaCourseAlias __all__ = ["get_all_courses", "get_course_by_name"] diff --git a/database/crud/users.py b/database/crud/users.py index 3024f26..8f885b6 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas.relational import Bank, NightlyData, User +from database.schemas import Bank, NightlyData, User __all__ = [ "get_or_add", diff --git a/database/crud/wordle.py b/database/crud/wordle.py index 4b2a312..ea8b892 100644 --- a/database/crud/wordle.py +++ b/database/crud/wordle.py @@ -1,56 +1,45 @@ +import datetime from typing import Optional -from database.enums import TempStorageKey -from database.mongo_types import MongoDatabase -from database.schemas.mongo.temporary_storage import TemporaryStorage -from database.schemas.mongo.wordle import WordleGame -from database.utils.datetime import today_only_date +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas import WordleGuess, WordleWord __all__ = [ "get_active_wordle_game", "make_wordle_guess", - "start_new_wordle_game", "set_daily_word", "reset_wordle_games", ] -async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]: +async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]: """Find a player's active game""" - collection = database[WordleGame.collection()] - result = await collection.find_one({"user_id": user_id}) - if result is None: - return None - - return WordleGame(**result) + statement = select(WordleGuess).where(WordleGuess.user_id == user_id) + guesses = (await session.execute(statement)).scalars().all() + return guesses -async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame: - """Start a new game""" - collection = database[WordleGame.collection()] - game = WordleGame(user_id=user_id) - await collection.insert_one(game.dict(by_alias=True)) - return game - - -async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str): +async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str): """Make a guess in your current game""" - collection = database[WordleGame.collection()] - await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}}) + guess_instance = WordleGuess(user_id=user_id, guess=guess) + session.add(guess_instance) + await session.commit() -async def get_daily_word(database: MongoDatabase) -> Optional[str]: +async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]: """Get the word of today""" - collection = database[TemporaryStorage.collection()] + statement = select(WordleWord).where(WordleWord.day == datetime.date.today()) + row = (await session.execute(statement)).scalar_one_or_none() - result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()}) - if result is None: + if row is None: return None - return result["word"] + return row -async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str: +async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str: """Set the word of today This does NOT overwrite the existing word if there is one, so that it can safely run @@ -60,23 +49,28 @@ async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = F Returns the word that was chosen. If one already existed, return that instead. """ - collection = database[TemporaryStorage.collection()] + current_word = await get_daily_word(session) - current_word = None if forced else await get_daily_word(database) - if current_word is not None: - return current_word + if current_word is None: + current_word = WordleWord(word=word, day=datetime.date.today()) + session.add(current_word) + await session.commit() - await collection.update_one( - {"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True - ) + # Remove all active games + await reset_wordle_games(session) + elif forced: + current_word.word = word + current_word.day = datetime.date.today() + session.add(current_word) + await session.commit() - # Remove all active games - await reset_wordle_games(database) + # Remove all active games + await reset_wordle_games(session) - return word + return current_word.word -async def reset_wordle_games(database: MongoDatabase): +async def reset_wordle_games(session: AsyncSession): """Reset all active games""" - collection = database[WordleGame.collection()] - await collection.drop() + statement = delete(WordleGuess) + await session.execute(statement) diff --git a/database/crud/wordle_stats.py b/database/crud/wordle_stats.py new file mode 100644 index 0000000..9d7a77d --- /dev/null +++ b/database/crud/wordle_stats.py @@ -0,0 +1,57 @@ +from datetime import date + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas import WordleStats + +__all__ = ["get_wordle_stats", "complete_wordle_game"] + + +async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats: + """Get a user's wordle stats + + If no entry is found, it is first created + """ + statement = select(WordleStats).where(WordleStats.user_id == user_id) + stats = (await session.execute(statement)).scalar_one_or_none() + if stats is not None: + return stats + + stats = WordleStats(user_id=user_id) + session.add(stats) + await session.commit() + await session.refresh(stats) + + return stats + + +async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool): + """Update the user's Wordle stats""" + stats = await get_wordle_stats(session, user_id) + stats.games += 1 + + if win: + stats.wins += 1 + + # Update streak + today = date.today() + last_win = stats.last_win + stats.last_win = today + + if last_win is None or (today - last_win).days > 1: + # Never won a game before or streak is over + stats.current_streak = 1 + else: + # On a streak: increase counter + stats.current_streak += 1 + + # Update max streak if necessary + if stats.current_streak > stats.highest_streak: + stats.highest_streak = stats.current_streak + else: + # Streak is over + stats.current_streak = 0 + + session.add(stats) + await session.commit() diff --git a/database/engine.py b/database/engine.py index b3e3947..23e5b89 100644 --- a/database/engine.py +++ b/database/engine.py @@ -1,6 +1,5 @@ from urllib.parse import quote_plus -import motor.motor_asyncio from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -26,16 +25,3 @@ postgres_engine = create_async_engine( DBSession = sessionmaker( autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False ) - -# MongoDB client -if not settings.TESTING: # pragma: no cover - encoded_mongo_username = quote_plus(settings.MONGO_USER) - encoded_mongo_password = quote_plus(settings.MONGO_PASS) - mongo_url = ( - f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/" - ) -else: - # Require no authentication when testing - mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/" - -mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url) diff --git a/database/enums.py b/database/enums.py index f64998e..3f75130 100644 --- a/database/enums.py +++ b/database/enums.py @@ -1,6 +1,6 @@ import enum -__all__ = ["TaskType", "TempStorageKey"] +__all__ = ["TaskType"] # There is a bug in typeshed that causes an incorrect PyCharm warning @@ -11,10 +11,3 @@ class TaskType(enum.IntEnum): BIRTHDAYS = enum.auto() UFORA_ANNOUNCEMENTS = enum.auto() - - -@enum.unique -class TempStorageKey(str, enum.Enum): - """Enum for keys to distinguish the TemporaryStorage rows""" - - WORDLE_WORD = "wordle_word" diff --git a/database/mongo_types.py b/database/mongo_types.py deleted file mode 100644 index 11f5b7a..0000000 --- a/database/mongo_types.py +++ /dev/null @@ -1,6 +0,0 @@ -import motor.motor_asyncio - -# Type aliases for the Motor types, which are way too long -MongoClient = motor.motor_asyncio.AsyncIOMotorClient -MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase -MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection diff --git a/database/schemas/relational.py b/database/schemas.py similarity index 80% rename from database/schemas/relational.py rename to database/schemas.py index 904459e..182653f 100644 --- a/database/schemas/relational.py +++ b/database/schemas.py @@ -37,6 +37,9 @@ __all__ = [ "UforaCourse", "UforaCourseAlias", "User", + "WordleGuess", + "WordleStats", + "WordleWord", ] @@ -231,3 +234,47 @@ class User(Base): nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) + wordle_guesses: list[WordleGuess] = relationship( + "WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + ) + wordle_stats: WordleStats = relationship( + "WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + ) + + +class WordleGuess(Base): + """A user's Wordle guesses for today""" + + __tablename__ = "wordle_guesses" + + wordle_guess_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + guess: str = Column(Text, nullable=False) + + user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin") + + +class WordleStats(Base): + """Stats about a user's wordle performance""" + + __tablename__ = "wordle_stats" + + wordle_stats_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + last_win: Optional[date] = Column(Date, nullable=True) + games: int = Column(Integer, server_default="0", nullable=False) + wins: int = Column(Integer, server_default="0", nullable=False) + current_streak: int = Column(Integer, server_default="0", nullable=False) + highest_streak: int = Column(Integer, server_default="0", nullable=False) + + user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin") + + +class WordleWord(Base): + """The current Wordle word""" + + __tablename__ = "wordle_word" + + word_id: int = Column(Integer, primary_key=True) + word: str = Column(Text, nullable=False) + day: date = Column(Date, nullable=False, unique=True) diff --git a/database/schemas/__init__.py b/database/schemas/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/database/schemas/mongo/__init__.py b/database/schemas/mongo/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/database/schemas/mongo/common.py b/database/schemas/mongo/common.py deleted file mode 100644 index dcddb54..0000000 --- a/database/schemas/mongo/common.py +++ /dev/null @@ -1,53 +0,0 @@ -from abc import ABC, abstractmethod - -from bson import ObjectId -from pydantic import BaseModel, Field - -__all__ = ["PyObjectId", "MongoBase", "MongoCollection"] - - -class PyObjectId(ObjectId): - """Custom type for bson ObjectIds""" - - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, value: str): - """Check that a string is a valid bson ObjectId""" - if not ObjectId.is_valid(value): - raise ValueError(f"Invalid ObjectId: '{value}'") - - return ObjectId(value) - - @classmethod - def __modify_schema__(cls, field_schema: dict): - field_schema.update(type="string") - - -class MongoBase(BaseModel): - """Base model that properly sets the _id field, and adds one by default""" - - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") - - class Config: - """Configuration for encoding and construction""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True - json_encoders = {ObjectId: str, PyObjectId: str} - use_enum_values = True - - -class MongoCollection(MongoBase, ABC): - """Base model for the 'main class' in a collection - - This field stores the name of the collection to avoid making typos against it - """ - - @staticmethod - @abstractmethod - def collection() -> str: - """Getter for the name of the collection, in order to avoid typos""" - raise NotImplementedError diff --git a/database/schemas/mongo/game_stats.py b/database/schemas/mongo/game_stats.py deleted file mode 100644 index af08117..0000000 --- a/database/schemas/mongo/game_stats.py +++ /dev/null @@ -1,40 +0,0 @@ -import datetime -from typing import Optional - -from overrides import overrides -from pydantic import BaseModel, Field, validator - -from database.schemas.mongo.common import MongoCollection - -__all__ = ["GameStats", "WordleStats"] - - -class WordleStats(BaseModel): - """Model that holds stats about a player's Wordle performance""" - - guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) - last_win: Optional[datetime.datetime] = None - wins: int = 0 - games: int = 0 - current_streak: int = 0 - max_streak: int = 0 - - @validator("guess_distribution") - def validate_guesses_length(cls, value: list[int]): - """Check that the distribution of guesses is of the correct length""" - if len(value) != 6: - raise ValueError(f"guess_distribution must be length 6, found {len(value)}") - - return value - - -class GameStats(MongoCollection): - """Collection that holds stats about how well a user has performed in games""" - - user_id: int - wordle: WordleStats = WordleStats() - - @staticmethod - @overrides - def collection() -> str: - return "game_stats" diff --git a/database/schemas/mongo/temporary_storage.py b/database/schemas/mongo/temporary_storage.py deleted file mode 100644 index deb444b..0000000 --- a/database/schemas/mongo/temporary_storage.py +++ /dev/null @@ -1,16 +0,0 @@ -from overrides import overrides - -from database.schemas.mongo.common import MongoCollection - -__all__ = ["TemporaryStorage"] - - -class TemporaryStorage(MongoCollection): - """Collection for lots of random things that don't belong in a full-blown collection""" - - key: str - - @staticmethod - @overrides - def collection() -> str: - return "temporary" diff --git a/database/schemas/mongo/wordle.py b/database/schemas/mongo/wordle.py deleted file mode 100644 index b03189b..0000000 --- a/database/schemas/mongo/wordle.py +++ /dev/null @@ -1,44 +0,0 @@ -import datetime - -from overrides import overrides -from pydantic import Field, validator - -from database.constants import WORDLE_GUESS_COUNT -from database.schemas.mongo.common import MongoCollection -from database.utils.datetime import today_only_date - -__all__ = ["WordleGame"] - - -class WordleGame(MongoCollection): - """Collection that holds people's active Wordle games""" - - day: datetime.datetime = Field(default_factory=lambda: today_only_date()) - guesses: list[str] = Field(default_factory=list) - user_id: int - - @staticmethod - @overrides - def collection() -> str: - return "wordle" - - @validator("guesses") - def validate_guesses_length(cls, value: list[int]): - """Check that the amount of guesses is of the correct length""" - if len(value) > 6: - raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") - - return value - - def is_game_over(self, word: str) -> bool: - """Check if the current game is over""" - # No guesses yet - if not self.guesses: - return False - - # Max amount of guesses allowed - if len(self.guesses) == WORDLE_GUESS_COUNT: - return True - - # Found the correct word - return self.guesses[-1] == word diff --git a/database/utils/caches.py b/database/utils/caches.py index ba26ab2..3911a0e 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -1,19 +1,15 @@ from abc import ABC, abstractmethod -from typing import Generic, TypeVar from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession from database.crud import links, memes, ufora_courses, wordle -from database.mongo_types import MongoDatabase __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] -T = TypeVar("T") - -class DatabaseCache(ABC, Generic[T]): +class DatabaseCache(ABC): """Base class for a simple cache-like structure The goal of this class is to store data for Discord auto-completion results @@ -25,7 +21,7 @@ class DatabaseCache(ABC, Generic[T]): Considering the fact that a user isn't obligated to choose something from the suggestions, chances are high we have to go to the database for the final action either way. - Also stores the data in lowercase to allow fast searching + Also stores the data in lowercase to allow fast searching. """ data: list[str] = [] @@ -36,7 +32,7 @@ class DatabaseCache(ABC, Generic[T]): self.data.clear() @abstractmethod - async def invalidate(self, database_session: T): + async def invalidate(self, database_session: AsyncSession): """Invalidate the data stored in this cache""" def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: @@ -48,7 +44,7 @@ class DatabaseCache(ABC, Generic[T]): return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] -class LinkCache(DatabaseCache[AsyncSession]): +class LinkCache(DatabaseCache): """Cache to store the names of links""" @overrides @@ -61,7 +57,7 @@ class LinkCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) -class MemeCache(DatabaseCache[AsyncSession]): +class MemeCache(DatabaseCache): """Cache to store the names of meme templates""" @overrides @@ -74,7 +70,7 @@ class MemeCache(DatabaseCache[AsyncSession]): self.data_transformed = list(map(str.lower, self.data)) -class UforaCourseCache(DatabaseCache[AsyncSession]): +class UforaCourseCache(DatabaseCache): """Cache to store the names of Ufora courses""" # Also store the aliases to add additional support @@ -119,10 +115,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] -class WordleCache(DatabaseCache[MongoDatabase]): +class WordleCache(DatabaseCache): """Cache to store the current daily Wordle word""" - async def invalidate(self, database_session: MongoDatabase): + async def invalidate(self, database_session: AsyncSession): word = await wordle.get_daily_word(database_session) if word is not None: self.data = [word] @@ -142,9 +138,9 @@ class CacheManager: self.ufora_courses = UforaCourseCache() self.wordle_word = WordleCache() - async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): + async def initialize_caches(self, postgres_session: AsyncSession): """Initialize the contents of all caches""" await self.links.invalidate(postgres_session) await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) - await self.wordle_word.invalidate(mongo_db) + await self.wordle_word.invalidate(postgres_session) diff --git a/database/utils/datetime.py b/database/utils/datetime.py index 0952f63..8450e84 100644 --- a/database/utils/datetime.py +++ b/database/utils/datetime.py @@ -1,15 +1,5 @@ -import datetime import zoneinfo -__all__ = ["LOCAL_TIMEZONE", "today_only_date"] +__all__ = ["LOCAL_TIMEZONE"] LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") - - -def today_only_date() -> datetime.datetime: - """Mongo can't handle datetime.date, so we need a datetime instance - - We do, however, only care about the date, so remove all the rest - """ - today = datetime.date.today() - return datetime.datetime(year=today.year, month=today.month, day=today.day) diff --git a/didier/cogs/test_cog.py b/didier/cogs/debug_cog.py similarity index 73% rename from didier/cogs/test_cog.py rename to didier/cogs/debug_cog.py index c304f67..2d03b9f 100644 --- a/didier/cogs/test_cog.py +++ b/didier/cogs/debug_cog.py @@ -4,7 +4,7 @@ from overrides import overrides from didier import Didier -class TestCog(commands.Cog): +class DebugCog(commands.Cog): """Testing cog for dev purposes""" client: Didier @@ -16,11 +16,11 @@ class TestCog(commands.Cog): async def cog_check(self, ctx: commands.Context) -> bool: return await self.client.is_owner(ctx.author) - @commands.command() - async def test(self, ctx: commands.Context): + @commands.command(aliases=["Dev"]) + async def debug(self, ctx: commands.Context): """Debugging command""" async def setup(client: Didier): """Load the cog""" - await client.add_cog(TestCog(client)) + await client.add_cog(DebugCog(client)) diff --git a/didier/cogs/other.py b/didier/cogs/other.py index 870c63b..dea9cd2 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -5,7 +5,7 @@ from discord import app_commands from discord.ext import commands from database.crud.links import get_link_by_name -from database.schemas.relational import Link +from database.schemas import Link from didier import Didier from didier.data.apis import urban_dictionary from didier.data.embeds.google import GoogleSearch diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py index b373897..c44eea6 100644 --- a/didier/data/apis/imgflip.py +++ b/didier/data/apis/imgflip.py @@ -1,7 +1,7 @@ from aiohttp import ClientSession import settings -from database.schemas.relational import MemeTemplate +from database.schemas import MemeTemplate from didier.exceptions.missing_env import MissingEnvironmentVariable from didier.utils.http.requests import ensure_post diff --git a/didier/data/embeds/deadlines.py b/didier/data/embeds/deadlines.py index 371eee9..fe3c988 100644 --- a/didier/data/embeds/deadlines.py +++ b/didier/data/embeds/deadlines.py @@ -4,7 +4,7 @@ from datetime import datetime import discord from overrides import overrides -from database.schemas.relational import Deadline +from database.schemas import Deadline from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import tz_aware_now from didier.utils.types.string import get_edu_year_name diff --git a/didier/data/embeds/ufora/announcements.py b/didier/data/embeds/ufora/announcements.py index 59e8290..31b4950 100644 --- a/didier/data/embeds/ufora/announcements.py +++ b/didier/data/embeds/ufora/announcements.py @@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import ufora_announcements as crud -from database.schemas.relational import UforaCourse +from database.schemas import UforaCourse from didier.data.embeds.base import EmbedBaseModel from didier.utils.discord.colours import ghent_university_blue from didier.utils.types.datetime import int_to_weekday diff --git a/didier/didier.py b/didier/didier.py index 337b342..12bd5b4 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -2,7 +2,6 @@ import logging import os import discord -import motor.motor_asyncio from aiohttp import ClientSession from discord.app_commands import AppCommandError from discord.ext import commands @@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import custom_commands -from database.engine import DBSession, mongo_client +from database.engine import DBSession from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.exceptions import HTTPException, NoMatch @@ -55,11 +54,6 @@ class Didier(commands.Bot): """Obtain a session for the PostgreSQL database""" return DBSession() - @property - def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: - """Obtain a reference to the MongoDB database""" - return mongo_client[settings.MONGO_DB] - async def setup_hook(self) -> None: """Do some initial setup @@ -71,7 +65,7 @@ class Didier(commands.Bot): # Initialize caches self.database_caches = CacheManager() async with self.postgres_session as session: - await self.database_caches.initialize_caches(session, self.mongo_db) + await self.database_caches.initialize_caches(session) # Load extensions await self._load_initial_extensions() diff --git a/didier/views/modals/deadlines.py b/didier/views/modals/deadlines.py index cd2a26c..972cfe1 100644 --- a/didier/views/modals/deadlines.py +++ b/didier/views/modals/deadlines.py @@ -5,7 +5,7 @@ from discord import Interaction from overrides import overrides from database.crud.deadlines import add_deadline -from database.schemas.relational import UforaCourse +from database.schemas import UforaCourse __all__ = ["AddDeadline"] diff --git a/didier/views/modals/memes.py b/didier/views/modals/memes.py index c98e17f..4f5518e 100644 --- a/didier/views/modals/memes.py +++ b/didier/views/modals/memes.py @@ -3,7 +3,7 @@ import traceback import discord.ui from overrides import overrides -from database.schemas.relational import MemeTemplate +from database.schemas import MemeTemplate from didier import Didier from didier.data.apis.imgflip import generate_meme diff --git a/docker-compose.test.yml b/docker-compose.test.yml index 4033ad4..d841c33 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -10,10 +10,3 @@ services: - POSTGRES_PASSWORD=pytest ports: - "5433:5432" - mongo-pytest: - image: mongo:5.0 - restart: always - environment: - - MONGO_INITDB_DATABASE=didier_pytest - ports: - - "27018:27017" diff --git a/docker-compose.yml b/docker-compose.yml index 4af0a0e..b9c5ee3 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,18 +12,5 @@ services: - "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}" volumes: - postgres:/var/lib/postgresql/data - mongo: - image: mongo:5.0 - restart: always - environment: - - MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root} - - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root} - - MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev} - command: [--auth] - ports: - - "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}" - volumes: - - mongo:/data/db volumes: postgres: - mongo: diff --git a/pyproject.toml b/pyproject.toml index 45c7c9a..dcb0d24 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,9 +44,6 @@ ignore_missing_imports = true asyncio_mode = "auto" env = [ "TESTING = 1", - "MONGO_DB = didier_pytest", - "MONGO_HOST = localhost", - "MONGO_PORT = 27018", "POSTGRES_DB = didier_pytest", "POSTGRES_USER = pytest", "POSTGRES_PASS = pytest", @@ -55,6 +52,5 @@ env = [ "DISCORD_TOKEN = token" ] markers = [ - "mongo: tests that use MongoDB", "postgres: tests that use PostgreSQL" ] diff --git a/requirements.txt b/requirements.txt index 748c632..759cb54 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,7 +7,6 @@ git+https://github.com/Rapptz/discord-ext-menus@8686b5d environs==9.5.0 feedparser==6.0.10 markdownify==0.11.2 -motor==3.0.0 overrides==6.1.0 pydantic==1.9.1 python-dateutil==2.8.2 diff --git a/settings.py b/settings.py index 349c6e2..7698cc9 100644 --- a/settings.py +++ b/settings.py @@ -37,13 +37,6 @@ SEMESTER: int = env.int("SEMESTER", 2) YEAR: int = env.int("YEAR", 3) """Database""" -# MongoDB -MONGO_DB: str = env.str("MONGO_DB", "didier") -MONGO_USER: str = env.str("MONGO_USER", "root") -MONGO_PASS: str = env.str("MONGO_PASS", "root") -MONGO_HOST: str = env.str("MONGO_HOST", "localhost") -MONGO_PORT: int = env.int("MONGO_PORT", "27017") - # PostgreSQL POSTGRES_DB: str = env.str("POSTGRES_DB", "didier") POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres") diff --git a/tests/conftest.py b/tests/conftest.py index 55919fd..c218524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,12 +2,10 @@ import asyncio from typing import AsyncGenerator, Generator from unittest.mock import MagicMock -import motor.motor_asyncio import pytest from sqlalchemy.ext.asyncio import AsyncSession -import settings -from database.engine import mongo_client, postgres_engine +from database.engine import postgres_engine from database.migrations import ensure_latest_migration, migrate from didier import Didier @@ -56,14 +54,6 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]: await connection.close() -@pytest.fixture -async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase: - """Fixture to get a MongoDB connection""" - database = mongo_client[settings.MONGO_DB] - yield database - mongo_client.drop_database(settings.MONGO_DB) - - @pytest.fixture def mock_client() -> Didier: """Fixture to get a mock Didier instance diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index dc5ce2a..8383b05 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users -from database.schemas.relational import ( +from database.schemas import ( Bank, UforaAnnouncement, UforaCourse, diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index e7f2242..86740d5 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud from database.crud import users -from database.schemas.relational import User +from database.schemas import User async def test_add_birthday_not_present(postgres: AsyncSession, user: User): diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index 8bd7e8f..1beddc6 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud from database.exceptions import currency as exceptions -from database.schemas.relational import Bank +from database.schemas import Bank async def test_add_dinks(postgres: AsyncSession, bank: Bank): diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index 6f141bc..8155068 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.schemas.relational import CustomCommand +from database.schemas import CustomCommand async def test_create_command_non_existing(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py index f34d0fa..99b82de 100644 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -2,7 +2,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.crud import dad_jokes as crud -from database.schemas.relational import DadJoke +from database.schemas import DadJoke async def test_add_dad_joke(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_game_stats.py b/tests/test_database/test_crud/test_game_stats.py deleted file mode 100644 index 4b7606f..0000000 --- a/tests/test_database/test_crud/test_game_stats.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest -from freezegun import freeze_time - -from database.crud import game_stats as crud -from database.mongo_types import MongoDatabase -from database.schemas.mongo.game_stats import GameStats -from database.utils.datetime import today_only_date - - -async def insert_game_stats(mongodb: MongoDatabase, stats: GameStats): - """Helper function to insert some stats""" - collection = mongodb[GameStats.collection()] - await collection.insert_one(stats.dict(by_alias=True)) - - -@pytest.mark.mongo -async def test_get_stats_non_existent_creates(mongodb: MongoDatabase, test_user_id: int): - """Test getting a user's stats when the db is empty""" - collection = mongodb[GameStats.collection()] - assert await collection.find_one({"user_id": test_user_id}) is None - await crud.get_game_stats(mongodb, test_user_id) - assert await collection.find_one({"user_id": test_user_id}) is not None - - -@pytest.mark.mongo -async def test_get_stats_existing_returns(mongodb: MongoDatabase, test_user_id: int): - """Test getting a user's stats when there's already an entry present""" - stats = GameStats(user_id=test_user_id) - stats.wordle.games = 20 - await insert_game_stats(mongodb, stats) - found_stats = await crud.get_game_stats(mongodb, test_user_id) - assert found_stats.wordle.games == 20 - - -@pytest.mark.mongo -@freeze_time("2022-07-30") -async def test_complete_wordle_game_won(mongodb: MongoDatabase, test_user_id: int): - """Test completing a wordle game when you win""" - await crud.complete_wordle_game(mongodb, test_user_id, win=True, guesses=2) - stats = await crud.get_game_stats(mongodb, test_user_id) - assert stats.wordle.guess_distribution == [0, 1, 0, 0, 0, 0] - assert stats.wordle.games == 1 - assert stats.wordle.wins == 1 - assert stats.wordle.current_streak == 1 - assert stats.wordle.max_streak == 1 - assert stats.wordle.last_win == today_only_date() - - -@pytest.mark.mongo -@freeze_time("2022-07-30") -async def test_complete_wordle_game_lost(mongodb: MongoDatabase, test_user_id: int): - """Test completing a wordle game when you lose""" - stats = GameStats(user_id=test_user_id) - stats.wordle.current_streak = 10 - await insert_game_stats(mongodb, stats) - - await crud.complete_wordle_game(mongodb, test_user_id, win=False) - stats = await crud.get_game_stats(mongodb, test_user_id) - - # Check that streak was broken - assert stats.wordle.current_streak == 0 - assert stats.wordle.games == 1 - assert stats.wordle.guess_distribution == [0, 0, 0, 0, 0, 0] diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py index b13b221..f3adede 100644 --- a/tests/test_database/test_crud/test_tasks.py +++ b/tests/test_database/test_crud/test_tasks.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import tasks as crud from database.enums import TaskType -from database.schemas.relational import Task +from database.schemas import Task @pytest.fixture diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index 34f4222..3621a82 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_announcements as crud -from database.schemas.relational import UforaAnnouncement, UforaCourse +from database.schemas import UforaAnnouncement, UforaCourse async def test_get_courses_with_announcements_none(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index 140bc4a..601ba8c 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -1,7 +1,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_courses as crud -from database.schemas.relational import UforaCourse +from database.schemas import UforaCourse async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse): diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index 96d3383..b726fab 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -2,7 +2,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users as crud -from database.schemas.relational import User +from database.schemas import User async def test_get_or_add_non_existing(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_wordle.py b/tests/test_database/test_crud/test_wordle.py index 3ddc979..f3aeb2a 100644 --- a/tests/test_database/test_crud/test_wordle.py +++ b/tests/test_database/test_crud/test_wordle.py @@ -1,136 +1,140 @@ -from datetime import datetime, timedelta +from datetime import date, timedelta import pytest from freezegun import freeze_time +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import wordle as crud -from database.enums import TempStorageKey -from database.mongo_types import MongoCollection, MongoDatabase -from database.schemas.mongo.temporary_storage import TemporaryStorage -from database.schemas.mongo.wordle import WordleGame +from database.schemas import User, WordleGuess, WordleWord @pytest.fixture -async def wordle_collection(mongodb: MongoDatabase) -> MongoCollection: - """Fixture to get a reference to the wordle collection""" - yield mongodb[WordleGame.collection()] +async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]: + """Fixture to generate some guesses""" + guesses = [] + + for guess in ["TEST", "WORDLE", "WORDS"]: + guess = WordleGuess(user_id=user.user_id, guess=guess) + postgres.add(guess) + await postgres.commit() + + guesses.append(guess) + + return guesses -@pytest.fixture -async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> WordleGame: - """Fixture to create a new game""" - game = WordleGame(user_id=test_user_id) - await wordle_collection.insert_one(game.dict(by_alias=True)) - yield game - - -@pytest.mark.mongo -async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int): - """Test starting a new game""" - result = await wordle_collection.find_one({"user_id": test_user_id}) - assert result is None - - await crud.start_new_wordle_game(mongodb, test_user_id) - - result = await wordle_collection.find_one({"user_id": test_user_id}) - assert result is not None - - -@pytest.mark.mongo -async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int): +@pytest.mark.postgres +async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User): """Test getting an active game when there is none""" - result = await crud.get_active_wordle_game(mongodb, test_user_id) - assert result is None + result = await crud.get_active_wordle_game(postgres, user.user_id) + assert not result -@pytest.mark.mongo -async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame): +@pytest.mark.postgres +async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]): """Test getting an active game when there is one""" - result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id) - assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True) + result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id) + assert result == wordle_guesses -@pytest.mark.mongo -async def test_get_daily_word_none(mongodb: MongoDatabase): +@pytest.mark.postgres +async def test_get_daily_word_none(postgres: AsyncSession): """Test getting the daily word when the database is empty""" - result = await crud.get_daily_word(mongodb) + result = await crud.get_daily_word(postgres) assert result is None -@pytest.mark.mongo +@pytest.mark.postgres @freeze_time("2022-07-30") -async def test_get_daily_word_not_today(mongodb: MongoDatabase): +async def test_get_daily_word_not_today(postgres: AsyncSession): """Test getting the daily word when there is an entry, but not for today""" - day = datetime.today() - timedelta(days=1) - collection = mongodb[TemporaryStorage.collection()] + day = date.today() - timedelta(days=1) word = "testword" - await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word}) + word_instance = WordleWord(word=word, day=day) + postgres.add(word_instance) + await postgres.commit() - assert await crud.get_daily_word(mongodb) is None + assert await crud.get_daily_word(postgres) is None -@pytest.mark.mongo +@pytest.mark.postgres @freeze_time("2022-07-30") -async def test_get_daily_word_present(mongodb: MongoDatabase): +async def test_get_daily_word_present(postgres: AsyncSession): """Test getting the daily word when there is one for today""" - day = datetime.today() - collection = mongodb[TemporaryStorage.collection()] + day = date.today() word = "testword" - await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word}) + word_instance = WordleWord(word=word, day=day) + postgres.add(word_instance) + await postgres.commit() - assert await crud.get_daily_word(mongodb) == word + daily_word = await crud.get_daily_word(postgres) + assert daily_word is not None + assert daily_word.word == word -@pytest.mark.mongo +@pytest.mark.postgres @freeze_time("2022-07-30") -async def test_set_daily_word_none_present(mongodb: MongoDatabase): +async def test_set_daily_word_none_present(postgres: AsyncSession): """Test setting the daily word when there is none""" - assert await crud.get_daily_word(mongodb) is None + assert await crud.get_daily_word(postgres) is None word = "testword" - await crud.set_daily_word(mongodb, word) - assert await crud.get_daily_word(mongodb) == word + await crud.set_daily_word(postgres, word) + + daily_word = await crud.get_daily_word(postgres) + assert daily_word is not None + assert daily_word.word == word -@pytest.mark.mongo +@pytest.mark.postgres @freeze_time("2022-07-30") -async def test_set_daily_word_present(mongodb: MongoDatabase): +async def test_set_daily_word_present(postgres: AsyncSession): """Test setting the daily word when there already is one""" word = "testword" - await crud.set_daily_word(mongodb, word) - await crud.set_daily_word(mongodb, "another word") - assert await crud.get_daily_word(mongodb) == word + await crud.set_daily_word(postgres, word) + await crud.set_daily_word(postgres, "another word") + + daily_word = await crud.get_daily_word(postgres) + assert daily_word is not None + assert daily_word.word == word -@pytest.mark.mongo +@pytest.mark.postgres @freeze_time("2022-07-30") -async def test_set_daily_word_force_overwrite(mongodb: MongoDatabase): +async def test_set_daily_word_force_overwrite(postgres: AsyncSession): """Test setting the daily word when there already is one, but "forced" is set to True""" word = "testword" - await crud.set_daily_word(mongodb, word) + await crud.set_daily_word(postgres, word) word = "anotherword" - await crud.set_daily_word(mongodb, word, forced=True) - assert await crud.get_daily_word(mongodb) == word + await crud.set_daily_word(postgres, word, forced=True) + + daily_word = await crud.get_daily_word(postgres) + assert daily_word is not None + assert daily_word.word == word -@pytest.mark.mongo -async def test_make_wordle_guess(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int): +@pytest.mark.postgres +async def test_make_wordle_guess(postgres: AsyncSession, user: User): """Test making a guess in your current game""" + test_user_id = user.user_id + guess = "guess" - await crud.make_wordle_guess(mongodb, test_user_id, guess) - wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id) - assert wordle_game.guesses == [guess] + await crud.make_wordle_guess(postgres, test_user_id, guess) + wordle_guesses = await crud.get_active_wordle_game(postgres, test_user_id) + assert list(map(lambda x: x.guess, wordle_guesses)) == [guess] other_guess = "otherguess" - await crud.make_wordle_guess(mongodb, test_user_id, other_guess) - wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id) - assert wordle_game.guesses == [guess, other_guess] + await crud.make_wordle_guess(postgres, test_user_id, other_guess) + wordle_guesses = await crud.get_active_wordle_game(postgres, test_user_id) + assert list(map(lambda x: x.guess, wordle_guesses)) == [guess, other_guess] -@pytest.mark.mongo -async def test_reset_wordle_games(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int): +@pytest.mark.postgres +async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User): """Test dropping the collection of active games""" - assert await crud.get_active_wordle_game(mongodb, test_user_id) is not None - await crud.reset_wordle_games(mongodb) - assert await crud.get_active_wordle_game(mongodb, test_user_id) is None + test_user_id = user.user_id + + assert await crud.get_active_wordle_game(postgres, test_user_id) + await crud.reset_wordle_games(postgres) + assert not await crud.get_active_wordle_game(postgres, test_user_id) diff --git a/tests/test_database/test_crud/test_wordle_stats.py b/tests/test_database/test_crud/test_wordle_stats.py new file mode 100644 index 0000000..925e5e8 --- /dev/null +++ b/tests/test_database/test_crud/test_wordle_stats.py @@ -0,0 +1,72 @@ +import datetime + +import pytest +from freezegun import freeze_time +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.crud import wordle_stats as crud +from database.schemas import User, WordleStats + + +async def insert_game_stats(postgres: AsyncSession, stats: WordleStats): + """Helper function to insert some stats""" + postgres.add(stats) + await postgres.commit() + + +@pytest.mark.postgres +async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User): + """Test getting a user's stats when the db is empty""" + test_user_id = user.user_id + + statement = select(WordleStats).where(WordleStats.user_id == test_user_id) + assert (await postgres.execute(statement)).scalar_one_or_none() is None + + await crud.get_wordle_stats(postgres, test_user_id) + assert (await postgres.execute(statement)).scalar_one_or_none() is not None + + +@pytest.mark.postgres +async def test_get_stats_existing_returns(postgres: AsyncSession, user: User): + """Test getting a user's stats when there's already an entry present""" + test_user_id = user.user_id + + stats = WordleStats(user_id=test_user_id) + stats.games = 20 + await insert_game_stats(postgres, stats) + found_stats = await crud.get_wordle_stats(postgres, test_user_id) + assert found_stats.games == 20 + + +@pytest.mark.postgres +@freeze_time("2022-07-30") +async def test_complete_wordle_game_won(postgres: AsyncSession, user: User): + """Test completing a wordle game when you win""" + test_user_id = user.user_id + + await crud.complete_wordle_game(postgres, test_user_id, win=True) + stats = await crud.get_wordle_stats(postgres, test_user_id) + assert stats.games == 1 + assert stats.wins == 1 + assert stats.current_streak == 1 + assert stats.highest_streak == 1 + assert stats.last_win == datetime.date.today() + + +@pytest.mark.postgres +@freeze_time("2022-07-30") +async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User): + """Test completing a wordle game when you lose""" + test_user_id = user.user_id + + stats = WordleStats(user_id=test_user_id) + stats.current_streak = 10 + await insert_game_stats(postgres, stats) + + await crud.complete_wordle_game(postgres, test_user_id, win=False) + stats = await crud.get_wordle_stats(postgres, test_user_id) + + # Check that streak was broken + assert stats.current_streak == 0 + assert stats.games == 1 diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index 3dc6adb..e62d7a3 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -1,6 +1,6 @@ from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas.relational import UforaCourse +from database.schemas import UforaCourse from database.utils.caches import UforaCourseCache