mirror of https://github.com/stijndcl/didier
Remove mongo & fix tests
parent
7b2109fb07
commit
8a4baf6bb8
|
@ -44,9 +44,3 @@ repos:
|
||||||
- "flake8-eradicate"
|
- "flake8-eradicate"
|
||||||
- "flake8-isort"
|
- "flake8-isort"
|
||||||
- "flake8-simplify"
|
- "flake8-simplify"
|
||||||
|
|
||||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
|
||||||
rev: v0.961
|
|
||||||
hooks:
|
|
||||||
- id: mypy
|
|
||||||
args: [--config, pyproject.toml]
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from database.engine import postgres_engine
|
from database.engine import postgres_engine
|
||||||
from database.schemas.relational import Base
|
from database.schemas import Base
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
# access to the values within the .ini file in use.
|
# access to the values within the .ini file in use.
|
||||||
|
|
|
@ -0,0 +1,63 @@
|
||||||
|
"""Wordle
|
||||||
|
|
||||||
|
Revision ID: 38b7c29f10ee
|
||||||
|
Revises: 36300b558ef1
|
||||||
|
Create Date: 2022-08-29 20:21:02.413631
|
||||||
|
|
||||||
|
"""
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision = "38b7c29f10ee"
|
||||||
|
down_revision = "36300b558ef1"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.create_table(
|
||||||
|
"wordle_word",
|
||||||
|
sa.Column("word_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("word", sa.Text(), nullable=False),
|
||||||
|
sa.Column("day", sa.Date(), nullable=False),
|
||||||
|
sa.PrimaryKeyConstraint("word_id"),
|
||||||
|
sa.UniqueConstraint("day"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"wordle_guesses",
|
||||||
|
sa.Column("wordle_guess_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("guess", sa.Text(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.user_id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("wordle_guess_id"),
|
||||||
|
)
|
||||||
|
op.create_table(
|
||||||
|
"wordle_stats",
|
||||||
|
sa.Column("wordle_stats_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.Column("last_win", sa.Date(), nullable=True),
|
||||||
|
sa.Column("games", sa.Integer(), server_default="0", nullable=False),
|
||||||
|
sa.Column("wins", sa.Integer(), server_default="0", nullable=False),
|
||||||
|
sa.Column("current_streak", sa.Integer(), server_default="0", nullable=False),
|
||||||
|
sa.Column("highest_streak", sa.Integer(), server_default="0", nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.user_id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("wordle_stats_id"),
|
||||||
|
)
|
||||||
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("wordle_stats")
|
||||||
|
op.drop_table("wordle_guesses")
|
||||||
|
op.drop_table("wordle_word")
|
||||||
|
# ### end Alembic commands ###
|
|
@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.schemas.relational import Birthday, User
|
from database.schemas import Birthday, User
|
||||||
|
|
||||||
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.exceptions import currency as exceptions
|
from database.exceptions import currency as exceptions
|
||||||
from database.schemas.relational import Bank, NightlyData
|
from database.schemas import Bank, NightlyData
|
||||||
from database.utils.math.currency import (
|
from database.utils.math.currency import (
|
||||||
capacity_upgrade_price,
|
capacity_upgrade_price,
|
||||||
interest_upgrade_price,
|
interest_upgrade_price,
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from database.schemas.relational import CustomCommand, CustomCommandAlias
|
from database.schemas import CustomCommand, CustomCommandAlias
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"clean_name",
|
"clean_name",
|
||||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from database.schemas.relational import DadJoke
|
from database.schemas import DadJoke
|
||||||
|
|
||||||
__all__ = ["add_dad_joke", "get_random_dad_joke"]
|
__all__ = ["add_dad_joke", "get_random_dad_joke"]
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from sqlalchemy.orm import selectinload
|
from sqlalchemy.orm import selectinload
|
||||||
|
|
||||||
from database.schemas.relational import Deadline, UforaCourse
|
from database.schemas import Deadline, UforaCourse
|
||||||
|
|
||||||
__all__ = ["add_deadline", "get_deadlines"]
|
__all__ = ["add_deadline", "get_deadlines"]
|
||||||
|
|
||||||
|
|
|
@ -1,59 +0,0 @@
|
||||||
import datetime
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
from database.mongo_types import MongoDatabase
|
|
||||||
from database.schemas.mongo.game_stats import GameStats
|
|
||||||
|
|
||||||
__all__ = ["get_game_stats", "complete_wordle_game"]
|
|
||||||
|
|
||||||
from database.utils.datetime import today_only_date
|
|
||||||
|
|
||||||
|
|
||||||
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
|
|
||||||
"""Get a user's game stats
|
|
||||||
|
|
||||||
If no entry is found, it is first created
|
|
||||||
"""
|
|
||||||
collection = database[GameStats.collection()]
|
|
||||||
stats = await collection.find_one({"user_id": user_id})
|
|
||||||
if stats is not None:
|
|
||||||
return GameStats(**stats)
|
|
||||||
|
|
||||||
stats = GameStats(user_id=user_id)
|
|
||||||
await collection.insert_one(stats.dict(by_alias=True))
|
|
||||||
return stats
|
|
||||||
|
|
||||||
|
|
||||||
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0):
|
|
||||||
"""Update the user's Wordle stats"""
|
|
||||||
stats = await get_game_stats(database, user_id)
|
|
||||||
|
|
||||||
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}}
|
|
||||||
|
|
||||||
if win:
|
|
||||||
update["$inc"]["wordle.wins"] = 1
|
|
||||||
update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1
|
|
||||||
|
|
||||||
# Update streak
|
|
||||||
today = today_only_date()
|
|
||||||
last_win = stats.wordle.last_win
|
|
||||||
update["$set"]["wordle.last_win"] = today
|
|
||||||
|
|
||||||
if last_win is None or (today - last_win).days > 1:
|
|
||||||
# Never won a game before or streak is over
|
|
||||||
update["$set"]["wordle.current_streak"] = 1
|
|
||||||
stats.wordle.current_streak = 1
|
|
||||||
else:
|
|
||||||
# On a streak: increase counter
|
|
||||||
update["$inc"]["wordle.current_streak"] = 1
|
|
||||||
stats.wordle.current_streak += 1
|
|
||||||
|
|
||||||
# Update max streak if necessary
|
|
||||||
if stats.wordle.current_streak > stats.wordle.max_streak:
|
|
||||||
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
|
|
||||||
else:
|
|
||||||
# Streak is over
|
|
||||||
update["$set"]["wordle.current_streak"] = 0
|
|
||||||
|
|
||||||
collection = database[GameStats.collection()]
|
|
||||||
await collection.update_one({"_id": stats.id}, update)
|
|
|
@ -4,7 +4,7 @@ from sqlalchemy import func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.exceptions import NoResultFoundException
|
from database.exceptions import NoResultFoundException
|
||||||
from database.schemas.relational import Link
|
from database.schemas import Link
|
||||||
|
|
||||||
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
|
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import MemeTemplate
|
from database.schemas import MemeTemplate
|
||||||
|
|
||||||
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]
|
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]
|
||||||
|
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.enums import TaskType
|
from database.enums import TaskType
|
||||||
from database.schemas.relational import Task
|
from database.schemas import Task
|
||||||
from database.utils.datetime import LOCAL_TIMEZONE
|
from database.utils.datetime import LOCAL_TIMEZONE
|
||||||
|
|
||||||
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
|
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
|
||||||
|
|
|
@ -3,7 +3,7 @@ import datetime
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
from database.schemas import UforaAnnouncement, UforaCourse
|
||||||
|
|
||||||
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]
|
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import UforaCourse, UforaCourseAlias
|
from database.schemas import UforaCourse, UforaCourseAlias
|
||||||
|
|
||||||
__all__ = ["get_all_courses", "get_course_by_name"]
|
__all__ = ["get_all_courses", "get_course_by_name"]
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import Bank, NightlyData, User
|
from database.schemas import Bank, NightlyData, User
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_or_add",
|
"get_or_add",
|
||||||
|
|
|
@ -1,56 +1,45 @@
|
||||||
|
import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from database.enums import TempStorageKey
|
from sqlalchemy import delete, select
|
||||||
from database.mongo_types import MongoDatabase
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
|
||||||
from database.schemas.mongo.wordle import WordleGame
|
from database.schemas import WordleGuess, WordleWord
|
||||||
from database.utils.datetime import today_only_date
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_active_wordle_game",
|
"get_active_wordle_game",
|
||||||
"make_wordle_guess",
|
"make_wordle_guess",
|
||||||
"start_new_wordle_game",
|
|
||||||
"set_daily_word",
|
"set_daily_word",
|
||||||
"reset_wordle_games",
|
"reset_wordle_games",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]:
|
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
|
||||||
"""Find a player's active game"""
|
"""Find a player's active game"""
|
||||||
collection = database[WordleGame.collection()]
|
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
|
||||||
result = await collection.find_one({"user_id": user_id})
|
guesses = (await session.execute(statement)).scalars().all()
|
||||||
if result is None:
|
return guesses
|
||||||
return None
|
|
||||||
|
|
||||||
return WordleGame(**result)
|
|
||||||
|
|
||||||
|
|
||||||
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame:
|
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
|
||||||
"""Start a new game"""
|
|
||||||
collection = database[WordleGame.collection()]
|
|
||||||
game = WordleGame(user_id=user_id)
|
|
||||||
await collection.insert_one(game.dict(by_alias=True))
|
|
||||||
return game
|
|
||||||
|
|
||||||
|
|
||||||
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str):
|
|
||||||
"""Make a guess in your current game"""
|
"""Make a guess in your current game"""
|
||||||
collection = database[WordleGame.collection()]
|
guess_instance = WordleGuess(user_id=user_id, guess=guess)
|
||||||
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
|
session.add(guess_instance)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
async def get_daily_word(database: MongoDatabase) -> Optional[str]:
|
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
|
||||||
"""Get the word of today"""
|
"""Get the word of today"""
|
||||||
collection = database[TemporaryStorage.collection()]
|
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
|
||||||
|
row = (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
|
||||||
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
|
if row is None:
|
||||||
if result is None:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return result["word"]
|
return row
|
||||||
|
|
||||||
|
|
||||||
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str:
|
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
|
||||||
"""Set the word of today
|
"""Set the word of today
|
||||||
|
|
||||||
This does NOT overwrite the existing word if there is one, so that it can safely run
|
This does NOT overwrite the existing word if there is one, so that it can safely run
|
||||||
|
@ -60,23 +49,28 @@ async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = F
|
||||||
|
|
||||||
Returns the word that was chosen. If one already existed, return that instead.
|
Returns the word that was chosen. If one already existed, return that instead.
|
||||||
"""
|
"""
|
||||||
collection = database[TemporaryStorage.collection()]
|
current_word = await get_daily_word(session)
|
||||||
|
|
||||||
current_word = None if forced else await get_daily_word(database)
|
if current_word is None:
|
||||||
if current_word is not None:
|
current_word = WordleWord(word=word, day=datetime.date.today())
|
||||||
return current_word
|
session.add(current_word)
|
||||||
|
await session.commit()
|
||||||
await collection.update_one(
|
|
||||||
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
|
|
||||||
)
|
|
||||||
|
|
||||||
# Remove all active games
|
# Remove all active games
|
||||||
await reset_wordle_games(database)
|
await reset_wordle_games(session)
|
||||||
|
elif forced:
|
||||||
|
current_word.word = word
|
||||||
|
current_word.day = datetime.date.today()
|
||||||
|
session.add(current_word)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
return word
|
# Remove all active games
|
||||||
|
await reset_wordle_games(session)
|
||||||
|
|
||||||
|
return current_word.word
|
||||||
|
|
||||||
|
|
||||||
async def reset_wordle_games(database: MongoDatabase):
|
async def reset_wordle_games(session: AsyncSession):
|
||||||
"""Reset all active games"""
|
"""Reset all active games"""
|
||||||
collection = database[WordleGame.collection()]
|
statement = delete(WordleGuess)
|
||||||
await collection.drop()
|
await session.execute(statement)
|
||||||
|
|
|
@ -0,0 +1,57 @@
|
||||||
|
from datetime import date
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.schemas import WordleStats
|
||||||
|
|
||||||
|
__all__ = ["get_wordle_stats", "complete_wordle_game"]
|
||||||
|
|
||||||
|
|
||||||
|
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
|
||||||
|
"""Get a user's wordle stats
|
||||||
|
|
||||||
|
If no entry is found, it is first created
|
||||||
|
"""
|
||||||
|
statement = select(WordleStats).where(WordleStats.user_id == user_id)
|
||||||
|
stats = (await session.execute(statement)).scalar_one_or_none()
|
||||||
|
if stats is not None:
|
||||||
|
return stats
|
||||||
|
|
||||||
|
stats = WordleStats(user_id=user_id)
|
||||||
|
session.add(stats)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(stats)
|
||||||
|
|
||||||
|
return stats
|
||||||
|
|
||||||
|
|
||||||
|
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
|
||||||
|
"""Update the user's Wordle stats"""
|
||||||
|
stats = await get_wordle_stats(session, user_id)
|
||||||
|
stats.games += 1
|
||||||
|
|
||||||
|
if win:
|
||||||
|
stats.wins += 1
|
||||||
|
|
||||||
|
# Update streak
|
||||||
|
today = date.today()
|
||||||
|
last_win = stats.last_win
|
||||||
|
stats.last_win = today
|
||||||
|
|
||||||
|
if last_win is None or (today - last_win).days > 1:
|
||||||
|
# Never won a game before or streak is over
|
||||||
|
stats.current_streak = 1
|
||||||
|
else:
|
||||||
|
# On a streak: increase counter
|
||||||
|
stats.current_streak += 1
|
||||||
|
|
||||||
|
# Update max streak if necessary
|
||||||
|
if stats.current_streak > stats.highest_streak:
|
||||||
|
stats.highest_streak = stats.current_streak
|
||||||
|
else:
|
||||||
|
# Streak is over
|
||||||
|
stats.current_streak = 0
|
||||||
|
|
||||||
|
session.add(stats)
|
||||||
|
await session.commit()
|
|
@ -1,6 +1,5 @@
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
import motor.motor_asyncio
|
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
@ -26,16 +25,3 @@ postgres_engine = create_async_engine(
|
||||||
DBSession = sessionmaker(
|
DBSession = sessionmaker(
|
||||||
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
)
|
)
|
||||||
|
|
||||||
# MongoDB client
|
|
||||||
if not settings.TESTING: # pragma: no cover
|
|
||||||
encoded_mongo_username = quote_plus(settings.MONGO_USER)
|
|
||||||
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
|
|
||||||
mongo_url = (
|
|
||||||
f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# Require no authentication when testing
|
|
||||||
mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
|
||||||
|
|
||||||
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
import enum
|
import enum
|
||||||
|
|
||||||
__all__ = ["TaskType", "TempStorageKey"]
|
__all__ = ["TaskType"]
|
||||||
|
|
||||||
|
|
||||||
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
||||||
|
@ -11,10 +11,3 @@ class TaskType(enum.IntEnum):
|
||||||
|
|
||||||
BIRTHDAYS = enum.auto()
|
BIRTHDAYS = enum.auto()
|
||||||
UFORA_ANNOUNCEMENTS = enum.auto()
|
UFORA_ANNOUNCEMENTS = enum.auto()
|
||||||
|
|
||||||
|
|
||||||
@enum.unique
|
|
||||||
class TempStorageKey(str, enum.Enum):
|
|
||||||
"""Enum for keys to distinguish the TemporaryStorage rows"""
|
|
||||||
|
|
||||||
WORDLE_WORD = "wordle_word"
|
|
||||||
|
|
|
@ -1,6 +0,0 @@
|
||||||
import motor.motor_asyncio
|
|
||||||
|
|
||||||
# Type aliases for the Motor types, which are way too long
|
|
||||||
MongoClient = motor.motor_asyncio.AsyncIOMotorClient
|
|
||||||
MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase
|
|
||||||
MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection
|
|
|
@ -37,6 +37,9 @@ __all__ = [
|
||||||
"UforaCourse",
|
"UforaCourse",
|
||||||
"UforaCourseAlias",
|
"UforaCourseAlias",
|
||||||
"User",
|
"User",
|
||||||
|
"WordleGuess",
|
||||||
|
"WordleStats",
|
||||||
|
"WordleWord",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -231,3 +234,47 @@ class User(Base):
|
||||||
nightly_data: NightlyData = relationship(
|
nightly_data: NightlyData = relationship(
|
||||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
wordle_guesses: list[WordleGuess] = relationship(
|
||||||
|
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
wordle_stats: WordleStats = relationship(
|
||||||
|
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class WordleGuess(Base):
|
||||||
|
"""A user's Wordle guesses for today"""
|
||||||
|
|
||||||
|
__tablename__ = "wordle_guesses"
|
||||||
|
|
||||||
|
wordle_guess_id: int = Column(Integer, primary_key=True)
|
||||||
|
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||||
|
guess: str = Column(Text, nullable=False)
|
||||||
|
|
||||||
|
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
|
||||||
|
|
||||||
|
|
||||||
|
class WordleStats(Base):
|
||||||
|
"""Stats about a user's wordle performance"""
|
||||||
|
|
||||||
|
__tablename__ = "wordle_stats"
|
||||||
|
|
||||||
|
wordle_stats_id: int = Column(Integer, primary_key=True)
|
||||||
|
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||||
|
last_win: Optional[date] = Column(Date, nullable=True)
|
||||||
|
games: int = Column(Integer, server_default="0", nullable=False)
|
||||||
|
wins: int = Column(Integer, server_default="0", nullable=False)
|
||||||
|
current_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||||
|
highest_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||||
|
|
||||||
|
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
|
||||||
|
|
||||||
|
|
||||||
|
class WordleWord(Base):
|
||||||
|
"""The current Wordle word"""
|
||||||
|
|
||||||
|
__tablename__ = "wordle_word"
|
||||||
|
|
||||||
|
word_id: int = Column(Integer, primary_key=True)
|
||||||
|
word: str = Column(Text, nullable=False)
|
||||||
|
day: date = Column(Date, nullable=False, unique=True)
|
|
@ -1,53 +0,0 @@
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
|
|
||||||
from bson import ObjectId
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
|
|
||||||
|
|
||||||
|
|
||||||
class PyObjectId(ObjectId):
|
|
||||||
"""Custom type for bson ObjectIds"""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls):
|
|
||||||
yield cls.validate
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate(cls, value: str):
|
|
||||||
"""Check that a string is a valid bson ObjectId"""
|
|
||||||
if not ObjectId.is_valid(value):
|
|
||||||
raise ValueError(f"Invalid ObjectId: '{value}'")
|
|
||||||
|
|
||||||
return ObjectId(value)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __modify_schema__(cls, field_schema: dict):
|
|
||||||
field_schema.update(type="string")
|
|
||||||
|
|
||||||
|
|
||||||
class MongoBase(BaseModel):
|
|
||||||
"""Base model that properly sets the _id field, and adds one by default"""
|
|
||||||
|
|
||||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
"""Configuration for encoding and construction"""
|
|
||||||
|
|
||||||
allow_population_by_field_name = True
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
json_encoders = {ObjectId: str, PyObjectId: str}
|
|
||||||
use_enum_values = True
|
|
||||||
|
|
||||||
|
|
||||||
class MongoCollection(MongoBase, ABC):
|
|
||||||
"""Base model for the 'main class' in a collection
|
|
||||||
|
|
||||||
This field stores the name of the collection to avoid making typos against it
|
|
||||||
"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@abstractmethod
|
|
||||||
def collection() -> str:
|
|
||||||
"""Getter for the name of the collection, in order to avoid typos"""
|
|
||||||
raise NotImplementedError
|
|
|
@ -1,40 +0,0 @@
|
||||||
import datetime
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from overrides import overrides
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
|
|
||||||
from database.schemas.mongo.common import MongoCollection
|
|
||||||
|
|
||||||
__all__ = ["GameStats", "WordleStats"]
|
|
||||||
|
|
||||||
|
|
||||||
class WordleStats(BaseModel):
|
|
||||||
"""Model that holds stats about a player's Wordle performance"""
|
|
||||||
|
|
||||||
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
|
||||||
last_win: Optional[datetime.datetime] = None
|
|
||||||
wins: int = 0
|
|
||||||
games: int = 0
|
|
||||||
current_streak: int = 0
|
|
||||||
max_streak: int = 0
|
|
||||||
|
|
||||||
@validator("guess_distribution")
|
|
||||||
def validate_guesses_length(cls, value: list[int]):
|
|
||||||
"""Check that the distribution of guesses is of the correct length"""
|
|
||||||
if len(value) != 6:
|
|
||||||
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
class GameStats(MongoCollection):
|
|
||||||
"""Collection that holds stats about how well a user has performed in games"""
|
|
||||||
|
|
||||||
user_id: int
|
|
||||||
wordle: WordleStats = WordleStats()
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@overrides
|
|
||||||
def collection() -> str:
|
|
||||||
return "game_stats"
|
|
|
@ -1,16 +0,0 @@
|
||||||
from overrides import overrides
|
|
||||||
|
|
||||||
from database.schemas.mongo.common import MongoCollection
|
|
||||||
|
|
||||||
__all__ = ["TemporaryStorage"]
|
|
||||||
|
|
||||||
|
|
||||||
class TemporaryStorage(MongoCollection):
|
|
||||||
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
|
||||||
|
|
||||||
key: str
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@overrides
|
|
||||||
def collection() -> str:
|
|
||||||
return "temporary"
|
|
|
@ -1,44 +0,0 @@
|
||||||
import datetime
|
|
||||||
|
|
||||||
from overrides import overrides
|
|
||||||
from pydantic import Field, validator
|
|
||||||
|
|
||||||
from database.constants import WORDLE_GUESS_COUNT
|
|
||||||
from database.schemas.mongo.common import MongoCollection
|
|
||||||
from database.utils.datetime import today_only_date
|
|
||||||
|
|
||||||
__all__ = ["WordleGame"]
|
|
||||||
|
|
||||||
|
|
||||||
class WordleGame(MongoCollection):
|
|
||||||
"""Collection that holds people's active Wordle games"""
|
|
||||||
|
|
||||||
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
|
|
||||||
guesses: list[str] = Field(default_factory=list)
|
|
||||||
user_id: int
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@overrides
|
|
||||||
def collection() -> str:
|
|
||||||
return "wordle"
|
|
||||||
|
|
||||||
@validator("guesses")
|
|
||||||
def validate_guesses_length(cls, value: list[int]):
|
|
||||||
"""Check that the amount of guesses is of the correct length"""
|
|
||||||
if len(value) > 6:
|
|
||||||
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def is_game_over(self, word: str) -> bool:
|
|
||||||
"""Check if the current game is over"""
|
|
||||||
# No guesses yet
|
|
||||||
if not self.guesses:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Max amount of guesses allowed
|
|
||||||
if len(self.guesses) == WORDLE_GUESS_COUNT:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Found the correct word
|
|
||||||
return self.guesses[-1] == word
|
|
|
@ -1,19 +1,15 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Generic, TypeVar
|
|
||||||
|
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import links, memes, ufora_courses, wordle
|
from database.crud import links, memes, ufora_courses, wordle
|
||||||
from database.mongo_types import MongoDatabase
|
|
||||||
|
|
||||||
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
|
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
class DatabaseCache(ABC):
|
||||||
class DatabaseCache(ABC, Generic[T]):
|
|
||||||
"""Base class for a simple cache-like structure
|
"""Base class for a simple cache-like structure
|
||||||
|
|
||||||
The goal of this class is to store data for Discord auto-completion results
|
The goal of this class is to store data for Discord auto-completion results
|
||||||
|
@ -25,7 +21,7 @@ class DatabaseCache(ABC, Generic[T]):
|
||||||
Considering the fact that a user isn't obligated to choose something from the suggestions,
|
Considering the fact that a user isn't obligated to choose something from the suggestions,
|
||||||
chances are high we have to go to the database for the final action either way.
|
chances are high we have to go to the database for the final action either way.
|
||||||
|
|
||||||
Also stores the data in lowercase to allow fast searching
|
Also stores the data in lowercase to allow fast searching.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
data: list[str] = []
|
data: list[str] = []
|
||||||
|
@ -36,7 +32,7 @@ class DatabaseCache(ABC, Generic[T]):
|
||||||
self.data.clear()
|
self.data.clear()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def invalidate(self, database_session: T):
|
async def invalidate(self, database_session: AsyncSession):
|
||||||
"""Invalidate the data stored in this cache"""
|
"""Invalidate the data stored in this cache"""
|
||||||
|
|
||||||
def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
|
def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
|
||||||
|
@ -48,7 +44,7 @@ class DatabaseCache(ABC, Generic[T]):
|
||||||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||||
|
|
||||||
|
|
||||||
class LinkCache(DatabaseCache[AsyncSession]):
|
class LinkCache(DatabaseCache):
|
||||||
"""Cache to store the names of links"""
|
"""Cache to store the names of links"""
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
@ -61,7 +57,7 @@ class LinkCache(DatabaseCache[AsyncSession]):
|
||||||
self.data_transformed = list(map(str.lower, self.data))
|
self.data_transformed = list(map(str.lower, self.data))
|
||||||
|
|
||||||
|
|
||||||
class MemeCache(DatabaseCache[AsyncSession]):
|
class MemeCache(DatabaseCache):
|
||||||
"""Cache to store the names of meme templates"""
|
"""Cache to store the names of meme templates"""
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
|
@ -74,7 +70,7 @@ class MemeCache(DatabaseCache[AsyncSession]):
|
||||||
self.data_transformed = list(map(str.lower, self.data))
|
self.data_transformed = list(map(str.lower, self.data))
|
||||||
|
|
||||||
|
|
||||||
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
class UforaCourseCache(DatabaseCache):
|
||||||
"""Cache to store the names of Ufora courses"""
|
"""Cache to store the names of Ufora courses"""
|
||||||
|
|
||||||
# Also store the aliases to add additional support
|
# Also store the aliases to add additional support
|
||||||
|
@ -119,10 +115,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
|
||||||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||||
|
|
||||||
|
|
||||||
class WordleCache(DatabaseCache[MongoDatabase]):
|
class WordleCache(DatabaseCache):
|
||||||
"""Cache to store the current daily Wordle word"""
|
"""Cache to store the current daily Wordle word"""
|
||||||
|
|
||||||
async def invalidate(self, database_session: MongoDatabase):
|
async def invalidate(self, database_session: AsyncSession):
|
||||||
word = await wordle.get_daily_word(database_session)
|
word = await wordle.get_daily_word(database_session)
|
||||||
if word is not None:
|
if word is not None:
|
||||||
self.data = [word]
|
self.data = [word]
|
||||||
|
@ -142,9 +138,9 @@ class CacheManager:
|
||||||
self.ufora_courses = UforaCourseCache()
|
self.ufora_courses = UforaCourseCache()
|
||||||
self.wordle_word = WordleCache()
|
self.wordle_word = WordleCache()
|
||||||
|
|
||||||
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
async def initialize_caches(self, postgres_session: AsyncSession):
|
||||||
"""Initialize the contents of all caches"""
|
"""Initialize the contents of all caches"""
|
||||||
await self.links.invalidate(postgres_session)
|
await self.links.invalidate(postgres_session)
|
||||||
await self.memes.invalidate(postgres_session)
|
await self.memes.invalidate(postgres_session)
|
||||||
await self.ufora_courses.invalidate(postgres_session)
|
await self.ufora_courses.invalidate(postgres_session)
|
||||||
await self.wordle_word.invalidate(mongo_db)
|
await self.wordle_word.invalidate(postgres_session)
|
||||||
|
|
|
@ -1,15 +1,5 @@
|
||||||
import datetime
|
|
||||||
import zoneinfo
|
import zoneinfo
|
||||||
|
|
||||||
__all__ = ["LOCAL_TIMEZONE", "today_only_date"]
|
__all__ = ["LOCAL_TIMEZONE"]
|
||||||
|
|
||||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||||
|
|
||||||
|
|
||||||
def today_only_date() -> datetime.datetime:
|
|
||||||
"""Mongo can't handle datetime.date, so we need a datetime instance
|
|
||||||
|
|
||||||
We do, however, only care about the date, so remove all the rest
|
|
||||||
"""
|
|
||||||
today = datetime.date.today()
|
|
||||||
return datetime.datetime(year=today.year, month=today.month, day=today.day)
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from overrides import overrides
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
|
|
||||||
|
|
||||||
class TestCog(commands.Cog):
|
class DebugCog(commands.Cog):
|
||||||
"""Testing cog for dev purposes"""
|
"""Testing cog for dev purposes"""
|
||||||
|
|
||||||
client: Didier
|
client: Didier
|
||||||
|
@ -16,11 +16,11 @@ class TestCog(commands.Cog):
|
||||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||||
return await self.client.is_owner(ctx.author)
|
return await self.client.is_owner(ctx.author)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command(aliases=["Dev"])
|
||||||
async def test(self, ctx: commands.Context):
|
async def debug(self, ctx: commands.Context):
|
||||||
"""Debugging command"""
|
"""Debugging command"""
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
"""Load the cog"""
|
"""Load the cog"""
|
||||||
await client.add_cog(TestCog(client))
|
await client.add_cog(DebugCog(client))
|
|
@ -5,7 +5,7 @@ from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from database.crud.links import get_link_by_name
|
from database.crud.links import get_link_by_name
|
||||||
from database.schemas.relational import Link
|
from database.schemas import Link
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.apis import urban_dictionary
|
from didier.data.apis import urban_dictionary
|
||||||
from didier.data.embeds.google import GoogleSearch
|
from didier.data.embeds.google import GoogleSearch
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.schemas.relational import MemeTemplate
|
from database.schemas import MemeTemplate
|
||||||
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
||||||
from didier.utils.http.requests import ensure_post
|
from didier.utils.http.requests import ensure_post
|
||||||
|
|
||||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
||||||
import discord
|
import discord
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
from database.schemas.relational import Deadline
|
from database.schemas import Deadline
|
||||||
from didier.data.embeds.base import EmbedBaseModel
|
from didier.data.embeds.base import EmbedBaseModel
|
||||||
from didier.utils.types.datetime import tz_aware_now
|
from didier.utils.types.datetime import tz_aware_now
|
||||||
from didier.utils.types.string import get_edu_year_name
|
from didier.utils.types.string import get_edu_year_name
|
||||||
|
|
|
@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.crud import ufora_announcements as crud
|
from database.crud import ufora_announcements as crud
|
||||||
from database.schemas.relational import UforaCourse
|
from database.schemas import UforaCourse
|
||||||
from didier.data.embeds.base import EmbedBaseModel
|
from didier.data.embeds.base import EmbedBaseModel
|
||||||
from didier.utils.discord.colours import ghent_university_blue
|
from didier.utils.discord.colours import ghent_university_blue
|
||||||
from didier.utils.types.datetime import int_to_weekday
|
from didier.utils.types.datetime import int_to_weekday
|
||||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
import motor.motor_asyncio
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from discord.app_commands import AppCommandError
|
from discord.app_commands import AppCommandError
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.crud import custom_commands
|
from database.crud import custom_commands
|
||||||
from database.engine import DBSession, mongo_client
|
from database.engine import DBSession
|
||||||
from database.utils.caches import CacheManager
|
from database.utils.caches import CacheManager
|
||||||
from didier.data.embeds.error_embed import create_error_embed
|
from didier.data.embeds.error_embed import create_error_embed
|
||||||
from didier.exceptions import HTTPException, NoMatch
|
from didier.exceptions import HTTPException, NoMatch
|
||||||
|
@ -55,11 +54,6 @@ class Didier(commands.Bot):
|
||||||
"""Obtain a session for the PostgreSQL database"""
|
"""Obtain a session for the PostgreSQL database"""
|
||||||
return DBSession()
|
return DBSession()
|
||||||
|
|
||||||
@property
|
|
||||||
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
|
||||||
"""Obtain a reference to the MongoDB database"""
|
|
||||||
return mongo_client[settings.MONGO_DB]
|
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Do some initial setup
|
"""Do some initial setup
|
||||||
|
|
||||||
|
@ -71,7 +65,7 @@ class Didier(commands.Bot):
|
||||||
# Initialize caches
|
# Initialize caches
|
||||||
self.database_caches = CacheManager()
|
self.database_caches = CacheManager()
|
||||||
async with self.postgres_session as session:
|
async with self.postgres_session as session:
|
||||||
await self.database_caches.initialize_caches(session, self.mongo_db)
|
await self.database_caches.initialize_caches(session)
|
||||||
|
|
||||||
# Load extensions
|
# Load extensions
|
||||||
await self._load_initial_extensions()
|
await self._load_initial_extensions()
|
||||||
|
|
|
@ -5,7 +5,7 @@ from discord import Interaction
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
from database.crud.deadlines import add_deadline
|
from database.crud.deadlines import add_deadline
|
||||||
from database.schemas.relational import UforaCourse
|
from database.schemas import UforaCourse
|
||||||
|
|
||||||
__all__ = ["AddDeadline"]
|
__all__ = ["AddDeadline"]
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@ import traceback
|
||||||
import discord.ui
|
import discord.ui
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
|
|
||||||
from database.schemas.relational import MemeTemplate
|
from database.schemas import MemeTemplate
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.apis.imgflip import generate_meme
|
from didier.data.apis.imgflip import generate_meme
|
||||||
|
|
||||||
|
|
|
@ -10,10 +10,3 @@ services:
|
||||||
- POSTGRES_PASSWORD=pytest
|
- POSTGRES_PASSWORD=pytest
|
||||||
ports:
|
ports:
|
||||||
- "5433:5432"
|
- "5433:5432"
|
||||||
mongo-pytest:
|
|
||||||
image: mongo:5.0
|
|
||||||
restart: always
|
|
||||||
environment:
|
|
||||||
- MONGO_INITDB_DATABASE=didier_pytest
|
|
||||||
ports:
|
|
||||||
- "27018:27017"
|
|
||||||
|
|
|
@ -12,18 +12,5 @@ services:
|
||||||
- "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
|
- "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
|
||||||
volumes:
|
volumes:
|
||||||
- postgres:/var/lib/postgresql/data
|
- postgres:/var/lib/postgresql/data
|
||||||
mongo:
|
|
||||||
image: mongo:5.0
|
|
||||||
restart: always
|
|
||||||
environment:
|
|
||||||
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
|
|
||||||
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
|
|
||||||
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
|
|
||||||
command: [--auth]
|
|
||||||
ports:
|
|
||||||
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
|
|
||||||
volumes:
|
|
||||||
- mongo:/data/db
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres:
|
postgres:
|
||||||
mongo:
|
|
||||||
|
|
|
@ -44,9 +44,6 @@ ignore_missing_imports = true
|
||||||
asyncio_mode = "auto"
|
asyncio_mode = "auto"
|
||||||
env = [
|
env = [
|
||||||
"TESTING = 1",
|
"TESTING = 1",
|
||||||
"MONGO_DB = didier_pytest",
|
|
||||||
"MONGO_HOST = localhost",
|
|
||||||
"MONGO_PORT = 27018",
|
|
||||||
"POSTGRES_DB = didier_pytest",
|
"POSTGRES_DB = didier_pytest",
|
||||||
"POSTGRES_USER = pytest",
|
"POSTGRES_USER = pytest",
|
||||||
"POSTGRES_PASS = pytest",
|
"POSTGRES_PASS = pytest",
|
||||||
|
@ -55,6 +52,5 @@ env = [
|
||||||
"DISCORD_TOKEN = token"
|
"DISCORD_TOKEN = token"
|
||||||
]
|
]
|
||||||
markers = [
|
markers = [
|
||||||
"mongo: tests that use MongoDB",
|
|
||||||
"postgres: tests that use PostgreSQL"
|
"postgres: tests that use PostgreSQL"
|
||||||
]
|
]
|
||||||
|
|
|
@ -7,7 +7,6 @@ git+https://github.com/Rapptz/discord-ext-menus@8686b5d
|
||||||
environs==9.5.0
|
environs==9.5.0
|
||||||
feedparser==6.0.10
|
feedparser==6.0.10
|
||||||
markdownify==0.11.2
|
markdownify==0.11.2
|
||||||
motor==3.0.0
|
|
||||||
overrides==6.1.0
|
overrides==6.1.0
|
||||||
pydantic==1.9.1
|
pydantic==1.9.1
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
|
|
|
@ -37,13 +37,6 @@ SEMESTER: int = env.int("SEMESTER", 2)
|
||||||
YEAR: int = env.int("YEAR", 3)
|
YEAR: int = env.int("YEAR", 3)
|
||||||
|
|
||||||
"""Database"""
|
"""Database"""
|
||||||
# MongoDB
|
|
||||||
MONGO_DB: str = env.str("MONGO_DB", "didier")
|
|
||||||
MONGO_USER: str = env.str("MONGO_USER", "root")
|
|
||||||
MONGO_PASS: str = env.str("MONGO_PASS", "root")
|
|
||||||
MONGO_HOST: str = env.str("MONGO_HOST", "localhost")
|
|
||||||
MONGO_PORT: int = env.int("MONGO_PORT", "27017")
|
|
||||||
|
|
||||||
# PostgreSQL
|
# PostgreSQL
|
||||||
POSTGRES_DB: str = env.str("POSTGRES_DB", "didier")
|
POSTGRES_DB: str = env.str("POSTGRES_DB", "didier")
|
||||||
POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres")
|
POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres")
|
||||||
|
|
|
@ -2,12 +2,10 @@ import asyncio
|
||||||
from typing import AsyncGenerator, Generator
|
from typing import AsyncGenerator, Generator
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import motor.motor_asyncio
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import settings
|
from database.engine import postgres_engine
|
||||||
from database.engine import mongo_client, postgres_engine
|
|
||||||
from database.migrations import ensure_latest_migration, migrate
|
from database.migrations import ensure_latest_migration, migrate
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
|
|
||||||
|
@ -56,14 +54,6 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||||
await connection.close()
|
await connection.close()
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
|
||||||
"""Fixture to get a MongoDB connection"""
|
|
||||||
database = mongo_client[settings.MONGO_DB]
|
|
||||||
yield database
|
|
||||||
mongo_client.drop_database(settings.MONGO_DB)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_client() -> Didier:
|
def mock_client() -> Didier:
|
||||||
"""Fixture to get a mock Didier instance
|
"""Fixture to get a mock Didier instance
|
||||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.schemas.relational import (
|
from database.schemas import (
|
||||||
Bank,
|
Bank,
|
||||||
UforaAnnouncement,
|
UforaAnnouncement,
|
||||||
UforaCourse,
|
UforaCourse,
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import birthdays as crud
|
from database.crud import birthdays as crud
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.schemas.relational import User
|
from database.schemas import User
|
||||||
|
|
||||||
|
|
||||||
async def test_add_birthday_not_present(postgres: AsyncSession, user: User):
|
async def test_add_birthday_not_present(postgres: AsyncSession, user: User):
|
||||||
|
|
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import currency as crud
|
from database.crud import currency as crud
|
||||||
from database.exceptions import currency as exceptions
|
from database.exceptions import currency as exceptions
|
||||||
from database.schemas.relational import Bank
|
from database.schemas import Bank
|
||||||
|
|
||||||
|
|
||||||
async def test_add_dinks(postgres: AsyncSession, bank: Bank):
|
async def test_add_dinks(postgres: AsyncSession, bank: Bank):
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from database.crud import custom_commands as crud
|
from database.crud import custom_commands as crud
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from database.schemas.relational import CustomCommand
|
from database.schemas import CustomCommand
|
||||||
|
|
||||||
|
|
||||||
async def test_create_command_non_existing(postgres: AsyncSession):
|
async def test_create_command_non_existing(postgres: AsyncSession):
|
||||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import dad_jokes as crud
|
from database.crud import dad_jokes as crud
|
||||||
from database.schemas.relational import DadJoke
|
from database.schemas import DadJoke
|
||||||
|
|
||||||
|
|
||||||
async def test_add_dad_joke(postgres: AsyncSession):
|
async def test_add_dad_joke(postgres: AsyncSession):
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
import pytest
|
|
||||||
from freezegun import freeze_time
|
|
||||||
|
|
||||||
from database.crud import game_stats as crud
|
|
||||||
from database.mongo_types import MongoDatabase
|
|
||||||
from database.schemas.mongo.game_stats import GameStats
|
|
||||||
from database.utils.datetime import today_only_date
|
|
||||||
|
|
||||||
|
|
||||||
async def insert_game_stats(mongodb: MongoDatabase, stats: GameStats):
|
|
||||||
"""Helper function to insert some stats"""
|
|
||||||
collection = mongodb[GameStats.collection()]
|
|
||||||
await collection.insert_one(stats.dict(by_alias=True))
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
|
||||||
async def test_get_stats_non_existent_creates(mongodb: MongoDatabase, test_user_id: int):
|
|
||||||
"""Test getting a user's stats when the db is empty"""
|
|
||||||
collection = mongodb[GameStats.collection()]
|
|
||||||
assert await collection.find_one({"user_id": test_user_id}) is None
|
|
||||||
await crud.get_game_stats(mongodb, test_user_id)
|
|
||||||
assert await collection.find_one({"user_id": test_user_id}) is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
|
||||||
async def test_get_stats_existing_returns(mongodb: MongoDatabase, test_user_id: int):
|
|
||||||
"""Test getting a user's stats when there's already an entry present"""
|
|
||||||
stats = GameStats(user_id=test_user_id)
|
|
||||||
stats.wordle.games = 20
|
|
||||||
await insert_game_stats(mongodb, stats)
|
|
||||||
found_stats = await crud.get_game_stats(mongodb, test_user_id)
|
|
||||||
assert found_stats.wordle.games == 20
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
|
||||||
@freeze_time("2022-07-30")
|
|
||||||
async def test_complete_wordle_game_won(mongodb: MongoDatabase, test_user_id: int):
|
|
||||||
"""Test completing a wordle game when you win"""
|
|
||||||
await crud.complete_wordle_game(mongodb, test_user_id, win=True, guesses=2)
|
|
||||||
stats = await crud.get_game_stats(mongodb, test_user_id)
|
|
||||||
assert stats.wordle.guess_distribution == [0, 1, 0, 0, 0, 0]
|
|
||||||
assert stats.wordle.games == 1
|
|
||||||
assert stats.wordle.wins == 1
|
|
||||||
assert stats.wordle.current_streak == 1
|
|
||||||
assert stats.wordle.max_streak == 1
|
|
||||||
assert stats.wordle.last_win == today_only_date()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
|
||||||
@freeze_time("2022-07-30")
|
|
||||||
async def test_complete_wordle_game_lost(mongodb: MongoDatabase, test_user_id: int):
|
|
||||||
"""Test completing a wordle game when you lose"""
|
|
||||||
stats = GameStats(user_id=test_user_id)
|
|
||||||
stats.wordle.current_streak = 10
|
|
||||||
await insert_game_stats(mongodb, stats)
|
|
||||||
|
|
||||||
await crud.complete_wordle_game(mongodb, test_user_id, win=False)
|
|
||||||
stats = await crud.get_game_stats(mongodb, test_user_id)
|
|
||||||
|
|
||||||
# Check that streak was broken
|
|
||||||
assert stats.wordle.current_streak == 0
|
|
||||||
assert stats.wordle.games == 1
|
|
||||||
assert stats.wordle.guess_distribution == [0, 0, 0, 0, 0, 0]
|
|
|
@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import tasks as crud
|
from database.crud import tasks as crud
|
||||||
from database.enums import TaskType
|
from database.enums import TaskType
|
||||||
from database.schemas.relational import Task
|
from database.schemas import Task
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
@ -3,7 +3,7 @@ import datetime
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import ufora_announcements as crud
|
from database.crud import ufora_announcements as crud
|
||||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
from database.schemas import UforaAnnouncement, UforaCourse
|
||||||
|
|
||||||
|
|
||||||
async def test_get_courses_with_announcements_none(postgres: AsyncSession):
|
async def test_get_courses_with_announcements_none(postgres: AsyncSession):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import ufora_courses as crud
|
from database.crud import ufora_courses as crud
|
||||||
from database.schemas.relational import UforaCourse
|
from database.schemas import UforaCourse
|
||||||
|
|
||||||
|
|
||||||
async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse):
|
async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import users as crud
|
from database.crud import users as crud
|
||||||
from database.schemas.relational import User
|
from database.schemas import User
|
||||||
|
|
||||||
|
|
||||||
async def test_get_or_add_non_existing(postgres: AsyncSession):
|
async def test_get_or_add_non_existing(postgres: AsyncSession):
|
||||||
|
|
|
@ -1,136 +1,140 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import date, timedelta
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import wordle as crud
|
from database.crud import wordle as crud
|
||||||
from database.enums import TempStorageKey
|
from database.schemas import User, WordleGuess, WordleWord
|
||||||
from database.mongo_types import MongoCollection, MongoDatabase
|
|
||||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
|
||||||
from database.schemas.mongo.wordle import WordleGame
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def wordle_collection(mongodb: MongoDatabase) -> MongoCollection:
|
async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]:
|
||||||
"""Fixture to get a reference to the wordle collection"""
|
"""Fixture to generate some guesses"""
|
||||||
yield mongodb[WordleGame.collection()]
|
guesses = []
|
||||||
|
|
||||||
|
for guess in ["TEST", "WORDLE", "WORDS"]:
|
||||||
|
guess = WordleGuess(user_id=user.user_id, guess=guess)
|
||||||
|
postgres.add(guess)
|
||||||
|
await postgres.commit()
|
||||||
|
|
||||||
|
guesses.append(guess)
|
||||||
|
|
||||||
|
return guesses
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.mark.postgres
|
||||||
async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> WordleGame:
|
async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User):
|
||||||
"""Fixture to create a new game"""
|
|
||||||
game = WordleGame(user_id=test_user_id)
|
|
||||||
await wordle_collection.insert_one(game.dict(by_alias=True))
|
|
||||||
yield game
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
|
||||||
async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int):
|
|
||||||
"""Test starting a new game"""
|
|
||||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
await crud.start_new_wordle_game(mongodb, test_user_id)
|
|
||||||
|
|
||||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
|
||||||
assert result is not None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
|
||||||
async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int):
|
|
||||||
"""Test getting an active game when there is none"""
|
"""Test getting an active game when there is none"""
|
||||||
result = await crud.get_active_wordle_game(mongodb, test_user_id)
|
result = await crud.get_active_wordle_game(postgres, user.user_id)
|
||||||
assert result is None
|
assert not result
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame):
|
async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]):
|
||||||
"""Test getting an active game when there is one"""
|
"""Test getting an active game when there is one"""
|
||||||
result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id)
|
result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id)
|
||||||
assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True)
|
assert result == wordle_guesses
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
async def test_get_daily_word_none(mongodb: MongoDatabase):
|
async def test_get_daily_word_none(postgres: AsyncSession):
|
||||||
"""Test getting the daily word when the database is empty"""
|
"""Test getting the daily word when the database is empty"""
|
||||||
result = await crud.get_daily_word(mongodb)
|
result = await crud.get_daily_word(postgres)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
@freeze_time("2022-07-30")
|
@freeze_time("2022-07-30")
|
||||||
async def test_get_daily_word_not_today(mongodb: MongoDatabase):
|
async def test_get_daily_word_not_today(postgres: AsyncSession):
|
||||||
"""Test getting the daily word when there is an entry, but not for today"""
|
"""Test getting the daily word when there is an entry, but not for today"""
|
||||||
day = datetime.today() - timedelta(days=1)
|
day = date.today() - timedelta(days=1)
|
||||||
collection = mongodb[TemporaryStorage.collection()]
|
|
||||||
|
|
||||||
word = "testword"
|
word = "testword"
|
||||||
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word})
|
word_instance = WordleWord(word=word, day=day)
|
||||||
|
postgres.add(word_instance)
|
||||||
|
await postgres.commit()
|
||||||
|
|
||||||
assert await crud.get_daily_word(mongodb) is None
|
assert await crud.get_daily_word(postgres) is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
@freeze_time("2022-07-30")
|
@freeze_time("2022-07-30")
|
||||||
async def test_get_daily_word_present(mongodb: MongoDatabase):
|
async def test_get_daily_word_present(postgres: AsyncSession):
|
||||||
"""Test getting the daily word when there is one for today"""
|
"""Test getting the daily word when there is one for today"""
|
||||||
day = datetime.today()
|
day = date.today()
|
||||||
collection = mongodb[TemporaryStorage.collection()]
|
|
||||||
|
|
||||||
word = "testword"
|
word = "testword"
|
||||||
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word})
|
word_instance = WordleWord(word=word, day=day)
|
||||||
|
postgres.add(word_instance)
|
||||||
|
await postgres.commit()
|
||||||
|
|
||||||
assert await crud.get_daily_word(mongodb) == word
|
daily_word = await crud.get_daily_word(postgres)
|
||||||
|
assert daily_word is not None
|
||||||
|
assert daily_word.word == word
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
@freeze_time("2022-07-30")
|
@freeze_time("2022-07-30")
|
||||||
async def test_set_daily_word_none_present(mongodb: MongoDatabase):
|
async def test_set_daily_word_none_present(postgres: AsyncSession):
|
||||||
"""Test setting the daily word when there is none"""
|
"""Test setting the daily word when there is none"""
|
||||||
assert await crud.get_daily_word(mongodb) is None
|
assert await crud.get_daily_word(postgres) is None
|
||||||
word = "testword"
|
word = "testword"
|
||||||
await crud.set_daily_word(mongodb, word)
|
await crud.set_daily_word(postgres, word)
|
||||||
assert await crud.get_daily_word(mongodb) == word
|
|
||||||
|
daily_word = await crud.get_daily_word(postgres)
|
||||||
|
assert daily_word is not None
|
||||||
|
assert daily_word.word == word
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
@freeze_time("2022-07-30")
|
@freeze_time("2022-07-30")
|
||||||
async def test_set_daily_word_present(mongodb: MongoDatabase):
|
async def test_set_daily_word_present(postgres: AsyncSession):
|
||||||
"""Test setting the daily word when there already is one"""
|
"""Test setting the daily word when there already is one"""
|
||||||
word = "testword"
|
word = "testword"
|
||||||
await crud.set_daily_word(mongodb, word)
|
await crud.set_daily_word(postgres, word)
|
||||||
await crud.set_daily_word(mongodb, "another word")
|
await crud.set_daily_word(postgres, "another word")
|
||||||
assert await crud.get_daily_word(mongodb) == word
|
|
||||||
|
daily_word = await crud.get_daily_word(postgres)
|
||||||
|
assert daily_word is not None
|
||||||
|
assert daily_word.word == word
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
@freeze_time("2022-07-30")
|
@freeze_time("2022-07-30")
|
||||||
async def test_set_daily_word_force_overwrite(mongodb: MongoDatabase):
|
async def test_set_daily_word_force_overwrite(postgres: AsyncSession):
|
||||||
"""Test setting the daily word when there already is one, but "forced" is set to True"""
|
"""Test setting the daily word when there already is one, but "forced" is set to True"""
|
||||||
word = "testword"
|
word = "testword"
|
||||||
await crud.set_daily_word(mongodb, word)
|
await crud.set_daily_word(postgres, word)
|
||||||
word = "anotherword"
|
word = "anotherword"
|
||||||
await crud.set_daily_word(mongodb, word, forced=True)
|
await crud.set_daily_word(postgres, word, forced=True)
|
||||||
assert await crud.get_daily_word(mongodb) == word
|
|
||||||
|
daily_word = await crud.get_daily_word(postgres)
|
||||||
|
assert daily_word is not None
|
||||||
|
assert daily_word.word == word
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
async def test_make_wordle_guess(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int):
|
async def test_make_wordle_guess(postgres: AsyncSession, user: User):
|
||||||
"""Test making a guess in your current game"""
|
"""Test making a guess in your current game"""
|
||||||
|
test_user_id = user.user_id
|
||||||
|
|
||||||
guess = "guess"
|
guess = "guess"
|
||||||
await crud.make_wordle_guess(mongodb, test_user_id, guess)
|
await crud.make_wordle_guess(postgres, test_user_id, guess)
|
||||||
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id)
|
wordle_guesses = await crud.get_active_wordle_game(postgres, test_user_id)
|
||||||
assert wordle_game.guesses == [guess]
|
assert list(map(lambda x: x.guess, wordle_guesses)) == [guess]
|
||||||
|
|
||||||
other_guess = "otherguess"
|
other_guess = "otherguess"
|
||||||
await crud.make_wordle_guess(mongodb, test_user_id, other_guess)
|
await crud.make_wordle_guess(postgres, test_user_id, other_guess)
|
||||||
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id)
|
wordle_guesses = await crud.get_active_wordle_game(postgres, test_user_id)
|
||||||
assert wordle_game.guesses == [guess, other_guess]
|
assert list(map(lambda x: x.guess, wordle_guesses)) == [guess, other_guess]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.mongo
|
@pytest.mark.postgres
|
||||||
async def test_reset_wordle_games(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int):
|
async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User):
|
||||||
"""Test dropping the collection of active games"""
|
"""Test dropping the collection of active games"""
|
||||||
assert await crud.get_active_wordle_game(mongodb, test_user_id) is not None
|
test_user_id = user.user_id
|
||||||
await crud.reset_wordle_games(mongodb)
|
|
||||||
assert await crud.get_active_wordle_game(mongodb, test_user_id) is None
|
assert await crud.get_active_wordle_game(postgres, test_user_id)
|
||||||
|
await crud.reset_wordle_games(postgres)
|
||||||
|
assert not await crud.get_active_wordle_game(postgres, test_user_id)
|
||||||
|
|
|
@ -0,0 +1,72 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from freezegun import freeze_time
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import wordle_stats as crud
|
||||||
|
from database.schemas import User, WordleStats
|
||||||
|
|
||||||
|
|
||||||
|
async def insert_game_stats(postgres: AsyncSession, stats: WordleStats):
|
||||||
|
"""Helper function to insert some stats"""
|
||||||
|
postgres.add(stats)
|
||||||
|
await postgres.commit()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.postgres
|
||||||
|
async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User):
|
||||||
|
"""Test getting a user's stats when the db is empty"""
|
||||||
|
test_user_id = user.user_id
|
||||||
|
|
||||||
|
statement = select(WordleStats).where(WordleStats.user_id == test_user_id)
|
||||||
|
assert (await postgres.execute(statement)).scalar_one_or_none() is None
|
||||||
|
|
||||||
|
await crud.get_wordle_stats(postgres, test_user_id)
|
||||||
|
assert (await postgres.execute(statement)).scalar_one_or_none() is not None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.postgres
|
||||||
|
async def test_get_stats_existing_returns(postgres: AsyncSession, user: User):
|
||||||
|
"""Test getting a user's stats when there's already an entry present"""
|
||||||
|
test_user_id = user.user_id
|
||||||
|
|
||||||
|
stats = WordleStats(user_id=test_user_id)
|
||||||
|
stats.games = 20
|
||||||
|
await insert_game_stats(postgres, stats)
|
||||||
|
found_stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||||
|
assert found_stats.games == 20
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.postgres
|
||||||
|
@freeze_time("2022-07-30")
|
||||||
|
async def test_complete_wordle_game_won(postgres: AsyncSession, user: User):
|
||||||
|
"""Test completing a wordle game when you win"""
|
||||||
|
test_user_id = user.user_id
|
||||||
|
|
||||||
|
await crud.complete_wordle_game(postgres, test_user_id, win=True)
|
||||||
|
stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||||
|
assert stats.games == 1
|
||||||
|
assert stats.wins == 1
|
||||||
|
assert stats.current_streak == 1
|
||||||
|
assert stats.highest_streak == 1
|
||||||
|
assert stats.last_win == datetime.date.today()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.postgres
|
||||||
|
@freeze_time("2022-07-30")
|
||||||
|
async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User):
|
||||||
|
"""Test completing a wordle game when you lose"""
|
||||||
|
test_user_id = user.user_id
|
||||||
|
|
||||||
|
stats = WordleStats(user_id=test_user_id)
|
||||||
|
stats.current_streak = 10
|
||||||
|
await insert_game_stats(postgres, stats)
|
||||||
|
|
||||||
|
await crud.complete_wordle_game(postgres, test_user_id, win=False)
|
||||||
|
stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||||
|
|
||||||
|
# Check that streak was broken
|
||||||
|
assert stats.current_streak == 0
|
||||||
|
assert stats.games == 1
|
|
@ -1,6 +1,6 @@
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import UforaCourse
|
from database.schemas import UforaCourse
|
||||||
from database.utils.caches import UforaCourseCache
|
from database.utils.caches import UforaCourseCache
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue