mirror of https://github.com/stijndcl/didier
Make game stats crud functions, split mongo schemas out a bit
parent
e4e77502e8
commit
bf41acd9f4
|
@ -0,0 +1,59 @@
|
||||||
|
import datetime
|
||||||
|
from typing import Union
|
||||||
|
|
||||||
|
from database.mongo_types import MongoDatabase
|
||||||
|
from database.schemas.mongo.game_stats import GameStats
|
||||||
|
|
||||||
|
__all__ = ["get_game_stats", "complete_wordle_game"]
|
||||||
|
|
||||||
|
from database.utils.datetime import today_only_date
|
||||||
|
|
||||||
|
|
||||||
|
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
|
||||||
|
"""Get a user's game stats
|
||||||
|
|
||||||
|
If no entry is found, it is first created
|
||||||
|
"""
|
||||||
|
collection = database[GameStats.collection()]
|
||||||
|
stats = await collection.find_one({"user_id": user_id})
|
||||||
|
if stats is not None:
|
||||||
|
return GameStats(**stats)
|
||||||
|
|
||||||
|
stats = GameStats(user_id=user_id)
|
||||||
|
await collection.insert_one(stats.dict(by_alias=True))
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int):
|
||||||
|
"""Update the user's Wordle stats"""
|
||||||
|
stats = await get_game_stats(database, user_id)
|
||||||
|
|
||||||
|
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}}
|
||||||
|
|
||||||
|
if win:
|
||||||
|
update["$inc"]["wordle.wins"] = 1
|
||||||
|
update["$inc"][f"wordle.guess_distribution.{guesses}"] = 1
|
||||||
|
|
||||||
|
# Update streak
|
||||||
|
today = today_only_date()
|
||||||
|
last_win = stats.wordle.last_win
|
||||||
|
update["$set"]["wordle.last_win"] = today
|
||||||
|
|
||||||
|
if last_win is None or (today - last_win).days > 1:
|
||||||
|
# Never won a game before or streak is over
|
||||||
|
update["$set"]["wordle.current_streak"] = 1
|
||||||
|
stats.wordle.current_streak = 1
|
||||||
|
else:
|
||||||
|
# On a streak: increase counter
|
||||||
|
update["$inc"]["wordle.current_streak"] = 1
|
||||||
|
stats.wordle.current_streak += 1
|
||||||
|
|
||||||
|
# Update max streak if necessary
|
||||||
|
if stats.wordle.current_streak > stats.wordle.max_streak:
|
||||||
|
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
|
||||||
|
else:
|
||||||
|
# Streak is over
|
||||||
|
update["$set"]["wordle.current_streak"] = 0
|
||||||
|
|
||||||
|
collection = database[GameStats.collection()]
|
||||||
|
collection.update_one({"_id": stats.id}, update)
|
|
@ -2,7 +2,8 @@ from typing import Optional
|
||||||
|
|
||||||
from database.enums import TempStorageKey
|
from database.enums import TempStorageKey
|
||||||
from database.mongo_types import MongoDatabase
|
from database.mongo_types import MongoDatabase
|
||||||
from database.schemas.mongo import TemporaryStorage, WordleGame
|
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||||
|
from database.schemas.mongo.wordle import WordleGame
|
||||||
from database.utils.datetime import today_only_date
|
from database.utils.datetime import today_only_date
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
@ -1,134 +0,0 @@
|
||||||
import datetime
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from bson import ObjectId
|
|
||||||
from overrides import overrides
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
|
|
||||||
from database.constants import WORDLE_GUESS_COUNT
|
|
||||||
|
|
||||||
__all__ = ["MongoBase", "TemporaryStorage", "WordleGame"]
|
|
||||||
|
|
||||||
from database.utils.datetime import today_only_date
|
|
||||||
|
|
||||||
|
|
||||||
class PyObjectId(ObjectId):
|
|
||||||
"""Custom type for bson ObjectIds"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls):
|
|
||||||
yield cls.validate
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate(cls, value: str):
|
|
||||||
"""Check that a string is a valid bson ObjectId"""
|
|
||||||
if not ObjectId.is_valid(value):
|
|
||||||
raise ValueError(f"Invalid ObjectId: '{value}'")
|
|
||||||
|
|
||||||
return ObjectId(value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __modify_schema__(cls, field_schema: dict):
|
|
||||||
field_schema.update(type="string")
|
|
||||||
|
|
||||||
|
|
||||||
class MongoBase(BaseModel):
|
|
||||||
"""Base model that properly sets the _id field, and adds one by default"""
|
|
||||||
|
|
||||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for encoding and construction"""
|
|
||||||
|
|
||||||
allow_population_by_field_name = True
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
json_encoders = {ObjectId: str, PyObjectId: str}
|
|
||||||
use_enum_values = True
|
|
||||||
|
|
||||||
|
|
||||||
class MongoCollection(MongoBase, ABC):
|
|
||||||
"""Base model for the 'main class' in a collection
|
|
||||||
|
|
||||||
This field stores the name of the collection to avoid making typos against it
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def collection() -> str:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class TemporaryStorage(MongoCollection):
|
|
||||||
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@overrides
|
|
||||||
def collection() -> str:
|
|
||||||
return "temporary"
|
|
||||||
|
|
||||||
|
|
||||||
class WordleStats(BaseModel):
|
|
||||||
"""Model that holds stats about a player's Wordle performance"""
|
|
||||||
|
|
||||||
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
|
||||||
last_guess: Optional[datetime.date] = None
|
|
||||||
win_rate: float = 0
|
|
||||||
current_streak: int = 0
|
|
||||||
max_streak: int = 0
|
|
||||||
|
|
||||||
@validator("guess_distribution")
|
|
||||||
def validate_guesses_length(cls, value: list[int]):
|
|
||||||
"""Check that the distribution of guesses is of the correct length"""
|
|
||||||
if len(value) != 6:
|
|
||||||
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class GameStats(MongoCollection):
|
|
||||||
"""Collection that holds stats about how well a user has performed in games"""
|
|
||||||
|
|
||||||
user_id: int
|
|
||||||
wordle: Optional[WordleStats] = None
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@overrides
|
|
||||||
def collection() -> str:
|
|
||||||
return "game_stats"
|
|
||||||
|
|
||||||
|
|
||||||
class WordleGame(MongoCollection):
|
|
||||||
"""Collection that holds people's active Wordle games"""
|
|
||||||
|
|
||||||
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
|
|
||||||
guesses: list[str] = Field(default_factory=list)
|
|
||||||
user_id: int
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@overrides
|
|
||||||
def collection() -> str:
|
|
||||||
return "wordle"
|
|
||||||
|
|
||||||
@validator("guesses")
|
|
||||||
def validate_guesses_length(cls, value: list[int]):
|
|
||||||
"""Check that the amount of guesses is of the correct length"""
|
|
||||||
if len(value) > 6:
|
|
||||||
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def is_game_over(self, word: str) -> bool:
|
|
||||||
"""Check if the current game is over"""
|
|
||||||
# No guesses yet
|
|
||||||
if not self.guesses:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Max amount of guesses allowed
|
|
||||||
if len(self.guesses) == WORDLE_GUESS_COUNT:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Found the correct word
|
|
||||||
return self.guesses[-1] == word
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
from bson import ObjectId
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
|
||||||
|
|
||||||
|
|
||||||
|
class PyObjectId(ObjectId):
|
||||||
|
"""Custom type for bson ObjectIds"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, value: str):
|
||||||
|
"""Check that a string is a valid bson ObjectId"""
|
||||||
|
if not ObjectId.is_valid(value):
|
||||||
|
raise ValueError(f"Invalid ObjectId: '{value}'")
|
||||||
|
|
||||||
|
return ObjectId(value)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __modify_schema__(cls, field_schema: dict):
|
||||||
|
field_schema.update(type="string")
|
||||||
|
|
||||||
|
|
||||||
|
class MongoBase(BaseModel):
|
||||||
|
"""Base model that properly sets the _id field, and adds one by default"""
|
||||||
|
|
||||||
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for encoding and construction"""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
json_encoders = {ObjectId: str, PyObjectId: str}
|
||||||
|
use_enum_values = True
|
||||||
|
|
||||||
|
|
||||||
|
class MongoCollection(MongoBase, ABC):
|
||||||
|
"""Base model for the 'main class' in a collection
|
||||||
|
|
||||||
|
This field stores the name of the collection to avoid making typos against it
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def collection() -> str:
|
||||||
|
"""Getter for the name of the collection, in order to avoid typos"""
|
||||||
|
raise NotImplementedError
|
|
@ -0,0 +1,40 @@
|
||||||
|
import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from overrides import overrides
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
|
from database.schemas.mongo.common import MongoCollection
|
||||||
|
|
||||||
|
__all__ = ["GameStats", "WordleStats"]
|
||||||
|
|
||||||
|
|
||||||
|
class WordleStats(BaseModel):
|
||||||
|
"""Model that holds stats about a player's Wordle performance"""
|
||||||
|
|
||||||
|
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
||||||
|
last_win: Optional[datetime.datetime] = None
|
||||||
|
wins: int = 0
|
||||||
|
games: int = 0
|
||||||
|
current_streak: int = 0
|
||||||
|
max_streak: int = 0
|
||||||
|
|
||||||
|
@validator("guess_distribution")
|
||||||
|
def validate_guesses_length(cls, value: list[int]):
|
||||||
|
"""Check that the distribution of guesses is of the correct length"""
|
||||||
|
if len(value) != 6:
|
||||||
|
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class GameStats(MongoCollection):
|
||||||
|
"""Collection that holds stats about how well a user has performed in games"""
|
||||||
|
|
||||||
|
user_id: int
|
||||||
|
wordle: WordleStats = WordleStats()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@overrides
|
||||||
|
def collection() -> str:
|
||||||
|
return "game_stats"
|
|
@ -0,0 +1,16 @@
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
|
from database.schemas.mongo.common import MongoCollection
|
||||||
|
|
||||||
|
__all__ = ["TemporaryStorage"]
|
||||||
|
|
||||||
|
|
||||||
|
class TemporaryStorage(MongoCollection):
|
||||||
|
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
||||||
|
|
||||||
|
key: str
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@overrides
|
||||||
|
def collection() -> str:
|
||||||
|
return "temporary"
|
|
@ -0,0 +1,44 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from overrides import overrides
|
||||||
|
from pydantic import Field, validator
|
||||||
|
|
||||||
|
from database.constants import WORDLE_GUESS_COUNT
|
||||||
|
from database.schemas.mongo.common import MongoCollection
|
||||||
|
from database.utils.datetime import today_only_date
|
||||||
|
|
||||||
|
__all__ = ["WordleGame"]
|
||||||
|
|
||||||
|
|
||||||
|
class WordleGame(MongoCollection):
|
||||||
|
"""Collection that holds people's active Wordle games"""
|
||||||
|
|
||||||
|
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
|
||||||
|
guesses: list[str] = Field(default_factory=list)
|
||||||
|
user_id: int
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@overrides
|
||||||
|
def collection() -> str:
|
||||||
|
return "wordle"
|
||||||
|
|
||||||
|
@validator("guesses")
|
||||||
|
def validate_guesses_length(cls, value: list[int]):
|
||||||
|
"""Check that the amount of guesses is of the correct length"""
|
||||||
|
if len(value) > 6:
|
||||||
|
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
def is_game_over(self, word: str) -> bool:
|
||||||
|
"""Check if the current game is over"""
|
||||||
|
# No guesses yet
|
||||||
|
if not self.guesses:
|
||||||
|
return False
|
||||||
|
|
||||||
|
# Max amount of guesses allowed
|
||||||
|
if len(self.guesses) == WORDLE_GUESS_COUNT:
|
||||||
|
return True
|
||||||
|
|
||||||
|
# Found the correct word
|
||||||
|
return self.guesses[-1] == word
|
|
@ -6,7 +6,7 @@ import discord
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||||
from database.schemas.mongo import WordleGame
|
from database.schemas.mongo.wordle import WordleGame
|
||||||
from didier.data.embeds.base import EmbedBaseModel
|
from didier.data.embeds.base import EmbedBaseModel
|
||||||
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
|
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from database.mongo_types import MongoDatabase
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.mongo
|
||||||
|
async def test_get_stats_non_existent(mongodb: MongoDatabase, test_user_id: int):
|
||||||
|
"""Test getting a user's stats when the db is empty"""
|
|
@ -6,7 +6,8 @@ from freezegun import freeze_time
|
||||||
from database.crud import wordle as crud
|
from database.crud import wordle as crud
|
||||||
from database.enums import TempStorageKey
|
from database.enums import TempStorageKey
|
||||||
from database.mongo_types import MongoCollection, MongoDatabase
|
from database.mongo_types import MongoCollection, MongoDatabase
|
||||||
from database.schemas.mongo import TemporaryStorage, WordleGame
|
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||||
|
from database.schemas.mongo.wordle import WordleGame
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
Loading…
Reference in New Issue