mirror of https://github.com/stijndcl/didier
				
				
				
			Make game stats crud functions, split mongo schemas out a bit
							parent
							
								
									e4e77502e8
								
							
						
					
					
						commit
						bf41acd9f4
					
				| 
						 | 
				
			
			@ -0,0 +1,59 @@
 | 
			
		|||
import datetime
 | 
			
		||||
from typing import Union
 | 
			
		||||
 | 
			
		||||
from database.mongo_types import MongoDatabase
 | 
			
		||||
from database.schemas.mongo.game_stats import GameStats
 | 
			
		||||
 | 
			
		||||
__all__ = ["get_game_stats", "complete_wordle_game"]
 | 
			
		||||
 | 
			
		||||
from database.utils.datetime import today_only_date
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
 | 
			
		||||
    """Get a user's game stats
 | 
			
		||||
 | 
			
		||||
    If no entry is found, it is first created
 | 
			
		||||
    """
 | 
			
		||||
    collection = database[GameStats.collection()]
 | 
			
		||||
    stats = await collection.find_one({"user_id": user_id})
 | 
			
		||||
    if stats is not None:
 | 
			
		||||
        return GameStats(**stats)
 | 
			
		||||
 | 
			
		||||
    stats = GameStats(user_id=user_id)
 | 
			
		||||
    await collection.insert_one(stats.dict(by_alias=True))
 | 
			
		||||
    return stats
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int):
 | 
			
		||||
    """Update the user's Wordle stats"""
 | 
			
		||||
    stats = await get_game_stats(database, user_id)
 | 
			
		||||
 | 
			
		||||
    update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}}
 | 
			
		||||
 | 
			
		||||
    if win:
 | 
			
		||||
        update["$inc"]["wordle.wins"] = 1
 | 
			
		||||
        update["$inc"][f"wordle.guess_distribution.{guesses}"] = 1
 | 
			
		||||
 | 
			
		||||
        # Update streak
 | 
			
		||||
        today = today_only_date()
 | 
			
		||||
        last_win = stats.wordle.last_win
 | 
			
		||||
        update["$set"]["wordle.last_win"] = today
 | 
			
		||||
 | 
			
		||||
        if last_win is None or (today - last_win).days > 1:
 | 
			
		||||
            # Never won a game before or streak is over
 | 
			
		||||
            update["$set"]["wordle.current_streak"] = 1
 | 
			
		||||
            stats.wordle.current_streak = 1
 | 
			
		||||
        else:
 | 
			
		||||
            # On a streak: increase counter
 | 
			
		||||
            update["$inc"]["wordle.current_streak"] = 1
 | 
			
		||||
            stats.wordle.current_streak += 1
 | 
			
		||||
 | 
			
		||||
        # Update max streak if necessary
 | 
			
		||||
        if stats.wordle.current_streak > stats.wordle.max_streak:
 | 
			
		||||
            update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
 | 
			
		||||
    else:
 | 
			
		||||
        # Streak is over
 | 
			
		||||
        update["$set"]["wordle.current_streak"] = 0
 | 
			
		||||
 | 
			
		||||
    collection = database[GameStats.collection()]
 | 
			
		||||
    collection.update_one({"_id": stats.id}, update)
 | 
			
		||||
| 
						 | 
				
			
			@ -2,7 +2,8 @@ from typing import Optional
 | 
			
		|||
 | 
			
		||||
from database.enums import TempStorageKey
 | 
			
		||||
from database.mongo_types import MongoDatabase
 | 
			
		||||
from database.schemas.mongo import TemporaryStorage, WordleGame
 | 
			
		||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
 | 
			
		||||
from database.schemas.mongo.wordle import WordleGame
 | 
			
		||||
from database.utils.datetime import today_only_date
 | 
			
		||||
 | 
			
		||||
__all__ = [
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,134 +0,0 @@
 | 
			
		|||
import datetime
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from bson import ObjectId
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
from pydantic import BaseModel, Field, validator
 | 
			
		||||
 | 
			
		||||
from database.constants import WORDLE_GUESS_COUNT
 | 
			
		||||
 | 
			
		||||
__all__ = ["MongoBase", "TemporaryStorage", "WordleGame"]
 | 
			
		||||
 | 
			
		||||
from database.utils.datetime import today_only_date
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PyObjectId(ObjectId):
 | 
			
		||||
    """Custom type for bson ObjectIds"""
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __get_validators__(cls):
 | 
			
		||||
        yield cls.validate
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def validate(cls, value: str):
 | 
			
		||||
        """Check that a string is a valid bson ObjectId"""
 | 
			
		||||
        if not ObjectId.is_valid(value):
 | 
			
		||||
            raise ValueError(f"Invalid ObjectId: '{value}'")
 | 
			
		||||
 | 
			
		||||
        return ObjectId(value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __modify_schema__(cls, field_schema: dict):
 | 
			
		||||
        field_schema.update(type="string")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MongoBase(BaseModel):
 | 
			
		||||
    """Base model that properly sets the _id field, and adds one by default"""
 | 
			
		||||
 | 
			
		||||
    id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        """Configuration for encoding and construction"""
 | 
			
		||||
 | 
			
		||||
        allow_population_by_field_name = True
 | 
			
		||||
        arbitrary_types_allowed = True
 | 
			
		||||
        json_encoders = {ObjectId: str, PyObjectId: str}
 | 
			
		||||
        use_enum_values = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MongoCollection(MongoBase, ABC):
 | 
			
		||||
    """Base model for the 'main class' in a collection
 | 
			
		||||
 | 
			
		||||
    This field stores the name of the collection to avoid making typos against it
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TemporaryStorage(MongoCollection):
 | 
			
		||||
    """Collection for lots of random things that don't belong in a full-blown collection"""
 | 
			
		||||
 | 
			
		||||
    key: str
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @overrides
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        return "temporary"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WordleStats(BaseModel):
 | 
			
		||||
    """Model that holds stats about a player's Wordle performance"""
 | 
			
		||||
 | 
			
		||||
    guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
 | 
			
		||||
    last_guess: Optional[datetime.date] = None
 | 
			
		||||
    win_rate: float = 0
 | 
			
		||||
    current_streak: int = 0
 | 
			
		||||
    max_streak: int = 0
 | 
			
		||||
 | 
			
		||||
    @validator("guess_distribution")
 | 
			
		||||
    def validate_guesses_length(cls, value: list[int]):
 | 
			
		||||
        """Check that the distribution of guesses is of the correct length"""
 | 
			
		||||
        if len(value) != 6:
 | 
			
		||||
            raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GameStats(MongoCollection):
 | 
			
		||||
    """Collection that holds stats about how well a user has performed in games"""
 | 
			
		||||
 | 
			
		||||
    user_id: int
 | 
			
		||||
    wordle: Optional[WordleStats] = None
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @overrides
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        return "game_stats"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WordleGame(MongoCollection):
 | 
			
		||||
    """Collection that holds people's active Wordle games"""
 | 
			
		||||
 | 
			
		||||
    day: datetime.datetime = Field(default_factory=lambda: today_only_date())
 | 
			
		||||
    guesses: list[str] = Field(default_factory=list)
 | 
			
		||||
    user_id: int
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @overrides
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        return "wordle"
 | 
			
		||||
 | 
			
		||||
    @validator("guesses")
 | 
			
		||||
    def validate_guesses_length(cls, value: list[int]):
 | 
			
		||||
        """Check that the amount of guesses is of the correct length"""
 | 
			
		||||
        if len(value) > 6:
 | 
			
		||||
            raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def is_game_over(self, word: str) -> bool:
 | 
			
		||||
        """Check if the current game is over"""
 | 
			
		||||
        # No guesses yet
 | 
			
		||||
        if not self.guesses:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Max amount of guesses allowed
 | 
			
		||||
        if len(self.guesses) == WORDLE_GUESS_COUNT:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        # Found the correct word
 | 
			
		||||
        return self.guesses[-1] == word
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,53 @@
 | 
			
		|||
from abc import ABC, abstractmethod
 | 
			
		||||
 | 
			
		||||
from bson import ObjectId
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
 | 
			
		||||
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PyObjectId(ObjectId):
 | 
			
		||||
    """Custom type for bson ObjectIds"""
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __get_validators__(cls):
 | 
			
		||||
        yield cls.validate
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def validate(cls, value: str):
 | 
			
		||||
        """Check that a string is a valid bson ObjectId"""
 | 
			
		||||
        if not ObjectId.is_valid(value):
 | 
			
		||||
            raise ValueError(f"Invalid ObjectId: '{value}'")
 | 
			
		||||
 | 
			
		||||
        return ObjectId(value)
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def __modify_schema__(cls, field_schema: dict):
 | 
			
		||||
        field_schema.update(type="string")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MongoBase(BaseModel):
 | 
			
		||||
    """Base model that properly sets the _id field, and adds one by default"""
 | 
			
		||||
 | 
			
		||||
    id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
 | 
			
		||||
 | 
			
		||||
    class Config:
 | 
			
		||||
        """Configuration for encoding and construction"""
 | 
			
		||||
 | 
			
		||||
        allow_population_by_field_name = True
 | 
			
		||||
        arbitrary_types_allowed = True
 | 
			
		||||
        json_encoders = {ObjectId: str, PyObjectId: str}
 | 
			
		||||
        use_enum_values = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class MongoCollection(MongoBase, ABC):
 | 
			
		||||
    """Base model for the 'main class' in a collection
 | 
			
		||||
 | 
			
		||||
    This field stores the name of the collection to avoid making typos against it
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        """Getter for the name of the collection, in order to avoid typos"""
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,40 @@
 | 
			
		|||
import datetime
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
from pydantic import BaseModel, Field, validator
 | 
			
		||||
 | 
			
		||||
from database.schemas.mongo.common import MongoCollection
 | 
			
		||||
 | 
			
		||||
__all__ = ["GameStats", "WordleStats"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WordleStats(BaseModel):
 | 
			
		||||
    """Model that holds stats about a player's Wordle performance"""
 | 
			
		||||
 | 
			
		||||
    guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
 | 
			
		||||
    last_win: Optional[datetime.datetime] = None
 | 
			
		||||
    wins: int = 0
 | 
			
		||||
    games: int = 0
 | 
			
		||||
    current_streak: int = 0
 | 
			
		||||
    max_streak: int = 0
 | 
			
		||||
 | 
			
		||||
    @validator("guess_distribution")
 | 
			
		||||
    def validate_guesses_length(cls, value: list[int]):
 | 
			
		||||
        """Check that the distribution of guesses is of the correct length"""
 | 
			
		||||
        if len(value) != 6:
 | 
			
		||||
            raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GameStats(MongoCollection):
 | 
			
		||||
    """Collection that holds stats about how well a user has performed in games"""
 | 
			
		||||
 | 
			
		||||
    user_id: int
 | 
			
		||||
    wordle: WordleStats = WordleStats()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @overrides
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        return "game_stats"
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,16 @@
 | 
			
		|||
from overrides import overrides
 | 
			
		||||
 | 
			
		||||
from database.schemas.mongo.common import MongoCollection
 | 
			
		||||
 | 
			
		||||
__all__ = ["TemporaryStorage"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class TemporaryStorage(MongoCollection):
 | 
			
		||||
    """Collection for lots of random things that don't belong in a full-blown collection"""
 | 
			
		||||
 | 
			
		||||
    key: str
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @overrides
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        return "temporary"
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,44 @@
 | 
			
		|||
import datetime
 | 
			
		||||
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
from pydantic import Field, validator
 | 
			
		||||
 | 
			
		||||
from database.constants import WORDLE_GUESS_COUNT
 | 
			
		||||
from database.schemas.mongo.common import MongoCollection
 | 
			
		||||
from database.utils.datetime import today_only_date
 | 
			
		||||
 | 
			
		||||
__all__ = ["WordleGame"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class WordleGame(MongoCollection):
 | 
			
		||||
    """Collection that holds people's active Wordle games"""
 | 
			
		||||
 | 
			
		||||
    day: datetime.datetime = Field(default_factory=lambda: today_only_date())
 | 
			
		||||
    guesses: list[str] = Field(default_factory=list)
 | 
			
		||||
    user_id: int
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @overrides
 | 
			
		||||
    def collection() -> str:
 | 
			
		||||
        return "wordle"
 | 
			
		||||
 | 
			
		||||
    @validator("guesses")
 | 
			
		||||
    def validate_guesses_length(cls, value: list[int]):
 | 
			
		||||
        """Check that the amount of guesses is of the correct length"""
 | 
			
		||||
        if len(value) > 6:
 | 
			
		||||
            raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
 | 
			
		||||
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    def is_game_over(self, word: str) -> bool:
 | 
			
		||||
        """Check if the current game is over"""
 | 
			
		||||
        # No guesses yet
 | 
			
		||||
        if not self.guesses:
 | 
			
		||||
            return False
 | 
			
		||||
 | 
			
		||||
        # Max amount of guesses allowed
 | 
			
		||||
        if len(self.guesses) == WORDLE_GUESS_COUNT:
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
        # Found the correct word
 | 
			
		||||
        return self.guesses[-1] == word
 | 
			
		||||
| 
						 | 
				
			
			@ -6,7 +6,7 @@ import discord
 | 
			
		|||
from overrides import overrides
 | 
			
		||||
 | 
			
		||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
 | 
			
		||||
from database.schemas.mongo import WordleGame
 | 
			
		||||
from database.schemas.mongo.wordle import WordleGame
 | 
			
		||||
from didier.data.embeds.base import EmbedBaseModel
 | 
			
		||||
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,8 @@
 | 
			
		|||
import pytest
 | 
			
		||||
 | 
			
		||||
from database.mongo_types import MongoDatabase
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.mongo
 | 
			
		||||
async def test_get_stats_non_existent(mongodb: MongoDatabase, test_user_id: int):
 | 
			
		||||
    """Test getting a user's stats when the db is empty"""
 | 
			
		||||
| 
						 | 
				
			
			@ -6,7 +6,8 @@ from freezegun import freeze_time
 | 
			
		|||
from database.crud import wordle as crud
 | 
			
		||||
from database.enums import TempStorageKey
 | 
			
		||||
from database.mongo_types import MongoCollection, MongoDatabase
 | 
			
		||||
from database.schemas.mongo import TemporaryStorage, WordleGame
 | 
			
		||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
 | 
			
		||||
from database.schemas.mongo.wordle import WordleGame
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.fixture
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue