mirror of https://github.com/stijndcl/didier
				
				
				
			Make game stats crud functions, split mongo schemas out a bit
							parent
							
								
									e4e77502e8
								
							
						
					
					
						commit
						bf41acd9f4
					
				|  | @ -0,0 +1,59 @@ | ||||||
|  | import datetime | ||||||
|  | from typing import Union | ||||||
|  | 
 | ||||||
|  | from database.mongo_types import MongoDatabase | ||||||
|  | from database.schemas.mongo.game_stats import GameStats | ||||||
|  | 
 | ||||||
|  | __all__ = ["get_game_stats", "complete_wordle_game"] | ||||||
|  | 
 | ||||||
|  | from database.utils.datetime import today_only_date | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats: | ||||||
|  |     """Get a user's game stats | ||||||
|  | 
 | ||||||
|  |     If no entry is found, it is first created | ||||||
|  |     """ | ||||||
|  |     collection = database[GameStats.collection()] | ||||||
|  |     stats = await collection.find_one({"user_id": user_id}) | ||||||
|  |     if stats is not None: | ||||||
|  |         return GameStats(**stats) | ||||||
|  | 
 | ||||||
|  |     stats = GameStats(user_id=user_id) | ||||||
|  |     await collection.insert_one(stats.dict(by_alias=True)) | ||||||
|  |     return stats | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int): | ||||||
|  |     """Update the user's Wordle stats""" | ||||||
|  |     stats = await get_game_stats(database, user_id) | ||||||
|  | 
 | ||||||
|  |     update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}} | ||||||
|  | 
 | ||||||
|  |     if win: | ||||||
|  |         update["$inc"]["wordle.wins"] = 1 | ||||||
|  |         update["$inc"][f"wordle.guess_distribution.{guesses}"] = 1 | ||||||
|  | 
 | ||||||
|  |         # Update streak | ||||||
|  |         today = today_only_date() | ||||||
|  |         last_win = stats.wordle.last_win | ||||||
|  |         update["$set"]["wordle.last_win"] = today | ||||||
|  | 
 | ||||||
|  |         if last_win is None or (today - last_win).days > 1: | ||||||
|  |             # Never won a game before or streak is over | ||||||
|  |             update["$set"]["wordle.current_streak"] = 1 | ||||||
|  |             stats.wordle.current_streak = 1 | ||||||
|  |         else: | ||||||
|  |             # On a streak: increase counter | ||||||
|  |             update["$inc"]["wordle.current_streak"] = 1 | ||||||
|  |             stats.wordle.current_streak += 1 | ||||||
|  | 
 | ||||||
|  |         # Update max streak if necessary | ||||||
|  |         if stats.wordle.current_streak > stats.wordle.max_streak: | ||||||
|  |             update["$set"]["wordle.max_streak"] = stats.wordle.current_streak | ||||||
|  |     else: | ||||||
|  |         # Streak is over | ||||||
|  |         update["$set"]["wordle.current_streak"] = 0 | ||||||
|  | 
 | ||||||
|  |     collection = database[GameStats.collection()] | ||||||
|  |     collection.update_one({"_id": stats.id}, update) | ||||||
|  | @ -2,7 +2,8 @@ from typing import Optional | ||||||
| 
 | 
 | ||||||
| from database.enums import TempStorageKey | from database.enums import TempStorageKey | ||||||
| from database.mongo_types import MongoDatabase | from database.mongo_types import MongoDatabase | ||||||
| from database.schemas.mongo import TemporaryStorage, WordleGame | from database.schemas.mongo.temporary_storage import TemporaryStorage | ||||||
|  | from database.schemas.mongo.wordle import WordleGame | ||||||
| from database.utils.datetime import today_only_date | from database.utils.datetime import today_only_date | ||||||
| 
 | 
 | ||||||
