mirror of https://github.com/stijndcl/didier
120 lines
3.2 KiB
Python
120 lines
3.2 KiB
Python
import datetime
|
|
from abc import ABC, abstractmethod
|
|
from typing import Optional
|
|
|
|
from bson import ObjectId
|
|
from overrides import overrides
|
|
from pydantic import BaseModel, Field, validator
|
|
|
|
__all__ = ["MongoBase", "TemporaryStorage", "WordleGame"]
|
|
|
|
from database.utils.datetime import today_only_date
|
|
|
|
|
|
class PyObjectId(ObjectId):
|
|
"""Custom type for bson ObjectIds"""
|
|
|
|
@classmethod
|
|
def __get_validators__(cls):
|
|
yield cls.validate
|
|
|
|
@classmethod
|
|
def validate(cls, value: str):
|
|
"""Check that a string is a valid bson ObjectId"""
|
|
if not ObjectId.is_valid(value):
|
|
raise ValueError(f"Invalid ObjectId: '{value}'")
|
|
|
|
return ObjectId(value)
|
|
|
|
@classmethod
|
|
def __modify_schema__(cls, field_schema: dict):
|
|
field_schema.update(type="string")
|
|
|
|
|
|
class MongoBase(BaseModel):
|
|
"""Base model that properly sets the _id field, and adds one by default"""
|
|
|
|
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
|
|
|
class Config:
|
|
"""Configuration for encoding and construction"""
|
|
|
|
allow_population_by_field_name = True
|
|
arbitrary_types_allowed = True
|
|
json_encoders = {ObjectId: str, PyObjectId: str}
|
|
use_enum_values = True
|
|
|
|
|
|
class MongoCollection(MongoBase, ABC):
|
|
"""Base model for the 'main class' in a collection
|
|
|
|
This field stores the name of the collection to avoid making typos against it
|
|
"""
|
|
|
|
@staticmethod
|
|
@abstractmethod
|
|
def collection() -> str:
|
|
raise NotImplementedError
|
|
|
|
|
|
class TemporaryStorage(MongoCollection):
|
|
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
|
|
|
key: str
|
|
|
|
@staticmethod
|
|
@overrides
|
|
def collection() -> str:
|
|
return "temporary"
|
|
|
|
|
|
class WordleStats(BaseModel):
|
|
"""Model that holds stats about a player's Wordle performance"""
|
|
|
|
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
|
last_guess: Optional[datetime.date] = None
|
|
win_rate: float = 0
|
|
current_streak: int = 0
|
|
max_streak: int = 0
|
|
|
|
@validator("guess_distribution")
|
|
def validate_guesses_length(cls, value: list[int]):
|
|
"""Check that the distribution of guesses is of the correct length"""
|
|
if len(value) != 6:
|
|
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
|
|
|
return value
|
|
|
|
|
|
class GameStats(MongoCollection):
|
|
"""Collection that holds stats about how well a user has performed in games"""
|
|
|
|
user_id: int
|
|
wordle: Optional[WordleStats] = None
|
|
|
|
@staticmethod
|
|
@overrides
|
|
def collection() -> str:
|
|
return "game_stats"
|
|
|
|
|
|
class WordleGame(MongoCollection):
|
|
"""Collection that holds people's active Wordle games"""
|
|
|
|
day: datetime.date = Field(default_factory=lambda: today_only_date())
|
|
guesses: list[str] = Field(default_factory=list)
|
|
user_id: int
|
|
|
|
@staticmethod
|
|
@overrides
|
|
def collection() -> str:
|
|
return "wordle"
|
|
|
|
@validator("guesses")
|
|
def validate_guesses_length(cls, value: list[int]):
|
|
"""Check that the amount of guesses is of the correct length"""
|
|
if len(value) > 6:
|
|
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
|
|
|
return value
|