diff --git a/database/crud/game_stats.py b/database/crud/game_stats.py new file mode 100644 index 0000000..90e2e98 --- /dev/null +++ b/database/crud/game_stats.py @@ -0,0 +1,59 @@ +import datetime +from typing import Union + +from database.mongo_types import MongoDatabase +from database.schemas.mongo.game_stats import GameStats + +__all__ = ["get_game_stats", "complete_wordle_game"] + +from database.utils.datetime import today_only_date + + +async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats: + """Get a user's game stats + + If no entry is found, it is first created + """ + collection = database[GameStats.collection()] + stats = await collection.find_one({"user_id": user_id}) + if stats is not None: + return GameStats(**stats) + + stats = GameStats(user_id=user_id) + await collection.insert_one(stats.dict(by_alias=True)) + return stats + + +async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int): + """Update the user's Wordle stats""" + stats = await get_game_stats(database, user_id) + + update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}} + + if win: + update["$inc"]["wordle.wins"] = 1 + update["$inc"][f"wordle.guess_distribution.{guesses}"] = 1 + + # Update streak + today = today_only_date() + last_win = stats.wordle.last_win + update["$set"]["wordle.last_win"] = today + + if last_win is None or (today - last_win).days > 1: + # Never won a game before or streak is over + update["$set"]["wordle.current_streak"] = 1 + stats.wordle.current_streak = 1 + else: + # On a streak: increase counter + update["$inc"]["wordle.current_streak"] = 1 + stats.wordle.current_streak += 1 + + # Update max streak if necessary + if stats.wordle.current_streak > stats.wordle.max_streak: + update["$set"]["wordle.max_streak"] = stats.wordle.current_streak + else: + # Streak is over + update["$set"]["wordle.current_streak"] = 0 + + collection = database[GameStats.collection()] + collection.update_one({"_id": stats.id}, update) diff --git a/database/crud/wordle.py b/database/crud/wordle.py index acd0242..4b2a312 100644 --- a/database/crud/wordle.py +++ b/database/crud/wordle.py @@ -2,7 +2,8 @@ from typing import Optional from database.enums import TempStorageKey from database.mongo_types import MongoDatabase -from database.schemas.mongo import TemporaryStorage, WordleGame +from database.schemas.mongo.temporary_storage import TemporaryStorage +from database.schemas.mongo.wordle import WordleGame from database.utils.datetime import today_only_date __all__ = [ diff --git a/database/schemas/mongo.py b/database/schemas/mongo.py deleted file mode 100644 index aeced32..0000000 --- a/database/schemas/mongo.py +++ /dev/null @@ -1,134 +0,0 @@ -import datetime -from abc import ABC, abstractmethod -from typing import Optional - -from bson import ObjectId -from overrides import overrides -from pydantic import BaseModel, Field, validator - -from database.constants import WORDLE_GUESS_COUNT - -__all__ = ["MongoBase", "TemporaryStorage", "WordleGame"] - -from database.utils.datetime import today_only_date - - -class PyObjectId(ObjectId): - """Custom type for bson ObjectIds""" - - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, value: str): - """Check that a string is a valid bson ObjectId""" - if not ObjectId.is_valid(value): - raise ValueError(f"Invalid ObjectId: '{value}'") - - return ObjectId(value) - - @classmethod - def __modify_schema__(cls, field_schema: dict): - field_schema.update(type="string") - - -class MongoBase(BaseModel): - """Base model that properly sets the _id field, and adds one by default""" - - id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") - - class Config: - """Configuration for encoding and construction""" - - allow_population_by_field_name = True - arbitrary_types_allowed = True - json_encoders = {ObjectId: str, PyObjectId: str} - use_enum_values = True - - -class MongoCollection(MongoBase, ABC): - """Base model for the 'main class' in a collection - - This field stores the name of the collection to avoid making typos against it - """ - - @staticmethod - @abstractmethod - def collection() -> str: - raise NotImplementedError - - -class TemporaryStorage(MongoCollection): - """Collection for lots of random things that don't belong in a full-blown collection""" - - key: str - - @staticmethod - @overrides - def collection() -> str: - return "temporary" - - -class WordleStats(BaseModel): - """Model that holds stats about a player's Wordle performance""" - - guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) - last_guess: Optional[datetime.date] = None - win_rate: float = 0 - current_streak: int = 0 - max_streak: int = 0 - - @validator("guess_distribution") - def validate_guesses_length(cls, value: list[int]): - """Check that the distribution of guesses is of the correct length""" - if len(value) != 6: - raise ValueError(f"guess_distribution must be length 6, found {len(value)}") - - return value - - -class GameStats(MongoCollection): - """Collection that holds stats about how well a user has performed in games""" - - user_id: int - wordle: Optional[WordleStats] = None - - @staticmethod - @overrides - def collection() -> str: - return "game_stats" - - -class WordleGame(MongoCollection): - """Collection that holds people's active Wordle games""" - - day: datetime.datetime = Field(default_factory=lambda: today_only_date()) - guesses: list[str] = Field(default_factory=list) - user_id: int - - @staticmethod - @overrides - def collection() -> str: - return "wordle" - - @validator("guesses") - def validate_guesses_length(cls, value: list[int]): - """Check that the amount of guesses is of the correct length""" - if len(value) > 6: - raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") - - return value - - def is_game_over(self, word: str) -> bool: - """Check if the current game is over""" - # No guesses yet - if not self.guesses: - return False - - # Max amount of guesses allowed - if len(self.guesses) == WORDLE_GUESS_COUNT: - return True - - # Found the correct word - return self.guesses[-1] == word diff --git a/database/schemas/mongo/__init__.py b/database/schemas/mongo/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/database/schemas/mongo/common.py b/database/schemas/mongo/common.py new file mode 100644 index 0000000..dcddb54 --- /dev/null +++ b/database/schemas/mongo/common.py @@ -0,0 +1,53 @@ +from abc import ABC, abstractmethod + +from bson import ObjectId +from pydantic import BaseModel, Field + +__all__ = ["PyObjectId", "MongoBase", "MongoCollection"] + + +class PyObjectId(ObjectId): + """Custom type for bson ObjectIds""" + + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, value: str): + """Check that a string is a valid bson ObjectId""" + if not ObjectId.is_valid(value): + raise ValueError(f"Invalid ObjectId: '{value}'") + + return ObjectId(value) + + @classmethod + def __modify_schema__(cls, field_schema: dict): + field_schema.update(type="string") + + +class MongoBase(BaseModel): + """Base model that properly sets the _id field, and adds one by default""" + + id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") + + class Config: + """Configuration for encoding and construction""" + + allow_population_by_field_name = True + arbitrary_types_allowed = True + json_encoders = {ObjectId: str, PyObjectId: str} + use_enum_values = True + + +class MongoCollection(MongoBase, ABC): + """Base model for the 'main class' in a collection + + This field stores the name of the collection to avoid making typos against it + """ + + @staticmethod + @abstractmethod + def collection() -> str: + """Getter for the name of the collection, in order to avoid typos""" + raise NotImplementedError diff --git a/database/schemas/mongo/game_stats.py b/database/schemas/mongo/game_stats.py new file mode 100644 index 0000000..af08117 --- /dev/null +++ b/database/schemas/mongo/game_stats.py @@ -0,0 +1,40 @@ +import datetime +from typing import Optional + +from overrides import overrides +from pydantic import BaseModel, Field, validator + +from database.schemas.mongo.common import MongoCollection + +__all__ = ["GameStats", "WordleStats"] + + +class WordleStats(BaseModel): + """Model that holds stats about a player's Wordle performance""" + + guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) + last_win: Optional[datetime.datetime] = None + wins: int = 0 + games: int = 0 + current_streak: int = 0 + max_streak: int = 0 + + @validator("guess_distribution") + def validate_guesses_length(cls, value: list[int]): + """Check that the distribution of guesses is of the correct length""" + if len(value) != 6: + raise ValueError(f"guess_distribution must be length 6, found {len(value)}") + + return value + + +class GameStats(MongoCollection): + """Collection that holds stats about how well a user has performed in games""" + + user_id: int + wordle: WordleStats = WordleStats() + + @staticmethod + @overrides + def collection() -> str: + return "game_stats" diff --git a/database/schemas/mongo/temporary_storage.py b/database/schemas/mongo/temporary_storage.py new file mode 100644 index 0000000..deb444b --- /dev/null +++ b/database/schemas/mongo/temporary_storage.py @@ -0,0 +1,16 @@ +from overrides import overrides + +from database.schemas.mongo.common import MongoCollection + +__all__ = ["TemporaryStorage"] + + +class TemporaryStorage(MongoCollection): + """Collection for lots of random things that don't belong in a full-blown collection""" + + key: str + + @staticmethod + @overrides + def collection() -> str: + return "temporary" diff --git a/database/schemas/mongo/wordle.py b/database/schemas/mongo/wordle.py new file mode 100644 index 0000000..b03189b --- /dev/null +++ b/database/schemas/mongo/wordle.py @@ -0,0 +1,44 @@ +import datetime + +from overrides import overrides +from pydantic import Field, validator + +from database.constants import WORDLE_GUESS_COUNT +from database.schemas.mongo.common import MongoCollection +from database.utils.datetime import today_only_date + +__all__ = ["WordleGame"] + + +class WordleGame(MongoCollection): + """Collection that holds people's active Wordle games""" + + day: datetime.datetime = Field(default_factory=lambda: today_only_date()) + guesses: list[str] = Field(default_factory=list) + user_id: int + + @staticmethod + @overrides + def collection() -> str: + return "wordle" + + @validator("guesses") + def validate_guesses_length(cls, value: list[int]): + """Check that the amount of guesses is of the correct length""" + if len(value) > 6: + raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") + + return value + + def is_game_over(self, word: str) -> bool: + """Check if the current game is over""" + # No guesses yet + if not self.guesses: + return False + + # Max amount of guesses allowed + if len(self.guesses) == WORDLE_GUESS_COUNT: + return True + + # Found the correct word + return self.guesses[-1] == word diff --git a/didier/data/embeds/wordle.py b/didier/data/embeds/wordle.py index 3f54d1e..44439c0 100644 --- a/didier/data/embeds/wordle.py +++ b/didier/data/embeds/wordle.py @@ -6,7 +6,7 @@ import discord from overrides import overrides from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH -from database.schemas.mongo import WordleGame +from database.schemas.mongo.wordle import WordleGame from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import int_to_weekday, tz_aware_now diff --git a/tests/test_database/test_crud/test_game_stats.py b/tests/test_database/test_crud/test_game_stats.py new file mode 100644 index 0000000..c48570a --- /dev/null +++ b/tests/test_database/test_crud/test_game_stats.py @@ -0,0 +1,8 @@ +import pytest + +from database.mongo_types import MongoDatabase + + +@pytest.mark.mongo +async def test_get_stats_non_existent(mongodb: MongoDatabase, test_user_id: int): + """Test getting a user's stats when the db is empty""" diff --git a/tests/test_database/test_crud/test_wordle.py b/tests/test_database/test_crud/test_wordle.py index 7c61e84..3ddc979 100644 --- a/tests/test_database/test_crud/test_wordle.py +++ b/tests/test_database/test_crud/test_wordle.py @@ -6,7 +6,8 @@ from freezegun import freeze_time from database.crud import wordle as crud from database.enums import TempStorageKey from database.mongo_types import MongoCollection, MongoDatabase -from database.schemas.mongo import TemporaryStorage, WordleGame +from database.schemas.mongo.temporary_storage import TemporaryStorage +from database.schemas.mongo.wordle import WordleGame @pytest.fixture