| __all__ = [ | __all__ = [ | ||||||
|  |  | ||||||
|  | @ -1,134 +0,0 @@ | ||||||
| import datetime |  | ||||||
| from abc import ABC, abstractmethod |  | ||||||
| from typing import Optional |  | ||||||
| 
 |  | ||||||
| from bson import ObjectId |  | ||||||
| from overrides import overrides |  | ||||||
| from pydantic import BaseModel, Field, validator |  | ||||||
| 
 |  | ||||||
| from database.constants import WORDLE_GUESS_COUNT |  | ||||||
| 
 |  | ||||||
| __all__ = ["MongoBase", "TemporaryStorage", "WordleGame"] |  | ||||||
| 
 |  | ||||||
| from database.utils.datetime import today_only_date |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class PyObjectId(ObjectId): |  | ||||||
|     """Custom type for bson ObjectIds""" |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def __get_validators__(cls): |  | ||||||
|         yield cls.validate |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def validate(cls, value: str): |  | ||||||
|         """Check that a string is a valid bson ObjectId""" |  | ||||||
|         if not ObjectId.is_valid(value): |  | ||||||
|             raise ValueError(f"Invalid ObjectId: '{value}'") |  | ||||||
| 
 |  | ||||||
|         return ObjectId(value) |  | ||||||
| 
 |  | ||||||
|     @classmethod |  | ||||||
|     def __modify_schema__(cls, field_schema: dict): |  | ||||||
|         field_schema.update(type="string") |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class MongoBase(BaseModel): |  | ||||||
|     """Base model that properly sets the _id field, and adds one by default""" |  | ||||||
| 
 |  | ||||||
|     id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") |  | ||||||
| 
 |  | ||||||
|     class Config: |  | ||||||
|         """Configuration for encoding and construction""" |  | ||||||
| 
 |  | ||||||
|         allow_population_by_field_name = True |  | ||||||
|         arbitrary_types_allowed = True |  | ||||||
|         json_encoders = {ObjectId: str, PyObjectId: str} |  | ||||||
|         use_enum_values = True |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class MongoCollection(MongoBase, ABC): |  | ||||||
|     """Base model for the 'main class' in a collection |  | ||||||
| 
 |  | ||||||
|     This field stores the name of the collection to avoid making typos against it |  | ||||||
|     """ |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     @abstractmethod |  | ||||||
|     def collection() -> str: |  | ||||||
|         raise NotImplementedError |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class TemporaryStorage(MongoCollection): |  | ||||||
|     """Collection for lots of random things that don't belong in a full-blown collection""" |  | ||||||
| 
 |  | ||||||
|     key: str |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     @overrides |  | ||||||
|     def collection() -> str: |  | ||||||
|         return "temporary" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class WordleStats(BaseModel): |  | ||||||
|     """Model that holds stats about a player's Wordle performance""" |  | ||||||
| 
 |  | ||||||
|     guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) |  | ||||||
|     last_guess: Optional[datetime.date] = None |  | ||||||
|     win_rate: float = 0 |  | ||||||
|     current_streak: int = 0 |  | ||||||
|     max_streak: int = 0 |  | ||||||
| 
 |  | ||||||
|     @validator("guess_distribution") |  | ||||||
|     def validate_guesses_length(cls, value: list[int]): |  | ||||||
|         """Check that the distribution of guesses is of the correct length""" |  | ||||||
|         if len(value) != 6: |  | ||||||
|             raise ValueError(f"guess_distribution must be length 6, found {len(value)}") |  | ||||||
| 
 |  | ||||||
|         return value |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class GameStats(MongoCollection): |  | ||||||
|     """Collection that holds stats about how well a user has performed in games""" |  | ||||||
| 
 |  | ||||||
|     user_id: int |  | ||||||
|     wordle: Optional[WordleStats] = None |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     @overrides |  | ||||||
|     def collection() -> str: |  | ||||||
|         return "game_stats" |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| class WordleGame(MongoCollection): |  | ||||||
|     """Collection that holds people's active Wordle games""" |  | ||||||
| 
 |  | ||||||
|     day: datetime.datetime = Field(default_factory=lambda: today_only_date()) |  | ||||||
|     guesses: list[str] = Field(default_factory=list) |  | ||||||
|     user_id: int |  | ||||||
| 
 |  | ||||||
|     @staticmethod |  | ||||||
|     @overrides |  | ||||||
|     def collection() -> str: |  | ||||||
|         return "wordle" |  | ||||||
| 
 |  | ||||||
|     @validator("guesses") |  | ||||||
|     def validate_guesses_length(cls, value: list[int]): |  | ||||||
|         """Check that the amount of guesses is of the correct length""" |  | ||||||
|         if len(value) > 6: |  | ||||||
|             raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") |  | ||||||
| 
 |  | ||||||
|         return value |  | ||||||
| 
 |  | ||||||
|     def is_game_over(self, word: str) -> bool: |  | ||||||
|         """Check if the current game is over""" |  | ||||||
|         # No guesses yet |  | ||||||
|         if not self.guesses: |  | ||||||
|             return False |  | ||||||
| 
 |  | ||||||
|         # Max amount of guesses allowed |  | ||||||
|         if len(self.guesses) == WORDLE_GUESS_COUNT: |  | ||||||
|             return True |  | ||||||
| 
 |  | ||||||
|         # Found the correct word |  | ||||||
|         return self.guesses[-1] == word |  | ||||||
|  | @ -0,0 +1,53 @@ | ||||||
|  | from abc import ABC, abstractmethod | ||||||
|  | 
 | ||||||
|  | from bson import ObjectId | ||||||
|  | from pydantic import BaseModel, Field | ||||||
|  | 
 | ||||||
|  | __all__ = ["PyObjectId", "MongoBase", "MongoCollection"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class PyObjectId(ObjectId): | ||||||
|  |     """Custom type for bson ObjectIds""" | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def __get_validators__(cls): | ||||||
|  |         yield cls.validate | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def validate(cls, value: str): | ||||||
|  |         """Check that a string is a valid bson ObjectId""" | ||||||
|  |         if not ObjectId.is_valid(value): | ||||||
|  |             raise ValueError(f"Invalid ObjectId: '{value}'") | ||||||
|  | 
 | ||||||
|  |         return ObjectId(value) | ||||||
|  | 
 | ||||||
|  |     @classmethod | ||||||
|  |     def __modify_schema__(cls, field_schema: dict): | ||||||
|  |         field_schema.update(type="string") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MongoBase(BaseModel): | ||||||
|  |     """Base model that properly sets the _id field, and adds one by default""" | ||||||
|  | 
 | ||||||
|  |     id: PyObjectId = Field(default_factory=PyObjectId, alias="_id") | ||||||
|  | 
 | ||||||
|  |     class Config: | ||||||
|  |         """Configuration for encoding and construction""" | ||||||
|  | 
 | ||||||
|  |         allow_population_by_field_name = True | ||||||
|  |         arbitrary_types_allowed = True | ||||||
|  |         json_encoders = {ObjectId: str, PyObjectId: str} | ||||||
|  |         use_enum_values = True | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class MongoCollection(MongoBase, ABC): | ||||||
|  |     """Base model for the 'main class' in a collection | ||||||
|  | 
 | ||||||
|  |     This field stores the name of the collection to avoid making typos against it | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     @abstractmethod | ||||||
|  |     def collection() -> str: | ||||||
|  |         """Getter for the name of the collection, in order to avoid typos""" | ||||||
|  |         raise NotImplementedError | ||||||
|  | @ -0,0 +1,40 @@ | ||||||
|  | import datetime | ||||||
|  | from typing import Optional | ||||||
|  | 
 | ||||||
|  | from overrides import overrides | ||||||
|  | from pydantic import BaseModel, Field, validator | ||||||
|  | 
 | ||||||
|  | from database.schemas.mongo.common import MongoCollection | ||||||
|  | 
 | ||||||
|  | __all__ = ["GameStats", "WordleStats"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class WordleStats(BaseModel): | ||||||
|  |     """Model that holds stats about a player's Wordle performance""" | ||||||
|  | 
 | ||||||
|  |     guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) | ||||||
|  |     last_win: Optional[datetime.datetime] = None | ||||||
|  |     wins: int = 0 | ||||||
|  |     games: int = 0 | ||||||
|  |     current_streak: int = 0 | ||||||
|  |     max_streak: int = 0 | ||||||
|  | 
 | ||||||
|  |     @validator("guess_distribution") | ||||||
|  |     def validate_guesses_length(cls, value: list[int]): | ||||||
|  |         """Check that the distribution of guesses is of the correct length""" | ||||||
|  |         if len(value) != 6: | ||||||
|  |             raise ValueError(f"guess_distribution must be length 6, found {len(value)}") | ||||||
|  | 
 | ||||||
|  |         return value | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class GameStats(MongoCollection): | ||||||
|  |     """Collection that holds stats about how well a user has performed in games""" | ||||||
|  | 
 | ||||||
|  |     user_id: int | ||||||
|  |     wordle: WordleStats = WordleStats() | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     @overrides | ||||||
|  |     def collection() -> str: | ||||||
|  |         return "game_stats" | ||||||
|  | @ -0,0 +1,16 @@ | ||||||
|  | from overrides import overrides | ||||||
|  | 
 | ||||||
|  | from database.schemas.mongo.common import MongoCollection | ||||||
|  | 
 | ||||||
|  | __all__ = ["TemporaryStorage"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class TemporaryStorage(MongoCollection): | ||||||
|  |     """Collection for lots of random things that don't belong in a full-blown collection""" | ||||||
|  | 
 | ||||||
|  |     key: str | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     @overrides | ||||||
|  |     def collection() -> str: | ||||||
|  |         return "temporary" | ||||||
|  | @ -0,0 +1,44 @@ | ||||||
|  | import datetime | ||||||
|  | 
 | ||||||
|  | from overrides import overrides | ||||||
|  | from pydantic import Field, validator | ||||||
|  | 
 | ||||||
|  | from database.constants import WORDLE_GUESS_COUNT | ||||||
|  | from database.schemas.mongo.common import MongoCollection | ||||||
|  | from database.utils.datetime import today_only_date | ||||||
|  | 
 | ||||||
|  | __all__ = ["WordleGame"] | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | class WordleGame(MongoCollection): | ||||||
|  |     """Collection that holds people's active Wordle games""" | ||||||
|  | 
 | ||||||
|  |     day: datetime.datetime = Field(default_factory=lambda: today_only_date()) | ||||||
|  |     guesses: list[str] = Field(default_factory=list) | ||||||
|  |     user_id: int | ||||||
|  | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     @overrides | ||||||
|  |     def collection() -> str: | ||||||
|  |         return "wordle" | ||||||
|  | 
 | ||||||
|  |     @validator("guesses") | ||||||
|  |     def validate_guesses_length(cls, value: list[int]): | ||||||
|  |         """Check that the amount of guesses is of the correct length""" | ||||||
|  |         if len(value) > 6: | ||||||
|  |             raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") | ||||||
|  | 
 | ||||||
|  |         return value | ||||||
|  | 
 | ||||||
|  |     def is_game_over(self, word: str) -> bool: | ||||||
|  |         """Check if the current game is over""" | ||||||
|  |         # No guesses yet | ||||||
|  |         if not self.guesses: | ||||||
|  |             return False | ||||||
|  | 
 | ||||||
|  |         # Max amount of guesses allowed | ||||||
|  |         if len(self.guesses) == WORDLE_GUESS_COUNT: | ||||||
|  |             return True | ||||||
|  | 
 | ||||||
|  |         # Found the correct word | ||||||
|  |         return self.guesses[-1] == word | ||||||
|  | @ -6,7 +6,7 @@ import discord | ||||||
| from overrides import overrides | from overrides import overrides | ||||||
| 
 | 
 | ||||||
| from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH | from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH | ||||||
| from database.schemas.mongo import WordleGame | from database.schemas.mongo.wordle import WordleGame | ||||||
| from didier.data.embeds.base import EmbedBaseModel | from didier.data.embeds.base import EmbedBaseModel | ||||||
| from didier.utils.types.datetime import int_to_weekday, tz_aware_now | from didier.utils.types.datetime import int_to_weekday, tz_aware_now | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -0,0 +1,8 @@ | ||||||
|  | import pytest | ||||||
|  | 
 | ||||||
|  | from database.mongo_types import MongoDatabase | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @pytest.mark.mongo | ||||||
|  | async def test_get_stats_non_existent(mongodb: MongoDatabase, test_user_id: int): | ||||||
|  |     """Test getting a user's stats when the db is empty""" | ||||||
|  | @ -6,7 +6,8 @@ from freezegun import freeze_time | ||||||
| from database.crud import wordle as crud | from database.crud import wordle as crud | ||||||
| from database.enums import TempStorageKey | from database.enums import TempStorageKey | ||||||
| from database.mongo_types import MongoCollection, MongoDatabase | from database.mongo_types import MongoCollection, MongoDatabase | ||||||
| from database.schemas.mongo import TemporaryStorage, WordleGame | from database.schemas.mongo.temporary_storage import TemporaryStorage | ||||||
|  | from database.schemas.mongo.wordle import WordleGame | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| @pytest.fixture | @pytest.fixture | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue