mirror of https://github.com/stijndcl/didier
commit
8308b4ad9a
|
@ -38,17 +38,6 @@ jobs:
|
|||
POSTGRES_DB: didier_pytest
|
||||
POSTGRES_USER: pytest
|
||||
POSTGRES_PASSWORD: pytest
|
||||
mongo:
|
||||
image: mongo:5.0
|
||||
options: >-
|
||||
--health-cmd mongo
|
||||
--health-interval 10s
|
||||
--health-timeout 5s
|
||||
--health-retries 5
|
||||
ports:
|
||||
- 27018:27017
|
||||
env:
|
||||
MONGO_DB: didier_pytest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Setup Python
|
||||
|
|
|
@ -44,9 +44,3 @@ repos:
|
|||
- "flake8-eradicate"
|
||||
- "flake8-isort"
|
||||
- "flake8-simplify"
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: v0.961
|
||||
hooks:
|
||||
- id: mypy
|
||||
args: [--config, pyproject.toml]
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncEngine
|
|||
|
||||
from alembic import context
|
||||
from database.engine import postgres_engine
|
||||
from database.schemas.relational import Base
|
||||
from database.schemas import Base
|
||||
|
||||
# this is the Alembic Config object, which provides
|
||||
# access to the values within the .ini file in use.
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
"""Wordle
|
||||
|
||||
Revision ID: 38b7c29f10ee
|
||||
Revises: 36300b558ef1
|
||||
Create Date: 2022-08-29 20:21:02.413631
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "38b7c29f10ee"
|
||||
down_revision = "36300b558ef1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"wordle_word",
|
||||
sa.Column("word_id", sa.Integer(), nullable=False),
|
||||
sa.Column("word", sa.Text(), nullable=False),
|
||||
sa.Column("day", sa.Date(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("word_id"),
|
||||
sa.UniqueConstraint("day"),
|
||||
)
|
||||
op.create_table(
|
||||
"wordle_guesses",
|
||||
sa.Column("wordle_guess_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("guess", sa.Text(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("wordle_guess_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"wordle_stats",
|
||||
sa.Column("wordle_stats_id", sa.Integer(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("last_win", sa.Date(), nullable=True),
|
||||
sa.Column("games", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("wins", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("current_streak", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.Column("highest_streak", sa.Integer(), server_default="0", nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("wordle_stats_id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("wordle_stats")
|
||||
op.drop_table("wordle_guesses")
|
||||
op.drop_table("wordle_word")
|
||||
# ### end Alembic commands ###
|
|
@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.crud import users
|
||||
from database.schemas.relational import Birthday, User
|
||||
from database.schemas import Birthday, User
|
||||
|
||||
__all__ = ["add_birthday", "get_birthday_for_user", "get_birthdays_on_day"]
|
||||
|
||||
|
@ -17,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date):
|
|||
|
||||
If already present, overwrites the existing one
|
||||
"""
|
||||
user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)])
|
||||
user = await users.get_or_add_user(session, user_id, options=[selectinload(User.birthday)])
|
||||
|
||||
if user.birthday is not None:
|
||||
bd = user.birthday
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.crud import users
|
||||
from database.exceptions import currency as exceptions
|
||||
from database.schemas.relational import Bank, NightlyData
|
||||
from database.schemas import Bank, NightlyData
|
||||
from database.utils.math.currency import (
|
||||
capacity_upgrade_price,
|
||||
interest_upgrade_price,
|
||||
|
@ -29,13 +29,13 @@ NIGHTLY_AMOUNT = 420
|
|||
|
||||
async def get_bank(session: AsyncSession, user_id: int) -> Bank:
|
||||
"""Get a user's bank info"""
|
||||
user = await users.get_or_add(session, user_id)
|
||||
user = await users.get_or_add_user(session, user_id)
|
||||
return user.bank
|
||||
|
||||
|
||||
async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData:
|
||||
"""Get a user's nightly info"""
|
||||
user = await users.get_or_add(session, user_id)
|
||||
user = await users.get_or_add_user(session, user_id)
|
||||
return user.nightly_data
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.schemas.relational import CustomCommand, CustomCommandAlias
|
||||
from database.schemas import CustomCommand, CustomCommandAlias
|
||||
|
||||
__all__ = [
|
||||
"clean_name",
|
||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.schemas.relational import DadJoke
|
||||
from database.schemas import DadJoke
|
||||
|
||||
__all__ = ["add_dad_joke", "get_random_dad_joke"]
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.schemas.relational import Deadline, UforaCourse
|
||||
from database.schemas import Deadline, UforaCourse
|
||||
|
||||
__all__ = ["add_deadline", "get_deadlines"]
|
||||
|
||||
|
|
|
@ -1,59 +0,0 @@
|
|||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.game_stats import GameStats
|
||||
|
||||
__all__ = ["get_game_stats", "complete_wordle_game"]
|
||||
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
|
||||
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
|
||||
"""Get a user's game stats
|
||||
|
||||
If no entry is found, it is first created
|
||||
"""
|
||||
collection = database[GameStats.collection()]
|
||||
stats = await collection.find_one({"user_id": user_id})
|
||||
if stats is not None:
|
||||
return GameStats(**stats)
|
||||
|
||||
stats = GameStats(user_id=user_id)
|
||||
await collection.insert_one(stats.dict(by_alias=True))
|
||||
return stats
|
||||
|
||||
|
||||
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0):
|
||||
"""Update the user's Wordle stats"""
|
||||
stats = await get_game_stats(database, user_id)
|
||||
|
||||
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}}
|
||||
|
||||
if win:
|
||||
update["$inc"]["wordle.wins"] = 1
|
||||
update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1
|
||||
|
||||
# Update streak
|
||||
today = today_only_date()
|
||||
last_win = stats.wordle.last_win
|
||||
update["$set"]["wordle.last_win"] = today
|
||||
|
||||
if last_win is None or (today - last_win).days > 1:
|
||||
# Never won a game before or streak is over
|
||||
update["$set"]["wordle.current_streak"] = 1
|
||||
stats.wordle.current_streak = 1
|
||||
else:
|
||||
# On a streak: increase counter
|
||||
update["$inc"]["wordle.current_streak"] = 1
|
||||
stats.wordle.current_streak += 1
|
||||
|
||||
# Update max streak if necessary
|
||||
if stats.wordle.current_streak > stats.wordle.max_streak:
|
||||
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
|
||||
else:
|
||||
# Streak is over
|
||||
update["$set"]["wordle.current_streak"] = 0
|
||||
|
||||
collection = database[GameStats.collection()]
|
||||
await collection.update_one({"_id": stats.id}, update)
|
|
@ -4,7 +4,7 @@ from sqlalchemy import func, select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions import NoResultFoundException
|
||||
from database.schemas.relational import Link
|
||||
from database.schemas import Link
|
||||
|
||||
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import MemeTemplate
|
||||
from database.schemas import MemeTemplate
|
||||
|
||||
__all__ = ["add_meme", "get_all_memes", "get_meme_by_name"]
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.enums import TaskType
|
||||
from database.schemas.relational import Task
|
||||
from database.schemas import Task
|
||||
from database.utils.datetime import LOCAL_TIMEZONE
|
||||
|
||||
__all__ = ["get_task_by_enum", "set_last_task_execution_time"]
|
||||
|
|
|
@ -3,7 +3,7 @@ import datetime
|
|||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||
from database.schemas import UforaAnnouncement, UforaCourse
|
||||
|
||||
__all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_old_announcements"]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ from typing import Optional
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import UforaCourse, UforaCourseAlias
|
||||
from database.schemas import UforaCourse, UforaCourseAlias
|
||||
|
||||
__all__ = ["get_all_courses", "get_course_by_name"]
|
||||
|
||||
|
|
|
@ -3,14 +3,14 @@ from typing import Optional
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import Bank, NightlyData, User
|
||||
from database.schemas import Bank, NightlyData, User
|
||||
|
||||
__all__ = [
|
||||
"get_or_add",
|
||||
"get_or_add_user",
|
||||
]
|
||||
|
||||
|
||||
async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
|
||||
async def get_or_add_user(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User:
|
||||
"""Get a user's profile
|
||||
|
||||
If it doesn't exist yet, create it (along with all linked datastructures)
|
||||
|
|
|
@ -1,56 +1,54 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from database.enums import TempStorageKey
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
from database.utils.datetime import today_only_date
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.schemas import WordleGuess, WordleWord
|
||||
|
||||
__all__ = [
|
||||
"get_active_wordle_game",
|
||||
"get_wordle_guesses",
|
||||
"make_wordle_guess",
|
||||
"start_new_wordle_game",
|
||||
"set_daily_word",
|
||||
"reset_wordle_games",
|
||||
]
|
||||
|
||||
|
||||
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]:
|
||||
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
|
||||
"""Find a player's active game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
result = await collection.find_one({"user_id": user_id})
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return WordleGame(**result)
|
||||
await get_or_add_user(session, user_id)
|
||||
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
|
||||
guesses = (await session.execute(statement)).scalars().all()
|
||||
return guesses
|
||||
|
||||
|
||||
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame:
|
||||
"""Start a new game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
game = WordleGame(user_id=user_id)
|
||||
await collection.insert_one(game.dict(by_alias=True))
|
||||
return game
|
||||
async def get_wordle_guesses(session: AsyncSession, user_id: int) -> list[str]:
|
||||
"""Get the strings of a player's guesses"""
|
||||
active_game = await get_active_wordle_game(session, user_id)
|
||||
return list(map(lambda g: g.guess.lower(), active_game))
|
||||
|
||||
|
||||
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str):
|
||||
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
|
||||
"""Make a guess in your current game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
|
||||
guess_instance = WordleGuess(user_id=user_id, guess=guess)
|
||||
session.add(guess_instance)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_daily_word(database: MongoDatabase) -> Optional[str]:
|
||||
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
|
||||
"""Get the word of today"""
|
||||
collection = database[TemporaryStorage.collection()]
|
||||
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
|
||||
row = (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
|
||||
if result is None:
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return result["word"]
|
||||
return row
|
||||
|
||||
|
||||
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str:
|
||||
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
|
||||
"""Set the word of today
|
||||
|
||||
This does NOT overwrite the existing word if there is one, so that it can safely run
|
||||
|
@ -60,23 +58,28 @@ async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = F
|
|||
|
||||
Returns the word that was chosen. If one already existed, return that instead.
|
||||
"""
|
||||
collection = database[TemporaryStorage.collection()]
|
||||
current_word = await get_daily_word(session)
|
||||
|
||||
current_word = None if forced else await get_daily_word(database)
|
||||
if current_word is not None:
|
||||
return current_word
|
||||
if current_word is None:
|
||||
current_word = WordleWord(word=word, day=datetime.date.today())
|
||||
session.add(current_word)
|
||||
await session.commit()
|
||||
|
||||
await collection.update_one(
|
||||
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
|
||||
)
|
||||
# Remove all active games
|
||||
await reset_wordle_games(session)
|
||||
elif forced:
|
||||
current_word.word = word
|
||||
session.add(current_word)
|
||||
await session.commit()
|
||||
|
||||
# Remove all active games
|
||||
await reset_wordle_games(database)
|
||||
# Remove all active games
|
||||
await reset_wordle_games(session)
|
||||
|
||||
return word
|
||||
return current_word.word
|
||||
|
||||
|
||||
async def reset_wordle_games(database: MongoDatabase):
|
||||
async def reset_wordle_games(session: AsyncSession):
|
||||
"""Reset all active games"""
|
||||
collection = database[WordleGame.collection()]
|
||||
await collection.drop()
|
||||
statement = delete(WordleGuess)
|
||||
await session.execute(statement)
|
||||
await session.commit()
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
from datetime import date
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.schemas import WordleStats
|
||||
|
||||
__all__ = ["get_wordle_stats", "complete_wordle_game"]
|
||||
|
||||
|
||||
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
|
||||
"""Get a user's wordle stats
|
||||
|
||||
If no entry is found, it is first created
|
||||
"""
|
||||
await get_or_add_user(session, user_id)
|
||||
|
||||
statement = select(WordleStats).where(WordleStats.user_id == user_id)
|
||||
stats = (await session.execute(statement)).scalar_one_or_none()
|
||||
if stats is not None:
|
||||
return stats
|
||||
|
||||
stats = WordleStats(user_id=user_id)
|
||||
session.add(stats)
|
||||
await session.commit()
|
||||
await session.refresh(stats)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
|
||||
"""Update the user's Wordle stats"""
|
||||
stats = await get_wordle_stats(session, user_id)
|
||||
stats.games += 1
|
||||
|
||||
if win:
|
||||
stats.wins += 1
|
||||
|
||||
# Update streak
|
||||
today = date.today()
|
||||
last_win = stats.last_win
|
||||
stats.last_win = today
|
||||
|
||||
if last_win is None or (today - last_win).days > 1:
|
||||
# Never won a game before or streak is over
|
||||
stats.current_streak = 1
|
||||
else:
|
||||
# On a streak: increase counter
|
||||
stats.current_streak += 1
|
||||
|
||||
# Update max streak if necessary
|
||||
if stats.current_streak > stats.highest_streak:
|
||||
stats.highest_streak = stats.current_streak
|
||||
else:
|
||||
# Streak is over
|
||||
stats.current_streak = 0
|
||||
|
||||
session.add(stats)
|
||||
await session.commit()
|
|
@ -1,6 +1,5 @@
|
|||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.motor_asyncio
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
@ -26,16 +25,3 @@ postgres_engine = create_async_engine(
|
|||
DBSession = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
|
||||
# MongoDB client
|
||||
if not settings.TESTING: # pragma: no cover
|
||||
encoded_mongo_username = quote_plus(settings.MONGO_USER)
|
||||
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
|
||||
mongo_url = (
|
||||
f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
)
|
||||
else:
|
||||
# Require no authentication when testing
|
||||
mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
|
||||
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import enum
|
||||
|
||||
__all__ = ["TaskType", "TempStorageKey"]
|
||||
__all__ = ["TaskType"]
|
||||
|
||||
|
||||
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
||||
|
@ -11,10 +11,3 @@ class TaskType(enum.IntEnum):
|
|||
|
||||
BIRTHDAYS = enum.auto()
|
||||
UFORA_ANNOUNCEMENTS = enum.auto()
|
||||
|
||||
|
||||
@enum.unique
|
||||
class TempStorageKey(str, enum.Enum):
|
||||
"""Enum for keys to distinguish the TemporaryStorage rows"""
|
||||
|
||||
WORDLE_WORD = "wordle_word"
|
||||
|
|
|
@ -1,6 +0,0 @@
|
|||
import motor.motor_asyncio
|
||||
|
||||
# Type aliases for the Motor types, which are way too long
|
||||
MongoClient = motor.motor_asyncio.AsyncIOMotorClient
|
||||
MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase
|
||||
MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection
|
|
@ -37,6 +37,9 @@ __all__ = [
|
|||
"UforaCourse",
|
||||
"UforaCourseAlias",
|
||||
"User",
|
||||
"WordleGuess",
|
||||
"WordleStats",
|
||||
"WordleWord",
|
||||
]
|
||||
|
||||
|
||||
|
@ -231,3 +234,47 @@ class User(Base):
|
|||
nightly_data: NightlyData = relationship(
|
||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_guesses: list[WordleGuess] = relationship(
|
||||
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_stats: WordleStats = relationship(
|
||||
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class WordleGuess(Base):
|
||||
"""A user's Wordle guesses for today"""
|
||||
|
||||
__tablename__ = "wordle_guesses"
|
||||
|
||||
wordle_guess_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
guess: str = Column(Text, nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class WordleStats(Base):
|
||||
"""Stats about a user's wordle performance"""
|
||||
|
||||
__tablename__ = "wordle_stats"
|
||||
|
||||
wordle_stats_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
last_win: Optional[date] = Column(Date, nullable=True)
|
||||
games: int = Column(Integer, server_default="0", nullable=False)
|
||||
wins: int = Column(Integer, server_default="0", nullable=False)
|
||||
current_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||
highest_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class WordleWord(Base):
|
||||
"""The current Wordle word"""
|
||||
|
||||
__tablename__ = "wordle_word"
|
||||
|
||||
word_id: int = Column(Integer, primary_key=True)
|
||||
word: str = Column(Text, nullable=False)
|
||||
day: date = Column(Date, nullable=False, unique=True)
|
|
@ -1,53 +0,0 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
|
||||
|
||||
|
||||
class PyObjectId(ObjectId):
|
||||
"""Custom type for bson ObjectIds"""
|
||||
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, value: str):
|
||||
"""Check that a string is a valid bson ObjectId"""
|
||||
if not ObjectId.is_valid(value):
|
||||
raise ValueError(f"Invalid ObjectId: '{value}'")
|
||||
|
||||
return ObjectId(value)
|
||||
|
||||
@classmethod
|
||||
def __modify_schema__(cls, field_schema: dict):
|
||||
field_schema.update(type="string")
|
||||
|
||||
|
||||
class MongoBase(BaseModel):
|
||||
"""Base model that properly sets the _id field, and adds one by default"""
|
||||
|
||||
id: PyObjectId = Field(default_factory=PyObjectId, alias="_id")
|
||||
|
||||
class Config:
|
||||
"""Configuration for encoding and construction"""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str, PyObjectId: str}
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class MongoCollection(MongoBase, ABC):
|
||||
"""Base model for the 'main class' in a collection
|
||||
|
||||
This field stores the name of the collection to avoid making typos against it
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def collection() -> str:
|
||||
"""Getter for the name of the collection, in order to avoid typos"""
|
||||
raise NotImplementedError
|
|
@ -1,40 +0,0 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from overrides import overrides
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
|
||||
__all__ = ["GameStats", "WordleStats"]
|
||||
|
||||
|
||||
class WordleStats(BaseModel):
|
||||
"""Model that holds stats about a player's Wordle performance"""
|
||||
|
||||
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
||||
last_win: Optional[datetime.datetime] = None
|
||||
wins: int = 0
|
||||
games: int = 0
|
||||
current_streak: int = 0
|
||||
max_streak: int = 0
|
||||
|
||||
@validator("guess_distribution")
|
||||
def validate_guesses_length(cls, value: list[int]):
|
||||
"""Check that the distribution of guesses is of the correct length"""
|
||||
if len(value) != 6:
|
||||
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class GameStats(MongoCollection):
|
||||
"""Collection that holds stats about how well a user has performed in games"""
|
||||
|
||||
user_id: int
|
||||
wordle: WordleStats = WordleStats()
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "game_stats"
|
|
@ -1,16 +0,0 @@
|
|||
from overrides import overrides
|
||||
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
|
||||
__all__ = ["TemporaryStorage"]
|
||||
|
||||
|
||||
class TemporaryStorage(MongoCollection):
|
||||
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
||||
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "temporary"
|
|
@ -1,44 +0,0 @@
|
|||
import datetime
|
||||
|
||||
from overrides import overrides
|
||||
from pydantic import Field, validator
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
__all__ = ["WordleGame"]
|
||||
|
||||
|
||||
class WordleGame(MongoCollection):
|
||||
"""Collection that holds people's active Wordle games"""
|
||||
|
||||
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
|
||||
guesses: list[str] = Field(default_factory=list)
|
||||
user_id: int
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "wordle"
|
||||
|
||||
@validator("guesses")
|
||||
def validate_guesses_length(cls, value: list[int]):
|
||||
"""Check that the amount of guesses is of the correct length"""
|
||||
if len(value) > 6:
|
||||
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
||||
|
||||
return value
|
||||
|
||||
def is_game_over(self, word: str) -> bool:
|
||||
"""Check if the current game is over"""
|
||||
# No guesses yet
|
||||
if not self.guesses:
|
||||
return False
|
||||
|
||||
# Max amount of guesses allowed
|
||||
if len(self.guesses) == WORDLE_GUESS_COUNT:
|
||||
return True
|
||||
|
||||
# Found the correct word
|
||||
return self.guesses[-1] == word
|
|
@ -1,19 +1,17 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from discord import app_commands
|
||||
from overrides import overrides
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import links, memes, ufora_courses, wordle
|
||||
from database.mongo_types import MongoDatabase
|
||||
|
||||
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
|
||||
|
||||
T = TypeVar("T")
|
||||
from database.schemas import WordleWord
|
||||
|
||||
|
||||
class DatabaseCache(ABC, Generic[T]):
|
||||
class DatabaseCache(ABC):
|
||||
"""Base class for a simple cache-like structure
|
||||
|
||||
The goal of this class is to store data for Discord auto-completion results
|
||||
|
@ -25,7 +23,7 @@ class DatabaseCache(ABC, Generic[T]):
|
|||
Considering the fact that a user isn't obligated to choose something from the suggestions,
|
||||
chances are high we have to go to the database for the final action either way.
|
||||
|
||||
Also stores the data in lowercase to allow fast searching
|
||||
Also stores the data in lowercase to allow fast searching.
|
||||
"""
|
||||
|
||||
data: list[str] = []
|
||||
|
@ -36,7 +34,7 @@ class DatabaseCache(ABC, Generic[T]):
|
|||
self.data.clear()
|
||||
|
||||
@abstractmethod
|
||||
async def invalidate(self, database_session: T):
|
||||
async def invalidate(self, database_session: AsyncSession):
|
||||
"""Invalidate the data stored in this cache"""
|
||||
|
||||
def get_autocomplete_suggestions(self, query: str) -> list[app_commands.Choice[str]]:
|
||||
|
@ -48,7 +46,7 @@ class DatabaseCache(ABC, Generic[T]):
|
|||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||
|
||||
|
||||
class LinkCache(DatabaseCache[AsyncSession]):
|
||||
class LinkCache(DatabaseCache):
|
||||
"""Cache to store the names of links"""
|
||||
|
||||
@overrides
|
||||
|
@ -61,7 +59,7 @@ class LinkCache(DatabaseCache[AsyncSession]):
|
|||
self.data_transformed = list(map(str.lower, self.data))
|
||||
|
||||
|
||||
class MemeCache(DatabaseCache[AsyncSession]):
|
||||
class MemeCache(DatabaseCache):
|
||||
"""Cache to store the names of meme templates"""
|
||||
|
||||
@overrides
|
||||
|
@ -74,7 +72,7 @@ class MemeCache(DatabaseCache[AsyncSession]):
|
|||
self.data_transformed = list(map(str.lower, self.data))
|
||||
|
||||
|
||||
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
||||
class UforaCourseCache(DatabaseCache):
|
||||
"""Cache to store the names of Ufora courses"""
|
||||
|
||||
# Also store the aliases to add additional support
|
||||
|
@ -119,13 +117,15 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
|
|||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||
|
||||
|
||||
class WordleCache(DatabaseCache[MongoDatabase]):
|
||||
class WordleCache(DatabaseCache):
|
||||
"""Cache to store the current daily Wordle word"""
|
||||
|
||||
async def invalidate(self, database_session: MongoDatabase):
|
||||
word: WordleWord
|
||||
|
||||
async def invalidate(self, database_session: AsyncSession):
|
||||
word = await wordle.get_daily_word(database_session)
|
||||
if word is not None:
|
||||
self.data = [word]
|
||||
self.word = word
|
||||
|
||||
|
||||
class CacheManager:
|
||||
|
@ -142,9 +142,9 @@ class CacheManager:
|
|||
self.ufora_courses = UforaCourseCache()
|
||||
self.wordle_word = WordleCache()
|
||||
|
||||
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
||||
async def initialize_caches(self, postgres_session: AsyncSession):
|
||||
"""Initialize the contents of all caches"""
|
||||
await self.links.invalidate(postgres_session)
|
||||
await self.memes.invalidate(postgres_session)
|
||||
await self.ufora_courses.invalidate(postgres_session)
|
||||
await self.wordle_word.invalidate(mongo_db)
|
||||
await self.wordle_word.invalidate(postgres_session)
|
||||
|
|
|
@ -1,15 +1,5 @@
|
|||
import datetime
|
||||
import zoneinfo
|
||||
|
||||
__all__ = ["LOCAL_TIMEZONE", "today_only_date"]
|
||||
__all__ = ["LOCAL_TIMEZONE"]
|
||||
|
||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||
|
||||
|
||||
def today_only_date() -> datetime.datetime:
|
||||
"""Mongo can't handle datetime.date, so we need a datetime instance
|
||||
|
||||
We do, however, only care about the date, so remove all the rest
|
||||
"""
|
||||
today = datetime.date.today()
|
||||
return datetime.datetime(year=today.year, month=today.month, day=today.day)
|
||||
|
|
|
@ -4,7 +4,7 @@ from overrides import overrides
|
|||
from didier import Didier
|
||||
|
||||
|
||||
class TestCog(commands.Cog):
|
||||
class DebugCog(commands.Cog):
|
||||
"""Testing cog for dev purposes"""
|
||||
|
||||
client: Didier
|
||||
|
@ -16,11 +16,11 @@ class TestCog(commands.Cog):
|
|||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
return await self.client.is_owner(ctx.author)
|
||||
|
||||
@commands.command()
|
||||
async def test(self, ctx: commands.Context):
|
||||
@commands.command(aliases=["Dev"])
|
||||
async def debug(self, ctx: commands.Context):
|
||||
"""Debugging command"""
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(TestCog(client))
|
||||
await client.add_cog(DebugCog(client))
|
|
@ -4,14 +4,11 @@ import discord
|
|||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||
from database.crud.wordle import (
|
||||
get_active_wordle_game,
|
||||
make_wordle_guess,
|
||||
start_new_wordle_game,
|
||||
)
|
||||
from database.constants import WORDLE_WORD_LENGTH
|
||||
from database.crud.wordle import get_wordle_guesses, make_wordle_guess
|
||||
from database.crud.wordle_stats import complete_wordle_game
|
||||
from didier import Didier
|
||||
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed
|
||||
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed, is_wordle_game_over
|
||||
|
||||
|
||||
class Games(commands.Cog):
|
||||
|
@ -35,31 +32,39 @@ class Games(commands.Cog):
|
|||
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id)
|
||||
if active_game is None:
|
||||
active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id)
|
||||
word_instance = self.client.database_caches.wordle_word.word
|
||||
|
||||
# Trying to guess with a complete game
|
||||
if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess:
|
||||
embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
async with self.client.postgres_session as session:
|
||||
guesses = await get_wordle_guesses(session, interaction.user.id)
|
||||
|
||||
# Make a guess
|
||||
if guess:
|
||||
# The guess is not a real word
|
||||
if guess.lower() not in self.client.wordle_words:
|
||||
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
|
||||
# Trying to guess with a complete game
|
||||
if is_wordle_game_over(guesses, word_instance.word):
|
||||
embed = WordleErrorEmbed(
|
||||
message="You've already completed today's Wordle.\nTry again tomorrow!"
|
||||
).to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
guess = guess.lower()
|
||||
await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess)
|
||||
# Make a guess
|
||||
if guess:
|
||||
# The guess is not a real word
|
||||
if guess.lower() not in self.client.wordle_words:
|
||||
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
# Don't re-request the game, we already have it
|
||||
# just append locally
|
||||
active_game.guesses.append(guess)
|
||||
guess = guess.lower()
|
||||
await make_wordle_guess(session, interaction.user.id, guess)
|
||||
|
||||
embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed()
|
||||
await interaction.followup.send(embed=embed)
|
||||
# Don't re-request the game, we already have it
|
||||
# just append locally
|
||||
guesses.append(guess)
|
||||
|
||||
embed = WordleEmbed(guesses=guesses, word=word_instance).to_embed()
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
# After responding to the interaction: update stats in the background
|
||||
game_over = is_wordle_game_over(guesses, word_instance.word)
|
||||
if game_over:
|
||||
await complete_wordle_game(session, interaction.user.id, word_instance.word in guesses)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
|
|
|
@ -5,7 +5,7 @@ from discord import app_commands
|
|||
from discord.ext import commands
|
||||
|
||||
from database.crud.links import get_link_by_name
|
||||
from database.schemas.relational import Link
|
||||
from database.schemas import Link
|
||||
from didier import Didier
|
||||
from didier.data.apis import urban_dictionary
|
||||
from didier.data.embeds.google import GoogleSearch
|
||||
|
|
|
@ -140,9 +140,9 @@ class Tasks(commands.Cog):
|
|||
@tasks.loop(time=DAILY_RESET_TIME)
|
||||
async def reset_wordle_word(self, forced: bool = False):
|
||||
"""Reset the daily Wordle word"""
|
||||
db = self.client.mongo_db
|
||||
word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words)), forced=forced)
|
||||
self.client.database_caches.wordle_word.data = [word]
|
||||
async with self.client.postgres_session as session:
|
||||
await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced)
|
||||
await self.client.database_caches.wordle_word.invalidate(session)
|
||||
|
||||
@reset_wordle_word.before_loop
|
||||
async def _before_reset_wordle_word(self):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from aiohttp import ClientSession
|
||||
|
||||
import settings
|
||||
from database.schemas.relational import MemeTemplate
|
||||
from database.schemas import MemeTemplate
|
||||
from didier.exceptions.missing_env import MissingEnvironmentVariable
|
||||
from didier.utils.http.requests import ensure_post
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ from datetime import datetime
|
|||
import discord
|
||||
from overrides import overrides
|
||||
|
||||
from database.schemas.relational import Deadline
|
||||
from database.schemas import Deadline
|
||||
from didier.data.embeds.base import EmbedBaseModel
|
||||
from didier.utils.types.datetime import tz_aware_now
|
||||
from didier.utils.types.string import get_edu_year_name
|
||||
|
|
|
@ -13,7 +13,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
import settings
|
||||
from database.crud import ufora_announcements as crud
|
||||
from database.schemas.relational import UforaCourse
|
||||
from database.schemas import UforaCourse
|
||||
from didier.data.embeds.base import EmbedBaseModel
|
||||
from didier.utils.discord.colours import ghent_university_blue
|
||||
from didier.utils.types.datetime import int_to_weekday
|
||||
|
|
|
@ -1,16 +1,26 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from overrides import overrides
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
from database.schemas import WordleWord
|
||||
from didier.data.embeds.base import EmbedBaseModel
|
||||
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
|
||||
|
||||
__all__ = ["WordleEmbed", "WordleErrorEmbed"]
|
||||
__all__ = ["is_wordle_game_over", "WordleEmbed", "WordleErrorEmbed"]
|
||||
|
||||
|
||||
def is_wordle_game_over(guesses: list[str], word: str) -> bool:
|
||||
"""Check if the current game is over or not"""
|
||||
if not guesses:
|
||||
return False
|
||||
|
||||
if len(guesses) == WORDLE_GUESS_COUNT:
|
||||
return True
|
||||
|
||||
return word.lower() in guesses
|
||||
|
||||
|
||||
def footer() -> str:
|
||||
|
@ -32,18 +42,18 @@ class WordleColour(enum.IntEnum):
|
|||
class WordleEmbed(EmbedBaseModel):
|
||||
"""Embed for a Wordle game"""
|
||||
|
||||
game: Optional[WordleGame]
|
||||
word: str
|
||||
guesses: list[str]
|
||||
word: WordleWord
|
||||
|
||||
def _letter_colour(self, guess: str, index: int) -> WordleColour:
|
||||
"""Get the colour for a guess at a given position"""
|
||||
if guess[index] == self.word[index]:
|
||||
if guess[index] == self.word.word[index]:
|
||||
return WordleColour.CORRECT
|
||||
|
||||
wrong_letter = 0
|
||||
wrong_position = 0
|
||||
|
||||
for i, letter in enumerate(self.word):
|
||||
for i, letter in enumerate(self.word.word):
|
||||
if letter == guess[index] and guess[i] != guess[index]:
|
||||
wrong_letter += 1
|
||||
|
||||
|
@ -68,9 +78,8 @@ class WordleEmbed(EmbedBaseModel):
|
|||
colours = []
|
||||
|
||||
# Add all the guesses
|
||||
if self.game is not None:
|
||||
for guess in self.game.guesses:
|
||||
colours.append(self._guess_colours(guess))
|
||||
for guess in self.guesses:
|
||||
colours.append(self._guess_colours(guess))
|
||||
|
||||
# Fill the rest with empty spots
|
||||
for _ in range(WORDLE_GUESS_COUNT - len(colours)):
|
||||
|
@ -99,19 +108,19 @@ class WordleEmbed(EmbedBaseModel):
|
|||
|
||||
colours = self.colour_code_game()
|
||||
|
||||
embed = discord.Embed(colour=discord.Colour.blue(), title="Wordle")
|
||||
embed = discord.Embed(colour=discord.Colour.blue(), title=f"Wordle #{self.word.word_id + 1}")
|
||||
emojis = self._colours_to_emojis(colours)
|
||||
|
||||
rows = [" ".join(row) for row in emojis]
|
||||
|
||||
# Don't reveal anything if we only want to show the colours
|
||||
if not only_colours and self.game is not None:
|
||||
for i, guess in enumerate(self.game.guesses):
|
||||
if not only_colours and self.guesses:
|
||||
for i, guess in enumerate(self.guesses):
|
||||
rows[i] += f" ||{guess.upper()}||"
|
||||
|
||||
# If the game is over, reveal the word
|
||||
if self.game.is_game_over(self.word):
|
||||
rows.append(f"\n\nThe word was **{self.word.upper()}**!")
|
||||
if is_wordle_game_over(self.guesses, self.word.word):
|
||||
rows.append(f"\n\nThe word was **{self.word.word.upper()}**!")
|
||||
|
||||
embed.description = "\n\n".join(rows)
|
||||
embed.set_footer(text=footer())
|
||||
|
|
|
@ -2,7 +2,6 @@ import logging
|
|||
import os
|
||||
|
||||
import discord
|
||||
import motor.motor_asyncio
|
||||
from aiohttp import ClientSession
|
||||
from discord.app_commands import AppCommandError
|
||||
from discord.ext import commands
|
||||
|
@ -10,7 +9,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
import settings
|
||||
from database.crud import custom_commands
|
||||
from database.engine import DBSession, mongo_client
|
||||
from database.engine import DBSession
|
||||
from database.utils.caches import CacheManager
|
||||
from didier.data.embeds.error_embed import create_error_embed
|
||||
from didier.exceptions import HTTPException, NoMatch
|
||||
|
@ -55,11 +54,6 @@ class Didier(commands.Bot):
|
|||
"""Obtain a session for the PostgreSQL database"""
|
||||
return DBSession()
|
||||
|
||||
@property
|
||||
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
||||
"""Obtain a reference to the MongoDB database"""
|
||||
return mongo_client[settings.MONGO_DB]
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Do some initial setup
|
||||
|
||||
|
@ -71,7 +65,7 @@ class Didier(commands.Bot):
|
|||
# Initialize caches
|
||||
self.database_caches = CacheManager()
|
||||
async with self.postgres_session as session:
|
||||
await self.database_caches.initialize_caches(session, self.mongo_db)
|
||||
await self.database_caches.initialize_caches(session)
|
||||
|
||||
# Load extensions
|
||||
await self._load_initial_extensions()
|
||||
|
|
|
@ -5,7 +5,7 @@ from discord import Interaction
|
|||
from overrides import overrides
|
||||
|
||||
from database.crud.deadlines import add_deadline
|
||||
from database.schemas.relational import UforaCourse
|
||||
from database.schemas import UforaCourse
|
||||
|
||||
__all__ = ["AddDeadline"]
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ import traceback
|
|||
import discord.ui
|
||||
from overrides import overrides
|
||||
|
||||
from database.schemas.relational import MemeTemplate
|
||||
from database.schemas import MemeTemplate
|
||||
from didier import Didier
|
||||
from didier.data.apis.imgflip import generate_meme
|
||||
|
||||
|
|
|
@ -10,10 +10,3 @@ services:
|
|||
- POSTGRES_PASSWORD=pytest
|
||||
ports:
|
||||
- "5433:5432"
|
||||
mongo-pytest:
|
||||
image: mongo:5.0
|
||||
restart: always
|
||||
environment:
|
||||
- MONGO_INITDB_DATABASE=didier_pytest
|
||||
ports:
|
||||
- "27018:27017"
|
||||
|
|
|
@ -12,18 +12,5 @@ services:
|
|||
- "${POSTGRES_PORT:-5432}:${POSTGRES_PORT:-5432}"
|
||||
volumes:
|
||||
- postgres:/var/lib/postgresql/data
|
||||
mongo:
|
||||
image: mongo:5.0
|
||||
restart: always
|
||||
environment:
|
||||
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
|
||||
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
|
||||
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
|
||||
command: [--auth]
|
||||
ports:
|
||||
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
|
||||
volumes:
|
||||
- mongo:/data/db
|
||||
volumes:
|
||||
postgres:
|
||||
mongo:
|
||||
|
|
|
@ -44,9 +44,6 @@ ignore_missing_imports = true
|
|||
asyncio_mode = "auto"
|
||||
env = [
|
||||
"TESTING = 1",
|
||||
"MONGO_DB = didier_pytest",
|
||||
"MONGO_HOST = localhost",
|
||||
"MONGO_PORT = 27018",
|
||||
"POSTGRES_DB = didier_pytest",
|
||||
"POSTGRES_USER = pytest",
|
||||
"POSTGRES_PASS = pytest",
|
||||
|
@ -55,6 +52,5 @@ env = [
|
|||
"DISCORD_TOKEN = token"
|
||||
]
|
||||
markers = [
|
||||
"mongo: tests that use MongoDB",
|
||||
"postgres: tests that use PostgreSQL"
|
||||
]
|
||||
|
|
|
@ -7,7 +7,6 @@ git+https://github.com/Rapptz/discord-ext-menus@8686b5d
|
|||
environs==9.5.0
|
||||
feedparser==6.0.10
|
||||
markdownify==0.11.2
|
||||
motor==3.0.0
|
||||
overrides==6.1.0
|
||||
pydantic==1.9.1
|
||||
python-dateutil==2.8.2
|
||||
|
|
|
@ -37,13 +37,6 @@ SEMESTER: int = env.int("SEMESTER", 2)
|
|||
YEAR: int = env.int("YEAR", 3)
|
||||
|
||||
"""Database"""
|
||||
# MongoDB
|
||||
MONGO_DB: str = env.str("MONGO_DB", "didier")
|
||||
MONGO_USER: str = env.str("MONGO_USER", "root")
|
||||
MONGO_PASS: str = env.str("MONGO_PASS", "root")
|
||||
MONGO_HOST: str = env.str("MONGO_HOST", "localhost")
|
||||
MONGO_PORT: int = env.int("MONGO_PORT", "27017")
|
||||
|
||||
# PostgreSQL
|
||||
POSTGRES_DB: str = env.str("POSTGRES_DB", "didier")
|
||||
POSTGRES_USER: str = env.str("POSTGRES_USER", "postgres")
|
||||
|
|
|
@ -2,12 +2,10 @@ import asyncio
|
|||
from typing import AsyncGenerator, Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import motor.motor_asyncio
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import settings
|
||||
from database.engine import mongo_client, postgres_engine
|
||||
from database.engine import postgres_engine
|
||||
from database.migrations import ensure_latest_migration, migrate
|
||||
from didier import Didier
|
||||
|
||||
|
@ -56,14 +54,6 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
|
|||
await connection.close()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def mongodb() -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
||||
"""Fixture to get a MongoDB connection"""
|
||||
database = mongo_client[settings.MONGO_DB]
|
||||
yield database
|
||||
mongo_client.drop_database(settings.MONGO_DB)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_client() -> Didier:
|
||||
"""Fixture to get a mock Didier instance
|
||||
|
|
|
@ -4,7 +4,7 @@ import pytest
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users
|
||||
from database.schemas.relational import (
|
||||
from database.schemas import (
|
||||
Bank,
|
||||
UforaAnnouncement,
|
||||
UforaCourse,
|
||||
|
@ -25,7 +25,7 @@ def test_user_id() -> int:
|
|||
@pytest.fixture
|
||||
async def user(postgres: AsyncSession, test_user_id: int) -> User:
|
||||
"""Fixture to create a user"""
|
||||
_user = await users.get_or_add(postgres, test_user_id)
|
||||
_user = await users.get_or_add_user(postgres, test_user_id)
|
||||
await postgres.refresh(_user)
|
||||
return _user
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.crud import birthdays as crud
|
||||
from database.crud import users
|
||||
from database.schemas.relational import User
|
||||
from database.schemas import User
|
||||
|
||||
|
||||
async def test_add_birthday_not_present(postgres: AsyncSession, user: User):
|
||||
|
@ -54,7 +54,7 @@ async def test_get_birthdays_on_day(postgres: AsyncSession, user: User):
|
|||
"""Test getting all birthdays on a given day"""
|
||||
await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
|
||||
|
||||
user_2 = await users.get_or_add(postgres, user.user_id + 1)
|
||||
user_2 = await users.get_or_add_user(postgres, user.user_id + 1)
|
||||
await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1))
|
||||
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
||||
assert len(birthdays) == 1
|
||||
|
|
|
@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.crud import currency as crud
|
||||
from database.exceptions import currency as exceptions
|
||||
from database.schemas.relational import Bank
|
||||
from database.schemas import Bank
|
||||
|
||||
|
||||
async def test_add_dinks(postgres: AsyncSession, bank: Bank):
|
||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
from database.crud import custom_commands as crud
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from database.schemas.relational import CustomCommand
|
||||
from database.schemas import CustomCommand
|
||||
|
||||
|
||||
async def test_create_command_non_existing(postgres: AsyncSession):
|
||||
|
|
|
@ -2,7 +2,7 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import dad_jokes as crud
|
||||
from database.schemas.relational import DadJoke
|
||||
from database.schemas import DadJoke
|
||||
|
||||
|
||||
async def test_add_dad_joke(postgres: AsyncSession):
|
||||
|
|
|
@ -1,63 +0,0 @@
|
|||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from database.crud import game_stats as crud
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.game_stats import GameStats
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
|
||||
async def insert_game_stats(mongodb: MongoDatabase, stats: GameStats):
|
||||
"""Helper function to insert some stats"""
|
||||
collection = mongodb[GameStats.collection()]
|
||||
await collection.insert_one(stats.dict(by_alias=True))
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_stats_non_existent_creates(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test getting a user's stats when the db is empty"""
|
||||
collection = mongodb[GameStats.collection()]
|
||||
assert await collection.find_one({"user_id": test_user_id}) is None
|
||||
await crud.get_game_stats(mongodb, test_user_id)
|
||||
assert await collection.find_one({"user_id": test_user_id}) is not None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_stats_existing_returns(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test getting a user's stats when there's already an entry present"""
|
||||
stats = GameStats(user_id=test_user_id)
|
||||
stats.wordle.games = 20
|
||||
await insert_game_stats(mongodb, stats)
|
||||
found_stats = await crud.get_game_stats(mongodb, test_user_id)
|
||||
assert found_stats.wordle.games == 20
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_won(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test completing a wordle game when you win"""
|
||||
await crud.complete_wordle_game(mongodb, test_user_id, win=True, guesses=2)
|
||||
stats = await crud.get_game_stats(mongodb, test_user_id)
|
||||
assert stats.wordle.guess_distribution == [0, 1, 0, 0, 0, 0]
|
||||
assert stats.wordle.games == 1
|
||||
assert stats.wordle.wins == 1
|
||||
assert stats.wordle.current_streak == 1
|
||||
assert stats.wordle.max_streak == 1
|
||||
assert stats.wordle.last_win == today_only_date()
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_lost(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test completing a wordle game when you lose"""
|
||||
stats = GameStats(user_id=test_user_id)
|
||||
stats.wordle.current_streak = 10
|
||||
await insert_game_stats(mongodb, stats)
|
||||
|
||||
await crud.complete_wordle_game(mongodb, test_user_id, win=False)
|
||||
stats = await crud.get_game_stats(mongodb, test_user_id)
|
||||
|
||||
# Check that streak was broken
|
||||
assert stats.wordle.current_streak == 0
|
||||
assert stats.wordle.games == 1
|
||||
assert stats.wordle.guess_distribution == [0, 0, 0, 0, 0, 0]
|
|
@ -7,7 +7,7 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||
|
||||
from database.crud import tasks as crud
|
||||
from database.enums import TaskType
|
||||
from database.schemas.relational import Task
|
||||
from database.schemas import Task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
@ -3,7 +3,7 @@ import datetime
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_announcements as crud
|
||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||
from database.schemas import UforaAnnouncement, UforaCourse
|
||||
|
||||
|
||||
async def test_get_courses_with_announcements_none(postgres: AsyncSession):
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_courses as crud
|
||||
from database.schemas.relational import UforaCourse
|
||||
from database.schemas import UforaCourse
|
||||
|
||||
|
||||
async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||
|
|
|
@ -2,12 +2,12 @@ from sqlalchemy import select
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users as crud
|
||||
from database.schemas.relational import User
|
||||
from database.schemas import User
|
||||
|
||||
|
||||
async def test_get_or_add_non_existing(postgres: AsyncSession):
|
||||
"""Test get_or_add for a user that doesn't exist"""
|
||||
await crud.get_or_add(postgres, 1)
|
||||
await crud.get_or_add_user(postgres, 1)
|
||||
statement = select(User)
|
||||
res = (await postgres.execute(statement)).scalars().all()
|
||||
|
||||
|
@ -18,8 +18,8 @@ async def test_get_or_add_non_existing(postgres: AsyncSession):
|
|||
|
||||
async def test_get_or_add_existing(postgres: AsyncSession):
|
||||
"""Test get_or_add for a user that does exist"""
|
||||
user = await crud.get_or_add(postgres, 1)
|
||||
user = await crud.get_or_add_user(postgres, 1)
|
||||
bank = user.bank
|
||||
|
||||
assert await crud.get_or_add(postgres, 1) == user
|
||||
assert (await crud.get_or_add(postgres, 1)).bank == bank
|
||||
assert await crud.get_or_add_user(postgres, 1) == user
|
||||
assert (await crud.get_or_add_user(postgres, 1)).bank == bank
|
||||
|
|
|
@ -1,136 +1,138 @@
|
|||
from datetime import datetime, timedelta
|
||||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import wordle as crud
|
||||
from database.enums import TempStorageKey
|
||||
from database.mongo_types import MongoCollection, MongoDatabase
|
||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
from database.schemas import User, WordleGuess, WordleWord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def wordle_collection(mongodb: MongoDatabase) -> MongoCollection:
|
||||
"""Fixture to get a reference to the wordle collection"""
|
||||
yield mongodb[WordleGame.collection()]
|
||||
async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]:
|
||||
"""Fixture to generate some guesses"""
|
||||
guesses = []
|
||||
|
||||
for guess in ["TEST", "WORDLE", "WORDS"]:
|
||||
guess = WordleGuess(user_id=user.user_id, guess=guess)
|
||||
postgres.add(guess)
|
||||
await postgres.commit()
|
||||
|
||||
guesses.append(guess)
|
||||
|
||||
return guesses
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> WordleGame:
|
||||
"""Fixture to create a new game"""
|
||||
game = WordleGame(user_id=test_user_id)
|
||||
await wordle_collection.insert_one(game.dict(by_alias=True))
|
||||
yield game
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int):
|
||||
"""Test starting a new game"""
|
||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
||||
assert result is None
|
||||
|
||||
await crud.start_new_wordle_game(mongodb, test_user_id)
|
||||
|
||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int):
|
||||
@pytest.mark.postgres
|
||||
async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User):
|
||||
"""Test getting an active game when there is none"""
|
||||
result = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||
assert result is None
|
||||
result = await crud.get_active_wordle_game(postgres, user.user_id)
|
||||
assert not result
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame):
|
||||
@pytest.mark.postgres
|
||||
async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]):
|
||||
"""Test getting an active game when there is one"""
|
||||
result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id)
|
||||
assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True)
|
||||
result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id)
|
||||
assert result == wordle_guesses
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_daily_word_none(mongodb: MongoDatabase):
|
||||
@pytest.mark.postgres
|
||||
async def test_get_daily_word_none(postgres: AsyncSession):
|
||||
"""Test getting the daily word when the database is empty"""
|
||||
result = await crud.get_daily_word(mongodb)
|
||||
result = await crud.get_daily_word(postgres)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_get_daily_word_not_today(mongodb: MongoDatabase):
|
||||
async def test_get_daily_word_not_today(postgres: AsyncSession):
|
||||
"""Test getting the daily word when there is an entry, but not for today"""
|
||||
day = datetime.today() - timedelta(days=1)
|
||||
collection = mongodb[TemporaryStorage.collection()]
|
||||
day = date.today() - timedelta(days=1)
|
||||
|
||||
word = "testword"
|
||||
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word})
|
||||
word_instance = WordleWord(word=word, day=day)
|
||||
postgres.add(word_instance)
|
||||
await postgres.commit()
|
||||
|
||||
assert await crud.get_daily_word(mongodb) is None
|
||||
assert await crud.get_daily_word(postgres) is None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_get_daily_word_present(mongodb: MongoDatabase):
|
||||
async def test_get_daily_word_present(postgres: AsyncSession):
|
||||
"""Test getting the daily word when there is one for today"""
|
||||
day = datetime.today()
|
||||
collection = mongodb[TemporaryStorage.collection()]
|
||||
day = date.today()
|
||||
|
||||
word = "testword"
|
||||
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word})
|
||||
word_instance = WordleWord(word=word, day=day)
|
||||
postgres.add(word_instance)
|
||||
await postgres.commit()
|
||||
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_none_present(mongodb: MongoDatabase):
|
||||
async def test_set_daily_word_none_present(postgres: AsyncSession):
|
||||
"""Test setting the daily word when there is none"""
|
||||
assert await crud.get_daily_word(mongodb) is None
|
||||
assert await crud.get_daily_word(postgres) is None
|
||||
word = "testword"
|
||||
await crud.set_daily_word(mongodb, word)
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
await crud.set_daily_word(postgres, word)
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_present(mongodb: MongoDatabase):
|
||||
async def test_set_daily_word_present(postgres: AsyncSession):
|
||||
"""Test setting the daily word when there already is one"""
|
||||
word = "testword"
|
||||
await crud.set_daily_word(mongodb, word)
|
||||
await crud.set_daily_word(mongodb, "another word")
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
await crud.set_daily_word(postgres, word)
|
||||
await crud.set_daily_word(postgres, "another word")
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_force_overwrite(mongodb: MongoDatabase):
|
||||
async def test_set_daily_word_force_overwrite(postgres: AsyncSession):
|
||||
"""Test setting the daily word when there already is one, but "forced" is set to True"""
|
||||
word = "testword"
|
||||
await crud.set_daily_word(mongodb, word)
|
||||
await crud.set_daily_word(postgres, word)
|
||||
word = "anotherword"
|
||||
await crud.set_daily_word(mongodb, word, forced=True)
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
await crud.set_daily_word(postgres, word, forced=True)
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_make_wordle_guess(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int):
|
||||
@pytest.mark.postgres
|
||||
async def test_make_wordle_guess(postgres: AsyncSession, user: User):
|
||||
"""Test making a guess in your current game"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
guess = "guess"
|
||||
await crud.make_wordle_guess(mongodb, test_user_id, guess)
|
||||
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||
assert wordle_game.guesses == [guess]
|
||||
await crud.make_wordle_guess(postgres, test_user_id, guess)
|
||||
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess]
|
||||
|
||||
other_guess = "otherguess"
|
||||
await crud.make_wordle_guess(mongodb, test_user_id, other_guess)
|
||||
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||
assert wordle_game.guesses == [guess, other_guess]
|
||||
await crud.make_wordle_guess(postgres, test_user_id, other_guess)
|
||||
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess, other_guess]
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_reset_wordle_games(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int):
|
||||
@pytest.mark.postgres
|
||||
async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User):
|
||||
"""Test dropping the collection of active games"""
|
||||
assert await crud.get_active_wordle_game(mongodb, test_user_id) is not None
|
||||
await crud.reset_wordle_games(mongodb)
|
||||
assert await crud.get_active_wordle_game(mongodb, test_user_id) is None
|
||||
test_user_id = user.user_id
|
||||
|
||||
assert await crud.get_active_wordle_game(postgres, test_user_id)
|
||||
await crud.reset_wordle_games(postgres)
|
||||
assert not await crud.get_active_wordle_game(postgres, test_user_id)
|
||||
|
|
|
@ -0,0 +1,72 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import wordle_stats as crud
|
||||
from database.schemas import User, WordleStats
|
||||
|
||||
|
||||
async def insert_game_stats(postgres: AsyncSession, stats: WordleStats):
|
||||
"""Helper function to insert some stats"""
|
||||
postgres.add(stats)
|
||||
await postgres.commit()
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User):
|
||||
"""Test getting a user's stats when the db is empty"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
statement = select(WordleStats).where(WordleStats.user_id == test_user_id)
|
||||
assert (await postgres.execute(statement)).scalar_one_or_none() is None
|
||||
|
||||
await crud.get_wordle_stats(postgres, test_user_id)
|
||||
assert (await postgres.execute(statement)).scalar_one_or_none() is not None
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_stats_existing_returns(postgres: AsyncSession, user: User):
|
||||
"""Test getting a user's stats when there's already an entry present"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
stats = WordleStats(user_id=test_user_id)
|
||||
stats.games = 20
|
||||
await insert_game_stats(postgres, stats)
|
||||
found_stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||
assert found_stats.games == 20
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_won(postgres: AsyncSession, user: User):
|
||||
"""Test completing a wordle game when you win"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
await crud.complete_wordle_game(postgres, test_user_id, win=True)
|
||||
stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||
assert stats.games == 1
|
||||
assert stats.wins == 1
|
||||
assert stats.current_streak == 1
|
||||
assert stats.highest_streak == 1
|
||||
assert stats.last_win == datetime.date.today()
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User):
|
||||
"""Test completing a wordle game when you lose"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
stats = WordleStats(user_id=test_user_id)
|
||||
stats.current_streak = 10
|
||||
await insert_game_stats(postgres, stats)
|
||||
|
||||
await crud.complete_wordle_game(postgres, test_user_id, win=False)
|
||||
stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||
|
||||
# Check that streak was broken
|
||||
assert stats.current_streak == 0
|
||||
assert stats.games == 1
|
|
@ -1,6 +1,6 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import UforaCourse
|
||||
from database.schemas import UforaCourse
|
||||
from database.utils.caches import UforaCourseCache
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue