Merge pull request #129 from stijndcl/remove-mongo

Remove Mongo again
pull/130/head
Stijn De Clercq 2022-08-29 21:30:40 +02:00 committed by GitHub
commit 8308b4ad9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
60 changed files with 487 additions and 603 deletions

View File

@ -38,17 +38,6 @@ jobs:
POSTGRES_DB: didier_pytest POSTGRES_DB: didier_pytest
POSTGRES_USER: pytest POSTGRES_USER: pytest
POSTGRES_PASSWORD: pytest POSTGRES_PASSWORD: pytest
mongo:
image: mongo:5.0
options: >-
--health-cmd mongo
--health-interval 10s
--health-timeout 5s
--health-retries 5
ports:
- 27018:27017
env:
MONGO_DB: didier_pytest
steps: steps:
- uses: actions/checkout@v3 - uses: actions/checkout@v3
- name: Setup Python - name: Setup Python

View File

@ -44,9 +44,3 @@ repos:
- "flake8-eradicate" - "flake8-eradicate"
- "flake8-isort" - "flake8-isort"
- "flake8-simplify" - "flake8-simplify"
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.961
hooks:
- id: mypy
args: [--config, pyproject.toml]

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context from alembic import context
from database.engine import postgres_engine from database.engine import postgres_engine
from database.schemas.relational import Base from database.schemas import Base
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
# access to the values within the .ini file in use. # access to the values within the .ini file in use.

View File

@ -0,0 +1,63 @@
"""Wordle
Revision ID: 38b7c29f10ee
Revises: 36300b558ef1
Create Date: 2022-08-29 20:21:02.413631
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "38b7c29f10ee"
down_revision = "36300b558ef1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"wordle_word",
sa.Column("word_id", sa.Integer(), nullable=False),
sa.Column("word", sa.Text(), nullable=False),
sa.Column("day", sa.Date(), nullable=False),
sa.PrimaryKeyConstraint("word_id"),
sa.UniqueConstraint("day"),
)
op.create_table(
"wordle_guesses",
sa.Column("wordle_guess_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("guess", sa.Text(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("wordle_guess_id"),
)
op.create_table(
"wordle_stats",
sa.Column("wordle_stats_id", sa.Integer(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("last_win", sa.Date(), nullable=True),
sa.Column("games", sa.Integer(), server_default="0", nullable=False),
sa.Column("wins", sa.Integer(), server_default="0", nullable=False),
sa.Column("current_streak", sa.Integer(), server_default="0", nullable=False),
sa.Column("highest_streak", sa.Integer(), server_default="0", nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("wordle_stats_id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("wordle_stats")
op.drop_table("wordle_guesses")
op.drop_table("wordle_word")
# ### end Alembic commands ###

View File

@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from database.crud import users from database.crud import users
from database.schemas.relational import Birthday, User from database.schemas import Birthday, User
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"] __all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
@ -17,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
If already present, overwrites the existing one If already present, overwrites the existing one
""" """
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) user = await users.get_or_add_user(session, user_id, options=[selectinload(User.birthday)])
if user.birthday is not None: if user.birthday is not None:
bd = user.birthday bd = user.birthday

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users from database.crud import users
from database.exceptions import currency as exceptions from database.exceptions import currency as exceptions
from database.schemas.relational import Bank, NightlyData from database.schemas import Bank, NightlyData
from database.utils.math.currency import ( from database.utils.math.currency import (
capacity_upgrade_price, capacity_upgrade_price,
interest_upgrade_price, interest_upgrade_price,
@ -29,13 +29,13 @@ NIGHTLY_AMOUNT = 420
async def get_bank(session: AsyncSession, user_id: int) -> Bank: async def get_bank(session: AsyncSession, user_id: int) -> Bank:
"""Get a user's bank info""" """Get a user's bank info"""
user = await users.get_or_add(session, user_id) user = await users.get_or_add_user(session, user_id)
return user.bank return user.bank
async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData: async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData:
"""Get a user's nightly info""" """Get a user's nightly info"""
user = await users.get_or_add(session, user_id) user = await users.get_or_add_user(session, user_id)
return user.nightly_data return user.nightly_data

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.schemas.relational import CustomCommand, CustomCommandAlias from database.schemas import CustomCommand, CustomCommandAlias
__all__ = [ __all__ = [
"clean_name", "clean_name",

View File

@ -2,7 +2,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.schemas.relational import DadJoke from database.schemas import DadJoke
__all__ = ["add_dad_joke", "get_random_dad_joke"] __all__ = ["add_dad_joke", "get_random_dad_joke"]

View File

@ -6,7 +6,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
from database.schemas.relational import Deadline, UforaCourse from database.schemas import Deadline, UforaCourse
__all__ = ["add_deadline", "get_deadlines"] __all__ = ["add_deadline", "get_deadlines"]

View File

@ -1,59 +0,0 @@
import datetime
from typing import Union
from database.mongo_types import MongoDatabase
from database.schemas.mongo.game_stats import GameStats
__all__ = ["get_game_stats", "complete_wordle_game"]
from database.utils.datetime import today_only_date
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
"""Get a user's game stats
If no entry is found, it is first created
"""
collection = database[GameStats.collection()]
stats = await collection.find_one({"user_id": user_id})
if stats is not None:
return GameStats(**stats)
stats = GameStats(user_id=user_id)
await collection.insert_one(stats.dict(by_alias=True))
return stats
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0):
"""Update the user's Wordle stats"""
stats = await get_game_stats(database, user_id)
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}}
if win:
update["$inc"]["wordle.wins"] = 1
update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1
# Update streak
today = today_only_date()
last_win = stats.wordle.last_win
update["$set"]["wordle.last_win"] = today
if last_win is None or (today - last_win).days > 1:
# Never won a game before or streak is over
update["$set"]["wordle.current_streak"] = 1
stats.wordle.current_streak = 1
else:
# On a streak: increase counter
update["$inc"]["wordle.current_streak"] = 1
stats.wordle.current_streak += 1
# Update max streak if necessary
if stats.wordle.current_streak > stats.wordle.max_streak:
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
else:
# Streak is over
update["$set"]["wordle.current_streak"] = 0
collection = database[GameStats.collection()]
await collection.update_one({"_id": stats.id}, update)

View File

@ -4,7 +4,7 @@ from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions import NoResultFoundException from database.exceptions import NoResultFoundException
from database.schemas.relational import Link from database.schemas import Link
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]

View File

@ -4,7 +4,7 @@ from sqlalchemy import select
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import MemeTemplate from database.schemas import MemeTemplate
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"] __all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]

View File

@ -5,7 +5,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.enums import TaskType from database.enums import TaskType
from database.schemas.relational import Task from database.schemas import Task
from database.utils.datetime import LOCAL_TIMEZONE from database.utils.datetime import LOCAL_TIMEZONE
__all__ = ["get_task_by_enum", "set_last_task_execution_time"] __all__ = ["get_task_by_enum", "set_last_task_execution_time"]

View File

@ -3,7 +3,7 @@ import datetime
from sqlalchemy import delete, select from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import UforaAnnouncement, UforaCourse from database.schemas import UforaAnnouncement, UforaCourse
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"] __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]

View File

@ -3,7 +3,7 @@ from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import UforaCourse, UforaCourseAlias from database.schemas import UforaCourse, UforaCourseAlias
__all__ = ["get_all_courses", "get_course_by_name"] __all__ = ["get_all_courses", "get_course_by_name"]

View File

@ -3,14 +3,14 @@ from typing import Optional
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import Bank, NightlyData, User from database.schemas import Bank, NightlyData, User
__all__ = [ __all__ = [
"get_or_add", "get_or_add_user",
] ]
async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: async def get_or_add_user(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
"""Get a user's profile """Get a user's profile
If it doesn't exist yet, create it (along with all linked datastructures) If it doesn't exist yet, create it (along with all linked datastructures)

View File

@ -1,56 +1,54 @@
import datetime
from typing import Optional from typing import Optional
from database.enums import TempStorageKey from sqlalchemy import delete, select
from database.mongo_types import MongoDatabase from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.mongo.temporary_storage import TemporaryStorage
from database.schemas.mongo.wordle import WordleGame from database.crud.users import get_or_add_user
from database.utils.datetime import today_only_date from database.schemas import WordleGuess, WordleWord
__all__ = [ __all__ = [
"get_active_wordle_game", "get_active_wordle_game",
"get_wordle_guesses",
"make_wordle_guess", "make_wordle_guess",
"start_new_wordle_game",
"set_daily_word", "set_daily_word",
"reset_wordle_games", "reset_wordle_games",
] ]
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]: async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
"""Find a player's active game""" """Find a player's active game"""
collection = database[WordleGame.collection()] await get_or_add_user(session, user_id)
result = await collection.find_one({"user_id": user_id}) statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
if result is None: guesses = (await session.execute(statement)).scalars().all()
return None return guesses
return WordleGame(**result)
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame: async def get_wordle_guesses(session: AsyncSession, user_id: int) -> list[str]:
"""Start a new game""" """Get the strings of a player's guesses"""
collection = database[WordleGame.collection()] active_game = await get_active_wordle_game(session, user_id)
game = WordleGame(user_id=user_id) return list(map(lambda g: g.guess.lower(), active_game))
await collection.insert_one(game.dict(by_alias=True))
return game
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str): async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
"""Make a guess in your current game""" """Make a guess in your current game"""
collection = database[WordleGame.collection()] guess_instance = WordleGuess(user_id=user_id, guess=guess)
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}}) session.add(guess_instance)
await session.commit()
async def get_daily_word(database: MongoDatabase) -> Optional[str]: async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
"""Get the word of today""" """Get the word of today"""
collection = database[TemporaryStorage.collection()] statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
row = (await session.execute(statement)).scalar_one_or_none()
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()}) if row is None:
if result is None:
return None return None
return result["word"] return row
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str: async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
"""Set the word of today """Set the word of today
This does NOT overwrite the existing word if there is one, so that it can safely run This does NOT overwrite the existing word if there is one, so that it can safely run
@ -60,23 +58,28 @@ async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = F
Returns the word that was chosen. If one already existed, return that instead. Returns the word that was chosen. If one already existed, return that instead.
""" """
collection = database[TemporaryStorage.collection()] current_word = await get_daily_word(session)
current_word = None if forced else await get_daily_word(database) if current_word is None:
if current_word is not None: current_word = WordleWord(word=word, day=datetime.date.today())
return current_word session.add(current_word)
await session.commit()
await collection.update_one(
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
)
# Remove all active games # Remove all active games
await reset_wordle_games(database) await reset_wordle_games(session)
elif forced:
current_word.word = word
session.add(current_word)
await session.commit()
return word # Remove all active games
await reset_wordle_games(session)
return current_word.word
async def reset_wordle_games(database: MongoDatabase): async def reset_wordle_games(session: AsyncSession):
"""Reset all active games""" """Reset all active games"""
collection = database[WordleGame.collection()] statement = delete(WordleGuess)
await collection.drop() await session.execute(statement)
await session.commit()

View File

@ -0,0 +1,60 @@
from datetime import date
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import WordleStats
__all__ = ["get_wordle_stats", "complete_wordle_game"]
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
"""Get a user's wordle stats
If no entry is found, it is first created
"""
await get_or_add_user(session, user_id)
statement = select(WordleStats).where(WordleStats.user_id == user_id)
stats = (await session.execute(statement)).scalar_one_or_none()
if stats is not None:
return stats
stats = WordleStats(user_id=user_id)
session.add(stats)
await session.commit()
await session.refresh(stats)
return stats
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
"""Update the user's Wordle stats"""
stats = await get_wordle_stats(session, user_id)
stats.games += 1
if win:
stats.wins += 1
# Update streak
today = date.today()
last_win = stats.last_win
stats.last_win = today
if last_win is None or (today - last_win).days > 1:
# Never won a game before or streak is over
stats.current_streak = 1
else:
# On a streak: increase counter
stats.current_streak += 1
# Update max streak if necessary
if stats.current_streak > stats.highest_streak:
stats.highest_streak = stats.current_streak
else:
# Streak is over
stats.current_streak = 0
session.add(stats)
await session.commit()

View File

@ -1,6 +1,5 @@
from urllib.parse import quote_plus from urllib.parse import quote_plus
import motor.motor_asyncio
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -26,16 +25,3 @@ postgres_engine = create_async_engine(
DBSession = sessionmaker( DBSession = sessionmaker(
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
) )
# MongoDB client
if not settings.TESTING: # pragma: no cover
encoded_mongo_username = quote_plus(settings.MONGO_USER)
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
mongo_url = (
f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
)
else:
# Require no authentication when testing
mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)

View File

@ -1,6 +1,6 @@
import enum import enum
__all__ = ["TaskType", "TempStorageKey"] __all__ = ["TaskType"]
# There is a bug in typeshed that causes an incorrect PyCharm warning # There is a bug in typeshed that causes an incorrect PyCharm warning
@ -11,10 +11,3 @@ class TaskType(enum.IntEnum):
BIRTHDAYS = enum.auto() BIRTHDAYS = enum.auto()
UFORA_ANNOUNCEMENTS = enum.auto() UFORA_ANNOUNCEMENTS = enum.auto()
@enum.unique
class TempStorageKey(str, enum.Enum):
"""Enum for keys to distinguish the TemporaryStorage rows"""
WORDLE_WORD = "wordle_word"

View File

@ -1,6 +0,0 @@
import motor.motor_asyncio
# Type aliases for the Motor types, which are way too long
MongoClient = motor.motor_asyncio.AsyncIOMotorClient
MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase
MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection

View File

@ -37,6 +37,9 @@ __all__ = [
"UforaCourse", "UforaCourse",
"UforaCourseAlias", "UforaCourseAlias",
"User", "User",
"WordleGuess",
"WordleStats",
"WordleWord",
] ]
@ -231,3 +234,47 @@ class User(Base):
nightly_data: NightlyData = relationship( nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
wordle_guesses: list[WordleGuess] = relationship(
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_stats: WordleStats = relationship(
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
class WordleGuess(Base):
"""A user's Wordle guesses for today"""
__tablename__ = "wordle_guesses"
wordle_guess_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
guess: str = Column(Text, nullable=False)
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
class WordleStats(Base):
"""Stats about a user's wordle performance"""
__tablename__ = "wordle_stats"
wordle_stats_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_win: Optional[date] = Column(Date, nullable=True)
games: int = Column(Integer, server_default="0", nullable=False)
wins: int = Column(Integer, server_default="0", nullable=False)
current_streak: int = Column(Integer, server_default="0", nullable=False)
highest_streak: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
class WordleWord(Base):
"""The current Wordle word"""
__tablename__ = "wordle_word"
word_id: int = Column(Integer, primary_key=True)
word: str = Column(Text, nullable=False)
day: date = Column(Date, nullable=False, unique=True)

View File

@ -1,53 +0,0 @@
from abc import ABC, abstractmethod
from bson import ObjectId
from pydantic import BaseModel, Field
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
class PyObjectId(ObjectId):
"""Custom type for bson ObjectIds"""
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, value: str):
"""Check that a string is a valid bson ObjectId"""
if not ObjectId.is_valid(value):
raise ValueError(f"Invalid ObjectId: '{value}'")
return ObjectId(value)
@classmethod
def __modify_schema__(cls, field_schema: dict):
field_schema.update(type="string")
class MongoBase(BaseModel):
"""Base model that properly sets the _id field, and adds one by default"""
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
class Config:
"""Configuration for encoding and construction"""
allow_population_by_field_name = True
arbitrary_types_allowed = True
json_encoders = {ObjectId: str, PyObjectId: str}
use_enum_values = True
class MongoCollection(MongoBase, ABC):
"""Base model for the 'main class' in a collection
This field stores the name of the collection to avoid making typos against it
"""
@staticmethod
@abstractmethod
def collection() -> str:
"""Getter for the name of the collection, in order to avoid typos"""
raise NotImplementedError

View File

@ -1,40 +0,0 @@
import datetime
from typing import Optional
from overrides import overrides
from pydantic import BaseModel, Field, validator
from database.schemas.mongo.common import MongoCollection
__all__ = ["GameStats", "WordleStats"]
class WordleStats(BaseModel):
"""Model that holds stats about a player's Wordle performance"""
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
last_win: Optional[datetime.datetime] = None
wins: int = 0
games: int = 0
current_streak: int = 0
max_streak: int = 0
@validator("guess_distribution")
def validate_guesses_length(cls, value: list[int]):
"""Check that the distribution of guesses is of the correct length"""
if len(value) != 6:
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
return value
class GameStats(MongoCollection):
"""Collection that holds stats about how well a user has performed in games"""
user_id: int
wordle: WordleStats = WordleStats()
@staticmethod
@overrides
def collection() -> str:
return "game_stats"

View File

@ -1,16 +0,0 @@
from overrides import overrides
from database.schemas.mongo.common import MongoCollection
__all__ = ["TemporaryStorage"]
class TemporaryStorage(MongoCollection):
"""Collection for lots of random things that don't belong in a full-blown collection"""
key: str
@staticmethod
@overrides
def collection() -> str:
return "temporary"

View File

@ -1,44 +0,0 @@
import datetime
from overrides import overrides
from pydantic import Field, validator
from database.constants import WORDLE_GUESS_COUNT
from database.schemas.mongo.common import MongoCollection
from database.utils.datetime import today_only_date
__all__ = ["WordleGame"]
class WordleGame(MongoCollection):
"""Collection that holds people's active Wordle games"""
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
guesses: list[str] = Field(default_factory=list)
user_id: int
@staticmethod
@overrides
def collection() -> str:
return "wordle"
@validator("guesses")
def validate_guesses_length(cls, value: list[int]):
"""Check that the amount of guesses is of the correct length"""
if len(value) > 6:
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
return value
def is_game_over(self, word: str) -> bool:
"""Check if the current game is over"""
# No guesses yet
if not self.guesses:
return False
# Max amount of guesses allowed
if len(self.guesses) == WORDLE_GUESS_COUNT:
return True
# Found the correct word
return self.guesses[-1] == word

View File

@ -1,19 +1,17 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Generic, TypeVar
from discord import app_commands from discord import app_commands
from overrides import overrides from overrides import overrides
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import links, memes, ufora_courses, wordle from database.crud import links, memes, ufora_courses, wordle
from database.mongo_types import MongoDatabase
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
T = TypeVar("T") from database.schemas import WordleWord
class DatabaseCache(ABC, Generic[T]): class DatabaseCache(ABC):
"""Base class for a simple cache-like structure """Base class for a simple cache-like structure
The goal of this class is to store data for Discord auto-completion results The goal of this class is to store data for Discord auto-completion results
@ -25,7 +23,7 @@ class DatabaseCache(ABC, Generic[T]):
Considering the fact that a user isn't obligated to choose something from the suggestions, Considering the fact that a user isn't obligated to choose something from the suggestions,
chances are high we have to go to the database for the final action either way. chances are high we have to go to the database for the final action either way.
Also stores the data in lowercase to allow fast searching Also stores the data in lowercase to allow fast searching.
""" """
data: list[str] = [] data: list[str] = []
@ -36,7 +34,7 @@ class DatabaseCache(ABC, Generic[T]):
self.data.clear() self.data.clear()
@abstractmethod @abstractmethod
async def invalidate(self, database_session: T): async def invalidate(self, database_session: AsyncSession):
"""Invalidate the data stored in this cache""" """Invalidate the data stored in this cache"""
def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]: def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
@ -48,7 +46,7 @@ class DatabaseCache(ABC, Generic[T]):
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
class LinkCache(DatabaseCache[AsyncSession]): class LinkCache(DatabaseCache):
"""Cache to store the names of links""" """Cache to store the names of links"""
@overrides @overrides
@ -61,7 +59,7 @@ class LinkCache(DatabaseCache[AsyncSession]):
self.data_transformed = list(map(str.lower, self.data)) self.data_transformed = list(map(str.lower, self.data))
class MemeCache(DatabaseCache[AsyncSession]): class MemeCache(DatabaseCache):
"""Cache to store the names of meme templates""" """Cache to store the names of meme templates"""
@overrides @overrides
@ -74,7 +72,7 @@ class MemeCache(DatabaseCache[AsyncSession]):
self.data_transformed = list(map(str.lower, self.data)) self.data_transformed = list(map(str.lower, self.data))
class UforaCourseCache(DatabaseCache[AsyncSession]): class UforaCourseCache(DatabaseCache):
"""Cache to store the names of Ufora courses""" """Cache to store the names of Ufora courses"""
# Also store the aliases to add additional support # Also store the aliases to add additional support
@ -119,13 +117,15 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
class WordleCache(DatabaseCache[MongoDatabase]): class WordleCache(DatabaseCache):
"""Cache to store the current daily Wordle word""" """Cache to store the current daily Wordle word"""
async def invalidate(self, database_session: MongoDatabase): word: WordleWord
async def invalidate(self, database_session: AsyncSession):
word = await wordle.get_daily_word(database_session) word = await wordle.get_daily_word(database_session)
if word is not None: if word is not None:
self.data = [word] self.word = word
class CacheManager: class CacheManager:
@ -142,9 +142,9 @@ class CacheManager:
self.ufora_courses = UforaCourseCache() self.ufora_courses = UforaCourseCache()
self.wordle_word = WordleCache() self.wordle_word = WordleCache()
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): async def initialize_caches(self, postgres_session: AsyncSession):
"""Initialize the contents of all caches""" """Initialize the contents of all caches"""
await self.links.invalidate(postgres_session) await self.links.invalidate(postgres_session)
await self.memes.invalidate(postgres_session) await self.memes.invalidate(postgres_session)
await self.ufora_courses.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session)
await self.wordle_word.invalidate(mongo_db) await self.wordle_word.invalidate(postgres_session)

View File

@ -1,15 +1,5 @@
import datetime
import zoneinfo import zoneinfo
__all__ = ["LOCAL_TIMEZONE", "today_only_date"] __all__ = ["LOCAL_TIMEZONE"]
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
def today_only_date() -> datetime.datetime:
"""Mongo can't handle datetime.date, so we need a datetime instance
We do, however, only care about the date, so remove all the rest
"""
today = datetime.date.today()
return datetime.datetime(year=today.year, month=today.month, day=today.day)

View File

@ -4,7 +4,7 @@ from overrides import overrides
from didier import Didier from didier import Didier
class TestCog(commands.Cog): class DebugCog(commands.Cog):
"""Testing cog for dev purposes""" """Testing cog for dev purposes"""
client: Didier client: Didier
@ -16,11 +16,11 @@ class TestCog(commands.Cog):
async def cog_check(self, ctx: commands.Context) -> bool: async def cog_check(self, ctx: commands.Context) -> bool:
return await self.client.is_owner(ctx.author) return await self.client.is_owner(ctx.author)
@commands.command() @commands.command(aliases=["Dev"])
async def test(self, ctx: commands.Context): async def debug(self, ctx: commands.Context):
"""Debugging command""" """Debugging command"""
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""
await client.add_cog(TestCog(client)) await client.add_cog(DebugCog(client))

View File

@ -4,14 +4,11 @@ import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH from database.constants import WORDLE_WORD_LENGTH
from database.crud.wordle import ( from database.crud.wordle import get_wordle_guesses, make_wordle_guess
get_active_wordle_game, from database.crud.wordle_stats import complete_wordle_game
make_wordle_guess,
start_new_wordle_game,
)
from didier import Didier from didier import Didier
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed, is_wordle_game_over
class Games(commands.Cog): class Games(commands.Cog):
@ -35,13 +32,16 @@ class Games(commands.Cog):
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed() embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
return await interaction.followup.send(embed=embed) return await interaction.followup.send(embed=embed)
active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id) word_instance = self.client.database_caches.wordle_word.word
if active_game is None:
active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id) async with self.client.postgres_session as session:
guesses = await get_wordle_guesses(session, interaction.user.id)
# Trying to guess with a complete game # Trying to guess with a complete game
if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess: if is_wordle_game_over(guesses, word_instance.word):
embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed() embed = WordleErrorEmbed(
message="You've already completed today's Wordle.\nTry again tomorrow!"
).to_embed()
return await interaction.followup.send(embed=embed) return await interaction.followup.send(embed=embed)
# Make a guess # Make a guess
@ -52,15 +52,20 @@ class Games(commands.Cog):
return await interaction.followup.send(embed=embed) return await interaction.followup.send(embed=embed)
guess = guess.lower() guess = guess.lower()
await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess) await make_wordle_guess(session, interaction.user.id, guess)
# Don't re-request the game, we already have it # Don't re-request the game, we already have it
# just append locally # just append locally
active_game.guesses.append(guess) guesses.append(guess)
embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed() embed = WordleEmbed(guesses=guesses, word=word_instance).to_embed()
await interaction.followup.send(embed=embed) await interaction.followup.send(embed=embed)
# After responding to the interaction: update stats in the background
game_over = is_wordle_game_over(guesses, word_instance.word)
if game_over:
await complete_wordle_game(session, interaction.user.id, word_instance.word in guesses)
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""

View File

@ -5,7 +5,7 @@ from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.crud.links import get_link_by_name from database.crud.links import get_link_by_name
from database.schemas.relational import Link from database.schemas import Link
from didier import Didier from didier import Didier
from didier.data.apis import urban_dictionary from didier.data.apis import urban_dictionary
from didier.data.embeds.google import GoogleSearch from didier.data.embeds.google import GoogleSearch

View File

@ -140,9 +140,9 @@ class Tasks(commands.Cog):
@tasks.loop(time=DAILY_RESET_TIME) @tasks.loop(time=DAILY_RESET_TIME)
async def reset_wordle_word(self, forced: bool = False): async def reset_wordle_word(self, forced: bool = False):
"""Reset the daily Wordle word""" """Reset the daily Wordle word"""
db = self.client.mongo_db async with self.client.postgres_session as session:
word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words)), forced=forced) await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced)
self.client.database_caches.wordle_word.data = [word] await self.client.database_caches.wordle_word.invalidate(session)
@reset_wordle_word.before_loop @reset_wordle_word.before_loop
async def _before_reset_wordle_word(self): async def _before_reset_wordle_word(self):

View File

@ -1,7 +1,7 @@
from aiohttp import ClientSession from aiohttp import ClientSession
import settings import settings
from database.schemas.relational import MemeTemplate from database.schemas import MemeTemplate
from didier.exceptions.missing_env import MissingEnvironmentVariable from didier.exceptions.missing_env import MissingEnvironmentVariable
from didier.utils.http.requests import ensure_post from didier.utils.http.requests import ensure_post

View File

@ -4,7 +4,7 @@ from datetime import datetime
import discord import discord
from overrides import overrides from overrides import overrides
from database.schemas.relational import Deadline from database.schemas import Deadline
from didier.data.embeds.base import EmbedBaseModel from didier.data.embeds.base import EmbedBaseModel
from didier.utils.types.datetime import tz_aware_now from didier.utils.types.datetime import tz_aware_now
from didier.utils.types.string import get_edu_year_name from didier.utils.types.string import get_edu_year_name

View File

@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud
from database.schemas.relational import UforaCourse from database.schemas import UforaCourse
from didier.data.embeds.base import EmbedBaseModel from didier.data.embeds.base import EmbedBaseModel
from didier.utils.discord.colours import ghent_university_blue from didier.utils.discord.colours import ghent_university_blue
from didier.utils.types.datetime import int_to_weekday from didier.utils.types.datetime import int_to_weekday

View File

@ -1,16 +1,26 @@
import enum import enum
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional
import discord import discord
from overrides import overrides from overrides import overrides
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
from database.schemas.mongo.wordle import WordleGame from database.schemas import WordleWord
from didier.data.embeds.base import EmbedBaseModel from didier.data.embeds.base import EmbedBaseModel
from didier.utils.types.datetime import int_to_weekday, tz_aware_now from didier.utils.types.datetime import int_to_weekday, tz_aware_now
__all__ = ["WordleEmbed", "WordleErrorEmbed"] __all__ = ["is_wordle_game_over", "WordleEmbed", "WordleErrorEmbed"]
def is_wordle_game_over(guesses: list[str], word: str) -> bool:
"""Check if the current game is over or not"""
if not guesses:
return False
if len(guesses) == WORDLE_GUESS_COUNT:
return True
return word.lower() in guesses
def footer() -> str: def footer() -> str:
@ -32,18 +42,18 @@ class WordleColour(enum.IntEnum):
class WordleEmbed(EmbedBaseModel): class WordleEmbed(EmbedBaseModel):
"""Embed for a Wordle game""" """Embed for a Wordle game"""
game: Optional[WordleGame] guesses: list[str]
word: str word: WordleWord
def _letter_colour(self, guess: str, index: int) -> WordleColour: def _letter_colour(self, guess: str, index: int) -> WordleColour:
"""Get the colour for a guess at a given position""" """Get the colour for a guess at a given position"""
if guess[index] == self.word[index]: if guess[index] == self.word.word[index]:
return WordleColour.CORRECT return WordleColour.CORRECT
wrong_letter = 0 wrong_letter = 0
wrong_position = 0 wrong_position = 0
for i, letter in enumerate(self.word): for i, letter in enumerate(self.word.word):
if letter == guess[index] and guess[i] != guess[index]: if letter == guess[index] and guess[i] != guess[index]:
wrong_letter += 1 wrong_letter += 1
@ -68,8 +78,7 @@ class WordleEmbed(EmbedBaseModel):
colours = [] colours = []
# Add all the guesses # Add all the guesses
if self.game is not None: for guess in self.guesses:
for guess in self.game.guesses:
colours.append(self._guess_colours(guess)) colours.append(self._guess_colours(guess))
# Fill the rest with empty spots # Fill the rest with empty spots
@ -99,19 +108,19 @@ class WordleEmbed(EmbedBaseModel):
colours = self.colour_code_game() colours = self.colour_code_game()
embed = discord.Embed(colour=discord.Colour.blue(), title="Wordle") embed = discord.Embed(colour=discord.Colour.blue(), title=f"Wordle #{self.word.word_id + 1}")
emojis = self._colours_to_emojis(colours) emojis = self._colours_to_emojis(colours)
rows = [" ".join(row) for row in emojis] rows = [" ".join(row) for row in emojis]
# Don't reveal anything if we only want to show the colours # Don't reveal anything if we only want to show the colours
if not only_colours and self.game is not None: if not only_colours and self.guesses:
for i, guess in enumerate(self.game.guesses): for i, guess in enumerate(self.guesses):
rows[i] += f" ||{guess.upper()}||" rows[i] += f" ||{guess.upper()}||"
# If the game is over, reveal the word # If the game is over, reveal the word
if self.game.is_game_over(self.word): if is_wordle_game_over(self.guesses, self.word.word):
rows.append(f"\n\nThe word was **{self.word.upper()}**!") rows.append(f"\n\nThe word was **{self.word.word.upper()}**!")
embed.description = "\n\n".join(rows) embed.description = "\n\n".join(rows)
embed.set_footer(text=footer()) embed.set_footer(text=footer())

View File

@ -2,7 +2,6 @@ import logging
import os import os
import discord import discord
import motor.motor_asyncio
from aiohttp import ClientSession from aiohttp import ClientSession
from discord.app_commands import AppCommandError from discord.app_commands import AppCommandError
from discord.ext import commands from discord.ext import commands
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import custom_commands from database.crud import custom_commands
from database.engine import DBSession, mongo_client from database.engine import DBSession
from database.utils.caches import CacheManager from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.error_embed import create_error_embed
from didier.exceptions import HTTPException, NoMatch from didier.exceptions import HTTPException, NoMatch
@ -55,11 +54,6 @@ class Didier(commands.Bot):
"""Obtain a session for the PostgreSQL database""" """Obtain a session for the PostgreSQL database"""
return DBSession() return DBSession()
@property
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
"""Obtain a reference to the MongoDB database"""
return mongo_client[settings.MONGO_DB]
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""Do some initial setup """Do some initial setup
@ -71,7 +65,7 @@ class Didier(commands.Bot):
# Initialize caches # Initialize caches
self.database_caches = CacheManager() self.database_caches = CacheManager()
async with self.postgres_session as session: async with self.postgres_session as session:
await self.database_caches.initialize_caches(session, self.mongo_db) await self.database_caches.initialize_caches(session)
# Load extensions # Load extensions
await self._load_initial_extensions() await self._load_initial_extensions()

View File

@ -5,7 +5,7 @@ from discord import Interaction
from overrides import overrides from overrides import overrides
from database.crud.deadlines import add_deadline from database.crud.deadlines import add_deadline
from database.schemas.relational import UforaCourse from database.schemas import UforaCourse
__all__ = ["AddDeadline"] __all__ = ["AddDeadline"]

View File

@ -3,7 +3,7 @@ import traceback
import discord.ui import discord.ui
from overrides import overrides from overrides import overrides
from database.schemas.relational import MemeTemplate from database.schemas import MemeTemplate
from didier import Didier from didier import Didier
from didier.data.apis.imgflip import generate_meme from didier.data.apis.imgflip import generate_meme

View File

@ -10,10 +10,3 @@ services:
- POSTGRES_PASSWORD=pytest - POSTGRES_PASSWORD=pytest
ports: ports:
- "5433:5432" - "5433:5432"
mongo-pytest:
image: mongo:5.0
restart: always
environment:
- MONGO_INITDB_DATABASE=didier_pytest
ports:
- "27018:27017"

View File

@ -12,18 +12,5 @@ services:
- "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}" - "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
volumes: volumes:
- postgres:/var/lib/postgresql/data - postgres:/var/lib/postgresql/data
mongo:
image: mongo:5.0
restart: always
environment:
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
command: [--auth]
ports:
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
volumes:
- mongo:/data/db
volumes: volumes:
postgres: postgres:
mongo:

View File

@ -44,9 +44,6 @@ ignore_missing_imports = true
asyncio_mode = "auto" asyncio_mode = "auto"
env = [ env = [
"TESTING = 1", "TESTING = 1",
"MONGO_DB = didier_pytest",
"MONGO_HOST = localhost",
"MONGO_PORT = 27018",
"POSTGRES_DB = didier_pytest", "POSTGRES_DB = didier_pytest",
"POSTGRES_USER = pytest", "POSTGRES_USER = pytest",
"POSTGRES_PASS = pytest", "POSTGRES_PASS = pytest",
@ -55,6 +52,5 @@ env = [
"DISCORD_TOKEN = token" "DISCORD_TOKEN = token"
] ]
markers = [ markers = [
"mongo: tests that use MongoDB",
"postgres: tests that use PostgreSQL" "postgres: tests that use PostgreSQL"
] ]

View File

@ -7,7 +7,6 @@ git+https://github.com/Rapptz/discord-ext-menus@8686b5d
environs==9.5.0 environs==9.5.0
feedparser==6.0.10 feedparser==6.0.10
markdownify==0.11.2 markdownify==0.11.2
motor==3.0.0
overrides==6.1.0 overrides==6.1.0
pydantic==1.9.1 pydantic==1.9.1
python-dateutil==2.8.2 python-dateutil==2.8.2

View File

@ -37,13 +37,6 @@ SEMESTER: int = env.int("SEMESTER", 2)
YEAR: int = env.int("YEAR", 3) YEAR: int = env.int("YEAR", 3)
"""Database""" """Database"""
# MongoDB
MONGO_DB: str = env.str("MONGO_DB", "didier")
MONGO_USER: str = env.str("MONGO_USER", "root")
MONGO_PASS: str = env.str("MONGO_PASS", "root")
MONGO_HOST: str = env.str("MONGO_HOST", "localhost")
MONGO_PORT: int = env.int("MONGO_PORT", "27017")
# PostgreSQL # PostgreSQL
POSTGRES_DB: str = env.str("POSTGRES_DB", "didier") POSTGRES_DB: str = env.str("POSTGRES_DB", "didier")
POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres") POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres")

View File

@ -2,12 +2,10 @@ import asyncio
from typing import AsyncGenerator, Generator from typing import AsyncGenerator, Generator
from unittest.mock import MagicMock from unittest.mock import MagicMock
import motor.motor_asyncio
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
import settings from database.engine import postgres_engine
from database.engine import mongo_client, postgres_engine
from database.migrations import ensure_latest_migration, migrate from database.migrations import ensure_latest_migration, migrate
from didier import Didier from didier import Didier
@ -56,14 +54,6 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
await connection.close() await connection.close()
@pytest.fixture
async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase:
"""Fixture to get a MongoDB connection"""
database = mongo_client[settings.MONGO_DB]
yield database
mongo_client.drop_database(settings.MONGO_DB)
@pytest.fixture @pytest.fixture
def mock_client() -> Didier: def mock_client() -> Didier:
"""Fixture to get a mock Didier instance """Fixture to get a mock Didier instance

View File

@ -4,7 +4,7 @@ import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users from database.crud import users
from database.schemas.relational import ( from database.schemas import (
Bank, Bank,
UforaAnnouncement, UforaAnnouncement,
UforaCourse, UforaCourse,
@ -25,7 +25,7 @@ def test_user_id() -> int:
@pytest.fixture @pytest.fixture
async def user(postgres: AsyncSession, test_user_id: int) -> User: async def user(postgres: AsyncSession, test_user_id: int) -> User:
"""Fixture to create a user""" """Fixture to create a user"""
_user = await users.get_or_add(postgres, test_user_id) _user = await users.get_or_add_user(postgres, test_user_id)
await postgres.refresh(_user) await postgres.refresh(_user)
return _user return _user

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import birthdays as crud from database.crud import birthdays as crud
from database.crud import users from database.crud import users
from database.schemas.relational import User from database.schemas import User
async def test_add_birthday_not_present(postgres: AsyncSession, user: User): async def test_add_birthday_not_present(postgres: AsyncSession, user: User):
@ -54,7 +54,7 @@ async def test_get_birthdays_on_day(postgres: AsyncSession, user: User):
"""Test getting all birthdays on a given day""" """Test getting all birthdays on a given day"""
await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001)) await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
user_2 = await users.get_or_add(postgres, user.user_id + 1) user_2 = await users.get_or_add_user(postgres, user.user_id + 1)
await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1)) await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1))
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
assert len(birthdays) == 1 assert len(birthdays) == 1

View File

@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import currency as crud from database.crud import currency as crud
from database.exceptions import currency as exceptions from database.exceptions import currency as exceptions
from database.schemas.relational import Bank from database.schemas import Bank
async def test_add_dinks(postgres: AsyncSession, bank: Bank): async def test_add_dinks(postgres: AsyncSession, bank: Bank):

View File

@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import custom_commands as crud from database.crud import custom_commands as crud
from database.exceptions.constraints import DuplicateInsertException from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException from database.exceptions.not_found import NoResultFoundException
from database.schemas.relational import CustomCommand from database.schemas import CustomCommand
async def test_create_command_non_existing(postgres: AsyncSession): async def test_create_command_non_existing(postgres: AsyncSession):

View File

@ -2,7 +2,7 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import dad_jokes as crud from database.crud import dad_jokes as crud
from database.schemas.relational import DadJoke from database.schemas import DadJoke
async def test_add_dad_joke(postgres: AsyncSession): async def test_add_dad_joke(postgres: AsyncSession):

View File

@ -1,63 +0,0 @@
import pytest
from freezegun import freeze_time
from database.crud import game_stats as crud
from database.mongo_types import MongoDatabase
from database.schemas.mongo.game_stats import GameStats
from database.utils.datetime import today_only_date
async def insert_game_stats(mongodb: MongoDatabase, stats: GameStats):
"""Helper function to insert some stats"""
collection = mongodb[GameStats.collection()]
await collection.insert_one(stats.dict(by_alias=True))
@pytest.mark.mongo
async def test_get_stats_non_existent_creates(mongodb: MongoDatabase, test_user_id: int):
"""Test getting a user's stats when the db is empty"""
collection = mongodb[GameStats.collection()]
assert await collection.find_one({"user_id": test_user_id}) is None
await crud.get_game_stats(mongodb, test_user_id)
assert await collection.find_one({"user_id": test_user_id}) is not None
@pytest.mark.mongo
async def test_get_stats_existing_returns(mongodb: MongoDatabase, test_user_id: int):
"""Test getting a user's stats when there's already an entry present"""
stats = GameStats(user_id=test_user_id)
stats.wordle.games = 20
await insert_game_stats(mongodb, stats)
found_stats = await crud.get_game_stats(mongodb, test_user_id)
assert found_stats.wordle.games == 20
@pytest.mark.mongo
@freeze_time("2022-07-30")
async def test_complete_wordle_game_won(mongodb: MongoDatabase, test_user_id: int):
"""Test completing a wordle game when you win"""
await crud.complete_wordle_game(mongodb, test_user_id, win=True, guesses=2)
stats = await crud.get_game_stats(mongodb, test_user_id)
assert stats.wordle.guess_distribution == [0, 1, 0, 0, 0, 0]
assert stats.wordle.games == 1
assert stats.wordle.wins == 1
assert stats.wordle.current_streak == 1
assert stats.wordle.max_streak == 1
assert stats.wordle.last_win == today_only_date()
@pytest.mark.mongo
@freeze_time("2022-07-30")
async def test_complete_wordle_game_lost(mongodb: MongoDatabase, test_user_id: int):
"""Test completing a wordle game when you lose"""
stats = GameStats(user_id=test_user_id)
stats.wordle.current_streak = 10
await insert_game_stats(mongodb, stats)
await crud.complete_wordle_game(mongodb, test_user_id, win=False)
stats = await crud.get_game_stats(mongodb, test_user_id)
# Check that streak was broken
assert stats.wordle.current_streak == 0
assert stats.wordle.games == 1
assert stats.wordle.guess_distribution == [0, 0, 0, 0, 0, 0]

View File

@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import tasks as crud from database.crud import tasks as crud
from database.enums import TaskType from database.enums import TaskType
from database.schemas.relational import Task from database.schemas import Task
@pytest.fixture @pytest.fixture

View File

@ -3,7 +3,7 @@ import datetime
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_announcements as crud from database.crud import ufora_announcements as crud
from database.schemas.relational import UforaAnnouncement, UforaCourse from database.schemas import UforaAnnouncement, UforaCourse
async def test_get_courses_with_announcements_none(postgres: AsyncSession): async def test_get_courses_with_announcements_none(postgres: AsyncSession):

View File

@ -1,7 +1,7 @@
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import ufora_courses as crud from database.crud import ufora_courses as crud
from database.schemas.relational import UforaCourse from database.schemas import UforaCourse
async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse): async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse):

View File

@ -2,12 +2,12 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import users as crud from database.crud import users as crud
from database.schemas.relational import User from database.schemas import User
async def test_get_or_add_non_existing(postgres: AsyncSession): async def test_get_or_add_non_existing(postgres: AsyncSession):
"""Test get_or_add for a user that doesn't exist""" """Test get_or_add for a user that doesn't exist"""
await crud.get_or_add(postgres, 1) await crud.get_or_add_user(postgres, 1)
statement = select(User) statement = select(User)
res = (await postgres.execute(statement)).scalars().all() res = (await postgres.execute(statement)).scalars().all()
@ -18,8 +18,8 @@ async def test_get_or_add_non_existing(postgres: AsyncSession):
async def test_get_or_add_existing(postgres: AsyncSession): async def test_get_or_add_existing(postgres: AsyncSession):
"""Test get_or_add for a user that does exist""" """Test get_or_add for a user that does exist"""
user = await crud.get_or_add(postgres, 1) user = await crud.get_or_add_user(postgres, 1)
bank = user.bank bank = user.bank
assert await crud.get_or_add(postgres, 1) == user assert await crud.get_or_add_user(postgres, 1) == user
assert (await crud.get_or_add(postgres, 1)).bank == bank assert (await crud.get_or_add_user(postgres, 1)).bank == bank

View File

@ -1,136 +1,138 @@
from datetime import datetime, timedelta from datetime import date, timedelta
import pytest import pytest
from freezegun import freeze_time from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import wordle as crud from database.crud import wordle as crud
from database.enums import TempStorageKey from database.schemas import User, WordleGuess, WordleWord
from database.mongo_types import MongoCollection, MongoDatabase
from database.schemas.mongo.temporary_storage import TemporaryStorage
from database.schemas.mongo.wordle import WordleGame
@pytest.fixture @pytest.fixture
async def wordle_collection(mongodb: MongoDatabase) -> MongoCollection: async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]:
"""Fixture to get a reference to the wordle collection""" """Fixture to generate some guesses"""
yield mongodb[WordleGame.collection()] guesses = []
for guess in ["TEST", "WORDLE", "WORDS"]:
guess = WordleGuess(user_id=user.user_id, guess=guess)
postgres.add(guess)
await postgres.commit()
guesses.append(guess)
return guesses
@pytest.fixture @pytest.mark.postgres
async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> WordleGame: async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User):
"""Fixture to create a new game"""
game = WordleGame(user_id=test_user_id)
await wordle_collection.insert_one(game.dict(by_alias=True))
yield game
@pytest.mark.mongo
async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int):
"""Test starting a new game"""
result = await wordle_collection.find_one({"user_id": test_user_id})
assert result is None
await crud.start_new_wordle_game(mongodb, test_user_id)
result = await wordle_collection.find_one({"user_id": test_user_id})
assert result is not None
@pytest.mark.mongo
async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int):
"""Test getting an active game when there is none""" """Test getting an active game when there is none"""
result = await crud.get_active_wordle_game(mongodb, test_user_id) result = await crud.get_active_wordle_game(postgres, user.user_id)
assert result is None assert not result
@pytest.mark.mongo @pytest.mark.postgres
async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame): async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]):
"""Test getting an active game when there is one""" """Test getting an active game when there is one"""
result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id) result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id)
assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True) assert result == wordle_guesses
@pytest.mark.mongo @pytest.mark.postgres
async def test_get_daily_word_none(mongodb: MongoDatabase): async def test_get_daily_word_none(postgres: AsyncSession):
"""Test getting the daily word when the database is empty""" """Test getting the daily word when the database is empty"""
result = await crud.get_daily_word(mongodb) result = await crud.get_daily_word(postgres)
assert result is None assert result is None
@pytest.mark.mongo @pytest.mark.postgres
@freeze_time("2022-07-30") @freeze_time("2022-07-30")
async def test_get_daily_word_not_today(mongodb: MongoDatabase): async def test_get_daily_word_not_today(postgres: AsyncSession):
"""Test getting the daily word when there is an entry, but not for today""" """Test getting the daily word when there is an entry, but not for today"""
day = datetime.today() - timedelta(days=1) day = date.today() - timedelta(days=1)
collection = mongodb[TemporaryStorage.collection()]
word = "testword" word = "testword"
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word}) word_instance = WordleWord(word=word, day=day)
postgres.add(word_instance)
await postgres.commit()
assert await crud.get_daily_word(mongodb) is None assert await crud.get_daily_word(postgres) is None
@pytest.mark.mongo @pytest.mark.postgres
@freeze_time("2022-07-30") @freeze_time("2022-07-30")
async def test_get_daily_word_present(mongodb: MongoDatabase): async def test_get_daily_word_present(postgres: AsyncSession):
"""Test getting the daily word when there is one for today""" """Test getting the daily word when there is one for today"""
day = datetime.today() day = date.today()
collection = mongodb[TemporaryStorage.collection()]
word = "testword" word = "testword"
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word}) word_instance = WordleWord(word=word, day=day)
postgres.add(word_instance)
await postgres.commit()
assert await crud.get_daily_word(mongodb) == word daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.mongo @pytest.mark.postgres
@freeze_time("2022-07-30") @freeze_time("2022-07-30")
async def test_set_daily_word_none_present(mongodb: MongoDatabase): async def test_set_daily_word_none_present(postgres: AsyncSession):
"""Test setting the daily word when there is none""" """Test setting the daily word when there is none"""
assert await crud.get_daily_word(mongodb) is None assert await crud.get_daily_word(postgres) is None
word = "testword" word = "testword"
await crud.set_daily_word(mongodb, word) await crud.set_daily_word(postgres, word)
assert await crud.get_daily_word(mongodb) == word
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.mongo @pytest.mark.postgres
@freeze_time("2022-07-30") @freeze_time("2022-07-30")
async def test_set_daily_word_present(mongodb: MongoDatabase): async def test_set_daily_word_present(postgres: AsyncSession):
"""Test setting the daily word when there already is one""" """Test setting the daily word when there already is one"""
word = "testword" word = "testword"
await crud.set_daily_word(mongodb, word) await crud.set_daily_word(postgres, word)
await crud.set_daily_word(mongodb, "another word") await crud.set_daily_word(postgres, "another word")
assert await crud.get_daily_word(mongodb) == word
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.mongo @pytest.mark.postgres
@freeze_time("2022-07-30") @freeze_time("2022-07-30")
async def test_set_daily_word_force_overwrite(mongodb: MongoDatabase): async def test_set_daily_word_force_overwrite(postgres: AsyncSession):
"""Test setting the daily word when there already is one, but "forced" is set to True""" """Test setting the daily word when there already is one, but "forced" is set to True"""
word = "testword" word = "testword"
await crud.set_daily_word(mongodb, word) await crud.set_daily_word(postgres, word)
word = "anotherword" word = "anotherword"
await crud.set_daily_word(mongodb, word, forced=True) await crud.set_daily_word(postgres, word, forced=True)
assert await crud.get_daily_word(mongodb) == word
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.mongo @pytest.mark.postgres
async def test_make_wordle_guess(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int): async def test_make_wordle_guess(postgres: AsyncSession, user: User):
"""Test making a guess in your current game""" """Test making a guess in your current game"""
test_user_id = user.user_id
guess = "guess" guess = "guess"
await crud.make_wordle_guess(mongodb, test_user_id, guess) await crud.make_wordle_guess(postgres, test_user_id, guess)
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id) assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess]
assert wordle_game.guesses == [guess]
other_guess = "otherguess" other_guess = "otherguess"
await crud.make_wordle_guess(mongodb, test_user_id, other_guess) await crud.make_wordle_guess(postgres, test_user_id, other_guess)
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id) assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess, other_guess]
assert wordle_game.guesses == [guess, other_guess]
@pytest.mark.mongo @pytest.mark.postgres
async def test_reset_wordle_games(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int): async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User):
"""Test dropping the collection of active games""" """Test dropping the collection of active games"""
assert await crud.get_active_wordle_game(mongodb, test_user_id) is not None test_user_id = user.user_id
await crud.reset_wordle_games(mongodb)
assert await crud.get_active_wordle_game(mongodb, test_user_id) is None assert await crud.get_active_wordle_game(postgres, test_user_id)
await crud.reset_wordle_games(postgres)
assert not await crud.get_active_wordle_game(postgres, test_user_id)

View File

@ -0,0 +1,72 @@
import datetime
import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import wordle_stats as crud
from database.schemas import User, WordleStats
async def insert_game_stats(postgres: AsyncSession, stats: WordleStats):
"""Helper function to insert some stats"""
postgres.add(stats)
await postgres.commit()
@pytest.mark.postgres
async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User):
"""Test getting a user's stats when the db is empty"""
test_user_id = user.user_id
statement = select(WordleStats).where(WordleStats.user_id == test_user_id)
assert (await postgres.execute(statement)).scalar_one_or_none() is None
await crud.get_wordle_stats(postgres, test_user_id)
assert (await postgres.execute(statement)).scalar_one_or_none() is not None
@pytest.mark.postgres
async def test_get_stats_existing_returns(postgres: AsyncSession, user: User):
"""Test getting a user's stats when there's already an entry present"""
test_user_id = user.user_id
stats = WordleStats(user_id=test_user_id)
stats.games = 20
await insert_game_stats(postgres, stats)
found_stats = await crud.get_wordle_stats(postgres, test_user_id)
assert found_stats.games == 20
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_complete_wordle_game_won(postgres: AsyncSession, user: User):
"""Test completing a wordle game when you win"""
test_user_id = user.user_id
await crud.complete_wordle_game(postgres, test_user_id, win=True)
stats = await crud.get_wordle_stats(postgres, test_user_id)
assert stats.games == 1
assert stats.wins == 1
assert stats.current_streak == 1
assert stats.highest_streak == 1
assert stats.last_win == datetime.date.today()
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User):
"""Test completing a wordle game when you lose"""
test_user_id = user.user_id
stats = WordleStats(user_id=test_user_id)
stats.current_streak = 10
await insert_game_stats(postgres, stats)
await crud.complete_wordle_game(postgres, test_user_id, win=False)
stats = await crud.get_wordle_stats(postgres, test_user_id)
# Check that streak was broken
assert stats.current_streak == 0
assert stats.games == 1

View File

@ -1,6 +1,6 @@
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.schemas.relational import UforaCourse from database.schemas import UforaCourse
from database.utils.caches import UforaCourseCache from database.utils.caches import UforaCourseCache