mirror of
https://github.com/stijndcl/didier.git
synced 2026-04-07 23:55:46 +02:00
Remove mongo & fix tests
This commit is contained in:
parent
7b2109fb07
commit
8a4baf6bb8
56 changed files with 406 additions and 539 deletions
|
|
@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.crud import users
|
||||
from database.schemas.relational import Birthday, User
|
||||
from database.schemas import Birthday, User
|
||||
|
||||
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.crud import users
|
||||
from database.exceptions import currency as exceptions
|
||||
from database.schemas.relational import Bank, NightlyData
|
||||
from database.schemas import Bank, NightlyData
|
||||
from database.utils.math.currency import (
|
||||
capacity_upgrade_price,
|
||||
interest_upgrade_price,
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.schemas.relational import CustomCommand, CustomCommandAlias
|
||||
from database.schemas import CustomCommand, CustomCommandAlias
|
||||
|
||||
__all__ = [
|
||||
"clean_name",
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.schemas.relational import DadJoke
|
||||
from database.schemas import DadJoke
|
||||
|
||||
__all__ = ["add_dad_joke", "get_random_dad_joke"]
|
||||
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.schemas.relational import Deadline, UforaCourse
|
||||
from database.schemas import Deadline, UforaCourse
|
||||
|
||||
__all__ = ["add_deadline", "get_deadlines"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,59 +0,0 @@
|
|||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.game_stats import GameStats
|
||||
|
||||
__all__ = ["get_game_stats", "complete_wordle_game"]
|
||||
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
|
||||
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
|
||||
"""Get a user's game stats
|
||||
|
||||
If no entry is found, it is first created
|
||||
"""
|
||||
collection = database[GameStats.collection()]
|
||||
stats = await collection.find_one({"user_id": user_id})
|
||||
if stats is not None:
|
||||
return GameStats(**stats)
|
||||
|
||||
stats = GameStats(user_id=user_id)
|
||||
await collection.insert_one(stats.dict(by_alias=True))
|
||||
return stats
|
||||
|
||||
|
||||
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0):
|
||||
"""Update the user's Wordle stats"""
|
||||
stats = await get_game_stats(database, user_id)
|
||||
|
||||
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}}
|
||||
|
||||
if win:
|
||||
update["$inc"]["wordle.wins"] = 1
|
||||
update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1
|
||||
|
||||
# Update streak
|
||||
today = today_only_date()
|
||||
last_win = stats.wordle.last_win
|
||||
update["$set"]["wordle.last_win"] = today
|
||||
|
||||
if last_win is None or (today - last_win).days > 1:
|
||||
# Never won a game before or streak is over
|
||||
update["$set"]["wordle.current_streak"] = 1
|
||||
stats.wordle.current_streak = 1
|
||||
else:
|
||||
# On a streak: increase counter
|
||||
update["$inc"]["wordle.current_streak"] = 1
|
||||
stats.wordle.current_streak += 1
|
||||
|
||||
# Update max streak if necessary
|
||||
if stats.wordle.current_streak > stats.wordle.max_streak:
|
||||
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
|
||||
else:
|
||||
# Streak is over
|
||||
update["$set"]["wordle.current_streak"] = 0
|
||||
|
||||
collection = database[GameStats.collection()]
|
||||
await collection.update_one({"_id": stats.id}, update)
|
||||
|
|
@ -4,7 +4,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions import NoResultFoundException
|
||||
from database.schemas.relational import Link
|
||||
from database.schemas import Link
|
||||
|
||||
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
|
||||
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import MemeTemplate
|
||||
from database.schemas import MemeTemplate
|
||||
|
||||
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.enums import TaskType
|
||||
from database.schemas.relational import Task
|
||||
from database.schemas import Task
|
||||
from database.utils.datetime import LOCAL_TIMEZONE
|
||||
|
||||
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import datetime
|
|||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||
from database.schemas import UforaAnnouncement, UforaCourse
|
||||
|
||||
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import UforaCourse, UforaCourseAlias
|
||||
from database.schemas import UforaCourse, UforaCourseAlias
|
||||
|
||||
__all__ = ["get_all_courses", "get_course_by_name"]
|
||||
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import Bank, NightlyData, User
|
||||
from database.schemas import Bank, NightlyData, User
|
||||
|
||||
__all__ = [
|
||||
"get_or_add",
|
||||
|
|
|
|||
|
|
@ -1,56 +1,45 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from database.enums import TempStorageKey
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
from database.utils.datetime import today_only_date
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas import WordleGuess, WordleWord
|
||||
|
||||
__all__ = [
|
||||
"get_active_wordle_game",
|
||||
"make_wordle_guess",
|
||||
"start_new_wordle_game",
|
||||
"set_daily_word",
|
||||
"reset_wordle_games",
|
||||
]
|
||||
|
||||
|
||||
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]:
|
||||
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
|
||||
"""Find a player's active game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
result = await collection.find_one({"user_id": user_id})
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return WordleGame(**result)
|
||||
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
|
||||
guesses = (await session.execute(statement)).scalars().all()
|
||||
return guesses
|
||||
|
||||
|
||||
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame:
|
||||
"""Start a new game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
game = WordleGame(user_id=user_id)
|
||||
await collection.insert_one(game.dict(by_alias=True))
|
||||
return game
|
||||
|
||||
|
||||
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str):
|
||||
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
|
||||
"""Make a guess in your current game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
|
||||
guess_instance = WordleGuess(user_id=user_id, guess=guess)
|
||||
session.add(guess_instance)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_daily_word(database: MongoDatabase) -> Optional[str]:
|
||||
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
|
||||
"""Get the word of today"""
|
||||
collection = database[TemporaryStorage.collection()]
|
||||
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
|
||||
row = (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
|
||||
if result is None:
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return result["word"]
|
||||
return row
|
||||
|
||||
|
||||
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str:
|
||||
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
|
||||
"""Set the word of today
|
||||
|
||||
This does NOT overwrite the existing word if there is one, so that it can safely run
|
||||
|
|
@ -60,23 +49,28 @@ async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = F
|
|||
|
||||
Returns the word that was chosen. If one already existed, return that instead.
|
||||
"""
|
||||
collection = database[TemporaryStorage.collection()]
|
||||
current_word = await get_daily_word(session)
|
||||
|
||||
current_word = None if forced else await get_daily_word(database)
|
||||
if current_word is not None:
|
||||
return current_word
|
||||
if current_word is None:
|
||||
current_word = WordleWord(word=word, day=datetime.date.today())
|
||||
session.add(current_word)
|
||||
await session.commit()
|
||||
|
||||
await collection.update_one(
|
||||
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
|
||||
)
|
||||
# Remove all active games
|
||||
await reset_wordle_games(session)
|
||||
elif forced:
|
||||
current_word.word = word
|
||||
current_word.day = datetime.date.today()
|
||||
session.add(current_word)
|
||||
await session.commit()
|
||||
|
||||
# Remove all active games
|
||||
await reset_wordle_games(database)
|
||||
# Remove all active games
|
||||
await reset_wordle_games(session)
|
||||
|
||||
return word
|
||||
return current_word.word
|
||||
|
||||
|
||||
async def reset_wordle_games(database: MongoDatabase):
|
||||
async def reset_wordle_games(session: AsyncSession):
|
||||
"""Reset all active games"""
|
||||
collection = database[WordleGame.collection()]
|
||||
await collection.drop()
|
||||
statement = delete(WordleGuess)
|
||||
await session.execute(statement)
|
||||
|
|
|
|||
57
database/crud/wordle_stats.py
Normal file
57
database/crud/wordle_stats.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
from datetime import date
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas import WordleStats
|
||||
|
||||
__all__ = ["get_wordle_stats", "complete_wordle_game"]
|
||||
|
||||
|
||||
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
|
||||
"""Get a user's wordle stats
|
||||
|
||||
If no entry is found, it is first created
|
||||
"""
|
||||
statement = select(WordleStats).where(WordleStats.user_id == user_id)
|
||||
stats = (await session.execute(statement)).scalar_one_or_none()
|
||||
if stats is not None:
|
||||
return stats
|
||||
|
||||
stats = WordleStats(user_id=user_id)
|
||||
session.add(stats)
|
||||
await session.commit()
|
||||
await session.refresh(stats)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
|
||||
"""Update the user's Wordle stats"""
|
||||
stats = await get_wordle_stats(session, user_id)
|
||||
stats.games += 1
|
||||
|
||||
if win:
|
||||
stats.wins += 1
|
||||
|
||||
# Update streak
|
||||
today = date.today()
|
||||
last_win = stats.last_win
|
||||
stats.last_win = today
|
||||
|
||||
if last_win is None or (today - last_win).days > 1:
|
||||
# Never won a game before or streak is over
|
||||
stats.current_streak = 1
|
||||
else:
|
||||
# On a streak: increase counter
|
||||
stats.current_streak += 1
|
||||
|
||||
# Update max streak if necessary
|
||||
if stats.current_streak > stats.highest_streak:
|
||||
stats.highest_streak = stats.current_streak
|
||||
else:
|
||||
# Streak is over
|
||||
stats.current_streak = 0
|
||||
|
||||
session.add(stats)
|
||||
await session.commit()
|
||||
|
|
@ -1,6 +1,5 @@
|
|||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.motor_asyncio
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
|
@ -26,16 +25,3 @@ postgres_engine = create_async_engine(
|
|||
DBSession = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
# MongoDB client
|
||||
if not settings.TESTING: # pragma: no cover
|
||||
encoded_mongo_username = quote_plus(settings.MONGO_USER)
|
||||
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
|
||||
mongo_url = (
|
||||
f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
)
|
||||
else:
|
||||
# Require no authentication when testing
|
||||
mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
|
||||
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
import enum
|
||||
|
||||
__all__ = ["TaskType", "TempStorageKey"]
|
||||
__all__ = ["TaskType"]
|
||||
|
||||
|
||||
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
||||
|
|
@ -11,10 +11,3 @@ class TaskType(enum.IntEnum):
|
|||
|
||||
BIRTHDAYS = enum.auto()
|
||||
UFORA_ANNOUNCEMENTS = enum.auto()
|
||||
|
||||
|
||||
@enum.unique
|
||||
class TempStorageKey(str, enum.Enum):
|
||||
"""Enum for keys to distinguish the TemporaryStorage rows"""
|
||||
|
||||
WORDLE_WORD = "wordle_word"
|
||||
|
|
|
|||
|
|
@ -1,6 +0,0 @@
|
|||
import motor.motor_asyncio
|
||||
|
||||
# Type aliases for the Motor types, which are way too long
|
||||
MongoClient = motor.motor_asyncio.AsyncIOMotorClient
|
||||
MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase
|
||||
MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection
|
||||
|
|
@ -37,6 +37,9 @@ __all__ = [
|
|||
"UforaCourse",
|
||||
"UforaCourseAlias",
|
||||
"User",
|
||||
"WordleGuess",
|
||||
"WordleStats",
|
||||
"WordleWord",
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -231,3 +234,47 @@ class User(Base):
|
|||
nightly_data: NightlyData = relationship(
|
||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_guesses: list[WordleGuess] = relationship(
|
||||
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_stats: WordleStats = relationship(
|
||||
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class WordleGuess(Base):
|
||||
"""A user's Wordle guesses for today"""
|
||||
|
||||
__tablename__ = "wordle_guesses"
|
||||
|
||||
wordle_guess_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
guess: str = Column(Text, nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class WordleStats(Base):
|
||||
"""Stats about a user's wordle performance"""
|
||||
|
||||
__tablename__ = "wordle_stats"
|
||||
|
||||
wordle_stats_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
last_win: Optional[date] = Column(Date, nullable=True)
|
||||
games: int = Column(Integer, server_default="0", nullable=False)
|
||||
wins: int = Column(Integer, server_default="0", nullable=False)
|
||||
current_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||
highest_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class WordleWord(Base):
|
||||
"""The current Wordle word"""
|
||||
|
||||
__tablename__ = "wordle_word"
|
||||
|
||||
word_id: int = Column(Integer, primary_key=True)
|
||||
word: str = Column(Text, nullable=False)
|
||||
day: date = Column(Date, nullable=False, unique=True)
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
|
||||
|
||||
|
||||
class PyObjectId(ObjectId):
|
||||
"""Custom type for bson ObjectIds"""
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: str):
|
||||
"""Check that a string is a valid bson ObjectId"""
|
||||
if not ObjectId.is_valid(value):
|
||||
raise ValueError(f"Invalid ObjectId: '{value}'")
|
||||
|
||||
return ObjectId(value)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: dict):
|
||||
field_schema.update(type="string")
|
||||
|
||||
|
||||
class MongoBase(BaseModel):
|
||||
"""Base model that properly sets the _id field, and adds one by default"""
|
||||
|
||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||
|
||||
class Config:
|
||||
"""Configuration for encoding and construction"""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str, PyObjectId: str}
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class MongoCollection(MongoBase, ABC):
|
||||
"""Base model for the 'main class' in a collection
|
||||
|
||||
This field stores the name of the collection to avoid making typos against it
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def collection() -> str:
|
||||
"""Getter for the name of the collection, in order to avoid typos"""
|
||||
raise NotImplementedError
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from overrides import overrides
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
|
||||
__all__ = ["GameStats", "WordleStats"]
|
||||
|
||||
|
||||
class WordleStats(BaseModel):
|
||||
"""Model that holds stats about a player's Wordle performance"""
|
||||
|
||||
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
||||
last_win: Optional[datetime.datetime] = None
|
||||
wins: int = 0
|
||||
games: int = 0
|
||||
current_streak: int = 0
|
||||
max_streak: int = 0
|
||||
|
||||
@validator("guess_distribution")
|
||||
def validate_guesses_length(cls, value: list[int]):
|
||||
"""Check that the distribution of guesses is of the correct length"""
|
||||
if len(value) != 6:
|
||||
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class GameStats(MongoCollection):
|
||||
"""Collection that holds stats about how well a user has performed in games"""
|
||||
|
||||
user_id: int
|
||||
wordle: WordleStats = WordleStats()
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "game_stats"
|
||||
|
|
@ -1,16 +0,0 @@
|
|||
from overrides import overrides
|
||||
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
|
||||
__all__ = ["TemporaryStorage"]
|
||||
|
||||
|
||||
class TemporaryStorage(MongoCollection):
|
||||
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
||||
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "temporary"
|
||||
|
|
@ -1,44 +0,0 @@
|
|||
import datetime
|
||||
|
||||
from overrides import overrides
|
||||
from pydantic import Field, validator
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
__all__ = ["WordleGame"]
|
||||
|
||||
|
||||
class WordleGame(MongoCollection):
|
||||
"""Collection that holds people's active Wordle games"""
|
||||
|
||||
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
|
||||
guesses: list[str] = Field(default_factory=list)
|
||||
user_id: int
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "wordle"
|
||||
|
||||
@validator("guesses")
|
||||
def validate_guesses_length(cls, value: list[int]):
|
||||
"""Check that the amount of guesses is of the correct length"""
|
||||
if len(value) > 6:
|
||||
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
||||
|
||||
return value
|
||||
|
||||
def is_game_over(self, word: str) -> bool:
|
||||
"""Check if the current game is over"""
|
||||
# No guesses yet
|
||||
if not self.guesses:
|
||||
return False
|
||||
|
||||
# Max amount of guesses allowed
|
||||
if len(self.guesses) == WORDLE_GUESS_COUNT:
|
||||
return True
|
||||
|
||||
# Found the correct word
|
||||
return self.guesses[-1] == word
|
||||
|
|
@ -1,19 +1,15 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from discord import app_commands
|
||||
from overrides import overrides
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import links, memes, ufora_courses, wordle
|
||||
from database.mongo_types import MongoDatabase
|
||||
|
||||
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DatabaseCache(ABC, Generic[T]):
|
||||
class DatabaseCache(ABC):
|
||||
"""Base class for a simple cache-like structure
|
||||
|
||||
The goal of this class is to store data for Discord auto-completion results
|
||||
|
|
@ -25,7 +21,7 @@ class DatabaseCache(ABC, Generic[T]):
|
|||
Considering the fact that a user isn't obligated to choose something from the suggestions,
|
||||
chances are high we have to go to the database for the final action either way.
|
||||
|
||||
Also stores the data in lowercase to allow fast searching
|
||||
Also stores the data in lowercase to allow fast searching.
|
||||
"""
|
||||
|
||||
data: list[str] = []
|
||||
|
|
@ -36,7 +32,7 @@ class DatabaseCache(ABC, Generic[T]):
|
|||
self.data.clear()
|
||||
|
||||
@abstractmethod
|
||||
async def invalidate(self, database_session: T):
|
||||
async def invalidate(self, database_session: AsyncSession):
|
||||
"""Invalidate the data stored in this cache"""
|
||||
|
||||
def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
|
||||
|
|
@ -48,7 +44,7 @@ class DatabaseCache(ABC, Generic[T]):
|
|||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||
|
||||
|
||||
class LinkCache(DatabaseCache[AsyncSession]):
|
||||
class LinkCache(DatabaseCache):
|
||||
"""Cache to store the names of links"""
|
||||
|
||||
@overrides
|
||||
|
|
@ -61,7 +57,7 @@ class LinkCache(DatabaseCache[AsyncSession]):
|
|||
self.data_transformed = list(map(str.lower, self.data))
|
||||
|
||||
|
||||
class MemeCache(DatabaseCache[AsyncSession]):
|
||||
class MemeCache(DatabaseCache):
|
||||
"""Cache to store the names of meme templates"""
|
||||
|
||||
@overrides
|
||||
|
|
@ -74,7 +70,7 @@ class MemeCache(DatabaseCache[AsyncSession]):
|
|||
self.data_transformed = list(map(str.lower, self.data))
|
||||
|
||||
|
||||
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
||||
class UforaCourseCache(DatabaseCache):
|
||||
"""Cache to store the names of Ufora courses"""
|
||||
|
||||
# Also store the aliases to add additional support
|
||||
|
|
@ -119,10 +115,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
|
|||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||
|
||||
|
||||
class WordleCache(DatabaseCache[MongoDatabase]):
|
||||
class WordleCache(DatabaseCache):
|
||||
"""Cache to store the current daily Wordle word"""
|
||||
|
||||
async def invalidate(self, database_session: MongoDatabase):
|
||||
async def invalidate(self, database_session: AsyncSession):
|
||||
word = await wordle.get_daily_word(database_session)
|
||||
if word is not None:
|
||||
self.data = [word]
|
||||
|
|
@ -142,9 +138,9 @@ class CacheManager:
|
|||
self.ufora_courses = UforaCourseCache()
|
||||
self.wordle_word = WordleCache()
|
||||
|
||||
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
||||
async def initialize_caches(self, postgres_session: AsyncSession):
|
||||
"""Initialize the contents of all caches"""
|
||||
await self.links.invalidate(postgres_session)
|
||||
await self.memes.invalidate(postgres_session)
|
||||
await self.ufora_courses.invalidate(postgres_session)
|
||||
await self.wordle_word.invalidate(mongo_db)
|
||||
await self.wordle_word.invalidate(postgres_session)
|
||||
|
|
|
|||
|
|
@ -1,15 +1,5 @@
|
|||
import datetime
|
||||
import zoneinfo
|
||||
|
||||
__all__ = ["LOCAL_TIMEZONE", "today_only_date"]
|
||||
__all__ = ["LOCAL_TIMEZONE"]
|
||||
|
||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||
|
||||
|
||||
def today_only_date() -> datetime.datetime:
|
||||
"""Mongo can't handle datetime.date, so we need a datetime instance
|
||||
|
||||
We do, however, only care about the date, so remove all the rest
|
||||
"""
|
||||
today = datetime.date.today()
|
||||
return datetime.datetime(year=today.year, month=today.month, day=today.day)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue