Remove mongo & fix tests

This commit is contained in:
stijndcl 2022-08-29 20:24:42 +02:00
parent 7b2109fb07
commit 8a4baf6bb8
56 changed files with 406 additions and 539 deletions

View file

@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from database.crud import users
from database.schemas.relational import Birthday, User
from database.schemas import Birthday, User
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]

View file

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users
from database.exceptions import currency as exceptions
from database.schemas.relational import Bank, NightlyData
from database.schemas import Bank, NightlyData
from database.utils.math.currency import (
capacity_upgrade_price,
interest_upgrade_price,

View file

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
from database.schemas.relational import CustomCommand, CustomCommandAlias
from database.schemas import CustomCommand, CustomCommandAlias
__all__ = [
"clean_name",

View file

@ -2,7 +2,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.not_found import NoResultFoundException
from database.schemas.relational import DadJoke
from database.schemas import DadJoke
__all__ = ["add_dad_joke", "get_random_dad_joke"]

View file

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from database.schemas.relational import Deadline, UforaCourse
from database.schemas import Deadline, UforaCourse
__all__ = ["add_deadline", "get_deadlines"]

View file

@ -1,59 +0,0 @@
import datetime
from typing import Union
from database.mongo_types import MongoDatabase
from database.schemas.mongo.game_stats import GameStats
__all__ = ["get_game_stats", "complete_wordle_game"]
from database.utils.datetime import today_only_date
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
"""Get a user's game stats
If no entry is found, it is first created
"""
collection = database[GameStats.collection()]
stats = await collection.find_one({"user_id": user_id})
if stats is not None:
return GameStats(**stats)
stats = GameStats(user_id=user_id)
await collection.insert_one(stats.dict(by_alias=True))
return stats
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0):
"""Update the user's Wordle stats"""
stats = await get_game_stats(database, user_id)
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}}
if win:
update["$inc"]["wordle.wins"] = 1
update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1
# Update streak
today = today_only_date()
last_win = stats.wordle.last_win
update["$set"]["wordle.last_win"] = today
if last_win is None or (today - last_win).days > 1:
# Never won a game before or streak is over
update["$set"]["wordle.current_streak"] = 1
stats.wordle.current_streak = 1
else:
# On a streak: increase counter
update["$inc"]["wordle.current_streak"] = 1
stats.wordle.current_streak += 1
# Update max streak if necessary
if stats.wordle.current_streak > stats.wordle.max_streak:
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
else:
# Streak is over
update["$set"]["wordle.current_streak"] = 0
collection = database[GameStats.collection()]
await collection.update_one({"_id": stats.id}, update)

View file

@ -4,7 +4,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions import NoResultFoundException
from database.schemas.relational import Link
from database.schemas import Link
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]

View file

@ -4,7 +4,7 @@ from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import MemeTemplate
from database.schemas import MemeTemplate
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]

View file

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.enums import TaskType
from database.schemas.relational import Task
from database.schemas import Task
from database.utils.datetime import LOCAL_TIMEZONE
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]

View file

@ -3,7 +3,7 @@ import datetime
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import UforaAnnouncement, UforaCourse
from database.schemas import UforaAnnouncement, UforaCourse
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]

View file

@ -3,7 +3,7 @@ from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import UforaCourse, UforaCourseAlias
from database.schemas import UforaCourse, UforaCourseAlias
__all__ = ["get_all_courses", "get_course_by_name"]

View file

@ -3,7 +3,7 @@ from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import Bank, NightlyData, User
from database.schemas import Bank, NightlyData, User
__all__ = [
"get_or_add",

View file

@ -1,56 +1,45 @@
import datetime
from typing import Optional
from database.enums import TempStorageKey
from database.mongo_types import MongoDatabase
from database.schemas.mongo.temporary_storage import TemporaryStorage
from database.schemas.mongo.wordle import WordleGame
from database.utils.datetime import today_only_date
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas import WordleGuess, WordleWord
__all__ = [
"get_active_wordle_game",
"make_wordle_guess",
"start_new_wordle_game",
"set_daily_word",
"reset_wordle_games",
]
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]:
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
"""Find a player's active game"""
collection = database[WordleGame.collection()]
result = await collection.find_one({"user_id": user_id})
if result is None:
return None
return WordleGame(**result)
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
guesses = (await session.execute(statement)).scalars().all()
return guesses
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame:
"""Start a new game"""
collection = database[WordleGame.collection()]
game = WordleGame(user_id=user_id)
await collection.insert_one(game.dict(by_alias=True))
return game
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str):
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
"""Make a guess in your current game"""
collection = database[WordleGame.collection()]
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
guess_instance = WordleGuess(user_id=user_id, guess=guess)
session.add(guess_instance)
await session.commit()
async def get_daily_word(database: MongoDatabase) -> Optional[str]:
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
"""Get the word of today"""
collection = database[TemporaryStorage.collection()]
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
row = (await session.execute(statement)).scalar_one_or_none()
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
if result is None:
if row is None:
return None
return result["word"]
return row
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str:
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
"""Set the word of today
This does NOT overwrite the existing word if there is one, so that it can safely run
@ -60,23 +49,28 @@ async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = F
Returns the word that was chosen. If one already existed, return that instead.
"""
collection = database[TemporaryStorage.collection()]
current_word = await get_daily_word(session)
current_word = None if forced else await get_daily_word(database)
if current_word is not None:
return current_word
if current_word is None:
current_word = WordleWord(word=word, day=datetime.date.today())
session.add(current_word)
await session.commit()
await collection.update_one(
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
)
# Remove all active games
await reset_wordle_games(session)
elif forced:
current_word.word = word
current_word.day = datetime.date.today()
session.add(current_word)
await session.commit()
# Remove all active games
await reset_wordle_games(database)
# Remove all active games
await reset_wordle_games(session)
return word
return current_word.word
async def reset_wordle_games(database: MongoDatabase):
async def reset_wordle_games(session: AsyncSession):
"""Reset all active games"""
collection = database[WordleGame.collection()]
await collection.drop()
statement = delete(WordleGuess)
await session.execute(statement)

View file

@ -0,0 +1,57 @@
from datetime import date
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas import WordleStats
__all__ = ["get_wordle_stats", "complete_wordle_game"]
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
"""Get a user's wordle stats
If no entry is found, it is first created
"""
statement = select(WordleStats).where(WordleStats.user_id == user_id)
stats = (await session.execute(statement)).scalar_one_or_none()
if stats is not None:
return stats
stats = WordleStats(user_id=user_id)
session.add(stats)
await session.commit()
await session.refresh(stats)
return stats
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
"""Update the user's Wordle stats"""
stats = await get_wordle_stats(session, user_id)
stats.games += 1
if win:
stats.wins += 1
# Update streak
today = date.today()
last_win = stats.last_win
stats.last_win = today
if last_win is None or (today - last_win).days > 1:
# Never won a game before or streak is over
stats.current_streak = 1
else:
# On a streak: increase counter
stats.current_streak += 1
# Update max streak if necessary
if stats.current_streak > stats.highest_streak:
stats.highest_streak = stats.current_streak
else:
# Streak is over
stats.current_streak = 0
session.add(stats)
await session.commit()

View file

@ -1,6 +1,5 @@
from urllib.parse import quote_plus
import motor.motor_asyncio
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
@ -26,16 +25,3 @@ postgres_engine = create_async_engine(
DBSession = sessionmaker(
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
)
# MongoDB client
if not settings.TESTING: # pragma: no cover
encoded_mongo_username = quote_plus(settings.MONGO_USER)
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
mongo_url = (
f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
)
else:
# Require no authentication when testing
mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)

View file

@ -1,6 +1,6 @@
import enum
__all__ = ["TaskType", "TempStorageKey"]
__all__ = ["TaskType"]
# There is a bug in typeshed that causes an incorrect PyCharm warning
@ -11,10 +11,3 @@ class TaskType(enum.IntEnum):
BIRTHDAYS = enum.auto()
UFORA_ANNOUNCEMENTS = enum.auto()
@enum.unique
class TempStorageKey(str, enum.Enum):
"""Enum for keys to distinguish the TemporaryStorage rows"""
WORDLE_WORD = "wordle_word"

View file

@ -1,6 +0,0 @@
import motor.motor_asyncio
# Type aliases for the Motor types, which are way too long
MongoClient = motor.motor_asyncio.AsyncIOMotorClient
MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase
MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection

View file

@ -37,6 +37,9 @@ __all__ = [
"UforaCourse",
"UforaCourseAlias",
"User",
"WordleGuess",
"WordleStats",
"WordleWord",
]
@ -231,3 +234,47 @@ class User(Base):
nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
wordle_guesses: list[WordleGuess] = relationship(
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_stats: WordleStats = relationship(
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
class WordleGuess(Base):
"""A user's Wordle guesses for today"""
__tablename__ = "wordle_guesses"
wordle_guess_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
guess: str = Column(Text, nullable=False)
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
class WordleStats(Base):
"""Stats about a user's wordle performance"""
__tablename__ = "wordle_stats"
wordle_stats_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_win: Optional[date] = Column(Date, nullable=True)
games: int = Column(Integer, server_default="0", nullable=False)
wins: int = Column(Integer, server_default="0", nullable=False)
current_streak: int = Column(Integer, server_default="0", nullable=False)
highest_streak: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
class WordleWord(Base):
"""The current Wordle word"""
__tablename__ = "wordle_word"
word_id: int = Column(Integer, primary_key=True)
word: str = Column(Text, nullable=False)
day: date = Column(Date, nullable=False, unique=True)

View file

@ -1,53 +0,0 @@
from abc import ABC, abstractmethod
from bson import ObjectId
from pydantic import BaseModel, Field
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
class PyObjectId(ObjectId):
"""Custom type for bson ObjectIds"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value: str):
"""Check that a string is a valid bson ObjectId"""
if not ObjectId.is_valid(value):
raise ValueError(f"Invalid ObjectId: '{value}'")
return ObjectId(value)
@classmethod
def __modify_schema__(cls, field_schema: dict):
field_schema.update(type="string")
class MongoBase(BaseModel):
"""Base model that properly sets the _id field, and adds one by default"""
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
class Config:
"""Configuration for encoding and construction"""
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str, PyObjectId: str}
use_enum_values = True
class MongoCollection(MongoBase, ABC):
"""Base model for the 'main class' in a collection
This field stores the name of the collection to avoid making typos against it
"""
@staticmethod
@abstractmethod
def collection() -> str:
"""Getter for the name of the collection, in order to avoid typos"""
raise NotImplementedError

View file

@ -1,40 +0,0 @@
import datetime
from typing import Optional
from overrides import overrides
from pydantic import BaseModel, Field, validator
from database.schemas.mongo.common import MongoCollection
__all__ = ["GameStats", "WordleStats"]
class WordleStats(BaseModel):
"""Model that holds stats about a player's Wordle performance"""
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
last_win: Optional[datetime.datetime] = None
wins: int = 0
games: int = 0
current_streak: int = 0
max_streak: int = 0
@validator("guess_distribution")
def validate_guesses_length(cls, value: list[int]):
"""Check that the distribution of guesses is of the correct length"""
if len(value) != 6:
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
return value
class GameStats(MongoCollection):
"""Collection that holds stats about how well a user has performed in games"""
user_id: int
wordle: WordleStats = WordleStats()
@staticmethod
@overrides
def collection() -> str:
return "game_stats"

View file

@ -1,16 +0,0 @@
from overrides import overrides
from database.schemas.mongo.common import MongoCollection
__all__ = ["TemporaryStorage"]
class TemporaryStorage(MongoCollection):
"""Collection for lots of random things that don't belong in a full-blown collection"""
key: str
@staticmethod
@overrides
def collection() -> str:
return "temporary"

View file

@ -1,44 +0,0 @@
import datetime
from overrides import overrides
from pydantic import Field, validator
from database.constants import WORDLE_GUESS_COUNT
from database.schemas.mongo.common import MongoCollection
from database.utils.datetime import today_only_date
__all__ = ["WordleGame"]
class WordleGame(MongoCollection):
"""Collection that holds people's active Wordle games"""
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
guesses: list[str] = Field(default_factory=list)
user_id: int
@staticmethod
@overrides
def collection() -> str:
return "wordle"
@validator("guesses")
def validate_guesses_length(cls, value: list[int]):
"""Check that the amount of guesses is of the correct length"""
if len(value) > 6:
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
return value
def is_game_over(self, word: str) -> bool:
"""Check if the current game is over"""
# No guesses yet
if not self.guesses:
return False
# Max amount of guesses allowed
if len(self.guesses) == WORDLE_GUESS_COUNT:
return True
# Found the correct word
return self.guesses[-1] == word

View file

@ -1,19 +1,15 @@
from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from discord import app_commands
from overrides import overrides
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import links, memes, ufora_courses, wordle
from database.mongo_types import MongoDatabase
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
T = TypeVar("T")
class DatabaseCache(ABC, Generic[T]):
class DatabaseCache(ABC):
"""Base class for a simple cache-like structure
The goal of this class is to store data for Discord auto-completion results
@ -25,7 +21,7 @@ class DatabaseCache(ABC, Generic[T]):
Considering the fact that a user isn't obligated to choose something from the suggestions,
chances are high we have to go to the database for the final action either way.
Also stores the data in lowercase to allow fast searching
Also stores the data in lowercase to allow fast searching.
"""
data: list[str] = []
@ -36,7 +32,7 @@ class DatabaseCache(ABC, Generic[T]):
self.data.clear()
@abstractmethod
async def invalidate(self, database_session: T):
async def invalidate(self, database_session: AsyncSession):
"""Invalidate the data stored in this cache"""
def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
@ -48,7 +44,7 @@ class DatabaseCache(ABC, Generic[T]):
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
class LinkCache(DatabaseCache[AsyncSession]):
class LinkCache(DatabaseCache):
"""Cache to store the names of links"""
@overrides
@ -61,7 +57,7 @@ class LinkCache(DatabaseCache[AsyncSession]):
self.data_transformed = list(map(str.lower, self.data))
class MemeCache(DatabaseCache[AsyncSession]):
class MemeCache(DatabaseCache):
"""Cache to store the names of meme templates"""
@overrides
@ -74,7 +70,7 @@ class MemeCache(DatabaseCache[AsyncSession]):
self.data_transformed = list(map(str.lower, self.data))
class UforaCourseCache(DatabaseCache[AsyncSession]):
class UforaCourseCache(DatabaseCache):
"""Cache to store the names of Ufora courses"""
# Also store the aliases to add additional support
@ -119,10 +115,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
class WordleCache(DatabaseCache[MongoDatabase]):
class WordleCache(DatabaseCache):
"""Cache to store the current daily Wordle word"""
async def invalidate(self, database_session: MongoDatabase):
async def invalidate(self, database_session: AsyncSession):
word = await wordle.get_daily_word(database_session)
if word is not None:
self.data = [word]
@ -142,9 +138,9 @@ class CacheManager:
self.ufora_courses = UforaCourseCache()
self.wordle_word = WordleCache()
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
async def initialize_caches(self, postgres_session: AsyncSession):
"""Initialize the contents of all caches"""
await self.links.invalidate(postgres_session)
await self.memes.invalidate(postgres_session)
await self.ufora_courses.invalidate(postgres_session)
await self.wordle_word.invalidate(mongo_db)
await self.wordle_word.invalidate(postgres_session)

View file

@ -1,15 +1,5 @@
import datetime
import zoneinfo
__all__ = ["LOCAL_TIMEZONE", "today_only_date"]
__all__ = ["LOCAL_TIMEZONE"]
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
def today_only_date() -> datetime.datetime:
"""Mongo can't handle datetime.date, so we need a datetime instance
We do, however, only care about the date, so remove all the rest
"""
today = datetime.date.today()
return datetime.datetime(year=today.year, month=today.month, day=today.day)