mirror of https://github.com/stijndcl/didier
WORDLE
parent
ea4181eac0
commit
db499f3742
|
@ -0,0 +1,2 @@
|
||||||
|
WORDLE_GUESS_COUNT = 6
|
||||||
|
WORDLE_WORD_LENGTH = 5
|
|
@ -1,32 +1,47 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from database.enums import TempStorageKey
|
from database.enums import TempStorageKey
|
||||||
from database.mongo_types import MongoCollection
|
from database.mongo_types import MongoDatabase
|
||||||
from database.schemas.mongo import WordleGame
|
from database.schemas.mongo import TemporaryStorage, WordleGame
|
||||||
from database.utils.datetime import today_only_date
|
from database.utils.datetime import today_only_date
|
||||||
|
|
||||||
__all__ = ["get_active_wordle_game", "make_wordle_guess", "start_new_wordle_game"]
|
__all__ = [
|
||||||
|
"get_active_wordle_game",
|
||||||
|
"make_wordle_guess",
|
||||||
|
"start_new_wordle_game",
|
||||||
|
"set_daily_word",
|
||||||
|
"reset_wordle_games",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
async def get_active_wordle_game(collection: MongoCollection, user_id: int) -> Optional[WordleGame]:
|
async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]:
|
||||||
"""Find a player's active game"""
|
"""Find a player's active game"""
|
||||||
return await collection.find_one({"user_id": user_id})
|
collection = database[WordleGame.collection()]
|
||||||
|
result = await collection.find_one({"user_id": user_id})
|
||||||
|
if result is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return WordleGame(**result)
|
||||||
|
|
||||||
|
|
||||||
async def start_new_wordle_game(collection: MongoCollection, user_id: int) -> WordleGame:
|
async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame:
|
||||||
"""Start a new game"""
|
"""Start a new game"""
|
||||||
|
collection = database[WordleGame.collection()]
|
||||||
game = WordleGame(user_id=user_id)
|
game = WordleGame(user_id=user_id)
|
||||||
await collection.insert_one(game.dict(by_alias=True))
|
await collection.insert_one(game.dict(by_alias=True))
|
||||||
return game
|
return game
|
||||||
|
|
||||||
|
|
||||||
async def make_wordle_guess(collection: MongoCollection, user_id: int, guess: str):
|
async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str):
|
||||||
"""Make a guess in your current game"""
|
"""Make a guess in your current game"""
|
||||||
|
collection = database[WordleGame.collection()]
|
||||||
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
|
await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}})
|
||||||
|
|
||||||
|
|
||||||
async def get_daily_word(collection: MongoCollection) -> Optional[str]:
|
async def get_daily_word(database: MongoDatabase) -> Optional[str]:
|
||||||
"""Get the word of today"""
|
"""Get the word of today"""
|
||||||
|
collection = database[TemporaryStorage.collection()]
|
||||||
|
|
||||||
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
|
result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()})
|
||||||
if result is None:
|
if result is None:
|
||||||
return None
|
return None
|
||||||
|
@ -34,16 +49,33 @@ async def get_daily_word(collection: MongoCollection) -> Optional[str]:
|
||||||
return result["word"]
|
return result["word"]
|
||||||
|
|
||||||
|
|
||||||
async def set_daily_word(collection: MongoCollection, word: str):
|
async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str:
|
||||||
"""Set the word of today
|
"""Set the word of today
|
||||||
|
|
||||||
This does NOT overwrite the existing word if there is one, so that it can safely run
|
This does NOT overwrite the existing word if there is one, so that it can safely run
|
||||||
on startup every time
|
on startup every time.
|
||||||
|
|
||||||
|
In order to always overwrite the current word, set the "forced"-kwarg to True.
|
||||||
|
|
||||||
|
Returns the word that was chosen. If one already existed, return that instead.
|
||||||
"""
|
"""
|
||||||
current_word = await get_daily_word(collection)
|
collection = database[TemporaryStorage.collection()]
|
||||||
|
|
||||||
|
current_word = None if forced else await get_daily_word(collection)
|
||||||
if current_word is not None:
|
if current_word is not None:
|
||||||
return
|
return current_word
|
||||||
|
|
||||||
await collection.update_one(
|
await collection.update_one(
|
||||||
{"key": TempStorageKey.WORDLE_WORD}, {"day": today_only_date(), "word": word}, upsert=True
|
{"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove all active games
|
||||||
|
await reset_wordle_games(database)
|
||||||
|
|
||||||
|
return word
|
||||||
|
|
||||||
|
|
||||||
|
async def reset_wordle_games(database: MongoDatabase):
|
||||||
|
"""Reset all active games"""
|
||||||
|
collection = database[WordleGame.collection()]
|
||||||
|
await collection.drop()
|
||||||
|
|
|
@ -4,14 +4,14 @@ from typing import Optional
|
||||||
|
|
||||||
from bson import ObjectId
|
from bson import ObjectId
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
from pydantic import BaseModel, Field, conlist
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
__all__ = ["MongoBase", "TemporaryStorage", "WordleGame"]
|
__all__ = ["MongoBase", "TemporaryStorage", "WordleGame"]
|
||||||
|
|
||||||
from database.utils.datetime import today_only_date
|
from database.utils.datetime import today_only_date
|
||||||
|
|
||||||
|
|
||||||
class PyObjectId(str):
|
class PyObjectId(ObjectId):
|
||||||
"""Custom type for bson ObjectIds"""
|
"""Custom type for bson ObjectIds"""
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
@ -71,12 +71,20 @@ class TemporaryStorage(MongoCollection):
|
||||||
class WordleStats(BaseModel):
|
class WordleStats(BaseModel):
|
||||||
"""Model that holds stats about a player's Wordle performance"""
|
"""Model that holds stats about a player's Wordle performance"""
|
||||||
|
|
||||||
guess_distribution: conlist(int, min_items=6, max_items=6) = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0])
|
||||||
last_guess: Optional[datetime.date] = None
|
last_guess: Optional[datetime.date] = None
|
||||||
win_rate: float = 0
|
win_rate: float = 0
|
||||||
current_streak: int = 0
|
current_streak: int = 0
|
||||||
max_streak: int = 0
|
max_streak: int = 0
|
||||||
|
|
||||||
|
@validator("guess_distribution")
|
||||||
|
def validate_guesses_length(cls, value: list[int]):
|
||||||
|
"""Check that the distribution of guesses is of the correct length"""
|
||||||
|
if len(value) != 6:
|
||||||
|
raise ValueError(f"guess_distribution must be length 6, found {len(value)}")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class GameStats(MongoCollection):
|
class GameStats(MongoCollection):
|
||||||
"""Collection that holds stats about how well a user has performed in games"""
|
"""Collection that holds stats about how well a user has performed in games"""
|
||||||
|
@ -94,10 +102,18 @@ class WordleGame(MongoCollection):
|
||||||
"""Collection that holds people's active Wordle games"""
|
"""Collection that holds people's active Wordle games"""
|
||||||
|
|
||||||
day: datetime.date = Field(default_factory=lambda: today_only_date())
|
day: datetime.date = Field(default_factory=lambda: today_only_date())
|
||||||
guesses: conlist(str, min_items=0, max_items=6) = Field(default_factory=list)
|
guesses: list[str] = Field(default_factory=list)
|
||||||
user_id: int
|
user_id: int
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@overrides
|
@overrides
|
||||||
def collection() -> str:
|
def collection() -> str:
|
||||||
return "wordle"
|
return "wordle"
|
||||||
|
|
||||||
|
@validator("guesses")
|
||||||
|
def validate_guesses_length(cls, value: list[int]):
|
||||||
|
"""Check that the amount of guesses is of the correct length"""
|
||||||
|
if len(value) > 6:
|
||||||
|
raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}")
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
|
@ -1,14 +1,18 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Generic, TypeVar
|
||||||
|
|
||||||
from overrides import overrides
|
from overrides import overrides
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import ufora_courses
|
from database.crud import ufora_courses, wordle
|
||||||
|
from database.mongo_types import MongoDatabase
|
||||||
|
|
||||||
__all__ = ["CacheManager"]
|
__all__ = ["CacheManager", "UforaCourseCache"]
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
class DatabaseCache(ABC):
|
class DatabaseCache(ABC, Generic[T]):
|
||||||
"""Base class for a simple cache-like structure
|
"""Base class for a simple cache-like structure
|
||||||
|
|
||||||
The goal of this class is to store data for Discord auto-completion results
|
The goal of this class is to store data for Discord auto-completion results
|
||||||
|
@ -31,10 +35,10 @@ class DatabaseCache(ABC):
|
||||||
self.data.clear()
|
self.data.clear()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def refresh(self, database_session: AsyncSession):
|
async def refresh(self, database_session: T):
|
||||||
"""Refresh the data stored in this cache"""
|
"""Refresh the data stored in this cache"""
|
||||||
|
|
||||||
async def invalidate(self, database_session: AsyncSession):
|
async def invalidate(self, database_session: T):
|
||||||
"""Invalidate the data stored in this cache"""
|
"""Invalidate the data stored in this cache"""
|
||||||
await self.refresh(database_session)
|
await self.refresh(database_session)
|
||||||
|
|
||||||
|
@ -45,7 +49,7 @@ class DatabaseCache(ABC):
|
||||||
return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
|
return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
|
||||||
|
|
||||||
|
|
||||||
class UforaCourseCache(DatabaseCache):
|
class UforaCourseCache(DatabaseCache[AsyncSession]):
|
||||||
"""Cache to store the names of Ufora courses"""
|
"""Cache to store the names of Ufora courses"""
|
||||||
|
|
||||||
# Also store the aliases to add additional support
|
# Also store the aliases to add additional support
|
||||||
|
@ -90,14 +94,26 @@ class UforaCourseCache(DatabaseCache):
|
||||||
return sorted(list(results))
|
return sorted(list(results))
|
||||||
|
|
||||||
|
|
||||||
|
class WordleCache(DatabaseCache[MongoDatabase]):
|
||||||
|
"""Cache to store the current daily Wordle word"""
|
||||||
|
|
||||||
|
async def refresh(self, database_session: MongoDatabase):
|
||||||
|
word = await wordle.get_daily_word(database_session)
|
||||||
|
if word is not None:
|
||||||
|
self.data = [word]
|
||||||
|
|
||||||
|
|
||||||
class CacheManager:
|
class CacheManager:
|
||||||
"""Class that keeps track of all caches"""
|
"""Class that keeps track of all caches"""
|
||||||
|
|
||||||
ufora_courses: UforaCourseCache
|
ufora_courses: UforaCourseCache
|
||||||
|
wordle_word: WordleCache
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.ufora_courses = UforaCourseCache()
|
self.ufora_courses = UforaCourseCache()
|
||||||
|
self.wordle_word = WordleCache()
|
||||||
|
|
||||||
async def initialize_caches(self, database_session: AsyncSession):
|
async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
|
||||||
"""Initialize the contents of all caches"""
|
"""Initialize the contents of all caches"""
|
||||||
await self.ufora_courses.refresh(database_session)
|
await self.ufora_courses.refresh(postgres_session)
|
||||||
|
await self.wordle_word.refresh(mongo_db)
|
||||||
|
|
|
@ -1,9 +1,17 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||||
|
from database.crud.wordle import (
|
||||||
|
get_active_wordle_game,
|
||||||
|
make_wordle_guess,
|
||||||
|
start_new_wordle_game,
|
||||||
|
)
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
|
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed
|
||||||
|
|
||||||
|
|
||||||
class Games(commands.Cog):
|
class Games(commands.Cog):
|
||||||
|
@ -15,11 +23,42 @@ class Games(commands.Cog):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
@app_commands.command(name="wordle", description="Play Wordle!")
|
@app_commands.command(name="wordle", description="Play Wordle!")
|
||||||
async def wordle(self, ctx: commands.Context, guess: Optional[str] = None):
|
async def wordle(self, interaction: discord.Interaction, guess: Optional[str] = None):
|
||||||
"""View your active Wordle game
|
"""View your active Wordle game
|
||||||
|
|
||||||
If an argument is provided, make a guess instead
|
If an argument is provided, make a guess instead
|
||||||
"""
|
"""
|
||||||
|
await interaction.response.defer(ephemeral=True)
|
||||||
|
|
||||||
|
# Guess is wrong length
|
||||||
|
if guess is not None and len(guess) != 0 and len(guess) != WORDLE_WORD_LENGTH:
|
||||||
|
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
|
||||||
|
return await interaction.followup.send(embed=embed)
|
||||||
|
|
||||||
|
active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id)
|
||||||
|
if active_game is None:
|
||||||
|
active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id)
|
||||||
|
|
||||||
|
# Trying to guess with a complete game
|
||||||
|
if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess:
|
||||||
|
embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed()
|
||||||
|
return await interaction.followup.send(embed=embed)
|
||||||
|
|
||||||
|
# Make a guess
|
||||||
|
if guess:
|
||||||
|
# The guess is not a real word
|
||||||
|
if guess not in self.client.wordle_words:
|
||||||
|
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
|
||||||
|
return await interaction.followup.send(embed=embed)
|
||||||
|
|
||||||
|
await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess)
|
||||||
|
|
||||||
|
# Don't re-request the game, we already have it
|
||||||
|
# just append locally
|
||||||
|
active_game.guesses.append(guess)
|
||||||
|
|
||||||
|
embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed()
|
||||||
|
await interaction.followup.send(embed=embed)
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
|
|
|
@ -9,7 +9,6 @@ from database import enums
|
||||||
from database.crud.birthdays import get_birthdays_on_day
|
from database.crud.birthdays import get_birthdays_on_day
|
||||||
from database.crud.ufora_announcements import remove_old_announcements
|
from database.crud.ufora_announcements import remove_old_announcements
|
||||||
from database.crud.wordle import set_daily_word
|
from database.crud.wordle import set_daily_word
|
||||||
from database.schemas.mongo import TemporaryStorage
|
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
from didier.data.embeds.ufora.announcements import fetch_ufora_announcements
|
||||||
from didier.decorators.tasks import timed_task
|
from didier.decorators.tasks import timed_task
|
||||||
|
@ -74,12 +73,15 @@ class Tasks(commands.Cog):
|
||||||
return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False)
|
return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False)
|
||||||
|
|
||||||
task = self._tasks[name]
|
task = self._tasks[name]
|
||||||
await task()
|
await task(forced=True)
|
||||||
|
await self.client.confirm_message(ctx.message)
|
||||||
|
|
||||||
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
|
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
|
||||||
@timed_task(enums.TaskType.BIRTHDAYS)
|
@timed_task(enums.TaskType.BIRTHDAYS)
|
||||||
async def check_birthdays(self):
|
async def check_birthdays(self, **kwargs):
|
||||||
"""Check if it's currently anyone's birthday"""
|
"""Check if it's currently anyone's birthday"""
|
||||||
|
_ = kwargs
|
||||||
|
|
||||||
now = tz_aware_now().date()
|
now = tz_aware_now().date()
|
||||||
async with self.client.postgres_session as session:
|
async with self.client.postgres_session as session:
|
||||||
birthdays = await get_birthdays_on_day(session, now)
|
birthdays = await get_birthdays_on_day(session, now)
|
||||||
|
@ -99,8 +101,10 @@ class Tasks(commands.Cog):
|
||||||
|
|
||||||
@tasks.loop(minutes=10)
|
@tasks.loop(minutes=10)
|
||||||
@timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS)
|
@timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS)
|
||||||
async def pull_ufora_announcements(self):
|
async def pull_ufora_announcements(self, **kwargs):
|
||||||
"""Task that checks for new Ufora announcements & logs them in a channel"""
|
"""Task that checks for new Ufora announcements & logs them in a channel"""
|
||||||
|
_ = kwargs
|
||||||
|
|
||||||
# In theory this shouldn't happen but just to please Mypy
|
# In theory this shouldn't happen but just to please Mypy
|
||||||
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
||||||
return
|
return
|
||||||
|
@ -123,11 +127,11 @@ class Tasks(commands.Cog):
|
||||||
await remove_old_announcements(session)
|
await remove_old_announcements(session)
|
||||||
|
|
||||||
@tasks.loop(time=DAILY_RESET_TIME)
|
@tasks.loop(time=DAILY_RESET_TIME)
|
||||||
async def reset_wordle_word(self):
|
async def reset_wordle_word(self, forced: bool = False):
|
||||||
"""Reset the daily Wordle word"""
|
"""Reset the daily Wordle word"""
|
||||||
db = self.client.mongo_db
|
db = self.client.mongo_db
|
||||||
collection = db[TemporaryStorage.collection()]
|
word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words)))
|
||||||
await set_daily_word(collection, random.choice(self.client.wordle_words))
|
self.client.database_caches.wordle_word.data = [word]
|
||||||
|
|
||||||
@reset_wordle_word.before_loop
|
@reset_wordle_word.before_loop
|
||||||
async def _before_reset_wordle_word(self):
|
async def _before_reset_wordle_word(self):
|
||||||
|
@ -145,7 +149,8 @@ class Tasks(commands.Cog):
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
"""Load the cog
|
"""Load the cog
|
||||||
|
|
||||||
Initially reset the Wordle word
|
Initially fetch the wordle word from the database, or reset it
|
||||||
|
if there hasn't been a reset yet today
|
||||||
"""
|
"""
|
||||||
cog = Tasks(client)
|
cog = Tasks(client)
|
||||||
await client.add_cog(cog)
|
await client.add_cog(cog)
|
||||||
|
|
|
@ -0,0 +1,130 @@
|
||||||
|
import enum
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
|
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||||
|
from database.schemas.mongo import WordleGame
|
||||||
|
from didier.data.embeds.base import EmbedBaseModel
|
||||||
|
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
|
||||||
|
|
||||||
|
__all__ = ["WordleEmbed", "WordleErrorEmbed"]
|
||||||
|
|
||||||
|
|
||||||
|
def footer() -> str:
|
||||||
|
"""Create the footer to put on the embed"""
|
||||||
|
today = tz_aware_now()
|
||||||
|
return f"{int_to_weekday(today.weekday())} {today.strftime('%d/%m/%Y')}"
|
||||||
|
|
||||||
|
|
||||||
|
class WordleColour(enum.IntEnum):
|
||||||
|
"""Colours for the Wordle embed"""
|
||||||
|
|
||||||
|
EMPTY = 0
|
||||||
|
WRONG_LETTER = 1
|
||||||
|
WRONG_POSITION = 2
|
||||||
|
CORRECT = 3
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordleEmbed(EmbedBaseModel):
|
||||||
|
"""Embed for a Wordle game"""
|
||||||
|
|
||||||
|
game: Optional[WordleGame]
|
||||||
|
word: str
|
||||||
|
|
||||||
|
def _letter_colour(self, guess: str, index: int) -> WordleColour:
|
||||||
|
"""Get the colour for a guess at a given position"""
|
||||||
|
if guess[index] == self.word[index]:
|
||||||
|
return WordleColour.CORRECT
|
||||||
|
|
||||||
|
wrong_letter = 0
|
||||||
|
wrong_position = 0
|
||||||
|
|
||||||
|
for i, letter in enumerate(self.word):
|
||||||
|
if letter == guess[index] and guess[i] != guess[index]:
|
||||||
|
wrong_letter += 1
|
||||||
|
|
||||||
|
if i <= index and guess[i] == guess[index] and letter != guess[index]:
|
||||||
|
wrong_position += 1
|
||||||
|
|
||||||
|
if i >= index:
|
||||||
|
if wrong_position == 0:
|
||||||
|
break
|
||||||
|
|
||||||
|
if wrong_position <= wrong_letter:
|
||||||
|
return WordleColour.WRONG_POSITION
|
||||||
|
|
||||||
|
return WordleColour.WRONG_LETTER
|
||||||
|
|
||||||
|
def _guess_colours(self, guess: str) -> list[WordleColour]:
|
||||||
|
"""Create the colour codes for a specific guess"""
|
||||||
|
return [self._letter_colour(guess, i) for i in range(WORDLE_WORD_LENGTH)]
|
||||||
|
|
||||||
|
def colour_code_game(self) -> list[list[WordleColour]]:
|
||||||
|
"""Create the colour codes for an entire game"""
|
||||||
|
colours = []
|
||||||
|
|
||||||
|
# Add all the guesses
|
||||||
|
if self.game is not None:
|
||||||
|
for guess in self.game.guesses:
|
||||||
|
colours.append(self._guess_colours(guess))
|
||||||
|
|
||||||
|
# Fill the rest with empty spots
|
||||||
|
for _ in range(WORDLE_GUESS_COUNT - len(colours)):
|
||||||
|
colours.append([WordleColour.EMPTY] * WORDLE_WORD_LENGTH)
|
||||||
|
|
||||||
|
return colours
|
||||||
|
|
||||||
|
def _colours_to_emojis(self, colours: list[list[WordleColour]]) -> list[list[str]]:
|
||||||
|
"""Turn the colours of the board into Discord emojis"""
|
||||||
|
colour_map = {
|
||||||
|
WordleColour.EMPTY: ":white_large_square:",
|
||||||
|
WordleColour.WRONG_LETTER: ":black_large_square:",
|
||||||
|
WordleColour.WRONG_POSITION: ":orange_square:",
|
||||||
|
WordleColour.CORRECT: ":green_square:",
|
||||||
|
}
|
||||||
|
|
||||||
|
emojis = []
|
||||||
|
for row in colours:
|
||||||
|
emojis.append(list(map(lambda char: colour_map[char], row)))
|
||||||
|
|
||||||
|
return emojis
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def to_embed(self) -> discord.Embed:
|
||||||
|
colours = self.colour_code_game()
|
||||||
|
|
||||||
|
embed = discord.Embed(colour=discord.Colour.blue(), title="Wordle")
|
||||||
|
emojis = self._colours_to_emojis(colours)
|
||||||
|
|
||||||
|
rows = [" ".join(row) for row in emojis]
|
||||||
|
|
||||||
|
for i, guess in enumerate(self.game.guesses):
|
||||||
|
rows[i] += f" ||{guess.upper()}||"
|
||||||
|
|
||||||
|
embed.description = "\n\n".join(rows)
|
||||||
|
|
||||||
|
# If the game is over, reveal the word
|
||||||
|
if len(self.game.guesses) == WORDLE_GUESS_COUNT or (self.game.guesses and self.game.guesses[-1] == self.word):
|
||||||
|
embed.description += f"\n\nThe word was **{self.word.upper()}**!"
|
||||||
|
|
||||||
|
embed.set_footer(text=footer())
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class WordleErrorEmbed(EmbedBaseModel):
|
||||||
|
"""Embed to send error messages to the user"""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def to_embed(self) -> discord.Embed:
|
||||||
|
embed = discord.Embed(colour=discord.Colour.red(), title="Wordle")
|
||||||
|
embed.description = self.message
|
||||||
|
embed.set_footer(text=footer())
|
||||||
|
return embed
|
|
@ -27,7 +27,7 @@ class Didier(commands.Bot):
|
||||||
error_channel: discord.abc.Messageable
|
error_channel: discord.abc.Messageable
|
||||||
initial_extensions: tuple[str, ...] = ()
|
initial_extensions: tuple[str, ...] = ()
|
||||||
http_session: ClientSession
|
http_session: ClientSession
|
||||||
wordle_words: tuple[str] = tuple()
|
wordle_words: set[str, ...] = set()
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
|
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
|
||||||
|
@ -64,14 +64,14 @@ class Didier(commands.Bot):
|
||||||
# Load the Wordle dictionary
|
# Load the Wordle dictionary
|
||||||
self._load_wordle_words()
|
self._load_wordle_words()
|
||||||
|
|
||||||
# Load extensions
|
|
||||||
await self._load_initial_extensions()
|
|
||||||
await self._load_directory_extensions("didier/cogs")
|
|
||||||
|
|
||||||
# Initialize caches
|
# Initialize caches
|
||||||
self.database_caches = CacheManager()
|
self.database_caches = CacheManager()
|
||||||
async with self.postgres_session as session:
|
async with self.postgres_session as session:
|
||||||
await self.database_caches.initialize_caches(session)
|
await self.database_caches.initialize_caches(session, self.mongo_db)
|
||||||
|
|
||||||
|
# Load extensions
|
||||||
|
await self._load_initial_extensions()
|
||||||
|
await self._load_directory_extensions("didier/cogs")
|
||||||
|
|
||||||
# Create aiohttp session
|
# Create aiohttp session
|
||||||
self.http_session = ClientSession()
|
self.http_session = ClientSession()
|
||||||
|
@ -107,13 +107,9 @@ class Didier(commands.Bot):
|
||||||
|
|
||||||
def _load_wordle_words(self):
|
def _load_wordle_words(self):
|
||||||
"""Load the dictionary of Wordle words"""
|
"""Load the dictionary of Wordle words"""
|
||||||
words = []
|
|
||||||
|
|
||||||
with open("files/dictionaries/words-english-wordle.txt", "r") as fp:
|
with open("files/dictionaries/words-english-wordle.txt", "r") as fp:
|
||||||
for line in fp:
|
for line in fp:
|
||||||
words.append(line.strip())
|
self.wordle_words.add(line.strip())
|
||||||
|
|
||||||
self.wordle_words = tuple(words)
|
|
||||||
|
|
||||||
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
||||||
"""Fetch a message from a reference"""
|
"""Fetch a message from a reference"""
|
||||||
|
|
|
@ -10,7 +10,7 @@ LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels")
|
||||||
|
|
||||||
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
|
def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this
|
||||||
"""Get the Dutch name of a weekday from the number"""
|
"""Get the Dutch name of a weekday from the number"""
|
||||||
return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number]
|
return ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"][number]
|
||||||
|
|
||||||
|
|
||||||
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
|
def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date:
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.schemas.relational import (
|
from database.schemas.relational import (
|
||||||
|
@ -22,7 +23,7 @@ def test_user_id() -> int:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def user(postgres, test_user_id) -> User:
|
async def user(postgres: AsyncSession, test_user_id: int) -> User:
|
||||||
"""Fixture to create a user"""
|
"""Fixture to create a user"""
|
||||||
_user = await users.get_or_add(postgres, test_user_id)
|
_user = await users.get_or_add(postgres, test_user_id)
|
||||||
await postgres.refresh(_user)
|
await postgres.refresh(_user)
|
||||||
|
@ -30,7 +31,7 @@ async def user(postgres, test_user_id) -> User:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def bank(postgres, user: User) -> Bank:
|
async def bank(postgres: AsyncSession, user: User) -> Bank:
|
||||||
"""Fixture to fetch the test user's bank"""
|
"""Fixture to fetch the test user's bank"""
|
||||||
_bank = user.bank
|
_bank = user.bank
|
||||||
await postgres.refresh(_bank)
|
await postgres.refresh(_bank)
|
||||||
|
@ -38,7 +39,7 @@ async def bank(postgres, user: User) -> Bank:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def ufora_course(postgres) -> UforaCourse:
|
async def ufora_course(postgres: AsyncSession) -> UforaCourse:
|
||||||
"""Fixture to create a course"""
|
"""Fixture to create a course"""
|
||||||
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
course = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
||||||
postgres.add(course)
|
postgres.add(course)
|
||||||
|
@ -47,7 +48,7 @@ async def ufora_course(postgres) -> UforaCourse:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse:
|
async def ufora_course_with_alias(postgres: AsyncSession, ufora_course: UforaCourse) -> UforaCourse:
|
||||||
"""Fixture to create a course with an alias"""
|
"""Fixture to create a course with an alias"""
|
||||||
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
|
alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias")
|
||||||
postgres.add(alias)
|
postgres.add(alias)
|
||||||
|
@ -57,7 +58,7 @@ async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaC
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement:
|
async def ufora_announcement(postgres: AsyncSession, ufora_course: UforaCourse) -> UforaAnnouncement:
|
||||||
"""Fixture to create an announcement"""
|
"""Fixture to create an announcement"""
|
||||||
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
|
announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now())
|
||||||
postgres.add(announcement)
|
postgres.add(announcement)
|
||||||
|
|
|
@ -1,13 +1,14 @@
|
||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
|
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import birthdays as crud
|
from database.crud import birthdays as crud
|
||||||
from database.crud import users
|
from database.crud import users
|
||||||
from database.schemas.relational import User
|
from database.schemas.relational import User
|
||||||
|
|
||||||
|
|
||||||
async def test_add_birthday_not_present(postgres, user: User):
|
async def test_add_birthday_not_present(postgres: AsyncSession, user: User):
|
||||||
"""Test setting a user's birthday when it doesn't exist yet"""
|
"""Test setting a user's birthday when it doesn't exist yet"""
|
||||||
assert user.birthday is None
|
assert user.birthday is None
|
||||||
|
|
||||||
|
@ -18,7 +19,7 @@ async def test_add_birthday_not_present(postgres, user: User):
|
||||||
assert user.birthday.birthday == bd_date
|
assert user.birthday.birthday == bd_date
|
||||||
|
|
||||||
|
|
||||||
async def test_add_birthday_overwrite(postgres, user: User):
|
async def test_add_birthday_overwrite(postgres: AsyncSession, user: User):
|
||||||
"""Test that setting a user's birthday when it already exists overwrites it"""
|
"""Test that setting a user's birthday when it already exists overwrites it"""
|
||||||
bd_date = datetime.today().date()
|
bd_date = datetime.today().date()
|
||||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||||
|
@ -31,7 +32,7 @@ async def test_add_birthday_overwrite(postgres, user: User):
|
||||||
assert user.birthday.birthday == new_bd_date
|
assert user.birthday.birthday == new_bd_date
|
||||||
|
|
||||||
|
|
||||||
async def test_get_birthday_exists(postgres, user: User):
|
async def test_get_birthday_exists(postgres: AsyncSession, user: User):
|
||||||
"""Test getting a user's birthday when it exists"""
|
"""Test getting a user's birthday when it exists"""
|
||||||
bd_date = datetime.today().date()
|
bd_date = datetime.today().date()
|
||||||
await crud.add_birthday(postgres, user.user_id, bd_date)
|
await crud.add_birthday(postgres, user.user_id, bd_date)
|
||||||
|
@ -42,14 +43,14 @@ async def test_get_birthday_exists(postgres, user: User):
|
||||||
assert bd.birthday == bd_date
|
assert bd.birthday == bd_date
|
||||||
|
|
||||||
|
|
||||||
async def test_get_birthday_not_exists(postgres, user: User):
|
async def test_get_birthday_not_exists(postgres: AsyncSession, user: User):
|
||||||
"""Test getting a user's birthday when it doesn't exist"""
|
"""Test getting a user's birthday when it doesn't exist"""
|
||||||
bd = await crud.get_birthday_for_user(postgres, user.user_id)
|
bd = await crud.get_birthday_for_user(postgres, user.user_id)
|
||||||
assert bd is None
|
assert bd is None
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2022/07/23")
|
@freeze_time("2022/07/23")
|
||||||
async def test_get_birthdays_on_day(postgres, user: User):
|
async def test_get_birthdays_on_day(postgres: AsyncSession, user: User):
|
||||||
"""Test getting all birthdays on a given day"""
|
"""Test getting all birthdays on a given day"""
|
||||||
await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
|
await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001))
|
||||||
|
|
||||||
|
@ -61,7 +62,7 @@ async def test_get_birthdays_on_day(postgres, user: User):
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2022/07/23")
|
@freeze_time("2022/07/23")
|
||||||
async def test_get_birthdays_none_present(postgres):
|
async def test_get_birthdays_none_present(postgres: AsyncSession):
|
||||||
"""Test getting all birthdays when there are none"""
|
"""Test getting all birthdays when there are none"""
|
||||||
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
birthdays = await crud.get_birthdays_on_day(postgres, datetime.today())
|
||||||
assert len(birthdays) == 0
|
assert len(birthdays) == 0
|
||||||
|
|
|
@ -2,13 +2,14 @@ import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import currency as crud
|
from database.crud import currency as crud
|
||||||
from database.exceptions import currency as exceptions
|
from database.exceptions import currency as exceptions
|
||||||
from database.schemas.relational import Bank
|
from database.schemas.relational import Bank
|
||||||
|
|
||||||
|
|
||||||
async def test_add_dinks(postgres, bank: Bank):
|
async def test_add_dinks(postgres: AsyncSession, bank: Bank):
|
||||||
"""Test adding dinks to an account"""
|
"""Test adding dinks to an account"""
|
||||||
assert bank.dinks == 0
|
assert bank.dinks == 0
|
||||||
await crud.add_dinks(postgres, bank.user_id, 10)
|
await crud.add_dinks(postgres, bank.user_id, 10)
|
||||||
|
@ -17,7 +18,7 @@ async def test_add_dinks(postgres, bank: Bank):
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2022/07/23")
|
@freeze_time("2022/07/23")
|
||||||
async def test_claim_nightly_available(postgres, bank: Bank):
|
async def test_claim_nightly_available(postgres: AsyncSession, bank: Bank):
|
||||||
"""Test claiming nightlies when it hasn't been done yet"""
|
"""Test claiming nightlies when it hasn't been done yet"""
|
||||||
await crud.claim_nightly(postgres, bank.user_id)
|
await crud.claim_nightly(postgres, bank.user_id)
|
||||||
await postgres.refresh(bank)
|
await postgres.refresh(bank)
|
||||||
|
@ -28,7 +29,7 @@ async def test_claim_nightly_available(postgres, bank: Bank):
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2022/07/23")
|
@freeze_time("2022/07/23")
|
||||||
async def test_claim_nightly_unavailable(postgres, bank: Bank):
|
async def test_claim_nightly_unavailable(postgres: AsyncSession, bank: Bank):
|
||||||
"""Test claiming nightlies twice in a day"""
|
"""Test claiming nightlies twice in a day"""
|
||||||
await crud.claim_nightly(postgres, bank.user_id)
|
await crud.claim_nightly(postgres, bank.user_id)
|
||||||
|
|
||||||
|
@ -39,7 +40,7 @@ async def test_claim_nightly_unavailable(postgres, bank: Bank):
|
||||||
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
assert bank.dinks == crud.NIGHTLY_AMOUNT
|
||||||
|
|
||||||
|
|
||||||
async def test_invest(postgres, bank: Bank):
|
async def test_invest(postgres: AsyncSession, bank: Bank):
|
||||||
"""Test investing some Dinks"""
|
"""Test investing some Dinks"""
|
||||||
bank.dinks = 100
|
bank.dinks = 100
|
||||||
postgres.add(bank)
|
postgres.add(bank)
|
||||||
|
@ -52,7 +53,7 @@ async def test_invest(postgres, bank: Bank):
|
||||||
assert bank.invested == 20
|
assert bank.invested == 20
|
||||||
|
|
||||||
|
|
||||||
async def test_invest_all(postgres, bank: Bank):
|
async def test_invest_all(postgres: AsyncSession, bank: Bank):
|
||||||
"""Test investing all dinks"""
|
"""Test investing all dinks"""
|
||||||
bank.dinks = 100
|
bank.dinks = 100
|
||||||
postgres.add(bank)
|
postgres.add(bank)
|
||||||
|
@ -65,7 +66,7 @@ async def test_invest_all(postgres, bank: Bank):
|
||||||
assert bank.invested == 100
|
assert bank.invested == 100
|
||||||
|
|
||||||
|
|
||||||
async def test_invest_more_than_owned(postgres, bank: Bank):
|
async def test_invest_more_than_owned(postgres: AsyncSession, bank: Bank):
|
||||||
"""Test investing more Dinks than you own"""
|
"""Test investing more Dinks than you own"""
|
||||||
bank.dinks = 100
|
bank.dinks = 100
|
||||||
postgres.add(bank)
|
postgres.add(bank)
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import custom_commands as crud
|
from database.crud import custom_commands as crud
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
|
@ -7,7 +8,7 @@ from database.exceptions.not_found import NoResultFoundException
|
||||||
from database.schemas.relational import CustomCommand
|
from database.schemas.relational import CustomCommand
|
||||||
|
|
||||||
|
|
||||||
async def test_create_command_non_existing(postgres):
|
async def test_create_command_non_existing(postgres: AsyncSession):
|
||||||
"""Test creating a new command when it doesn't exist yet"""
|
"""Test creating a new command when it doesn't exist yet"""
|
||||||
await crud.create_command(postgres, "name", "response")
|
await crud.create_command(postgres, "name", "response")
|
||||||
|
|
||||||
|
@ -16,7 +17,7 @@ async def test_create_command_non_existing(postgres):
|
||||||
assert commands[0].name == "name"
|
assert commands[0].name == "name"
|
||||||
|
|
||||||
|
|
||||||
async def test_create_command_duplicate_name(postgres):
|
async def test_create_command_duplicate_name(postgres: AsyncSession):
|
||||||
"""Test creating a command when the name already exists"""
|
"""Test creating a command when the name already exists"""
|
||||||
await crud.create_command(postgres, "name", "response")
|
await crud.create_command(postgres, "name", "response")
|
||||||
|
|
||||||
|
@ -24,7 +25,7 @@ async def test_create_command_duplicate_name(postgres):
|
||||||
await crud.create_command(postgres, "name", "other response")
|
await crud.create_command(postgres, "name", "other response")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_command_name_is_alias(postgres):
|
async def test_create_command_name_is_alias(postgres: AsyncSession):
|
||||||
"""Test creating a command when the name is taken by an alias"""
|
"""Test creating a command when the name is taken by an alias"""
|
||||||
await crud.create_command(postgres, "name", "response")
|
await crud.create_command(postgres, "name", "response")
|
||||||
await crud.create_alias(postgres, "name", "n")
|
await crud.create_alias(postgres, "name", "n")
|
||||||
|
@ -33,7 +34,7 @@ async def test_create_command_name_is_alias(postgres):
|
||||||
await crud.create_command(postgres, "n", "other response")
|
await crud.create_command(postgres, "n", "other response")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias(postgres):
|
async def test_create_alias(postgres: AsyncSession):
|
||||||
"""Test creating an alias when the name is still free"""
|
"""Test creating an alias when the name is still free"""
|
||||||
command = await crud.create_command(postgres, "name", "response")
|
command = await crud.create_command(postgres, "name", "response")
|
||||||
await crud.create_alias(postgres, command.name, "n")
|
await crud.create_alias(postgres, command.name, "n")
|
||||||
|
@ -43,13 +44,13 @@ async def test_create_alias(postgres):
|
||||||
assert command.aliases[0].alias == "n"
|
assert command.aliases[0].alias == "n"
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias_non_existing(postgres):
|
async def test_create_alias_non_existing(postgres: AsyncSession):
|
||||||
"""Test creating an alias when the command doesn't exist"""
|
"""Test creating an alias when the command doesn't exist"""
|
||||||
with pytest.raises(NoResultFoundException):
|
with pytest.raises(NoResultFoundException):
|
||||||
await crud.create_alias(postgres, "name", "alias")
|
await crud.create_alias(postgres, "name", "alias")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias_duplicate(postgres):
|
async def test_create_alias_duplicate(postgres: AsyncSession):
|
||||||
"""Test creating an alias when another alias already has this name"""
|
"""Test creating an alias when another alias already has this name"""
|
||||||
command = await crud.create_command(postgres, "name", "response")
|
command = await crud.create_command(postgres, "name", "response")
|
||||||
await crud.create_alias(postgres, command.name, "n")
|
await crud.create_alias(postgres, command.name, "n")
|
||||||
|
@ -58,7 +59,7 @@ async def test_create_alias_duplicate(postgres):
|
||||||
await crud.create_alias(postgres, command.name, "n")
|
await crud.create_alias(postgres, command.name, "n")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias_is_command(postgres):
|
async def test_create_alias_is_command(postgres: AsyncSession):
|
||||||
"""Test creating an alias when the name is taken by a command"""
|
"""Test creating an alias when the name is taken by a command"""
|
||||||
await crud.create_command(postgres, "n", "response")
|
await crud.create_command(postgres, "n", "response")
|
||||||
command = await crud.create_command(postgres, "name", "response")
|
command = await crud.create_command(postgres, "name", "response")
|
||||||
|
@ -67,7 +68,7 @@ async def test_create_alias_is_command(postgres):
|
||||||
await crud.create_alias(postgres, command.name, "n")
|
await crud.create_alias(postgres, command.name, "n")
|
||||||
|
|
||||||
|
|
||||||
async def test_create_alias_match_by_alias(postgres):
|
async def test_create_alias_match_by_alias(postgres: AsyncSession):
|
||||||
"""Test creating an alias for a command when matching the name to another alias"""
|
"""Test creating an alias for a command when matching the name to another alias"""
|
||||||
command = await crud.create_command(postgres, "name", "response")
|
command = await crud.create_command(postgres, "name", "response")
|
||||||
await crud.create_alias(postgres, command.name, "a1")
|
await crud.create_alias(postgres, command.name, "a1")
|
||||||
|
@ -75,21 +76,21 @@ async def test_create_alias_match_by_alias(postgres):
|
||||||
assert alias.command == command
|
assert alias.command == command
|
||||||
|
|
||||||
|
|
||||||
async def test_get_command_by_name_exists(postgres):
|
async def test_get_command_by_name_exists(postgres: AsyncSession):
|
||||||
"""Test getting a command by name"""
|
"""Test getting a command by name"""
|
||||||
await crud.create_command(postgres, "name", "response")
|
await crud.create_command(postgres, "name", "response")
|
||||||
command = await crud.get_command(postgres, "name")
|
command = await crud.get_command(postgres, "name")
|
||||||
assert command is not None
|
assert command is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_get_command_by_cleaned_name(postgres):
|
async def test_get_command_by_cleaned_name(postgres: AsyncSession):
|
||||||
"""Test getting a command by the cleaned version of the name"""
|
"""Test getting a command by the cleaned version of the name"""
|
||||||
command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response")
|
command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response")
|
||||||
found = await crud.get_command(postgres, "capitalizednamewithspaces")
|
found = await crud.get_command(postgres, "capitalizednamewithspaces")
|
||||||
assert command == found
|
assert command == found
|
||||||
|
|
||||||
|
|
||||||
async def test_get_command_by_alias(postgres):
|
async def test_get_command_by_alias(postgres: AsyncSession):
|
||||||
"""Test getting a command by an alias"""
|
"""Test getting a command by an alias"""
|
||||||
command = await crud.create_command(postgres, "name", "response")
|
command = await crud.create_command(postgres, "name", "response")
|
||||||
await crud.create_alias(postgres, command.name, "a1")
|
await crud.create_alias(postgres, command.name, "a1")
|
||||||
|
@ -99,12 +100,12 @@ async def test_get_command_by_alias(postgres):
|
||||||
assert command == found
|
assert command == found
|
||||||
|
|
||||||
|
|
||||||
async def test_get_command_non_existing(postgres):
|
async def test_get_command_non_existing(postgres: AsyncSession):
|
||||||
"""Test getting a command when it doesn't exist"""
|
"""Test getting a command when it doesn't exist"""
|
||||||
assert await crud.get_command(postgres, "name") is None
|
assert await crud.get_command(postgres, "name") is None
|
||||||
|
|
||||||
|
|
||||||
async def test_edit_command(postgres):
|
async def test_edit_command(postgres: AsyncSession):
|
||||||
"""Test editing an existing command"""
|
"""Test editing an existing command"""
|
||||||
command = await crud.create_command(postgres, "name", "response")
|
command = await crud.create_command(postgres, "name", "response")
|
||||||
await crud.edit_command(postgres, command.name, "new name", "new response")
|
await crud.edit_command(postgres, command.name, "new name", "new response")
|
||||||
|
@ -112,7 +113,7 @@ async def test_edit_command(postgres):
|
||||||
assert command.response == "new response"
|
assert command.response == "new response"
|
||||||
|
|
||||||
|
|
||||||
async def test_edit_command_non_existing(postgres):
|
async def test_edit_command_non_existing(postgres: AsyncSession):
|
||||||
"""Test editing a command that doesn't exist"""
|
"""Test editing a command that doesn't exist"""
|
||||||
with pytest.raises(NoResultFoundException):
|
with pytest.raises(NoResultFoundException):
|
||||||
await crud.edit_command(postgres, "name", "n", "r")
|
await crud.edit_command(postgres, "name", "n", "r")
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import dad_jokes as crud
|
from database.crud import dad_jokes as crud
|
||||||
from database.schemas.relational import DadJoke
|
from database.schemas.relational import DadJoke
|
||||||
|
|
||||||
|
|
||||||
async def test_add_dad_joke(postgres):
|
async def test_add_dad_joke(postgres: AsyncSession):
|
||||||
"""Test creating a new joke"""
|
"""Test creating a new joke"""
|
||||||
statement = select(DadJoke)
|
statement = select(DadJoke)
|
||||||
result = (await postgres.execute(statement)).scalars().all()
|
result = (await postgres.execute(statement)).scalars().all()
|
||||||
|
|
|
@ -3,6 +3,7 @@ import datetime
|
||||||
import pytest
|
import pytest
|
||||||
from freezegun import freeze_time
|
from freezegun import freeze_time
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import tasks as crud
|
from database.crud import tasks as crud
|
||||||
from database.enums import TaskType
|
from database.enums import TaskType
|
||||||
|
@ -16,7 +17,7 @@ def task_type() -> TaskType:
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
async def task(postgres, task_type: TaskType) -> Task:
|
async def task(postgres: AsyncSession, task_type: TaskType) -> Task:
|
||||||
"""Fixture to create a task"""
|
"""Fixture to create a task"""
|
||||||
task = Task(task=task_type)
|
task = Task(task=task_type)
|
||||||
postgres.add(task)
|
postgres.add(task)
|
||||||
|
@ -24,21 +25,21 @@ async def task(postgres, task_type: TaskType) -> Task:
|
||||||
return task
|
return task
|
||||||
|
|
||||||
|
|
||||||
async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType):
|
async def test_get_task_by_enum_present(postgres: AsyncSession, task: Task, task_type: TaskType):
|
||||||
"""Test getting a task by its enum type when it exists"""
|
"""Test getting a task by its enum type when it exists"""
|
||||||
result = await crud.get_task_by_enum(postgres, task_type)
|
result = await crud.get_task_by_enum(postgres, task_type)
|
||||||
assert result is not None
|
assert result is not None
|
||||||
assert result == task
|
assert result == task
|
||||||
|
|
||||||
|
|
||||||
async def test_get_task_by_enum_not_present(postgres, task_type: TaskType):
|
async def test_get_task_by_enum_not_present(postgres: AsyncSession, task_type: TaskType):
|
||||||
"""Test getting a task by its enum type when it doesn't exist"""
|
"""Test getting a task by its enum type when it doesn't exist"""
|
||||||
result = await crud.get_task_by_enum(postgres, task_type)
|
result = await crud.get_task_by_enum(postgres, task_type)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2022/07/24")
|
@freeze_time("2022/07/24")
|
||||||
async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType):
|
async def test_set_execution_time_exists(postgres: AsyncSession, task: Task, task_type: TaskType):
|
||||||
"""Test setting the execution time of an existing task"""
|
"""Test setting the execution time of an existing task"""
|
||||||
await postgres.refresh(task)
|
await postgres.refresh(task)
|
||||||
assert task.previous_run is None
|
assert task.previous_run is None
|
||||||
|
@ -49,7 +50,7 @@ async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskTy
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2022/07/24")
|
@freeze_time("2022/07/24")
|
||||||
async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType):
|
async def test_set_execution_time_doesnt_exist(postgres: AsyncSession, task_type: TaskType):
|
||||||
"""Test setting the execution time of a non-existing task"""
|
"""Test setting the execution time of a non-existing task"""
|
||||||
statement = select(Task).where(Task.task == task_type)
|
statement = select(Task).where(Task.task == task_type)
|
||||||
results = list((await postgres.execute(statement)).scalars().all())
|
results = list((await postgres.execute(statement)).scalars().all())
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import ufora_announcements as crud
|
from database.crud import ufora_announcements as crud
|
||||||
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
from database.schemas.relational import UforaAnnouncement, UforaCourse
|
||||||
|
|
||||||
|
|
||||||
async def test_get_courses_with_announcements_none(postgres):
|
async def test_get_courses_with_announcements_none(postgres: AsyncSession):
|
||||||
"""Test getting all courses with announcements when there are none"""
|
"""Test getting all courses with announcements when there are none"""
|
||||||
results = await crud.get_courses_with_announcements(postgres)
|
results = await crud.get_courses_with_announcements(postgres)
|
||||||
assert len(results) == 0
|
assert len(results) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_get_courses_with_announcements(postgres):
|
async def test_get_courses_with_announcements(postgres: AsyncSession):
|
||||||
"""Test getting all courses with announcements"""
|
"""Test getting all courses with announcements"""
|
||||||
course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True)
|
||||||
course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False)
|
course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False)
|
||||||
|
@ -22,14 +24,14 @@ async def test_get_courses_with_announcements(postgres):
|
||||||
assert results[0] == course_1
|
assert results[0] == course_1
|
||||||
|
|
||||||
|
|
||||||
async def test_create_new_announcement(ufora_course: UforaCourse, postgres):
|
async def test_create_new_announcement(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||||
"""Test creating a new announcement"""
|
"""Test creating a new announcement"""
|
||||||
await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now())
|
await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now())
|
||||||
await postgres.refresh(ufora_course)
|
await postgres.refresh(ufora_course)
|
||||||
assert len(ufora_course.announcements) == 1
|
assert len(ufora_course.announcements) == 1
|
||||||
|
|
||||||
|
|
||||||
async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres):
|
async def test_remove_old_announcements(postgres: AsyncSession, ufora_announcement: UforaAnnouncement):
|
||||||
"""Test removing all stale announcements"""
|
"""Test removing all stale announcements"""
|
||||||
course = ufora_announcement.course
|
course = ufora_announcement.course
|
||||||
ufora_announcement.publication_date -= datetime.timedelta(weeks=2)
|
ufora_announcement.publication_date -= datetime.timedelta(weeks=2)
|
||||||
|
|
|
@ -1,20 +1,22 @@
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import ufora_courses as crud
|
from database.crud import ufora_courses as crud
|
||||||
from database.schemas.relational import UforaCourse
|
from database.schemas.relational import UforaCourse
|
||||||
|
|
||||||
|
|
||||||
async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse):
|
async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||||
"""Test getting a course by its name when the query is an exact match"""
|
"""Test getting a course by its name when the query is an exact match"""
|
||||||
match = await crud.get_course_by_name(postgres, "Test")
|
match = await crud.get_course_by_name(postgres, "Test")
|
||||||
assert match == ufora_course
|
assert match == ufora_course
|
||||||
|
|
||||||
|
|
||||||
async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse):
|
async def test_get_course_by_name_substring(postgres: AsyncSession, ufora_course: UforaCourse):
|
||||||
"""Test getting a course by its name when the query is a substring"""
|
"""Test getting a course by its name when the query is a substring"""
|
||||||
match = await crud.get_course_by_name(postgres, "es")
|
match = await crud.get_course_by_name(postgres, "es")
|
||||||
assert match == ufora_course
|
assert match == ufora_course
|
||||||
|
|
||||||
|
|
||||||
async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse):
|
async def test_get_course_by_name_alias(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||||
"""Test getting a course by its name when the name doesn't match, but the alias does"""
|
"""Test getting a course by its name when the name doesn't match, but the alias does"""
|
||||||
match = await crud.get_course_by_name(postgres, "ali")
|
match = await crud.get_course_by_name(postgres, "ali")
|
||||||
assert match == ufora_course_with_alias
|
assert match == ufora_course_with_alias
|
||||||
|
|
|
@ -1,10 +1,11 @@
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.crud import users as crud
|
from database.crud import users as crud
|
||||||
from database.schemas.relational import User
|
from database.schemas.relational import User
|
||||||
|
|
||||||
|
|
||||||
async def test_get_or_add_non_existing(postgres):
|
async def test_get_or_add_non_existing(postgres: AsyncSession):
|
||||||
"""Test get_or_add for a user that doesn't exist"""
|
"""Test get_or_add for a user that doesn't exist"""
|
||||||
await crud.get_or_add(postgres, 1)
|
await crud.get_or_add(postgres, 1)
|
||||||
statement = select(User)
|
statement = select(User)
|
||||||
|
@ -15,7 +16,7 @@ async def test_get_or_add_non_existing(postgres):
|
||||||
assert res[0].nightly_data is not None
|
assert res[0].nightly_data is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_get_or_add_existing(postgres):
|
async def test_get_or_add_existing(postgres: AsyncSession):
|
||||||
"""Test get_or_add for a user that does exist"""
|
"""Test get_or_add for a user that does exist"""
|
||||||
user = await crud.get_or_add(postgres, 1)
|
user = await crud.get_or_add(postgres, 1)
|
||||||
bank = user.bank
|
bank = user.bank
|
||||||
|
|
|
@ -19,24 +19,24 @@ async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) ->
|
||||||
yield game
|
yield game
|
||||||
|
|
||||||
|
|
||||||
async def test_start_new_game(wordle_collection: MongoCollection, test_user_id: int):
|
async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int):
|
||||||
"""Test starting a new game"""
|
"""Test starting a new game"""
|
||||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
result = await wordle_collection.find_one({"user_id": test_user_id})
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
await crud.start_new_wordle_game(wordle_collection, test_user_id)
|
await crud.start_new_wordle_game(mongodb, test_user_id)
|
||||||
|
|
||||||
result = await wordle_collection.find_one({"user_id": test_user_id})
|
result = await wordle_collection.find_one({"user_id": test_user_id})
|
||||||
assert result is not None
|
assert result is not None
|
||||||
|
|
||||||
|
|
||||||
async def test_get_active_wordle_game_none(wordle_collection: MongoCollection, test_user_id: int):
|
async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int):
|
||||||
"""Test getting an active game when there is none"""
|
"""Test getting an active game when there is none"""
|
||||||
result = await crud.get_active_wordle_game(wordle_collection, test_user_id)
|
result = await crud.get_active_wordle_game(mongodb, test_user_id)
|
||||||
assert result is None
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
async def test_get_active_wordle_game(wordle_collection: MongoCollection, wordle_game: WordleGame):
|
async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame):
|
||||||
"""Test getting an active game when there is none"""
|
"""Test getting an active game when there is none"""
|
||||||
result = await crud.get_active_wordle_game(wordle_collection, wordle_game.user_id)
|
result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id)
|
||||||
assert result == wordle_game.dict(by_alias=True)
|
assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True)
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.schemas.relational import UforaCourse
|
from database.schemas.relational import UforaCourse
|
||||||
from database.utils.caches import UforaCourseCache
|
from database.utils.caches import UforaCourseCache
|
||||||
|
|
||||||
|
|
||||||
async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse):
|
async def test_ufora_course_cache_refresh_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||||
"""Test loading the data for the Ufora Course cache when it's empty"""
|
"""Test loading the data for the Ufora Course cache when it's empty"""
|
||||||
cache = UforaCourseCache()
|
cache = UforaCourseCache()
|
||||||
await cache.refresh(postgres)
|
await cache.refresh(postgres)
|
||||||
|
@ -12,7 +14,7 @@ async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alia
|
||||||
assert cache.aliases == {"alias": "test"}
|
assert cache.aliases == {"alias": "test"}
|
||||||
|
|
||||||
|
|
||||||
async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse):
|
async def test_ufora_course_cache_refresh_not_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
|
||||||
"""Test loading the data for the Ufora Course cache when it's not empty anymore"""
|
"""Test loading the data for the Ufora Course cache when it's not empty anymore"""
|
||||||
cache = UforaCourseCache()
|
cache = UforaCourseCache()
|
||||||
cache.data = ["Something"]
|
cache.data = ["Something"]
|
||||||
|
|
Loading…
Reference in New Issue