mirror of https://github.com/stijndcl/didier
commit
8aedde46de
|
@ -49,8 +49,6 @@ jobs:
|
|||
- 27018:27017
|
||||
env:
|
||||
MONGO_DB: didier_pytest
|
||||
MONGO_USER: pytest
|
||||
MONGO_PASSWORD: pytest
|
||||
steps:
|
||||
- uses: actions/checkout@v3
|
||||
- name: Setup Python
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
WORDLE_GUESS_COUNT = 6
|
||||
WORDLE_WORD_LENGTH = 5
|
|
@ -16,7 +16,7 @@ async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
|
|||
return dad_joke
|
||||
|
||||
|
||||
async def get_random_dad_joke(session: AsyncSession) -> DadJoke:
|
||||
async def get_random_dad_joke(session: AsyncSession) -> DadJoke: # pragma: no cover # randomness is untestable
|
||||
"""Return a random database entry"""
|
||||
statement = select(DadJoke).order_by(func.random())
|
||||
row = (await session.execute(statement)).first()
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
import datetime
|
||||
from typing import Union
|
||||
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.game_stats import GameStats
|
||||
|
||||
__all__ = ["get_game_stats", "complete_wordle_game"]
|
||||
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
|
||||
async def get_game_stats(database: MongoDatabase, user_id: int) -> GameStats:
|
||||
"""Get a user's game stats
|
||||
|
||||
If no entry is found, it is first created
|
||||
"""
|
||||
collection = database[GameStats.collection()]
|
||||
stats = await collection.find_one({"user_id": user_id})
|
||||
if stats is not None:
|
||||
return GameStats(**stats)
|
||||
|
||||
stats = GameStats(user_id=user_id)
|
||||
await collection.insert_one(stats.dict(by_alias=True))
|
||||
return stats
|
||||
|
||||
|
||||
async def complete_wordle_game(database: MongoDatabase, user_id: int, win: bool, guesses: int = 0):
|
||||
"""Update the user's Wordle stats"""
|
||||
stats = await get_game_stats(database, user_id)
|
||||
|
||||
update: dict[str, dict[str, Union[int, datetime.datetime]]] = {"$inc": {"wordle.games": 1}, "$set": {}}
|
||||
|
||||
if win:
|
||||
update["$inc"]["wordle.wins"] = 1
|
||||
update["$inc"][f"wordle.guess_distribution.{guesses - 1}"] = 1
|
||||
|
||||
# Update streak
|
||||
today = today_only_date()
|
||||
last_win = stats.wordle.last_win
|
||||
update["$set"]["wordle.last_win"] = today
|
||||
|
||||
if last_win is None or (today - last_win).days > 1:
|
||||
# Never won a game before or streak is over
|
||||
update["$set"]["wordle.current_streak"] = 1
|
||||
stats.wordle.current_streak = 1
|
||||
else:
|
||||
# On a streak: increase counter
|
||||
update["$inc"]["wordle.current_streak"] = 1
|
||||
stats.wordle.current_streak += 1
|
||||
|
||||
# Update max streak if necessary
|
||||
if stats.wordle.current_streak > stats.wordle.max_streak:
|
||||
update["$set"]["wordle.max_streak"] = stats.wordle.current_streak
|
||||
else:
|
||||
# Streak is over
|
||||
update["$set"]["wordle.current_streak"] = 0
|
||||
|
||||
collection = database[GameStats.collection()]
|
||||
await collection.update_one({"_id": stats.id}, update)
|
|
@ -0,0 +1,82 @@
|
|||
from typing import Optional
|
||||
|
||||
from database.enums import TempStorageKey
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
__all__ = [
|
||||
"get_active_wordle_game",
|
||||
"make_wordle_guess",
|
||||
"start_new_wordle_game",
|
||||
"set_daily_word",
|
||||
"reset_wordle_games",
|
||||
]
|
||||
|
||||
|
||||
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]:
|
||||
"""Find a player's active game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
result = await collection.find_one({"user_id": user_id})
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return WordleGame(**result)
|
||||
|
||||
|
||||
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame:
|
||||
"""Start a new game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
game = WordleGame(user_id=user_id)
|
||||
await collection.insert_one(game.dict(by_alias=True))
|
||||
return game
|
||||
|
||||
|
||||
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str):
|
||||
"""Make a guess in your current game"""
|
||||
collection = database[WordleGame.collection()]
|
||||
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
|
||||
|
||||
|
||||
async def get_daily_word(database: MongoDatabase) -> Optional[str]:
|
||||
"""Get the word of today"""
|
||||
collection = database[TemporaryStorage.collection()]
|
||||
|
||||
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
|
||||
if result is None:
|
||||
return None
|
||||
|
||||
return result["word"]
|
||||
|
||||
|
||||
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str:
|
||||
"""Set the word of today
|
||||
|
||||
This does NOT overwrite the existing word if there is one, so that it can safely run
|
||||
on startup every time.
|
||||
|
||||
In order to always overwrite the current word, set the "forced"-kwarg to True.
|
||||
|
||||
Returns the word that was chosen. If one already existed, return that instead.
|
||||
"""
|
||||
collection = database[TemporaryStorage.collection()]
|
||||
|
||||
current_word = None if forced else await get_daily_word(database)
|
||||
if current_word is not None:
|
||||
return current_word
|
||||
|
||||
await collection.update_one(
|
||||
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
|
||||
)
|
||||
|
||||
# Remove all active games
|
||||
await reset_wordle_games(database)
|
||||
|
||||
return word
|
||||
|
||||
|
||||
async def reset_wordle_games(database: MongoDatabase):
|
||||
"""Reset all active games"""
|
||||
collection = database[WordleGame.collection()]
|
||||
await collection.drop()
|
|
@ -28,7 +28,14 @@ DBSession = sessionmaker(
|
|||
)
|
||||
|
||||
# MongoDB client
|
||||
encoded_mongo_username = quote_plus(settings.MONGO_USER)
|
||||
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
|
||||
mongo_url = f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
if not settings.TESTING: # pragma: no cover
|
||||
encoded_mongo_username = quote_plus(settings.MONGO_USER)
|
||||
encoded_mongo_password = quote_plus(settings.MONGO_PASS)
|
||||
mongo_url = (
|
||||
f"mongodb://{encoded_mongo_username}:{encoded_mongo_password}@{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
)
|
||||
else:
|
||||
# Require no authentication when testing
|
||||
mongo_url = f"mongodb://{settings.MONGO_HOST}:{settings.MONGO_PORT}/"
|
||||
|
||||
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(mongo_url)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import enum
|
||||
|
||||
__all__ = ["TaskType"]
|
||||
__all__ = ["TaskType", "TempStorageKey"]
|
||||
|
||||
|
||||
# There is a bug in typeshed that causes an incorrect PyCharm warning
|
||||
|
@ -11,3 +11,10 @@ class TaskType(enum.IntEnum):
|
|||
|
||||
BIRTHDAYS = enum.auto()
|
||||
UFORA_ANNOUNCEMENTS = enum.auto()
|
||||
|
||||
|
||||
@enum.unique
|
||||
class TempStorageKey(str, enum.Enum):
|
||||
"""Enum for keys to distinguish the TemporaryStorage rows"""
|
||||
|
||||
WORDLE_WORD = "wordle_word"
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
import motor.motor_asyncio
|
||||
|
||||
# Type aliases for the Motor types, which are way too long
|
||||
MongoClient = motor.motor_asyncio.AsyncIOMotorClient
|
||||
MongoDatabase = motor.motor_asyncio.AsyncIOMotorDatabase
|
||||
MongoCollection = motor.motor_asyncio.AsyncIOMotorCollection
|
|
@ -1,10 +1,12 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
from bson import ObjectId
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
__all__ = ["MongoBase"]
|
||||
__all__ = ["PyObjectId", "MongoBase", "MongoCollection"]
|
||||
|
||||
|
||||
class PyObjectId(str):
|
||||
class PyObjectId(ObjectId):
|
||||
"""Custom type for bson ObjectIds"""
|
||||
|
||||
@classmethod
|
||||
|
@ -36,3 +38,16 @@ class MongoBase(BaseModel):
|
|||
arbitrary_types_allowed = True
|
||||
json_encoders = {ObjectId: str, PyObjectId: str}
|
||||
use_enum_values = True
|
||||
|
||||
|
||||
class MongoCollection(MongoBase, ABC):
|
||||
"""Base model for the 'main class' in a collection
|
||||
|
||||
This field stores the name of the collection to avoid making typos against it
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def collection() -> str:
|
||||
"""Getter for the name of the collection, in order to avoid typos"""
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,40 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from overrides import overrides
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
|
||||
__all__ = ["GameStats", "WordleStats"]
|
||||
|
||||
|
||||
class WordleStats(BaseModel):
|
||||
"""Model that holds stats about a player's Wordle performance"""
|
||||
|
||||
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
||||
last_win: Optional[datetime.datetime] = None
|
||||
wins: int = 0
|
||||
games: int = 0
|
||||
current_streak: int = 0
|
||||
max_streak: int = 0
|
||||
|
||||
@validator("guess_distribution")
|
||||
def validate_guesses_length(cls, value: list[int]):
|
||||
"""Check that the distribution of guesses is of the correct length"""
|
||||
if len(value) != 6:
|
||||
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class GameStats(MongoCollection):
|
||||
"""Collection that holds stats about how well a user has performed in games"""
|
||||
|
||||
user_id: int
|
||||
wordle: WordleStats = WordleStats()
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "game_stats"
|
|
@ -0,0 +1,16 @@
|
|||
from overrides import overrides
|
||||
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
|
||||
__all__ = ["TemporaryStorage"]
|
||||
|
||||
|
||||
class TemporaryStorage(MongoCollection):
|
||||
"""Collection for lots of random things that don't belong in a full-blown collection"""
|
||||
|
||||
key: str
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "temporary"
|
|
@ -0,0 +1,44 @@
|
|||
import datetime
|
||||
|
||||
from overrides import overrides
|
||||
from pydantic import Field, validator
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT
|
||||
from database.schemas.mongo.common import MongoCollection
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
__all__ = ["WordleGame"]
|
||||
|
||||
|
||||
class WordleGame(MongoCollection):
|
||||
"""Collection that holds people's active Wordle games"""
|
||||
|
||||
day: datetime.datetime = Field(default_factory=lambda: today_only_date())
|
||||
guesses: list[str] = Field(default_factory=list)
|
||||
user_id: int
|
||||
|
||||
@staticmethod
|
||||
@overrides
|
||||
def collection() -> str:
|
||||
return "wordle"
|
||||
|
||||
@validator("guesses")
|
||||
def validate_guesses_length(cls, value: list[int]):
|
||||
"""Check that the amount of guesses is of the correct length"""
|
||||
if len(value) > 6:
|
||||
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
||||
|
||||
return value
|
||||
|
||||
def is_game_over(self, word: str) -> bool:
|
||||
"""Check if the current game is over"""
|
||||
# No guesses yet
|
||||
if not self.guesses:
|
||||
return False
|
||||
|
||||
# Max amount of guesses allowed
|
||||
if len(self.guesses) == WORDLE_GUESS_COUNT:
|
||||
return True
|
||||
|
||||
# Found the correct word
|
||||
return self.guesses[-1] == word
|
|
@ -1,14 +1,18 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Generic, TypeVar
|
||||
|
||||
from overrides import overrides
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_courses
|
||||
from database.crud import ufora_courses, wordle
|
||||
from database.mongo_types import MongoDatabase
|
||||
|
||||
__all__ = ["CacheManager"]
|
||||
__all__ = ["CacheManager", "UforaCourseCache"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class DatabaseCache(ABC):
|
||||
class DatabaseCache(ABC, Generic[T]):
|
||||
"""Base class for a simple cache-like structure
|
||||
|
||||
The goal of this class is to store data for Discord auto-completion results
|
||||
|
@ -31,10 +35,10 @@ class DatabaseCache(ABC):
|
|||
self.data.clear()
|
||||
|
||||
@abstractmethod
|
||||
async def refresh(self, database_session: AsyncSession):
|
||||
async def refresh(self, database_session: T):
|
||||
"""Refresh the data stored in this cache"""
|
||||
|
||||
async def invalidate(self, database_session: AsyncSession):
|
||||
async def invalidate(self, database_session: T):
|
||||
"""Invalidate the data stored in this cache"""
|
||||
await self.refresh(database_session)
|
||||
|
||||
|
@ -45,7 +49,7 @@ class DatabaseCache(ABC):
|
|||
return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
|
||||
|
||||
|
||||
class UforaCourseCache(DatabaseCache):
|
||||
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
||||
"""Cache to store the names of Ufora courses"""
|
||||
|
||||
# Also store the aliases to add additional support
|
||||
|
@ -90,14 +94,26 @@ class UforaCourseCache(DatabaseCache):
|
|||
return sorted(list(results))
|
||||
|
||||
|
||||
class WordleCache(DatabaseCache[MongoDatabase]):
|
||||
"""Cache to store the current daily Wordle word"""
|
||||
|
||||
async def refresh(self, database_session: MongoDatabase):
|
||||
word = await wordle.get_daily_word(database_session)
|
||||
if word is not None:
|
||||
self.data = [word]
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Class that keeps track of all caches"""
|
||||
|
||||
ufora_courses: UforaCourseCache
|
||||
wordle_word: WordleCache
|
||||
|
||||
def __init__(self):
|
||||
self.ufora_courses = UforaCourseCache()
|
||||
self.wordle_word = WordleCache()
|
||||
|
||||
async def initialize_caches(self, database_session: AsyncSession):
|
||||
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
||||
"""Initialize the contents of all caches"""
|
||||
await self.ufora_courses.refresh(database_session)
|
||||
await self.ufora_courses.refresh(postgres_session)
|
||||
await self.wordle_word.refresh(mongo_db)
|
||||
|
|
|
@ -1,5 +1,15 @@
|
|||
import datetime
|
||||
import zoneinfo
|
||||
|
||||
__all__ = ["LOCAL_TIMEZONE"]
|
||||
__all__ = ["LOCAL_TIMEZONE", "today_only_date"]
|
||||
|
||||
LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||
|
||||
|
||||
def today_only_date() -> datetime.datetime:
|
||||
"""Mongo can't handle datetime.date, so we need a datetime instance
|
||||
|
||||
We do, however, only care about the date, so remove all the rest
|
||||
"""
|
||||
today = datetime.date.today()
|
||||
return datetime.datetime(year=today.year, month=today.month, day=today.day)
|
||||
|
|
|
@ -0,0 +1,67 @@
|
|||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||
from database.crud.wordle import (
|
||||
get_active_wordle_game,
|
||||
make_wordle_guess,
|
||||
start_new_wordle_game,
|
||||
)
|
||||
from didier import Didier
|
||||
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed
|
||||
|
||||
|
||||
class Games(commands.Cog):
|
||||
"""Cog for various games"""
|
||||
|
||||
client: Didier
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@app_commands.command(name="wordle", description="Play Wordle!")
|
||||
async def wordle(self, interaction: discord.Interaction, guess: Optional[str] = None):
|
||||
"""View your active Wordle game
|
||||
|
||||
If an argument is provided, make a guess instead
|
||||
"""
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Guess is wrong length
|
||||
if guess is not None and len(guess) != 0 and len(guess) != WORDLE_WORD_LENGTH:
|
||||
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id)
|
||||
if active_game is None:
|
||||
active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id)
|
||||
|
||||
# Trying to guess with a complete game
|
||||
if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess:
|
||||
embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
# Make a guess
|
||||
if guess:
|
||||
# The guess is not a real word
|
||||
if guess.lower() not in self.client.wordle_words:
|
||||
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
guess = guess.lower()
|
||||
await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess)
|
||||
|
||||
# Don't re-request the game, we already have it
|
||||
# just append locally
|
||||
active_game.guesses.append(guess)
|
||||
|
||||
embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed()
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(Games(client))
|
|
@ -8,6 +8,7 @@ import settings
|
|||
from database import enums
|
||||
from database.crud.birthdays import get_birthdays_on_day
|
||||
from database.crud.ufora_announcements import remove_old_announcements
|
||||
from database.crud.wordle import set_daily_word
|
||||
from didier import Didier
|
||||
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
||||
from didier.decorators.tasks import timed_task
|
||||
|
@ -46,7 +47,14 @@ class Tasks(commands.Cog):
|
|||
self.pull_ufora_announcements.start()
|
||||
self.remove_old_ufora_announcements.start()
|
||||
|
||||
self._tasks = {"birthdays": self.check_birthdays, "ufora": self.pull_ufora_announcements}
|
||||
# Start other tasks
|
||||
self.reset_wordle_word.start()
|
||||
|
||||
self._tasks = {
|
||||
"birthdays": self.check_birthdays,
|
||||
"ufora": self.pull_ufora_announcements,
|
||||
"wordle": self.reset_wordle_word,
|
||||
}
|
||||
|
||||
@commands.group(name="Tasks", aliases=["Task"], case_insensitive=True, invoke_without_command=True)
|
||||
@commands.check(is_owner)
|
||||
|
@ -65,12 +73,15 @@ class Tasks(commands.Cog):
|
|||
return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False)
|
||||
|
||||
task = self._tasks[name]
|
||||
await task()
|
||||
await task(forced=True)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
||||
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
|
||||
@timed_task(enums.TaskType.BIRTHDAYS)
|
||||
async def check_birthdays(self):
|
||||
async def check_birthdays(self, **kwargs):
|
||||
"""Check if it's currently anyone's birthday"""
|
||||
_ = kwargs
|
||||
|
||||
now = tz_aware_now().date()
|
||||
async with self.client.postgres_session as session:
|
||||
birthdays = await get_birthdays_on_day(session, now)
|
||||
|
@ -90,8 +101,10 @@ class Tasks(commands.Cog):
|
|||
|
||||
@tasks.loop(minutes=10)
|
||||
@timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS)
|
||||
async def pull_ufora_announcements(self):
|
||||
async def pull_ufora_announcements(self, **kwargs):
|
||||
"""Task that checks for new Ufora announcements & logs them in a channel"""
|
||||
_ = kwargs
|
||||
|
||||
# In theory this shouldn't happen but just to please Mypy
|
||||
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
||||
return
|
||||
|
@ -113,6 +126,17 @@ class Tasks(commands.Cog):
|
|||
async with self.client.postgres_session as session:
|
||||
await remove_old_announcements(session)
|
||||
|
||||
@tasks.loop(time=DAILY_RESET_TIME)
|
||||
async def reset_wordle_word(self, forced: bool = False):
|
||||
"""Reset the daily Wordle word"""
|
||||
db = self.client.mongo_db
|
||||
word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words)), forced=forced)
|
||||
self.client.database_caches.wordle_word.data = [word]
|
||||
|
||||
@reset_wordle_word.before_loop
|
||||
async def _before_reset_wordle_word(self):
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
@check_birthdays.error
|
||||
@pull_ufora_announcements.error
|
||||
@remove_old_ufora_announcements.error
|
||||
|
@ -123,5 +147,11 @@ class Tasks(commands.Cog):
|
|||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(Tasks(client))
|
||||
"""Load the cog
|
||||
|
||||
Initially fetch the wordle word from the database, or reset it
|
||||
if there hasn't been a reset yet today
|
||||
"""
|
||||
cog = Tasks(client)
|
||||
await client.add_cog(cog)
|
||||
await cog.reset_wordle_word()
|
||||
|
|
|
@ -13,7 +13,7 @@ class EmbedBaseModel(ABC):
|
|||
"""Abstract base class for a model that can be turned into a Discord embed"""
|
||||
|
||||
@abstractmethod
|
||||
def to_embed(self) -> discord.Embed:
|
||||
def to_embed(self, **kwargs: dict) -> discord.Embed:
|
||||
"""Turn this model into a Discord embed"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
@ -33,7 +33,7 @@ class GoogleSearch(EmbedBaseModel):
|
|||
return embed
|
||||
|
||||
@overrides
|
||||
def to_embed(self) -> discord.Embed:
|
||||
def to_embed(self, **kwargs: dict) -> discord.Embed:
|
||||
if not self.data.results or self.data.status_code != HTTPStatus.OK:
|
||||
return self._error_embed()
|
||||
|
||||
|
|
|
@ -47,7 +47,7 @@ class UforaNotification(EmbedBaseModel):
|
|||
self.published_dt = self._published_datetime()
|
||||
self._published = self._get_published()
|
||||
|
||||
def to_embed(self) -> discord.Embed:
|
||||
def to_embed(self, **kwargs: dict) -> discord.Embed:
|
||||
"""Turn the notification into an embed"""
|
||||
embed = discord.Embed(colour=discord.Colour.from_rgb(30, 100, 200))
|
||||
|
||||
|
|
|
@ -46,7 +46,7 @@ class Definition(EmbedPydantic):
|
|||
return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH)
|
||||
|
||||
@overrides
|
||||
def to_embed(self) -> discord.Embed:
|
||||
def to_embed(self, **kwargs: dict) -> discord.Embed:
|
||||
embed = discord.Embed(colour=colours.urban_dictionary_green())
|
||||
embed.set_author(name="Urban Dictionary")
|
||||
|
||||
|
|
|
@ -0,0 +1,133 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from overrides import overrides
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
from didier.data.embeds.base import EmbedBaseModel
|
||||
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
|
||||
|
||||
__all__ = ["WordleEmbed", "WordleErrorEmbed"]
|
||||
|
||||
|
||||
def footer() -> str:
|
||||
"""Create the footer to put on the embed"""
|
||||
today = tz_aware_now()
|
||||
return f"{int_to_weekday(today.weekday())} {today.strftime('%d/%m/%Y')}"
|
||||
|
||||
|
||||
class WordleColour(enum.IntEnum):
|
||||
"""Colours for the Wordle embed"""
|
||||
|
||||
EMPTY = 0
|
||||
WRONG_LETTER = 1
|
||||
WRONG_POSITION = 2
|
||||
CORRECT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordleEmbed(EmbedBaseModel):
|
||||
"""Embed for a Wordle game"""
|
||||
|
||||
game: Optional[WordleGame]
|
||||
word: str
|
||||
|
||||
def _letter_colour(self, guess: str, index: int) -> WordleColour:
|
||||
"""Get the colour for a guess at a given position"""
|
||||
if guess[index] == self.word[index]:
|
||||
return WordleColour.CORRECT
|
||||
|
||||
wrong_letter = 0
|
||||
wrong_position = 0
|
||||
|
||||
for i, letter in enumerate(self.word):
|
||||
if letter == guess[index] and guess[i] != guess[index]:
|
||||
wrong_letter += 1
|
||||
|
||||
if i <= index and guess[i] == guess[index] and letter != guess[index]:
|
||||
wrong_position += 1
|
||||
|
||||
if i >= index:
|
||||
if wrong_position == 0:
|
||||
break
|
||||
|
||||
if wrong_position <= wrong_letter:
|
||||
return WordleColour.WRONG_POSITION
|
||||
|
||||
return WordleColour.WRONG_LETTER
|
||||
|
||||
def _guess_colours(self, guess: str) -> list[WordleColour]:
|
||||
"""Create the colour codes for a specific guess"""
|
||||
return [self._letter_colour(guess, i) for i in range(WORDLE_WORD_LENGTH)]
|
||||
|
||||
def colour_code_game(self) -> list[list[WordleColour]]:
|
||||
"""Create the colour codes for an entire game"""
|
||||
colours = []
|
||||
|
||||
# Add all the guesses
|
||||
if self.game is not None:
|
||||
for guess in self.game.guesses:
|
||||
colours.append(self._guess_colours(guess))
|
||||
|
||||
# Fill the rest with empty spots
|
||||
for _ in range(WORDLE_GUESS_COUNT - len(colours)):
|
||||
colours.append([WordleColour.EMPTY] * WORDLE_WORD_LENGTH)
|
||||
|
||||
return colours
|
||||
|
||||
def _colours_to_emojis(self, colours: list[list[WordleColour]]) -> list[list[str]]:
|
||||
"""Turn the colours of the board into Discord emojis"""
|
||||
colour_map = {
|
||||
WordleColour.EMPTY: ":white_large_square:",
|
||||
WordleColour.WRONG_LETTER: ":black_large_square:",
|
||||
WordleColour.WRONG_POSITION: ":orange_square:",
|
||||
WordleColour.CORRECT: ":green_square:",
|
||||
}
|
||||
|
||||
emojis = []
|
||||
for row in colours:
|
||||
emojis.append(list(map(lambda char: colour_map[char], row)))
|
||||
|
||||
return emojis
|
||||
|
||||
@overrides
|
||||
def to_embed(self, **kwargs) -> discord.Embed:
|
||||
only_colours = kwargs.get("only_colours", False)
|
||||
|
||||
colours = self.colour_code_game()
|
||||
|
||||
embed = discord.Embed(colour=discord.Colour.blue(), title="Wordle")
|
||||
emojis = self._colours_to_emojis(colours)
|
||||
|
||||
rows = [" ".join(row) for row in emojis]
|
||||
|
||||
# Don't reveal anything if we only want to show the colours
|
||||
if not only_colours and self.game is not None:
|
||||
for i, guess in enumerate(self.game.guesses):
|
||||
rows[i] += f" ||{guess.upper()}||"
|
||||
|
||||
# If the game is over, reveal the word
|
||||
if self.game.is_game_over(self.word):
|
||||
rows.append(f"\n\nThe word was **{self.word.upper()}**!")
|
||||
|
||||
embed.description = "\n\n".join(rows)
|
||||
embed.set_footer(text=footer())
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordleErrorEmbed(EmbedBaseModel):
|
||||
"""Embed to send error messages to the user"""
|
||||
|
||||
message: str
|
||||
|
||||
@overrides
|
||||
def to_embed(self, **kwargs: dict) -> discord.Embed:
|
||||
embed = discord.Embed(colour=discord.Colour.red(), title="Wordle")
|
||||
embed.description = self.message
|
||||
embed.set_footer(text=footer())
|
||||
return embed
|
|
@ -27,6 +27,7 @@ class Didier(commands.Bot):
|
|||
error_channel: discord.abc.Messageable
|
||||
initial_extensions: tuple[str, ...] = ()
|
||||
http_session: ClientSession
|
||||
wordle_words: set[str] = set()
|
||||
|
||||
def __init__(self):
|
||||
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
|
||||
|
@ -60,14 +61,17 @@ class Didier(commands.Bot):
|
|||
|
||||
This hook is called once the bot is initialised
|
||||
"""
|
||||
# Load extensions
|
||||
await self._load_initial_extensions()
|
||||
await self._load_directory_extensions("didier/cogs")
|
||||
# Load the Wordle dictionary
|
||||
self._load_wordle_words()
|
||||
|
||||
# Initialize caches
|
||||
self.database_caches = CacheManager()
|
||||
async with self.postgres_session as session:
|
||||
await self.database_caches.initialize_caches(session)
|
||||
await self.database_caches.initialize_caches(session, self.mongo_db)
|
||||
|
||||
# Load extensions
|
||||
await self._load_initial_extensions()
|
||||
await self._load_directory_extensions("didier/cogs")
|
||||
|
||||
# Create aiohttp session
|
||||
self.http_session = ClientSession()
|
||||
|
@ -101,6 +105,12 @@ class Didier(commands.Bot):
|
|||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||
await self._load_directory_extensions(new_path)
|
||||
|
||||
def _load_wordle_words(self):
|
||||
"""Load the dictionary of Wordle words"""
|
||||
with open("files/dictionaries/words-english-wordle.txt", "r") as fp:
|
||||
for line in fp:
|
||||
self.wordle_words.add(line.strip())
|
||||
|
||||
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
||||
"""Fetch a message from a reference"""
|
||||
# Message is in the cache, return it
|
||||
|
|
|
@ -10,7 +10,7 @@ LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
|||
|
||||
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
|
||||
"""Get the Dutch name of a weekday from the number"""
|
||||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
||||
return ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"][number]
|
||||
|
||||
|
||||
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
|
||||
|
|
|
@ -14,8 +14,6 @@ services:
|
|||
image: mongo:5.0
|
||||
restart: always
|
||||
environment:
|
||||
- MONGO_INITDB_ROOT_USERNAME=pytest
|
||||
- MONGO_INITDB_ROOT_PASSWORD=pytest
|
||||
- MONGO_INITDB_DATABASE=didier_pytest
|
||||
ports:
|
||||
- "27018:27017"
|
||||
|
|
File diff suppressed because it is too large
Load Diff
|
@ -20,6 +20,7 @@ omit = [
|
|||
"./didier/utils/discord/colours.py",
|
||||
"./didier/utils/discord/constants.py",
|
||||
"./didier/utils/discord/flags/*",
|
||||
"./didier/views/modals/*"
|
||||
]
|
||||
|
||||
[tool.isort]
|
||||
|
@ -42,9 +43,8 @@ ignore_missing_imports = true
|
|||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
env = [
|
||||
"TESTING = 1",
|
||||
"MONGO_DB = didier_pytest",
|
||||
"MONGO_USER = pytest",
|
||||
"MONGO_PASS = pytest",
|
||||
"MONGO_HOST = localhost",
|
||||
"MONGO_PORT = 27018",
|
||||
"POSTGRES_DB = didier_pytest",
|
||||
|
@ -54,3 +54,7 @@ env = [
|
|||
"POSTGRES_PORT = 5433",
|
||||
"DISCORD_TOKEN = token"
|
||||
]
|
||||
markers = [
|
||||
"mongo: tests that use MongoDB",
|
||||
"postgres: tests that use PostgreSQL"
|
||||
]
|
||||
|
|
|
@ -30,10 +30,10 @@ A separate database is used in the tests, as it would obviously not be ideal whe
|
|||
|
||||
```shell
|
||||
# Starting the database
|
||||
docker-compose up -d db
|
||||
docker compose up -d
|
||||
|
||||
# Starting the database used in tests
|
||||
docker-compose up -d db-pytest
|
||||
docker compose -f docker-compose.test.yml up -d
|
||||
```
|
||||
|
||||
### Commands
|
||||
|
|
|
@ -8,6 +8,7 @@ env.read_env()
|
|||
|
||||
__all__ = [
|
||||
"SANDBOX",
|
||||
"TESTING",
|
||||
"LOGFILE",
|
||||
"POSTGRES_DB",
|
||||
"POSTGRES_USER",
|
||||
|
@ -28,6 +29,7 @@ __all__ = [
|
|||
|
||||
"""General config"""
|
||||
SANDBOX: bool = env.bool("SANDBOX", True)
|
||||
TESTING: bool = env.bool("TESTING", False)
|
||||
LOGFILE: str = env.str("LOGFILE", "didier.log")
|
||||
SEMESTER: int = env.int("SEMESTER", 2)
|
||||
YEAR: int = env.int("YEAR", 3)
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users
|
||||
from database.schemas.relational import (
|
||||
|
@ -22,7 +23,7 @@ def test_user_id() -> int:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def user(postgres, test_user_id) -> User:
|
||||
async def user(postgres: AsyncSession, test_user_id: int) -> User:
|
||||
"""Fixture to create a user"""
|
||||
_user = await users.get_or_add(postgres, test_user_id)
|
||||
await postgres.refresh(_user)
|
||||
|
@ -30,7 +31,7 @@ async def user(postgres, test_user_id) -> User:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def bank(postgres, user: User) -> Bank:
|
||||
async def bank(postgres: AsyncSession, user: User) -> Bank:
|
||||
"""Fixture to fetch the test user's bank"""
|
||||
_bank = user.bank
|
||||
await postgres.refresh(_bank)
|
||||
|
@ -38,7 +39,7 @@ async def bank(postgres, user: User) -> Bank:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def ufora_course(postgres) -> UforaCourse:
|
||||
async def ufora_course(postgres: AsyncSession) -> UforaCourse:
|
||||
"""Fixture to create a course"""
|
||||
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
||||
postgres.add(course)
|
||||
|
@ -47,7 +48,7 @@ async def ufora_course(postgres) -> UforaCourse:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse:
|
||||
async def ufora_course_with_alias(postgres: AsyncSession, ufora_course: UforaCourse) -> UforaCourse:
|
||||
"""Fixture to create a course with an alias"""
|
||||
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
|
||||
postgres.add(alias)
|
||||
|
@ -57,7 +58,7 @@ async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaC
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement:
|
||||
async def ufora_announcement(postgres: AsyncSession, ufora_course: UforaCourse) -> UforaAnnouncement:
|
||||
"""Fixture to create an announcement"""
|
||||
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
|
||||
postgres.add(announcement)
|
||||
|
|
|
@ -1,13 +1,14 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import birthdays as crud
|
||||
from database.crud import users
|
||||
from database.schemas.relational import User
|
||||
|
||||
|
||||
async def test_add_birthday_not_present(postgres, user: User):
|
||||
async def test_add_birthday_not_present(postgres: AsyncSession, user: User):
|
||||
"""Test setting a user's birthday when it doesn't exist yet"""
|
||||
assert user.birthday is None
|
||||
|
||||
|
@ -18,7 +19,7 @@ async def test_add_birthday_not_present(postgres, user: User):
|
|||
assert user.birthday.birthday == bd_date
|
||||
|
||||
|
||||
async def test_add_birthday_overwrite(postgres, user: User):
|
||||
async def test_add_birthday_overwrite(postgres: AsyncSession, user: User):
|
||||
"""Test that setting a user's birthday when it already exists overwrites it"""
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||
|
@ -31,7 +32,7 @@ async def test_add_birthday_overwrite(postgres, user: User):
|
|||
assert user.birthday.birthday == new_bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_exists(postgres, user: User):
|
||||
async def test_get_birthday_exists(postgres: AsyncSession, user: User):
|
||||
"""Test getting a user's birthday when it exists"""
|
||||
bd_date = datetime.today().date()
|
||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||
|
@ -42,14 +43,14 @@ async def test_get_birthday_exists(postgres, user: User):
|
|||
assert bd.birthday == bd_date
|
||||
|
||||
|
||||
async def test_get_birthday_not_exists(postgres, user: User):
|
||||
async def test_get_birthday_not_exists(postgres: AsyncSession, user: User):
|
||||
"""Test getting a user's birthday when it doesn't exist"""
|
||||
bd = await crud.get_birthday_for_user(postgres, user.user_id)
|
||||
assert bd is None
|
||||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_on_day(postgres, user: User):
|
||||
async def test_get_birthdays_on_day(postgres: AsyncSession, user: User):
|
||||
"""Test getting all birthdays on a given day"""
|
||||
await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
|
||||
|
||||
|
@ -61,7 +62,7 @@ async def test_get_birthdays_on_day(postgres, user: User):
|
|||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_get_birthdays_none_present(postgres):
|
||||
async def test_get_birthdays_none_present(postgres: AsyncSession):
|
||||
"""Test getting all birthdays when there are none"""
|
||||
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
||||
assert len(birthdays) == 0
|
||||
|
|
|
@ -2,13 +2,14 @@ import datetime
|
|||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import currency as crud
|
||||
from database.exceptions import currency as exceptions
|
||||
from database.schemas.relational import Bank
|
||||
|
||||
|
||||
async def test_add_dinks(postgres, bank: Bank):
|
||||
async def test_add_dinks(postgres: AsyncSession, bank: Bank):
|
||||
"""Test adding dinks to an account"""
|
||||
assert bank.dinks == 0
|
||||
await crud.add_dinks(postgres, bank.user_id, 10)
|
||||
|
@ -17,7 +18,7 @@ async def test_add_dinks(postgres, bank: Bank):
|
|||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_claim_nightly_available(postgres, bank: Bank):
|
||||
async def test_claim_nightly_available(postgres: AsyncSession, bank: Bank):
|
||||
"""Test claiming nightlies when it hasn't been done yet"""
|
||||
await crud.claim_nightly(postgres, bank.user_id)
|
||||
await postgres.refresh(bank)
|
||||
|
@ -28,7 +29,7 @@ async def test_claim_nightly_available(postgres, bank: Bank):
|
|||
|
||||
|
||||
@freeze_time("2022/07/23")
|
||||
async def test_claim_nightly_unavailable(postgres, bank: Bank):
|
||||
async def test_claim_nightly_unavailable(postgres: AsyncSession, bank: Bank):
|
||||
"""Test claiming nightlies twice in a day"""
|
||||
await crud.claim_nightly(postgres, bank.user_id)
|
||||
|
||||
|
@ -39,7 +40,7 @@ async def test_claim_nightly_unavailable(postgres, bank: Bank):
|
|||
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||
|
||||
|
||||
async def test_invest(postgres, bank: Bank):
|
||||
async def test_invest(postgres: AsyncSession, bank: Bank):
|
||||
"""Test investing some Dinks"""
|
||||
bank.dinks = 100
|
||||
postgres.add(bank)
|
||||
|
@ -52,7 +53,7 @@ async def test_invest(postgres, bank: Bank):
|
|||
assert bank.invested == 20
|
||||
|
||||
|
||||
async def test_invest_all(postgres, bank: Bank):
|
||||
async def test_invest_all(postgres: AsyncSession, bank: Bank):
|
||||
"""Test investing all dinks"""
|
||||
bank.dinks = 100
|
||||
postgres.add(bank)
|
||||
|
@ -65,7 +66,7 @@ async def test_invest_all(postgres, bank: Bank):
|
|||
assert bank.invested == 100
|
||||
|
||||
|
||||
async def test_invest_more_than_owned(postgres, bank: Bank):
|
||||
async def test_invest_more_than_owned(postgres: AsyncSession, bank: Bank):
|
||||
"""Test investing more Dinks than you own"""
|
||||
bank.dinks = 100
|
||||
postgres.add(bank)
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import pytest
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import custom_commands as crud
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
|
@ -7,7 +8,7 @@ from database.exceptions.not_found import NoResultFoundException
|
|||
from database.schemas.relational import CustomCommand
|
||||
|
||||
|
||||
async def test_create_command_non_existing(postgres):
|
||||
async def test_create_command_non_existing(postgres: AsyncSession):
|
||||
"""Test creating a new command when it doesn't exist yet"""
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
|
||||
|
@ -16,7 +17,7 @@ async def test_create_command_non_existing(postgres):
|
|||
assert commands[0].name == "name"
|
||||
|
||||
|
||||
async def test_create_command_duplicate_name(postgres):
|
||||
async def test_create_command_duplicate_name(postgres: AsyncSession):
|
||||
"""Test creating a command when the name already exists"""
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
|
||||
|
@ -24,7 +25,7 @@ async def test_create_command_duplicate_name(postgres):
|
|||
await crud.create_command(postgres, "name", "other response")
|
||||
|
||||
|
||||
async def test_create_command_name_is_alias(postgres):
|
||||
async def test_create_command_name_is_alias(postgres: AsyncSession):
|
||||
"""Test creating a command when the name is taken by an alias"""
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, "name", "n")
|
||||
|
@ -33,7 +34,7 @@ async def test_create_command_name_is_alias(postgres):
|
|||
await crud.create_command(postgres, "n", "other response")
|
||||
|
||||
|
||||
async def test_create_alias(postgres):
|
||||
async def test_create_alias(postgres: AsyncSession):
|
||||
"""Test creating an alias when the name is still free"""
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
@ -43,13 +44,13 @@ async def test_create_alias(postgres):
|
|||
assert command.aliases[0].alias == "n"
|
||||
|
||||
|
||||
async def test_create_alias_non_existing(postgres):
|
||||
async def test_create_alias_non_existing(postgres: AsyncSession):
|
||||
"""Test creating an alias when the command doesn't exist"""
|
||||
with pytest.raises(NoResultFoundException):
|
||||
await crud.create_alias(postgres, "name", "alias")
|
||||
|
||||
|
||||
async def test_create_alias_duplicate(postgres):
|
||||
async def test_create_alias_duplicate(postgres: AsyncSession):
|
||||
"""Test creating an alias when another alias already has this name"""
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
@ -58,7 +59,7 @@ async def test_create_alias_duplicate(postgres):
|
|||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
||||
|
||||
async def test_create_alias_is_command(postgres):
|
||||
async def test_create_alias_is_command(postgres: AsyncSession):
|
||||
"""Test creating an alias when the name is taken by a command"""
|
||||
await crud.create_command(postgres, "n", "response")
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
|
@ -67,7 +68,7 @@ async def test_create_alias_is_command(postgres):
|
|||
await crud.create_alias(postgres, command.name, "n")
|
||||
|
||||
|
||||
async def test_create_alias_match_by_alias(postgres):
|
||||
async def test_create_alias_match_by_alias(postgres: AsyncSession):
|
||||
"""Test creating an alias for a command when matching the name to another alias"""
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "a1")
|
||||
|
@ -75,21 +76,21 @@ async def test_create_alias_match_by_alias(postgres):
|
|||
assert alias.command == command
|
||||
|
||||
|
||||
async def test_get_command_by_name_exists(postgres):
|
||||
async def test_get_command_by_name_exists(postgres: AsyncSession):
|
||||
"""Test getting a command by name"""
|
||||
await crud.create_command(postgres, "name", "response")
|
||||
command = await crud.get_command(postgres, "name")
|
||||
assert command is not None
|
||||
|
||||
|
||||
async def test_get_command_by_cleaned_name(postgres):
|
||||
async def test_get_command_by_cleaned_name(postgres: AsyncSession):
|
||||
"""Test getting a command by the cleaned version of the name"""
|
||||
command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response")
|
||||
found = await crud.get_command(postgres, "capitalizednamewithspaces")
|
||||
assert command == found
|
||||
|
||||
|
||||
async def test_get_command_by_alias(postgres):
|
||||
async def test_get_command_by_alias(postgres: AsyncSession):
|
||||
"""Test getting a command by an alias"""
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.create_alias(postgres, command.name, "a1")
|
||||
|
@ -99,12 +100,12 @@ async def test_get_command_by_alias(postgres):
|
|||
assert command == found
|
||||
|
||||
|
||||
async def test_get_command_non_existing(postgres):
|
||||
async def test_get_command_non_existing(postgres: AsyncSession):
|
||||
"""Test getting a command when it doesn't exist"""
|
||||
assert await crud.get_command(postgres, "name") is None
|
||||
|
||||
|
||||
async def test_edit_command(postgres):
|
||||
async def test_edit_command(postgres: AsyncSession):
|
||||
"""Test editing an existing command"""
|
||||
command = await crud.create_command(postgres, "name", "response")
|
||||
await crud.edit_command(postgres, command.name, "new name", "new response")
|
||||
|
@ -112,7 +113,7 @@ async def test_edit_command(postgres):
|
|||
assert command.response == "new response"
|
||||
|
||||
|
||||
async def test_edit_command_non_existing(postgres):
|
||||
async def test_edit_command_non_existing(postgres: AsyncSession):
|
||||
"""Test editing a command that doesn't exist"""
|
||||
with pytest.raises(NoResultFoundException):
|
||||
await crud.edit_command(postgres, "name", "n", "r")
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import dad_jokes as crud
|
||||
from database.schemas.relational import DadJoke
|
||||
|
||||
|
||||
async def test_add_dad_joke(postgres):
|
||||
async def test_add_dad_joke(postgres: AsyncSession):
|
||||
"""Test creating a new joke"""
|
||||
statement = select(DadJoke)
|
||||
result = (await postgres.execute(statement)).scalars().all()
|
||||
|
|
|
@ -0,0 +1,63 @@
|
|||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from database.crud import game_stats as crud
|
||||
from database.mongo_types import MongoDatabase
|
||||
from database.schemas.mongo.game_stats import GameStats
|
||||
from database.utils.datetime import today_only_date
|
||||
|
||||
|
||||
async def insert_game_stats(mongodb: MongoDatabase, stats: GameStats):
|
||||
"""Helper function to insert some stats"""
|
||||
collection = mongodb[GameStats.collection()]
|
||||
await collection.insert_one(stats.dict(by_alias=True))
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_stats_non_existent_creates(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test getting a user's stats when the db is empty"""
|
||||
collection = mongodb[GameStats.collection()]
|
||||
assert await collection.find_one({"user_id": test_user_id}) is None
|
||||
await crud.get_game_stats(mongodb, test_user_id)
|
||||
assert await collection.find_one({"user_id": test_user_id}) is not None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_stats_existing_returns(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test getting a user's stats when there's already an entry present"""
|
||||
stats = GameStats(user_id=test_user_id)
|
||||
stats.wordle.games = 20
|
||||
await insert_game_stats(mongodb, stats)
|
||||
found_stats = await crud.get_game_stats(mongodb, test_user_id)
|
||||
assert found_stats.wordle.games == 20
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_won(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test completing a wordle game when you win"""
|
||||
await crud.complete_wordle_game(mongodb, test_user_id, win=True, guesses=2)
|
||||
stats = await crud.get_game_stats(mongodb, test_user_id)
|
||||
assert stats.wordle.guess_distribution == [0, 1, 0, 0, 0, 0]
|
||||
assert stats.wordle.games == 1
|
||||
assert stats.wordle.wins == 1
|
||||
assert stats.wordle.current_streak == 1
|
||||
assert stats.wordle.max_streak == 1
|
||||
assert stats.wordle.last_win == today_only_date()
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_lost(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test completing a wordle game when you lose"""
|
||||
stats = GameStats(user_id=test_user_id)
|
||||
stats.wordle.current_streak = 10
|
||||
await insert_game_stats(mongodb, stats)
|
||||
|
||||
await crud.complete_wordle_game(mongodb, test_user_id, win=False)
|
||||
stats = await crud.get_game_stats(mongodb, test_user_id)
|
||||
|
||||
# Check that streak was broken
|
||||
assert stats.wordle.current_streak == 0
|
||||
assert stats.wordle.games == 1
|
||||
assert stats.wordle.guess_distribution == [0, 0, 0, 0, 0, 0]
|
|
@ -3,6 +3,7 @@ import datetime
|
|||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import tasks as crud
|
||||
from database.enums import TaskType
|
||||
|
@ -16,7 +17,7 @@ def task_type() -> TaskType:
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def task(postgres, task_type: TaskType) -> Task:
|
||||
async def task(postgres: AsyncSession, task_type: TaskType) -> Task:
|
||||
"""Fixture to create a task"""
|
||||
task = Task(task=task_type)
|
||||
postgres.add(task)
|
||||
|
@ -24,21 +25,21 @@ async def task(postgres, task_type: TaskType) -> Task:
|
|||
return task
|
||||
|
||||
|
||||
async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType):
|
||||
async def test_get_task_by_enum_present(postgres: AsyncSession, task: Task, task_type: TaskType):
|
||||
"""Test getting a task by its enum type when it exists"""
|
||||
result = await crud.get_task_by_enum(postgres, task_type)
|
||||
assert result is not None
|
||||
assert result == task
|
||||
|
||||
|
||||
async def test_get_task_by_enum_not_present(postgres, task_type: TaskType):
|
||||
async def test_get_task_by_enum_not_present(postgres: AsyncSession, task_type: TaskType):
|
||||
"""Test getting a task by its enum type when it doesn't exist"""
|
||||
result = await crud.get_task_by_enum(postgres, task_type)
|
||||
assert result is None
|
||||
|
||||
|
||||
@freeze_time("2022/07/24")
|
||||
async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType):
|
||||
async def test_set_execution_time_exists(postgres: AsyncSession, task: Task, task_type: TaskType):
|
||||
"""Test setting the execution time of an existing task"""
|
||||
await postgres.refresh(task)
|
||||
assert task.previous_run is None
|
||||
|
@ -49,7 +50,7 @@ async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskTy
|
|||
|
||||
|
||||
@freeze_time("2022/07/24")
|
||||
async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType):
|
||||
async def test_set_execution_time_doesnt_exist(postgres: AsyncSession, task_type: TaskType):
|
||||
"""Test setting the execution time of a non-existing task"""
|
||||
statement = select(Task).where(Task.task == task_type)
|
||||
results = list((await postgres.execute(statement)).scalars().all())
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
import datetime
|
||||
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_announcements as crud
|
||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||
|
||||
|
||||
async def test_get_courses_with_announcements_none(postgres):
|
||||
async def test_get_courses_with_announcements_none(postgres: AsyncSession):
|
||||
"""Test getting all courses with announcements when there are none"""
|
||||
results = await crud.get_courses_with_announcements(postgres)
|
||||
assert len(results) == 0
|
||||
|
||||
|
||||
async def test_get_courses_with_announcements(postgres):
|
||||
async def test_get_courses_with_announcements(postgres: AsyncSession):
|
||||
"""Test getting all courses with announcements"""
|
||||
course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
||||
course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False)
|
||||
|
@ -22,14 +24,14 @@ async def test_get_courses_with_announcements(postgres):
|
|||
assert results[0] == course_1
|
||||
|
||||
|
||||
async def test_create_new_announcement(ufora_course: UforaCourse, postgres):
|
||||
async def test_create_new_announcement(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||
"""Test creating a new announcement"""
|
||||
await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now())
|
||||
await postgres.refresh(ufora_course)
|
||||
assert len(ufora_course.announcements) == 1
|
||||
|
||||
|
||||
async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres):
|
||||
async def test_remove_old_announcements(postgres: AsyncSession, ufora_announcement: UforaAnnouncement):
|
||||
"""Test removing all stale announcements"""
|
||||
course = ufora_announcement.course
|
||||
ufora_announcement.publication_date -= datetime.timedelta(weeks=2)
|
||||
|
|
|
@ -1,20 +1,22 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import ufora_courses as crud
|
||||
from database.schemas.relational import UforaCourse
|
||||
|
||||
|
||||
async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse):
|
||||
async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||
"""Test getting a course by its name when the query is an exact match"""
|
||||
match = await crud.get_course_by_name(postgres, "Test")
|
||||
assert match == ufora_course
|
||||
|
||||
|
||||
async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse):
|
||||
async def test_get_course_by_name_substring(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||
"""Test getting a course by its name when the query is a substring"""
|
||||
match = await crud.get_course_by_name(postgres, "es")
|
||||
assert match == ufora_course
|
||||
|
||||
|
||||
async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse):
|
||||
async def test_get_course_by_name_alias(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||
"""Test getting a course by its name when the name doesn't match, but the alias does"""
|
||||
match = await crud.get_course_by_name(postgres, "ali")
|
||||
assert match == ufora_course_with_alias
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import users as crud
|
||||
from database.schemas.relational import User
|
||||
|
||||
|
||||
async def test_get_or_add_non_existing(postgres):
|
||||
async def test_get_or_add_non_existing(postgres: AsyncSession):
|
||||
"""Test get_or_add for a user that doesn't exist"""
|
||||
await crud.get_or_add(postgres, 1)
|
||||
statement = select(User)
|
||||
|
@ -15,7 +16,7 @@ async def test_get_or_add_non_existing(postgres):
|
|||
assert res[0].nightly_data is not None
|
||||
|
||||
|
||||
async def test_get_or_add_existing(postgres):
|
||||
async def test_get_or_add_existing(postgres: AsyncSession):
|
||||
"""Test get_or_add for a user that does exist"""
|
||||
user = await crud.get_or_add(postgres, 1)
|
||||
bank = user.bank
|
||||
|
|
|
@ -0,0 +1,136 @@
|
|||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from database.crud import wordle as crud
|
||||
from database.enums import TempStorageKey
|
||||
from database.mongo_types import MongoCollection, MongoDatabase
|
||||
from database.schemas.mongo.temporary_storage import TemporaryStorage
|
||||
from database.schemas.mongo.wordle import WordleGame
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def wordle_collection(mongodb: MongoDatabase) -> MongoCollection:
|
||||
"""Fixture to get a reference to the wordle collection"""
|
||||
yield mongodb[WordleGame.collection()]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> WordleGame:
|
||||
"""Fixture to create a new game"""
|
||||
game = WordleGame(user_id=test_user_id)
|
||||
await wordle_collection.insert_one(game.dict(by_alias=True))
|
||||
yield game
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int):
|
||||
"""Test starting a new game"""
|
||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
||||
assert result is None
|
||||
|
||||
await crud.start_new_wordle_game(mongodb, test_user_id)
|
||||
|
||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
||||
assert result is not None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int):
|
||||
"""Test getting an active game when there is none"""
|
||||
result = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame):
|
||||
"""Test getting an active game when there is one"""
|
||||
result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id)
|
||||
assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True)
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_get_daily_word_none(mongodb: MongoDatabase):
|
||||
"""Test getting the daily word when the database is empty"""
|
||||
result = await crud.get_daily_word(mongodb)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_get_daily_word_not_today(mongodb: MongoDatabase):
|
||||
"""Test getting the daily word when there is an entry, but not for today"""
|
||||
day = datetime.today() - timedelta(days=1)
|
||||
collection = mongodb[TemporaryStorage.collection()]
|
||||
|
||||
word = "testword"
|
||||
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word})
|
||||
|
||||
assert await crud.get_daily_word(mongodb) is None
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_get_daily_word_present(mongodb: MongoDatabase):
|
||||
"""Test getting the daily word when there is one for today"""
|
||||
day = datetime.today()
|
||||
collection = mongodb[TemporaryStorage.collection()]
|
||||
|
||||
word = "testword"
|
||||
await collection.insert_one({"key": TempStorageKey.WORDLE_WORD, "day": day, "word": word})
|
||||
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_none_present(mongodb: MongoDatabase):
|
||||
"""Test setting the daily word when there is none"""
|
||||
assert await crud.get_daily_word(mongodb) is None
|
||||
word = "testword"
|
||||
await crud.set_daily_word(mongodb, word)
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_present(mongodb: MongoDatabase):
|
||||
"""Test setting the daily word when there already is one"""
|
||||
word = "testword"
|
||||
await crud.set_daily_word(mongodb, word)
|
||||
await crud.set_daily_word(mongodb, "another word")
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_force_overwrite(mongodb: MongoDatabase):
|
||||
"""Test setting the daily word when there already is one, but "forced" is set to True"""
|
||||
word = "testword"
|
||||
await crud.set_daily_word(mongodb, word)
|
||||
word = "anotherword"
|
||||
await crud.set_daily_word(mongodb, word, forced=True)
|
||||
assert await crud.get_daily_word(mongodb) == word
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_make_wordle_guess(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int):
|
||||
"""Test making a guess in your current game"""
|
||||
guess = "guess"
|
||||
await crud.make_wordle_guess(mongodb, test_user_id, guess)
|
||||
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||
assert wordle_game.guesses == [guess]
|
||||
|
||||
other_guess = "otherguess"
|
||||
await crud.make_wordle_guess(mongodb, test_user_id, other_guess)
|
||||
wordle_game = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||
assert wordle_game.guesses == [guess, other_guess]
|
||||
|
||||
|
||||
@pytest.mark.mongo
|
||||
async def test_reset_wordle_games(mongodb: MongoDatabase, wordle_game: WordleGame, test_user_id: int):
|
||||
"""Test dropping the collection of active games"""
|
||||
assert await crud.get_active_wordle_game(mongodb, test_user_id) is not None
|
||||
await crud.reset_wordle_games(mongodb)
|
||||
assert await crud.get_active_wordle_game(mongodb, test_user_id) is None
|
|
@ -1,8 +1,10 @@
|
|||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.schemas.relational import UforaCourse
|
||||
from database.utils.caches import UforaCourseCache
|
||||
|
||||
|
||||
async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse):
|
||||
async def test_ufora_course_cache_refresh_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||
"""Test loading the data for the Ufora Course cache when it's empty"""
|
||||
cache = UforaCourseCache()
|
||||
await cache.refresh(postgres)
|
||||
|
@ -12,7 +14,7 @@ async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alia
|
|||
assert cache.aliases == {"alias": "test"}
|
||||
|
||||
|
||||
async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse):
|
||||
async def test_ufora_course_cache_refresh_not_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||
"""Test loading the data for the Ufora Course cache when it's not empty anymore"""
|
||||
cache = UforaCourseCache()
|
||||
cache.data = ["Something"]
|
||||
|
|
Loading…
Reference in New Issue