diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index 23384df..572b40e 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -38,6 +38,17 @@ jobs: POSTGRES_DB: didier_pytest POSTGRES_USER: pytest POSTGRES_PASSWORD: pytest + mongo: + image: mongo:5.0 + options: >- + --health-cmd mongo + --health-interval 10s + --health-timeout 5s + --health-retries 5 + ports: + - 27018:27017 + env: + MONGO_DB: didier_pytest steps: - uses: actions/checkout@v3 - name: Setup Python diff --git a/alembic/env.py b/alembic/env.py index 18b7c21..beaa206 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine from alembic import context from database.engine import postgres_engine -from database.schemas import Base +from database.schemas.relational import Base # this is the Alembic Config object, which provides # access to the values within the .ini file in use. diff --git a/alembic/versions/38b7c29f10ee_wordle.py b/alembic/versions/38b7c29f10ee_wordle.py deleted file mode 100644 index 8fe53b2..0000000 --- a/alembic/versions/38b7c29f10ee_wordle.py +++ /dev/null @@ -1,63 +0,0 @@ -"""Wordle - -Revision ID: 38b7c29f10ee -Revises: 36300b558ef1 -Create Date: 2022-08-29 20:21:02.413631 - -""" -import sqlalchemy as sa - -from alembic import op - -# revision identifiers, used by Alembic. -revision = "38b7c29f10ee" -down_revision = "36300b558ef1" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "wordle_word", - sa.Column("word_id", sa.Integer(), nullable=False), - sa.Column("word", sa.Text(), nullable=False), - sa.Column("day", sa.Date(), nullable=False), - sa.PrimaryKeyConstraint("word_id"), - sa.UniqueConstraint("day"), - ) - op.create_table( - "wordle_guesses", - sa.Column("wordle_guess_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("guess", sa.Text(), nullable=False), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("wordle_guess_id"), - ) - op.create_table( - "wordle_stats", - sa.Column("wordle_stats_id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.BigInteger(), nullable=True), - sa.Column("last_win", sa.Date(), nullable=True), - sa.Column("games", sa.Integer(), server_default="0", nullable=False), - sa.Column("wins", sa.Integer(), server_default="0", nullable=False), - sa.Column("current_streak", sa.Integer(), server_default="0", nullable=False), - sa.Column("highest_streak", sa.Integer(), server_default="0", nullable=False), - sa.ForeignKeyConstraint( - ["user_id"], - ["users.user_id"], - ), - sa.PrimaryKeyConstraint("wordle_stats_id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("wordle_stats") - op.drop_table("wordle_guesses") - op.drop_table("wordle_word") - # ### end Alembic commands ### diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 7ffa308..f078488 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from database.crud import users -from database.schemas import Birthday, User +from database.schemas.relational import Birthday, User __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] @@ -17,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date): If already present, overwrites the existing one """ - user = await users.get_or_add_user(session, user_id, options=[selectinload(User.birthday)]) + user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) if user.birthday is not None: bd = user.birthday diff --git a/database/crud/currency.py b/database/crud/currency.py index da0ff84..382801d 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.exceptions import currency as exceptions -from database.schemas import Bank, NightlyData +from database.schemas.relational import Bank, NightlyData from database.utils.math.currency import ( capacity_upgrade_price, interest_upgrade_price, @@ -29,13 +29,13 @@ NIGHTLY_AMOUNT = 420 async def get_bank(session: AsyncSession, user_id: int) -> Bank: """Get a user's bank info""" - user = await users.get_or_add_user(session, user_id) + user = await users.get_or_add(session, user_id) return user.bank async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData: """Get a user's nightly info""" - user = await users.get_or_add_user(session, user_id) + user = await users.get_or_add(session, user_id) return user.nightly_data diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index bb6ac0c..d0e86a1 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.schemas import CustomCommand, CustomCommandAlias +from database.schemas.relational import CustomCommand, CustomCommandAlias __all__ = [ "clean_name", diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py index 30697bd..3d673de 100644 --- a/database/crud/dad_jokes.py +++ b/database/crud/dad_jokes.py @@ -2,7 +2,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.not_found import NoResultFoundException -from database.schemas import DadJoke +from database.schemas.relational import DadJoke __all__ = ["add_dad_joke", "get_random_dad_joke"] diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index ce7fba5..c1b2885 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -6,7 +6,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload -from database.schemas import Deadline, UforaCourse +from database.schemas.relational import Deadline, UforaCourse __all__ = ["add_deadline", "get_deadlines"] diff --git a/database/crud/game_stats.py b/database/crud/game_stats.py new file mode 100644 index 0000000..5634b49 --- /dev/null +++ b/database/crud/game_stats.py @@ -0,0 +1,59 @@ +import datetime +from typing import Union + +from database.mongo_types import MongoDatabase +from database.schemas.mongo.game_stats import GameStats + +__all__ = ["get_game_stats", "complete_wordle_game"] + +from database.utils.datetime import today_only_date + + +async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats: + """Get a user's game stats + + If no entry is found, it is first created + """ + collection = database[GameStats.collection()] + stats = await collection.find_one({"user_id": user_id}) + if stats is not None: + return GameStats(**stats) + + stats = GameStats(user_id=user_id) + await collection.insert_one(stats.dict(by_alias=True)) + return stats + + +async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0): + """Update the user's Wordle stats""" + stats = await get_game_stats(database, user_id) + + update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}} + + if win: + update["$inc"]["wordle.wins"] = 1 + update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1 + + # Update streak + today = today_only_date() + last_win = stats.wordle.last_win + update["$set"]["wordle.last_win"] = today + + if last_win is None or (today - last_win).days > 1: + # Never won a game before or streak is over + update["$set"]["wordle.current_streak"] = 1 + stats.wordle.current_streak = 1 + else: + # On a streak: increase counter + update["$inc"]["wordle.current_streak"] = 1 + stats.wordle.current_streak += 1 + + # Update max streak if necessary + if stats.wordle.current_streak > stats.wordle.max_streak: + update["$set"]["wordle.max_streak"] = stats.wordle.current_streak + else: + # Streak is over + update["$set"]["wordle.current_streak"] = 0 + + collection = database[GameStats.collection()] + await collection.update_one({"_id": stats.id}, update) diff --git a/database/crud/links.py b/database/crud/links.py index 495e0f3..e97c328 100644 --- a/database/crud/links.py +++ b/database/crud/links.py @@ -4,7 +4,7 @@ from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions import NoResultFoundException -from database.schemas import Link +from database.schemas.relational import Link __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] diff --git a/database/crud/memes.py b/database/crud/memes.py index ab288aa..f92f6ef 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -4,7 +4,7 @@ from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas import MemeTemplate +from database.schemas.relational import MemeTemplate __all__ = ["add_meme", "get_all_memes", "get_meme_by_name"] diff --git a/database/crud/tasks.py b/database/crud/tasks.py index f66ffc5..a3b6f38 100644 --- a/database/crud/tasks.py +++ b/database/crud/tasks.py @@ -5,7 +5,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.enums import TaskType -from database.schemas import Task +from database.schemas.relational import Task from database.utils.datetime import LOCAL_TIMEZONE __all__ = ["get_task_by_enum", "set_last_task_execution_time"] diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 688bcc7..e2dbd16 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas import UforaAnnouncement, UforaCourse +from database.schemas.relational import UforaAnnouncement, UforaCourse __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"] diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index 19369c1..f6dd853 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -3,7 +3,7 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas import UforaCourse, UforaCourseAlias +from database.schemas.relational import UforaCourse, UforaCourseAlias __all__ = ["get_all_courses", "get_course_by_name"] diff --git a/database/crud/users.py b/database/crud/users.py index bd4f2ad..3024f26 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -3,14 +3,14 @@ from typing import Optional from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas import Bank, NightlyData, User +from database.schemas.relational import Bank, NightlyData, User __all__ = [ - "get_or_add_user", + "get_or_add", ] -async def get_or_add_user(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: +async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: """Get a user's profile If it doesn't exist yet, create it (along with all linked datastructures) diff --git a/database/crud/wordle.py b/database/crud/wordle.py index a978fdc..4b2a312 100644 --- a/database/crud/wordle.py +++ b/database/crud/wordle.py @@ -1,54 +1,56 @@ -import datetime from typing import Optional -from sqlalchemy import delete, select -from sqlalchemy.ext.asyncio import AsyncSession - -from database.crud.users import get_or_add_user -from database.schemas import WordleGuess, WordleWord +from database.enums import TempStorageKey +from database.mongo_types import MongoDatabase +from database.schemas.mongo.temporary_storage import TemporaryStorage +from database.schemas.mongo.wordle import WordleGame +from database.utils.datetime import today_only_date __all__ = [ "get_active_wordle_game", - "get_wordle_guesses", "make_wordle_guess", + "start_new_wordle_game", "set_daily_word", "reset_wordle_games", ] -async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]: +async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]: """Find a player's active game""" - await get_or_add_user(session, user_id) - statement = select(WordleGuess).where(WordleGuess.user_id == user_id) - guesses = (await session.execute(statement)).scalars().all() - return guesses - - -async def get_wordle_guesses(session: AsyncSession, user_id: int) -> list[str]: - """Get the strings of a player's guesses""" - active_game = await get_active_wordle_game(session, user_id) - return list(map(lambda g: g.guess.lower(), active_game)) - - -async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str): - """Make a guess in your current game""" - guess_instance = WordleGuess(user_id=user_id, guess=guess) - session.add(guess_instance) - await session.commit() - - -async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]: - """Get the word of today""" - statement = select(WordleWord).where(WordleWord.day == datetime.date.today()) - row = (await session.execute(statement)).scalar_one_or_none() - - if row is None: + collection = database[WordleGame.collection()] + result = await collection.find_one({"user_id": user_id}) + if result is None: return None - return row + return WordleGame(**result) -async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str: +async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame: + """Start a new game""" + collection = database[WordleGame.collection()] + game = WordleGame(user_id=user_id) + await collection.insert_one(game.dict(by_alias=True)) + return game + + +async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str): + """Make a guess in your current game""" + collection = database[WordleGame.collection()] + await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}}) + + +async def get_daily_word(database: MongoDatabase) -> Optional[str]: + """Get the word of today""" + collection = database[TemporaryStorage.collection()] + + result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()}) + if result is None: + return None + + return result["word"] + + +async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str: """Set the word of today This does NOT overwrite the existing word if there is one, so that it can safely run @@ -58,28 +60,23 @@ async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = Fal Returns the word that was chosen. If one already existed, return that instead. """ - current_word = await get_daily_word(session) + collection = database[TemporaryStorage.collection()] - if current_word is None: - current_word = WordleWord(word=word, day=datetime.date.today()) - session.add(current_word) - await session.commit() + current_word = None if forced else await get_daily_word(database) + if current_word is not None: + return current_word - # Remove all active games - await reset_wordle_games(session) - elif forced: - current_word.word = word - session.add(current_word) - await session.commit() + await collection.update_one( + {"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True + ) - # Remove all active games - await reset_wordle_games(session) + # Remove all active games + await reset_wordle_games(database) - return current_word.word + return word -async def reset_wordle_games(session: AsyncSession): +async def reset_wordle_games(database: MongoDatabase): """Reset all active games""" - statement = delete(WordleGuess) - await session.execute(statement) - await session.commit() + collection = database[WordleGame.collection()] + await collection.drop() diff --git a/database/crud/wordle_stats.py b/database/crud/wordle_stats.py deleted file mode 100644 index 71c123e..0000000 --- a/database/crud/wordle_stats.py +++ /dev/null @@ -1,60 +0,0 @@ -from datetime import date - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from database.crud.users import get_or_add_user -from database.schemas import WordleStats - -__all__ = ["get_wordle_stats", "complete_wordle_game"] - - -async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats: - """Get a user's wordle stats - - If no entry is found, it is first created - """ - await get_or_add_user(session, user_id) - - statement = select(WordleStats).where(WordleStats.user_id == user_id) - stats = (await session.execute(statement)).scalar_one_or_none() - if stats is not None: - return stats - - stats = WordleStats(user_id=user_id) - session.add(stats) - await session.commit() - await session.refresh(stats) - - return stats - - -async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool): - """Update the user's Wordle stats""" - stats = await get_wordle_stats(session, user_id) - stats.games += 1 - - if win: - stats.wins += 1 - - # Update streak - today = date.today() - last_win = stats.last_win - stats.last_win = today - - if last_win is None or (today - last_win).days > 1: - # Never won a game before or streak is over - stats.current_streak = 1 - else: - # On a streak: increase counter - stats.current_streak += 1 - - # Update max streak if necessary - if stats.current_streak > stats.highest_streak: - stats.highest_streak = stats.current_streak - else: - # Streak is over - stats.current_streak = 0 - - session.add(stats) - await session.commit() diff --git a/database/engine.py b/database/engine.py index 23e5b89..b3e3947 100644 --- a/database/engine.py +++ b/database/engine.py @@ -1,5 +1,6 @@ from urllib.parse import quote_plus +import motor.motor_asyncio from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -25,3 +26,16 @@ postgres_engine = create_async_engine( DBSession = sessionmaker( autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False ) + +# MongoDB client +if not settings.TESTING: # pragma: no cover + encoded_mongo_username = quote_plus(settings.MONGO_USER) + encoded_mongo_password = quote_plus(settings.MONGO_PASS) + mongo_url = ( + f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/" + ) +else: + # Require no authentication when testing + mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/" + +mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url) diff --git a/database/enums.py b/database/enums.py index 3f75130..f64998e 100644 --- a/database/enums.py +++ b/database/enums.py @@ -1,6 +1,6 @@ import enum -__all__ = ["TaskType"] +__all__ = ["TaskType", "TempStorageKey"] # There is a bug in typeshed that causes an incorrect PyCharm warning @@ -11,3 +11,10 @@ class TaskType(enum.IntEnum): BIRTHDAYS = enum.auto() UFORA_ANNOUNCEMENTS = enum.auto() + + +@enum.unique +class TempStorageKey(str, enum.Enum): + """Enum for keys to distinguish the TemporaryStorage rows""" + + WORDLE_WORD = "wordle_word" diff --git a/database/mongo_types.py b/database/mongo_types.py new file mode 100644 index 0000000..11f5b7a --- /dev/null +++ b/database/mongo_types.py @@ -0,0 +1,6 @@ +import motor.motor_asyncio + +# Type aliases for the Motor types, which are way too long +MongoClient = motor.motor_asyncio.AsyncIOMotorClient +MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase +MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection diff --git a/database/schemas/__init__.py b/database/schemas/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/schemas/mongo/__init__.py b/database/schemas/mongo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/schemas/mongo/common.py b/database/schemas/mongo/common.py new file mode 100644 index 0000000..dcddb54 --- /dev/null +++ b/database/schemas/mongo/common.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod + +from bson import ObjectId +from pydantic import BaseModel, Field + +__all__ = ["PyObjectId", "MongoBase", "MongoCollection"] + + +class PyObjectId(ObjectId): + """Custom type for bson ObjectIds""" + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: str): + """Check that a string is a valid bson ObjectId""" + if not ObjectId.is_valid(value): + raise ValueError(f"Invalid ObjectId: '{value}'") + + return ObjectId(value) + + @classmethod + def __modify_schema__(cls, field_schema: dict): + field_schema.update(type="string") + + +class MongoBase(BaseModel): + """Base model that properly sets the _id field, and adds one by default""" + + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + + class Config: + """Configuration for encoding and construction""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + json_encoders = {ObjectId: str, PyObjectId: str} + use_enum_values = True + + +class MongoCollection(MongoBase, ABC): + """Base model for the 'main class' in a collection + + This field stores the name of the collection to avoid making typos against it + """ + + @staticmethod + @abstractmethod + def collection() -> str: + """Getter for the name of the collection, in order to avoid typos""" + raise NotImplementedError diff --git a/database/schemas/mongo/game_stats.py b/database/schemas/mongo/game_stats.py new file mode 100644 index 0000000..af08117 --- /dev/null +++ b/database/schemas/mongo/game_stats.py @@ -0,0 +1,40 @@ +import datetime +from typing import Optional + +from overrides import overrides +from pydantic import BaseModel, Field, validator + +from database.schemas.mongo.common import MongoCollection + +__all__ = ["GameStats", "WordleStats"] + + +class WordleStats(BaseModel): + """Model that holds stats about a player's Wordle performance""" + + guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) + last_win: Optional[datetime.datetime] = None + wins: int = 0 + games: int = 0 + current_streak: int = 0 + max_streak: int = 0 + + @validator("guess_distribution") + def validate_guesses_length(cls, value: list[int]): + """Check that the distribution of guesses is of the correct length""" + if len(value) != 6: + raise ValueError(f"guess_distribution must be length 6, found {len(value)}") + + return value + + +class GameStats(MongoCollection): + """Collection that holds stats about how well a user has performed in games""" + + user_id: int + wordle: WordleStats = WordleStats() + + @staticmethod + @overrides + def collection() -> str: + return "game_stats" diff --git a/database/schemas/mongo/temporary_storage.py b/database/schemas/mongo/temporary_storage.py new file mode 100644 index 0000000..deb444b --- /dev/null +++ b/database/schemas/mongo/temporary_storage.py @@ -0,0 +1,16 @@ +from overrides import overrides + +from database.schemas.mongo.common import MongoCollection + +__all__ = ["TemporaryStorage"] + + +class TemporaryStorage(MongoCollection): + """Collection for lots of random things that don't belong in a full-blown collection""" + + key: str + + @staticmethod + @overrides + def collection() -> str: + return "temporary" diff --git a/database/schemas/mongo/wordle.py b/database/schemas/mongo/wordle.py new file mode 100644 index 0000000..b03189b --- /dev/null +++ b/database/schemas/mongo/wordle.py @@ -0,0 +1,44 @@ +import datetime + +from overrides import overrides +from pydantic import Field, validator + +from database.constants import WORDLE_GUESS_COUNT +from database.schemas.mongo.common import MongoCollection +from database.utils.datetime import today_only_date + +__all__ = ["WordleGame"] + + +class WordleGame(MongoCollection): + """Collection that holds people's active Wordle games""" + + day: datetime.datetime = Field(default_factory=lambda: today_only_date()) + guesses: list[str] = Field(default_factory=list) + user_id: int + + @staticmethod + @overrides + def collection() -> str: + return "wordle" + + @validator("guesses") + def validate_guesses_length(cls, value: list[int]): + """Check that the amount of guesses is of the correct length""" + if len(value) > 6: + raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") + + return value + + def is_game_over(self, word: str) -> bool: + """Check if the current game is over""" + # No guesses yet + if not self.guesses: + return False + + # Max amount of guesses allowed + if len(self.guesses) == WORDLE_GUESS_COUNT: + return True + + # Found the correct word + return self.guesses[-1] == word diff --git a/database/schemas.py b/database/schemas/relational.py similarity index 80% rename from database/schemas.py rename to database/schemas/relational.py index 182653f..904459e 100644 --- a/database/schemas.py +++ b/database/schemas/relational.py @@ -37,9 +37,6 @@ __all__ = [ "UforaCourse", "UforaCourseAlias", "User", - "WordleGuess", - "WordleStats", - "WordleWord", ] @@ -234,47 +231,3 @@ class User(Base): nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - wordle_guesses: list[WordleGuess] = relationship( - "WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" - ) - wordle_stats: WordleStats = relationship( - "WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" - ) - - -class WordleGuess(Base): - """A user's Wordle guesses for today""" - - __tablename__ = "wordle_guesses" - - wordle_guess_id: int = Column(Integer, primary_key=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - guess: str = Column(Text, nullable=False) - - user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin") - - -class WordleStats(Base): - """Stats about a user's wordle performance""" - - __tablename__ = "wordle_stats" - - wordle_stats_id: int = Column(Integer, primary_key=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - last_win: Optional[date] = Column(Date, nullable=True) - games: int = Column(Integer, server_default="0", nullable=False) - wins: int = Column(Integer, server_default="0", nullable=False) - current_streak: int = Column(Integer, server_default="0", nullable=False) - highest_streak: int = Column(Integer, server_default="0", nullable=False) - - user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin") - - -class WordleWord(Base): - """The current Wordle word""" - - __tablename__ = "wordle_word" - - word_id: int = Column(Integer, primary_key=True) - word: str = Column(Text, nullable=False) - day: date = Column(Date, nullable=False, unique=True) diff --git a/database/utils/caches.py b/database/utils/caches.py index 25c9cb5..ba26ab2 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -1,17 +1,19 @@ from abc import ABC, abstractmethod +from typing import Generic, TypeVar from discord import app_commands from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession from database.crud import links, memes, ufora_courses, wordle +from database.mongo_types import MongoDatabase __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] -from database.schemas import WordleWord +T = TypeVar("T") -class DatabaseCache(ABC): +class DatabaseCache(ABC, Generic[T]): """Base class for a simple cache-like structure The goal of this class is to store data for Discord auto-completion results @@ -23,7 +25,7 @@ class DatabaseCache(ABC): Considering the fact that a user isn't obligated to choose something from the suggestions, chances are high we have to go to the database for the final action either way. - Also stores the data in lowercase to allow fast searching. + Also stores the data in lowercase to allow fast searching """ data: list[str] = [] @@ -34,7 +36,7 @@ class DatabaseCache(ABC): self.data.clear() @abstractmethod - async def invalidate(self, database_session: AsyncSession): + async def invalidate(self, database_session: T): """Invalidate the data stored in this cache""" def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: @@ -46,7 +48,7 @@ class DatabaseCache(ABC): return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] -class LinkCache(DatabaseCache): +class LinkCache(DatabaseCache[AsyncSession]): """Cache to store the names of links""" @overrides @@ -59,7 +61,7 @@ class LinkCache(DatabaseCache): self.data_transformed = list(map(str.lower, self.data)) -class MemeCache(DatabaseCache): +class MemeCache(DatabaseCache[AsyncSession]): """Cache to store the names of meme templates""" @overrides @@ -72,7 +74,7 @@ class MemeCache(DatabaseCache): self.data_transformed = list(map(str.lower, self.data)) -class UforaCourseCache(DatabaseCache): +class UforaCourseCache(DatabaseCache[AsyncSession]): """Cache to store the names of Ufora courses""" # Also store the aliases to add additional support @@ -117,15 +119,13 @@ class UforaCourseCache(DatabaseCache): return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] -class WordleCache(DatabaseCache): +class WordleCache(DatabaseCache[MongoDatabase]): """Cache to store the current daily Wordle word""" - word: WordleWord - - async def invalidate(self, database_session: AsyncSession): + async def invalidate(self, database_session: MongoDatabase): word = await wordle.get_daily_word(database_session) if word is not None: - self.word = word + self.data = [word] class CacheManager: @@ -142,9 +142,9 @@ class CacheManager: self.ufora_courses = UforaCourseCache() self.wordle_word = WordleCache() - async def initialize_caches(self, postgres_session: AsyncSession): + async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): """Initialize the contents of all caches""" await self.links.invalidate(postgres_session) await self.memes.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session) - await self.wordle_word.invalidate(postgres_session) + await self.wordle_word.invalidate(mongo_db) diff --git a/database/utils/datetime.py b/database/utils/datetime.py index 8450e84..0952f63 100644 --- a/database/utils/datetime.py +++ b/database/utils/datetime.py @@ -1,5 +1,15 @@ +import datetime import zoneinfo -__all__ = ["LOCAL_TIMEZONE"] +__all__ = ["LOCAL_TIMEZONE", "today_only_date"] LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") + + +def today_only_date() -> datetime.datetime: + """Mongo can't handle datetime.date, so we need a datetime instance + + We do, however, only care about the date, so remove all the rest + """ + today = datetime.date.today() + return datetime.datetime(year=today.year, month=today.month, day=today.day) diff --git a/didier/cogs/debug_cog.py b/didier/cogs/debug_cog.py deleted file mode 100644 index 2d03b9f..0000000 --- a/didier/cogs/debug_cog.py +++ /dev/null @@ -1,26 +0,0 @@ -from discord.ext import commands -from overrides import overrides - -from didier import Didier - - -class DebugCog(commands.Cog): - """Testing cog for dev purposes""" - - client: Didier - - def __init__(self, client: Didier): - self.client = client - - @overrides - async def cog_check(self, ctx: commands.Context) -> bool: - return await self.client.is_owner(ctx.author) - - @commands.command(aliases=["Dev"]) - async def debug(self, ctx: commands.Context): - """Debugging command""" - - -async def setup(client: Didier): - """Load the cog""" - await client.add_cog(DebugCog(client)) diff --git a/didier/cogs/games.py b/didier/cogs/games.py index 6114c77..1a4453a 100644 --- a/didier/cogs/games.py +++ b/didier/cogs/games.py @@ -4,11 +4,14 @@ import discord from discord import app_commands from discord.ext import commands -from database.constants import WORDLE_WORD_LENGTH -from database.crud.wordle import get_wordle_guesses, make_wordle_guess -from database.crud.wordle_stats import complete_wordle_game +from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH +from database.crud.wordle import ( + get_active_wordle_game, + make_wordle_guess, + start_new_wordle_game, +) from didier import Didier -from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed, is_wordle_game_over +from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed class Games(commands.Cog): @@ -32,39 +35,31 @@ class Games(commands.Cog): embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed() return await interaction.followup.send(embed=embed) - word_instance = self.client.database_caches.wordle_word.word + active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id) + if active_game is None: + active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id) - async with self.client.postgres_session as session: - guesses = await get_wordle_guesses(session, interaction.user.id) + # Trying to guess with a complete game + if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess: + embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed() + return await interaction.followup.send(embed=embed) - # Trying to guess with a complete game - if is_wordle_game_over(guesses, word_instance.word): - embed = WordleErrorEmbed( - message="You've already completed today's Wordle.\nTry again tomorrow!" - ).to_embed() + # Make a guess + if guess: + # The guess is not a real word + if guess.lower() not in self.client.wordle_words: + embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed() return await interaction.followup.send(embed=embed) - # Make a guess - if guess: - # The guess is not a real word - if guess.lower() not in self.client.wordle_words: - embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed() - return await interaction.followup.send(embed=embed) + guess = guess.lower() + await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess) - guess = guess.lower() - await make_wordle_guess(session, interaction.user.id, guess) + # Don't re-request the game, we already have it + # just append locally + active_game.guesses.append(guess) - # Don't re-request the game, we already have it - # just append locally - guesses.append(guess) - - embed = WordleEmbed(guesses=guesses, word=word_instance).to_embed() - await interaction.followup.send(embed=embed) - - # After responding to the interaction: update stats in the background - game_over = is_wordle_game_over(guesses, word_instance.word) - if game_over: - await complete_wordle_game(session, interaction.user.id, word_instance.word in guesses) + embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed() + await interaction.followup.send(embed=embed) async def setup(client: Didier): diff --git a/didier/cogs/other.py b/didier/cogs/other.py index dea9cd2..870c63b 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -5,7 +5,7 @@ from discord import app_commands from discord.ext import commands from database.crud.links import get_link_by_name -from database.schemas import Link +from database.schemas.relational import Link from didier import Didier from didier.data.apis import urban_dictionary from didier.data.embeds.google import GoogleSearch diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 64f5501..901aac6 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -140,9 +140,9 @@ class Tasks(commands.Cog): @tasks.loop(time=DAILY_RESET_TIME) async def reset_wordle_word(self, forced: bool = False): """Reset the daily Wordle word""" - async with self.client.postgres_session as session: - await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced) - await self.client.database_caches.wordle_word.invalidate(session) + db = self.client.mongo_db + word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words)), forced=forced) + self.client.database_caches.wordle_word.data = [word] @reset_wordle_word.before_loop async def _before_reset_wordle_word(self): diff --git a/didier/data/apis/imgflip.py b/didier/data/apis/imgflip.py index c44eea6..b373897 100644 --- a/didier/data/apis/imgflip.py +++ b/didier/data/apis/imgflip.py @@ -1,7 +1,7 @@ from aiohttp import ClientSession import settings -from database.schemas import MemeTemplate +from database.schemas.relational import MemeTemplate from didier.exceptions.missing_env import MissingEnvironmentVariable from didier.utils.http.requests import ensure_post diff --git a/didier/data/embeds/deadlines.py b/didier/data/embeds/deadlines.py index fe3c988..371eee9 100644 --- a/didier/data/embeds/deadlines.py +++ b/didier/data/embeds/deadlines.py @@ -4,7 +4,7 @@ from datetime import datetime import discord from overrides import overrides -from database.schemas import Deadline +from database.schemas.relational import Deadline from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import tz_aware_now from didier.utils.types.string import get_edu_year_name diff --git a/didier/data/embeds/ufora/announcements.py b/didier/data/embeds/ufora/announcements.py index 31b4950..59e8290 100644 --- a/didier/data/embeds/ufora/announcements.py +++ b/didier/data/embeds/ufora/announcements.py @@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import ufora_announcements as crud -from database.schemas import UforaCourse +from database.schemas.relational import UforaCourse from didier.data.embeds.base import EmbedBaseModel from didier.utils.discord.colours import ghent_university_blue from didier.utils.types.datetime import int_to_weekday diff --git a/didier/data/embeds/wordle.py b/didier/data/embeds/wordle.py index 959ad62..d29a29f 100644 --- a/didier/data/embeds/wordle.py +++ b/didier/data/embeds/wordle.py @@ -1,26 +1,16 @@ import enum from dataclasses import dataclass +from typing import Optional import discord from overrides import overrides from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH -from database.schemas import WordleWord +from database.schemas.mongo.wordle import WordleGame from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import int_to_weekday, tz_aware_now -__all__ = ["is_wordle_game_over", "WordleEmbed", "WordleErrorEmbed"] - - -def is_wordle_game_over(guesses: list[str], word: str) -> bool: - """Check if the current game is over or not""" - if not guesses: - return False - - if len(guesses) == WORDLE_GUESS_COUNT: - return True - - return word.lower() in guesses +__all__ = ["WordleEmbed", "WordleErrorEmbed"] def footer() -> str: @@ -42,18 +32,18 @@ class WordleColour(enum.IntEnum): class WordleEmbed(EmbedBaseModel): """Embed for a Wordle game""" - guesses: list[str] - word: WordleWord + game: Optional[WordleGame] + word: str def _letter_colour(self, guess: str, index: int) -> WordleColour: """Get the colour for a guess at a given position""" - if guess[index] == self.word.word[index]: + if guess[index] == self.word[index]: return WordleColour.CORRECT wrong_letter = 0 wrong_position = 0 - for i, letter in enumerate(self.word.word): + for i, letter in enumerate(self.word): if letter == guess[index] and guess[i] != guess[index]: wrong_letter += 1 @@ -78,8 +68,9 @@ class WordleEmbed(EmbedBaseModel): colours = [] # Add all the guesses - for guess in self.guesses: - colours.append(self._guess_colours(guess)) + if self.game is not None: + for guess in self.game.guesses: + colours.append(self._guess_colours(guess)) # Fill the rest with empty spots for _ in range(WORDLE_GUESS_COUNT - len(colours)): @@ -108,19 +99,19 @@ class WordleEmbed(EmbedBaseModel): colours = self.colour_code_game() - embed = discord.Embed(colour=discord.Colour.blue(), title=f"Wordle #{self.word.word_id + 1}") + embed = discord.Embed(colour=discord.Colour.blue(), title="Wordle") emojis = self._colours_to_emojis(colours) rows = [" ".join(row) for row in emojis] # Don't reveal anything if we only want to show the colours - if not only_colours and self.guesses: - for i, guess in enumerate(self.guesses): + if not only_colours and self.game is not None: + for i, guess in enumerate(self.game.guesses): rows[i] += f" ||{guess.upper()}||" # If the game is over, reveal the word - if is_wordle_game_over(self.guesses, self.word.word): - rows.append(f"\n\nThe word was **{self.word.word.upper()}**!") + if self.game.is_game_over(self.word): + rows.append(f"\n\nThe word was **{self.word.upper()}**!") embed.description = "\n\n".join(rows) embed.set_footer(text=footer()) diff --git a/didier/didier.py b/didier/didier.py index 12bd5b4..337b342 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -2,6 +2,7 @@ import logging import os import discord +import motor.motor_asyncio from aiohttp import ClientSession from discord.app_commands import AppCommandError from discord.ext import commands @@ -9,7 +10,7 @@ from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import custom_commands -from database.engine import DBSession +from database.engine import DBSession, mongo_client from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.exceptions import HTTPException, NoMatch @@ -54,6 +55,11 @@ class Didier(commands.Bot): """Obtain a session for the PostgreSQL database""" return DBSession() + @property + def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: + """Obtain a reference to the MongoDB database""" + return mongo_client[settings.MONGO_DB] + async def setup_hook(self) -> None: """Do some initial setup @@ -65,7 +71,7 @@ class Didier(commands.Bot): # Initialize caches self.database_caches = CacheManager() async with self.postgres_session as session: - await self.database_caches.initialize_caches(session) + await self.database_caches.initialize_caches(session, self.mongo_db) # Load extensions await self._load_initial_extensions() diff --git a/didier/utils/discord/menus/common.py b/didier/utils/discord/menus/common.py deleted file mode 100644 index 983723d..0000000 --- a/didier/utils/discord/menus/common.py +++ /dev/null @@ -1,152 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Generic, Optional, TypeVar, cast - -import discord -from discord.ext import commands -from overrides import overrides - -import settings - -__all__ = ["Menu", "PageSource"] - - -T = TypeVar("T") - - -class PageSource(ABC, Generic[T]): - """Base class that handles the embeds displayed in a menu""" - - dataset: list[T] - embeds: list[discord.Embed] = [] - page_count: int - per_page: int - - def __init__(self, dataset: list[T], *, per_page: int = 10): - self.dataset = dataset - self.per_page = per_page - self.page_count = self._get_page_count() - self.create_embeds() - self._add_embed_page_footers() - - def _get_page_count(self) -> int: - """Calculate the amount of pages required""" - if len(self.dataset) % self.per_page == 0: - return len(self.dataset) // self.per_page - - return (len(self.dataset) // self.per_page) + 1 - - def __getitem__(self, index: int) -> discord.Embed: - return self.embeds[index] - - def __len__(self): - return self.page_count - - def _add_embed_page_footers(self): - """Add the current page in the footer of every embed""" - for i, embed in enumerate(self.embeds): - embed.set_footer(text=f"{i + 1}/{self.page_count}") - - @abstractmethod - def create_embeds(self): - """Method that builds the list of embeds from the input data""" - raise NotImplementedError - - -class Menu(discord.ui.View): - """Base class for a menu""" - - ctx: commands.Context - current_page: int = 0 - ephemeral: bool - message: Optional[discord.Message] = None - source: PageSource - - def __init__(self, source: PageSource, *, ephemeral: bool = False, timeout: Optional[int] = None): - super().__init__(timeout=timeout or settings.MENU_TIMEOUT) - self.ephemeral = ephemeral - self.source = source - - def do_button_disabling(self): - """Disable buttons depending on the current page""" - first_page = cast(discord.ui.Button, self.children[0]) - first_page.disabled = self.current_page == 0 - - previous_page = cast(discord.ui.Button, self.children[1]) - previous_page.disabled = self.current_page == 0 - - next_page = cast(discord.ui.Button, self.children[3]) - next_page.disabled = self.current_page == len(self.source) - 1 - - last_page = cast(discord.ui.Button, self.children[4]) - last_page.disabled = self.current_page == len(self.source) - 1 - - async def display_current_state(self, interaction: Optional[discord.Interaction] = None): - """Display the current state of the view - - Enable/disable buttons, show a different embed, ... - """ - self.do_button_disabling() - - print(self.current_page, self.source[self.current_page].footer.text) - - # Send the initial message if there is none yet, else edit the existing one - if self.message is None: - self.message = await self.ctx.reply( - embed=self.source[self.current_page], view=self, mention_author=False, ephemeral=self.ephemeral - ) - elif interaction is not None: - await interaction.response.edit_message(embed=self.source[self.current_page], view=self) - - async def start(self, ctx: commands.Context): - """Send the initial message with this menu""" - self.ctx = ctx - await self.display_current_state() - - async def stop_view(self, interaction: Optional[discord.Interaction] = None): - """Stop the view & clear all the items""" - self.stop() - self.clear_items() - - if interaction is not None: - await interaction.response.edit_message(view=self) - elif self.message is not None: - await self.message.edit(view=self) - - @overrides - async def interaction_check(self, interaction: discord.Interaction, /) -> bool: - """Only allow the person that started the menu to use the menu""" - return interaction.user == self.ctx.author - - @overrides - async def on_timeout(self) -> None: - """Remove all buttons when the view times out""" - await self.stop_view() - - @discord.ui.button(label="<<", style=discord.ButtonStyle.primary, disabled=True) - async def first_page(self, interaction: discord.Interaction, button: discord.ui.Button): - """Button to go back to the first page""" - self.current_page = 0 - await self.display_current_state(interaction) - - @discord.ui.button(label="<", style=discord.ButtonStyle.primary, disabled=True) - async def previous_page(self, interaction: discord.Interaction, button: discord.ui.Button): - """Button to go back to the previous page""" - self.current_page -= 1 - await self.display_current_state(interaction) - - @discord.ui.button(label="Stop", style=discord.ButtonStyle.red) - async def stop_pages(self, interaction: discord.Interaction, button: discord.ui.Button): - """Button to stop the view""" - await self.stop_view(interaction) - - @discord.ui.button(label=">", style=discord.ButtonStyle.primary) - async def next_page(self, interaction: discord.Interaction, button: discord.ui.Button): - """Button to show the next page""" - self.current_page += 1 - await self.display_current_state(interaction) - - @discord.ui.button(label=">>", style=discord.ButtonStyle.primary) - async def last_page(self, interaction: discord.Interaction, button: discord.ui.Button): - """Button to show the last page""" - self.current_page = len(self.source) - 1 - await self.display_current_state(interaction) diff --git a/didier/views/modals/deadlines.py b/didier/views/modals/deadlines.py index 972cfe1..cd2a26c 100644 --- a/didier/views/modals/deadlines.py +++ b/didier/views/modals/deadlines.py @@ -5,7 +5,7 @@ from discord import Interaction from overrides import overrides from database.crud.deadlines import add_deadline -from database.schemas import UforaCourse +from database.schemas.relational import UforaCourse __all__ = ["AddDeadline"] diff --git a/didier/views/modals/memes.py b/didier/views/modals/memes.py index 4f5518e..c98e17f 100644 --- a/didier/views/modals/memes.py +++ b/didier/views/modals/memes.py @@ -3,7 +3,7 @@ import traceback import discord.ui from overrides import overrides -from database.schemas import MemeTemplate +from database.schemas.relational import MemeTemplate from didier import Didier from didier.data.apis.imgflip import generate_meme diff --git a/docker-compose.test.yml b/docker-compose.test.yml index d841c33..4033ad4 100644 --- a/docker-compose.test.yml +++ b/docker-compose.test.yml @@ -10,3 +10,10 @@ services: - POSTGRES_PASSWORD=pytest ports: - "5433:5432" + mongo-pytest: + image: mongo:5.0 + restart: always + environment: + - MONGO_INITDB_DATABASE=didier_pytest + ports: + - "27018:27017" diff --git a/docker-compose.yml b/docker-compose.yml index b9c5ee3..4af0a0e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -12,5 +12,18 @@ services: - "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}" volumes: - postgres:/var/lib/postgresql/data + mongo: + image: mongo:5.0 + restart: always + environment: + - MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root} + - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root} + - MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev} + command: [--auth] + ports: + - "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}" + volumes: + - mongo:/data/db volumes: postgres: + mongo: diff --git a/pyproject.toml b/pyproject.toml index dcb0d24..45c7c9a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,9 @@ ignore_missing_imports = true asyncio_mode = "auto" env = [ "TESTING = 1", + "MONGO_DB = didier_pytest", + "MONGO_HOST = localhost", + "MONGO_PORT = 27018", "POSTGRES_DB = didier_pytest", "POSTGRES_USER = pytest", "POSTGRES_PASS = pytest", @@ -52,5 +55,6 @@ env = [ "DISCORD_TOKEN = token" ] markers = [ + "mongo: tests that use MongoDB", "postgres: tests that use PostgreSQL" ] diff --git a/requirements.txt b/requirements.txt index 759cb54..b064107 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,10 +3,10 @@ alembic==1.8.0 asyncpg==0.25.0 beautifulsoup4==4.11.1 discord.py==2.0.1 -git+https://github.com/Rapptz/discord-ext-menus@8686b5d environs==9.5.0 feedparser==6.0.10 markdownify==0.11.2 +motor==3.0.0 overrides==6.1.0 pydantic==1.9.1 python-dateutil==2.8.2 diff --git a/settings.py b/settings.py index 7698cc9..a041397 100644 --- a/settings.py +++ b/settings.py @@ -37,6 +37,13 @@ SEMESTER: int = env.int("SEMESTER", 2) YEAR: int = env.int("YEAR", 3) """Database""" +# MongoDB +MONGO_DB: str = env.str("MONGO_DB", "didier") +MONGO_USER: str = env.str("MONGO_USER", "root") +MONGO_PASS: str = env.str("MONGO_PASS", "root") +MONGO_HOST: str = env.str("MONGO_HOST", "localhost") +MONGO_PORT: int = env.int("MONGO_PORT", "27017") + # PostgreSQL POSTGRES_DB: str = env.str("POSTGRES_DB", "didier") POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres") @@ -56,9 +63,6 @@ BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CH ERRORS_CHANNEL: Optional[int] = env.int("ERRORS_CHANNEL", None) UFORA_ANNOUNCEMENTS_CHANNEL: Optional[int] = env.int("UFORA_ANNOUNCEMENTS_CHANNEL", None) -""""General config""" -MENU_TIMEOUT: int = env.int("MENU_TIMEOUT", 30) - """API Keys""" UFORA_RSS_TOKEN: Optional[str] = env.str("UFORA_RSS_TOKEN", None) URBAN_DICTIONARY_TOKEN: Optional[str] = env.str("URBAN_DICTIONARY_TOKEN", None) diff --git a/tests/conftest.py b/tests/conftest.py index c218524..55919fd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,10 +2,12 @@ import asyncio from typing import AsyncGenerator, Generator from unittest.mock import MagicMock +import motor.motor_asyncio import pytest from sqlalchemy.ext.asyncio import AsyncSession -from database.engine import postgres_engine +import settings +from database.engine import mongo_client, postgres_engine from database.migrations import ensure_latest_migration, migrate from didier import Didier @@ -54,6 +56,14 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]: await connection.close() +@pytest.fixture +async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase: + """Fixture to get a MongoDB connection""" + database = mongo_client[settings.MONGO_DB] + yield database + mongo_client.drop_database(settings.MONGO_DB) + + @pytest.fixture def mock_client() -> Didier: """Fixture to get a mock Didier instance diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index a675a35..dc5ce2a 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -4,7 +4,7 @@ import pytest from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users -from database.schemas import ( +from database.schemas.relational import ( Bank, UforaAnnouncement, UforaCourse, @@ -25,7 +25,7 @@ def test_user_id() -> int: @pytest.fixture async def user(postgres: AsyncSession, test_user_id: int) -> User: """Fixture to create a user""" - _user = await users.get_or_add_user(postgres, test_user_id) + _user = await users.get_or_add(postgres, test_user_id) await postgres.refresh(_user) return _user diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 29a3791..e7f2242 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud from database.crud import users -from database.schemas import User +from database.schemas.relational import User async def test_add_birthday_not_present(postgres: AsyncSession, user: User): @@ -54,7 +54,7 @@ async def test_get_birthdays_on_day(postgres: AsyncSession, user: User): """Test getting all birthdays on a given day""" await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001)) - user_2 = await users.get_or_add_user(postgres, user.user_id + 1) + user_2 = await users.get_or_add(postgres, user.user_id + 1) await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1)) birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 1 diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index 1beddc6..8bd7e8f 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud from database.exceptions import currency as exceptions -from database.schemas import Bank +from database.schemas.relational import Bank async def test_add_dinks(postgres: AsyncSession, bank: Bank): diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index 8155068..6f141bc 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException from database.exceptions.not_found import NoResultFoundException -from database.schemas import CustomCommand +from database.schemas.relational import CustomCommand async def test_create_command_non_existing(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py index 99b82de..f34d0fa 100644 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -2,7 +2,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.crud import dad_jokes as crud -from database.schemas import DadJoke +from database.schemas.relational import DadJoke async def test_add_dad_joke(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_game_stats.py b/tests/test_database/test_crud/test_game_stats.py new file mode 100644 index 0000000..4b7606f --- /dev/null +++ b/tests/test_database/test_crud/test_game_stats.py @@ -0,0 +1,63 @@ +import pytest +from freezegun import freeze_time + +from database.crud import game_stats as crud +from database.mongo_types import MongoDatabase +from database.schemas.mongo.game_stats import GameStats +from database.utils.datetime import today_only_date + + +async def insert_game_stats(mongodb: MongoDatabase, stats: GameStats): + """Helper function to insert some stats""" + collection = mongodb[GameStats.collection()] + await collection.insert_one(stats.dict(by_alias=True)) + + +@pytest.mark.mongo +async def test_get_stats_non_existent_creates(mongodb: MongoDatabase, test_user_id: int): + """Test getting a user's stats when the db is empty""" + collection = mongodb[GameStats.collection()] + assert await collection.find_one({"user_id": test_user_id}) is None + await crud.get_game_stats(mongodb, test_user_id) + assert await collection.find_one({"user_id": test_user_id}) is not None + + +@pytest.mark.mongo +async def test_get_stats_existing_returns(mongodb: MongoDatabase, test_user_id: int): + """Test getting a user's stats when there's already an entry present""" + stats = GameStats(user_id=test_user_id) + stats.wordle.games = 20 + await insert_game_stats(mongodb, stats) + found_stats = await crud.get_game_stats(mongodb, test_user_id) + assert found_stats.wordle.games == 20 + + +@pytest.mark.mongo +@freeze_time("2022-07-30") +async def test_complete_wordle_game_won(mongodb: MongoDatabase, test_user_id: int): + """Test completing a wordle game when you win""" + await crud.complete_wordle_game(mongodb, test_user_id, win=True, guesses=2) + stats = await crud.get_game_stats(mongodb, test_user_id) + assert stats.wordle.guess_distribution == [0, 1, 0, 0, 0, 0] + assert stats.wordle.games == 1 + assert stats.wordle.wins == 1 + assert stats.wordle.current_streak == 1 + assert stats.wordle.max_streak == 1 + assert stats.wordle.last_win == today_only_date() + + +@pytest.mark.mongo +@freeze_time("2022-07-30") +async def test_complete_wordle_game_lost(mongodb: MongoDatabase, test_user_id: int): + """Test completing a wordle game when you lose""" + stats = GameStats(user_id=test_user_id) + stats.wordle.current_streak = 10 + await insert_game_stats(mongodb, stats) + + await crud.complete_wordle_game(mongodb, test_user_id, win=False) + stats = await crud.get_game_stats(mongodb, test_user_id) + + # Check that streak was broken + assert stats.wordle.current_streak == 0 + assert stats.wordle.games == 1 + assert stats.wordle.guess_distribution == [0, 0, 0, 0, 0, 0] diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py index f3adede..b13b221 100644 --- a/tests/test_database/test_crud/test_tasks.py +++ b/tests/test_database/test_crud/test_tasks.py @@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import tasks as crud from database.enums import TaskType -from database.schemas import Task +from database.schemas.relational import Task @pytest.fixture diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index 3621a82..34f4222 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -3,7 +3,7 @@ import datetime from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_announcements as crud -from database.schemas import UforaAnnouncement, UforaCourse +from database.schemas.relational import UforaAnnouncement, UforaCourse async def test_get_courses_with_announcements_none(postgres: AsyncSession): diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index 601ba8c..140bc4a 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -1,7 +1,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.crud import ufora_courses as crud -from database.schemas import UforaCourse +from database.schemas.relational import UforaCourse async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse): diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index b5c2fc6..96d3383 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -2,12 +2,12 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users as crud -from database.schemas import User +from database.schemas.relational import User async def test_get_or_add_non_existing(postgres: AsyncSession): """Test get_or_add for a user that doesn't exist""" - await crud.get_or_add_user(postgres, 1) + await crud.get_or_add(postgres, 1) statement = select(User) res = (await postgres.execute(statement)).scalars().all() @@ -18,8 +18,8 @@ async def test_get_or_add_non_existing(postgres: AsyncSession): async def test_get_or_add_existing(postgres: AsyncSession): """Test get_or_add for a user that does exist""" - user = await crud.get_or_add_user(postgres, 1) + user = await crud.get_or_add(postgres, 1) bank = user.bank - assert await crud.get_or_add_user(postgres, 1) == user - assert (await crud.get_or_add_user(postgres, 1)).bank == bank + assert await crud.get_or_add(postgres, 1) == user + assert (await crud.get_or_add(postgres, 1)).bank == bank diff --git a/tests/test_database/test_crud/test_wordle.py b/tests/test_database/test_crud/test_wordle.py index 82204f9..3ddc979 100644 --- a/tests/test_database/test_crud/test_wordle.py +++ b/tests/test_database/test_crud/test_wordle.py @@ -1,138 +1,136 @@ -from datetime import date, timedelta +from datetime import datetime, timedelta import pytest from freezegun import freeze_time -from sqlalchemy.ext.asyncio import AsyncSession from database.crud import wordle as crud -from database.schemas import User, WordleGuess, WordleWord +from database.enums import TempStorageKey +from database.mongo_types import MongoCollection, MongoDatabase +from database.schemas.mongo.temporary_storage import TemporaryStorage +from database.schemas.mongo.wordle import WordleGame @pytest.fixture -async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]: - """Fixture to generate some guesses""" - guesses = [] - - for guess in ["TEST", "WORDLE", "WORDS"]: - guess = WordleGuess(user_id=user.user_id, guess=guess) - postgres.add(guess) - await postgres.commit() - - guesses.append(guess) - - return guesses +async def wordle_collection(mongodb: MongoDatabase) -> MongoCollection: + """Fixture to get a reference to the wordle collection""" + yield mongodb[WordleGame.collection()] -@pytest.mark.postgres -async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User): +@pytest.fixture +async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> WordleGame: + """Fixture to create a new game""" + game = WordleGame(user_id=test_user_id) + await wordle_collection.insert_one(game.dict(by_alias=True)) + yield game + + +@pytest.mark.mongo +async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int): + """Test starting a new game""" + result = await wordle_collection.find_one({"user_id": test_user_id}) + assert result is None + + await crud.start_new_wordle_game(mongodb, test_user_id) + + result = await wordle_collection.find_one({"user_id": test_user_id}) + assert result is not None + + +@pytest.mark.mongo +async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int): """Test getting an active game when there is none""" - result = await crud.get_active_wordle_game(postgres, user.user_id) - assert not result - - -@pytest.mark.postgres -async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]): - """Test getting an active game when there is one""" - result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id) - assert result == wordle_guesses - - -@pytest.mark.postgres -async def test_get_daily_word_none(postgres: AsyncSession): - """Test getting the daily word when the database is empty""" - result = await crud.get_daily_word(postgres) + result = await crud.get_active_wordle_game(mongodb, test_user_id) assert result is None -@pytest.mark.postgres +@pytest.mark.mongo +async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame): + """Test getting an active game when there is one""" + result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id) + assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True) + + +@pytest.mark.mongo +async def test_get_daily_word_none(mongodb: MongoDatabase): + """Test getting the daily word when the database is empty""" + result = await crud.get_daily_word(mongodb) + assert result is None + + +@pytest.mark.mongo @freeze_time("2022-07-30") -async def test_get_daily_word_not_today(postgres: AsyncSession): +async def test_get_daily_word_not_today(mongodb: MongoDatabase): """Test getting the daily word when there is an entry, but not for today""" - day = date.today() - timedelta(days=1) + day = datetime.today() - timedelta(days=1) + collection = mongodb[TemporaryStorage.collection()] word = "testword" - word_instance = WordleWord(word=word, day=day) - postgres.add(word_instance) - await postgres.commit() + await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word}) - assert await crud.get_daily_word(postgres) is None + assert await crud.get_daily_word(mongodb) is None -@pytest.mark.postgres +@pytest.mark.mongo @freeze_time("2022-07-30") -async def test_get_daily_word_present(postgres: AsyncSession): +async def test_get_daily_word_present(mongodb: MongoDatabase): """Test getting the daily word when there is one for today""" - day = date.today() + day = datetime.today() + collection = mongodb[TemporaryStorage.collection()] word = "testword" - word_instance = WordleWord(word=word, day=day) - postgres.add(word_instance) - await postgres.commit() + await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word}) - daily_word = await crud.get_daily_word(postgres) - assert daily_word is not None - assert daily_word.word == word + assert await crud.get_daily_word(mongodb) == word -@pytest.mark.postgres +@pytest.mark.mongo @freeze_time("2022-07-30") -async def test_set_daily_word_none_present(postgres: AsyncSession): +async def test_set_daily_word_none_present(mongodb: MongoDatabase): """Test setting the daily word when there is none""" - assert await crud.get_daily_word(postgres) is None + assert await crud.get_daily_word(mongodb) is None word = "testword" - await crud.set_daily_word(postgres, word) - - daily_word = await crud.get_daily_word(postgres) - assert daily_word is not None - assert daily_word.word == word + await crud.set_daily_word(mongodb, word) + assert await crud.get_daily_word(mongodb) == word -@pytest.mark.postgres +@pytest.mark.mongo @freeze_time("2022-07-30") -async def test_set_daily_word_present(postgres: AsyncSession): +async def test_set_daily_word_present(mongodb: MongoDatabase): """Test setting the daily word when there already is one""" word = "testword" - await crud.set_daily_word(postgres, word) - await crud.set_daily_word(postgres, "another word") - - daily_word = await crud.get_daily_word(postgres) - assert daily_word is not None - assert daily_word.word == word + await crud.set_daily_word(mongodb, word) + await crud.set_daily_word(mongodb, "another word") + assert await crud.get_daily_word(mongodb) == word -@pytest.mark.postgres +@pytest.mark.mongo @freeze_time("2022-07-30") -async def test_set_daily_word_force_overwrite(postgres: AsyncSession): +async def test_set_daily_word_force_overwrite(mongodb: MongoDatabase): """Test setting the daily word when there already is one, but "forced" is set to True""" word = "testword" - await crud.set_daily_word(postgres, word) + await crud.set_daily_word(mongodb, word) word = "anotherword" - await crud.set_daily_word(postgres, word, forced=True) - - daily_word = await crud.get_daily_word(postgres) - assert daily_word is not None - assert daily_word.word == word + await crud.set_daily_word(mongodb, word, forced=True) + assert await crud.get_daily_word(mongodb) == word -@pytest.mark.postgres -async def test_make_wordle_guess(postgres: AsyncSession, user: User): +@pytest.mark.mongo +async def test_make_wordle_guess(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int): """Test making a guess in your current game""" - test_user_id = user.user_id - guess = "guess" - await crud.make_wordle_guess(postgres, test_user_id, guess) - assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess] + await crud.make_wordle_guess(mongodb, test_user_id, guess) + wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id) + assert wordle_game.guesses == [guess] other_guess = "otherguess" - await crud.make_wordle_guess(postgres, test_user_id, other_guess) - assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess, other_guess] + await crud.make_wordle_guess(mongodb, test_user_id, other_guess) + wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id) + assert wordle_game.guesses == [guess, other_guess] -@pytest.mark.postgres -async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User): +@pytest.mark.mongo +async def test_reset_wordle_games(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int): """Test dropping the collection of active games""" - test_user_id = user.user_id - - assert await crud.get_active_wordle_game(postgres, test_user_id) - await crud.reset_wordle_games(postgres) - assert not await crud.get_active_wordle_game(postgres, test_user_id) + assert await crud.get_active_wordle_game(mongodb, test_user_id) is not None + await crud.reset_wordle_games(mongodb) + assert await crud.get_active_wordle_game(mongodb, test_user_id) is None diff --git a/tests/test_database/test_crud/test_wordle_stats.py b/tests/test_database/test_crud/test_wordle_stats.py deleted file mode 100644 index 925e5e8..0000000 --- a/tests/test_database/test_crud/test_wordle_stats.py +++ /dev/null @@ -1,72 +0,0 @@ -import datetime - -import pytest -from freezegun import freeze_time -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from database.crud import wordle_stats as crud -from database.schemas import User, WordleStats - - -async def insert_game_stats(postgres: AsyncSession, stats: WordleStats): - """Helper function to insert some stats""" - postgres.add(stats) - await postgres.commit() - - -@pytest.mark.postgres -async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User): - """Test getting a user's stats when the db is empty""" - test_user_id = user.user_id - - statement = select(WordleStats).where(WordleStats.user_id == test_user_id) - assert (await postgres.execute(statement)).scalar_one_or_none() is None - - await crud.get_wordle_stats(postgres, test_user_id) - assert (await postgres.execute(statement)).scalar_one_or_none() is not None - - -@pytest.mark.postgres -async def test_get_stats_existing_returns(postgres: AsyncSession, user: User): - """Test getting a user's stats when there's already an entry present""" - test_user_id = user.user_id - - stats = WordleStats(user_id=test_user_id) - stats.games = 20 - await insert_game_stats(postgres, stats) - found_stats = await crud.get_wordle_stats(postgres, test_user_id) - assert found_stats.games == 20 - - -@pytest.mark.postgres -@freeze_time("2022-07-30") -async def test_complete_wordle_game_won(postgres: AsyncSession, user: User): - """Test completing a wordle game when you win""" - test_user_id = user.user_id - - await crud.complete_wordle_game(postgres, test_user_id, win=True) - stats = await crud.get_wordle_stats(postgres, test_user_id) - assert stats.games == 1 - assert stats.wins == 1 - assert stats.current_streak == 1 - assert stats.highest_streak == 1 - assert stats.last_win == datetime.date.today() - - -@pytest.mark.postgres -@freeze_time("2022-07-30") -async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User): - """Test completing a wordle game when you lose""" - test_user_id = user.user_id - - stats = WordleStats(user_id=test_user_id) - stats.current_streak = 10 - await insert_game_stats(postgres, stats) - - await crud.complete_wordle_game(postgres, test_user_id, win=False) - stats = await crud.get_wordle_stats(postgres, test_user_id) - - # Check that streak was broken - assert stats.current_streak == 0 - assert stats.games == 1 diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index e62d7a3..3dc6adb 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -1,6 +1,6 @@ from sqlalchemy.ext.asyncio import AsyncSession -from database.schemas import UforaCourse +from database.schemas.relational import UforaCourse from database.utils.caches import UforaCourseCache