From db499f3742030501512fb45a927ca10308dfb61f Mon Sep 17 00:00:00 2001 From: stijndcl Date: Wed, 27 Jul 2022 21:10:43 +0200 Subject: [PATCH] WORDLE --- database/constants.py | 2 + database/crud/wordle.py | 58 ++++++-- database/schemas/mongo.py | 24 +++- database/utils/caches.py | 32 +++-- didier/cogs/games.py | 41 +++++- didier/cogs/tasks.py | 21 +-- didier/data/embeds/wordle.py | 130 ++++++++++++++++++ didier/didier.py | 18 +-- didier/utils/types/datetime.py | 2 +- tests/test_database/conftest.py | 11 +- .../test_database/test_crud/test_birthdays.py | 13 +- .../test_database/test_crud/test_currency.py | 13 +- .../test_crud/test_custom_commands.py | 29 ++-- .../test_database/test_crud/test_dad_jokes.py | 3 +- tests/test_database/test_crud/test_tasks.py | 11 +- .../test_crud/test_ufora_announcements.py | 10 +- .../test_crud/test_ufora_courses.py | 8 +- tests/test_database/test_crud/test_users.py | 5 +- tests/test_database/test_crud/test_wordle.py | 14 +- tests/test_database/test_utils/test_caches.py | 6 +- 20 files changed, 350 insertions(+), 101 deletions(-) create mode 100644 database/constants.py create mode 100644 didier/data/embeds/wordle.py diff --git a/database/constants.py b/database/constants.py new file mode 100644 index 0000000..0b0da00 --- /dev/null +++ b/database/constants.py @@ -0,0 +1,2 @@ +WORDLE_GUESS_COUNT = 6 +WORDLE_WORD_LENGTH = 5 diff --git a/database/crud/wordle.py b/database/crud/wordle.py index 4f4e650..9ebd75f 100644 --- a/database/crud/wordle.py +++ b/database/crud/wordle.py @@ -1,32 +1,47 @@ from typing import Optional from database.enums import TempStorageKey -from database.mongo_types import MongoCollection -from database.schemas.mongo import WordleGame +from database.mongo_types import MongoDatabase +from database.schemas.mongo import TemporaryStorage, WordleGame from database.utils.datetime import today_only_date -__all__ = ["get_active_wordle_game", "make_wordle_guess", "start_new_wordle_game"] +__all__ = [ + "get_active_wordle_game", + "make_wordle_guess", + "start_new_wordle_game", + "set_daily_word", + "reset_wordle_games", +] -async def get_active_wordle_game(collection: MongoCollection, user_id: int) -> Optional[WordleGame]: +async def get_active_wordle_game(database: MongoDatabase, user_id: int) -> Optional[WordleGame]: """Find a player's active game""" - return await collection.find_one({"user_id": user_id}) + collection = database[WordleGame.collection()] + result = await collection.find_one({"user_id": user_id}) + if result is None: + return None + + return WordleGame(**result) -async def start_new_wordle_game(collection: MongoCollection, user_id: int) -> WordleGame: +async def start_new_wordle_game(database: MongoDatabase, user_id: int) -> WordleGame: """Start a new game""" + collection = database[WordleGame.collection()] game = WordleGame(user_id=user_id) await collection.insert_one(game.dict(by_alias=True)) return game -async def make_wordle_guess(collection: MongoCollection, user_id: int, guess: str): +async def make_wordle_guess(database: MongoDatabase, user_id: int, guess: str): """Make a guess in your current game""" + collection = database[WordleGame.collection()] await collection.update_one({"user_id": user_id}, {"$push": {"guesses": guess}}) -async def get_daily_word(collection: MongoCollection) -> Optional[str]: +async def get_daily_word(database: MongoDatabase) -> Optional[str]: """Get the word of today""" + collection = database[TemporaryStorage.collection()] + result = await collection.find_one({"key": TempStorageKey.WORDLE_WORD, "day": today_only_date()}) if result is None: return None @@ -34,16 +49,33 @@ async def get_daily_word(collection: MongoCollection) -> Optional[str]: return result["word"] -async def set_daily_word(collection: MongoCollection, word: str): +async def set_daily_word(database: MongoDatabase, word: str, *, forced: bool = False) -> str: """Set the word of today This does NOT overwrite the existing word if there is one, so that it can safely run - on startup every time + on startup every time. + + In order to always overwrite the current word, set the "forced"-kwarg to True. + + Returns the word that was chosen. If one already existed, return that instead. """ - current_word = await get_daily_word(collection) + collection = database[TemporaryStorage.collection()] + + current_word = None if forced else await get_daily_word(collection) if current_word is not None: - return + return current_word await collection.update_one( - {"key": TempStorageKey.WORDLE_WORD}, {"day": today_only_date(), "word": word}, upsert=True + {"key": TempStorageKey.WORDLE_WORD}, {"$set": {"day": today_only_date(), "word": word}}, upsert=True ) + + # Remove all active games + await reset_wordle_games(database) + + return word + + +async def reset_wordle_games(database: MongoDatabase): + """Reset all active games""" + collection = database[WordleGame.collection()] + await collection.drop() diff --git a/database/schemas/mongo.py b/database/schemas/mongo.py index 90b9deb..8cbed86 100644 --- a/database/schemas/mongo.py +++ b/database/schemas/mongo.py @@ -4,14 +4,14 @@ from typing import Optional from bson import ObjectId from overrides import overrides -from pydantic import BaseModel, Field, conlist +from pydantic import BaseModel, Field, validator __all__ = ["MongoBase", "TemporaryStorage", "WordleGame"] from database.utils.datetime import today_only_date -class PyObjectId(str): +class PyObjectId(ObjectId): """Custom type for bson ObjectIds""" @classmethod @@ -71,12 +71,20 @@ class TemporaryStorage(MongoCollection): class WordleStats(BaseModel): """Model that holds stats about a player's Wordle performance""" - guess_distribution: conlist(int, min_items=6, max_items=6) = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) + guess_distribution: list[int] = Field(default_factory=lambda: [0, 0, 0, 0, 0, 0]) last_guess: Optional[datetime.date] = None win_rate: float = 0 current_streak: int = 0 max_streak: int = 0 + @validator("guess_distribution") + def validate_guesses_length(cls, value: list[int]): + """Check that the distribution of guesses is of the correct length""" + if len(value) != 6: + raise ValueError(f"guess_distribution must be length 6, found {len(value)}") + + return value + class GameStats(MongoCollection): """Collection that holds stats about how well a user has performed in games""" @@ -94,10 +102,18 @@ class WordleGame(MongoCollection): """Collection that holds people's active Wordle games""" day: datetime.date = Field(default_factory=lambda: today_only_date()) - guesses: conlist(str, min_items=0, max_items=6) = Field(default_factory=list) + guesses: list[str] = Field(default_factory=list) user_id: int @staticmethod @overrides def collection() -> str: return "wordle" + + @validator("guesses") + def validate_guesses_length(cls, value: list[int]): + """Check that the amount of guesses is of the correct length""" + if len(value) > 6: + raise ValueError(f"guess_distribution must be no longer than 6 elements, found {len(value)}") + + return value diff --git a/database/utils/caches.py b/database/utils/caches.py index edc3a5e..4e35147 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -1,14 +1,18 @@ from abc import ABC, abstractmethod +from typing import Generic, TypeVar from overrides import overrides from sqlalchemy.ext.asyncio import AsyncSession -from database.crud import ufora_courses +from database.crud import ufora_courses, wordle +from database.mongo_types import MongoDatabase -__all__ = ["CacheManager"] +__all__ = ["CacheManager", "UforaCourseCache"] + +T = TypeVar("T") -class DatabaseCache(ABC): +class DatabaseCache(ABC, Generic[T]): """Base class for a simple cache-like structure The goal of this class is to store data for Discord auto-completion results @@ -31,10 +35,10 @@ class DatabaseCache(ABC): self.data.clear() @abstractmethod - async def refresh(self, database_session: AsyncSession): + async def refresh(self, database_session: T): """Refresh the data stored in this cache""" - async def invalidate(self, database_session: AsyncSession): + async def invalidate(self, database_session: T): """Invalidate the data stored in this cache""" await self.refresh(database_session) @@ -45,7 +49,7 @@ class DatabaseCache(ABC): return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] -class UforaCourseCache(DatabaseCache): +class UforaCourseCache(DatabaseCache[AsyncSession]): """Cache to store the names of Ufora courses""" # Also store the aliases to add additional support @@ -90,14 +94,26 @@ class UforaCourseCache(DatabaseCache): return sorted(list(results)) +class WordleCache(DatabaseCache[MongoDatabase]): + """Cache to store the current daily Wordle word""" + + async def refresh(self, database_session: MongoDatabase): + word = await wordle.get_daily_word(database_session) + if word is not None: + self.data = [word] + + class CacheManager: """Class that keeps track of all caches""" ufora_courses: UforaCourseCache + wordle_word: WordleCache def __init__(self): self.ufora_courses = UforaCourseCache() + self.wordle_word = WordleCache() - async def initialize_caches(self, database_session: AsyncSession): + async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): """Initialize the contents of all caches""" - await self.ufora_courses.refresh(database_session) + await self.ufora_courses.refresh(postgres_session) + await self.wordle_word.refresh(mongo_db) diff --git a/didier/cogs/games.py b/didier/cogs/games.py index 765448f..ed7c27e 100644 --- a/didier/cogs/games.py +++ b/didier/cogs/games.py @@ -1,9 +1,17 @@ from typing import Optional +import discord from discord import app_commands from discord.ext import commands +from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH +from database.crud.wordle import ( + get_active_wordle_game, + make_wordle_guess, + start_new_wordle_game, +) from didier import Didier +from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed class Games(commands.Cog): @@ -15,11 +23,42 @@ class Games(commands.Cog): self.client = client @app_commands.command(name="wordle", description="Play Wordle!") - async def wordle(self, ctx: commands.Context, guess: Optional[str] = None): + async def wordle(self, interaction: discord.Interaction, guess: Optional[str] = None): """View your active Wordle game If an argument is provided, make a guess instead """ + await interaction.response.defer(ephemeral=True) + + # Guess is wrong length + if guess is not None and len(guess) != 0 and len(guess) != WORDLE_WORD_LENGTH: + embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed() + return await interaction.followup.send(embed=embed) + + active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id) + if active_game is None: + active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id) + + # Trying to guess with a complete game + if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess: + embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed() + return await interaction.followup.send(embed=embed) + + # Make a guess + if guess: + # The guess is not a real word + if guess not in self.client.wordle_words: + embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed() + return await interaction.followup.send(embed=embed) + + await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess) + + # Don't re-request the game, we already have it + # just append locally + active_game.guesses.append(guess) + + embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed() + await interaction.followup.send(embed=embed) async def setup(client: Didier): diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index d37990e..836b28c 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -9,7 +9,6 @@ from database import enums from database.crud.birthdays import get_birthdays_on_day from database.crud.ufora_announcements import remove_old_announcements from database.crud.wordle import set_daily_word -from database.schemas.mongo import TemporaryStorage from didier import Didier from didier.data.embeds.ufora.announcements import fetch_ufora_announcements from didier.decorators.tasks import timed_task @@ -74,12 +73,15 @@ class Tasks(commands.Cog): return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False) task = self._tasks[name] - await task() + await task(forced=True) + await self.client.confirm_message(ctx.message) @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) @timed_task(enums.TaskType.BIRTHDAYS) - async def check_birthdays(self): + async def check_birthdays(self, **kwargs): """Check if it's currently anyone's birthday""" + _ = kwargs + now = tz_aware_now().date() async with self.client.postgres_session as session: birthdays = await get_birthdays_on_day(session, now) @@ -99,8 +101,10 @@ class Tasks(commands.Cog): @tasks.loop(minutes=10) @timed_task(enums.TaskType.UFORA_ANNOUNCEMENTS) - async def pull_ufora_announcements(self): + async def pull_ufora_announcements(self, **kwargs): """Task that checks for new Ufora announcements & logs them in a channel""" + _ = kwargs + # In theory this shouldn't happen but just to please Mypy if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: return @@ -123,11 +127,11 @@ class Tasks(commands.Cog): await remove_old_announcements(session) @tasks.loop(time=DAILY_RESET_TIME) - async def reset_wordle_word(self): + async def reset_wordle_word(self, forced: bool = False): """Reset the daily Wordle word""" db = self.client.mongo_db - collection = db[TemporaryStorage.collection()] - await set_daily_word(collection, random.choice(self.client.wordle_words)) + word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words))) + self.client.database_caches.wordle_word.data = [word] @reset_wordle_word.before_loop async def _before_reset_wordle_word(self): @@ -145,7 +149,8 @@ class Tasks(commands.Cog): async def setup(client: Didier): """Load the cog - Initially reset the Wordle word + Initially fetch the wordle word from the database, or reset it + if there hasn't been a reset yet today """ cog = Tasks(client) await client.add_cog(cog) diff --git a/didier/data/embeds/wordle.py b/didier/data/embeds/wordle.py new file mode 100644 index 0000000..d012550 --- /dev/null +++ b/didier/data/embeds/wordle.py @@ -0,0 +1,130 @@ +import enum +from dataclasses import dataclass +from typing import Optional + +import discord +from overrides import overrides + +from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH +from database.schemas.mongo import WordleGame +from didier.data.embeds.base import EmbedBaseModel +from didier.utils.types.datetime import int_to_weekday, tz_aware_now + +__all__ = ["WordleEmbed", "WordleErrorEmbed"] + + +def footer() -> str: + """Create the footer to put on the embed""" + today = tz_aware_now() + return f"{int_to_weekday(today.weekday())} {today.strftime('%d/%m/%Y')}" + + +class WordleColour(enum.IntEnum): + """Colours for the Wordle embed""" + + EMPTY = 0 + WRONG_LETTER = 1 + WRONG_POSITION = 2 + CORRECT = 3 + + +@dataclass +class WordleEmbed(EmbedBaseModel): + """Embed for a Wordle game""" + + game: Optional[WordleGame] + word: str + + def _letter_colour(self, guess: str, index: int) -> WordleColour: + """Get the colour for a guess at a given position""" + if guess[index] == self.word[index]: + return WordleColour.CORRECT + + wrong_letter = 0 + wrong_position = 0 + + for i, letter in enumerate(self.word): + if letter == guess[index] and guess[i] != guess[index]: + wrong_letter += 1 + + if i <= index and guess[i] == guess[index] and letter != guess[index]: + wrong_position += 1 + + if i >= index: + if wrong_position == 0: + break + + if wrong_position <= wrong_letter: + return WordleColour.WRONG_POSITION + + return WordleColour.WRONG_LETTER + + def _guess_colours(self, guess: str) -> list[WordleColour]: + """Create the colour codes for a specific guess""" + return [self._letter_colour(guess, i) for i in range(WORDLE_WORD_LENGTH)] + + def colour_code_game(self) -> list[list[WordleColour]]: + """Create the colour codes for an entire game""" + colours = [] + + # Add all the guesses + if self.game is not None: + for guess in self.game.guesses: + colours.append(self._guess_colours(guess)) + + # Fill the rest with empty spots + for _ in range(WORDLE_GUESS_COUNT - len(colours)): + colours.append([WordleColour.EMPTY] * WORDLE_WORD_LENGTH) + + return colours + + def _colours_to_emojis(self, colours: list[list[WordleColour]]) -> list[list[str]]: + """Turn the colours of the board into Discord emojis""" + colour_map = { + WordleColour.EMPTY: ":white_large_square:", + WordleColour.WRONG_LETTER: ":black_large_square:", + WordleColour.WRONG_POSITION: ":orange_square:", + WordleColour.CORRECT: ":green_square:", + } + + emojis = [] + for row in colours: + emojis.append(list(map(lambda char: colour_map[char], row))) + + return emojis + + @overrides + def to_embed(self) -> discord.Embed: + colours = self.colour_code_game() + + embed = discord.Embed(colour=discord.Colour.blue(), title="Wordle") + emojis = self._colours_to_emojis(colours) + + rows = [" ".join(row) for row in emojis] + + for i, guess in enumerate(self.game.guesses): + rows[i] += f" ||{guess.upper()}||" + + embed.description = "\n\n".join(rows) + + # If the game is over, reveal the word + if len(self.game.guesses) == WORDLE_GUESS_COUNT or (self.game.guesses and self.game.guesses[-1] == self.word): + embed.description += f"\n\nThe word was **{self.word.upper()}**!" + + embed.set_footer(text=footer()) + + return embed + + +@dataclass +class WordleErrorEmbed(EmbedBaseModel): + """Embed to send error messages to the user""" + + message: str + + @overrides + def to_embed(self) -> discord.Embed: + embed = discord.Embed(colour=discord.Colour.red(), title="Wordle") + embed.description = self.message + embed.set_footer(text=footer()) + return embed diff --git a/didier/didier.py b/didier/didier.py index ea0ac41..fc57a83 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -27,7 +27,7 @@ class Didier(commands.Bot): error_channel: discord.abc.Messageable initial_extensions: tuple[str, ...] = () http_session: ClientSession - wordle_words: tuple[str] = tuple() + wordle_words: set[str, ...] = set() def __init__(self): activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE) @@ -64,14 +64,14 @@ class Didier(commands.Bot): # Load the Wordle dictionary self._load_wordle_words() - # Load extensions - await self._load_initial_extensions() - await self._load_directory_extensions("didier/cogs") - # Initialize caches self.database_caches = CacheManager() async with self.postgres_session as session: - await self.database_caches.initialize_caches(session) + await self.database_caches.initialize_caches(session, self.mongo_db) + + # Load extensions + await self._load_initial_extensions() + await self._load_directory_extensions("didier/cogs") # Create aiohttp session self.http_session = ClientSession() @@ -107,13 +107,9 @@ class Didier(commands.Bot): def _load_wordle_words(self): """Load the dictionary of Wordle words""" - words = [] - with open("files/dictionaries/words-english-wordle.txt", "r") as fp: for line in fp: - words.append(line.strip()) - - self.wordle_words = tuple(words) + self.wordle_words.add(line.strip()) async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: """Fetch a message from a reference""" diff --git a/didier/utils/types/datetime.py b/didier/utils/types/datetime.py index 42f58a9..7b2d5c1 100644 --- a/didier/utils/types/datetime.py +++ b/didier/utils/types/datetime.py @@ -10,7 +10,7 @@ LOCAL_TIMEZONE = zoneinfo.ZoneInfo("Europe/Brussels") def int_to_weekday(number: int) -> str: # pragma: no cover # it's useless to write a test for this """Get the Dutch name of a weekday from the number""" - return ["Maandag", "Dinsdag", "Woensdag", "Donderdag", "Vrijdag", "Zaterdag", "Zondag"][number] + return ["Monday", "Tuesday", "Wednesday", "Thursday", "Friday", "Saturday", "Sunday"][number] def str_to_date(date_str: str, formats: Union[list[str], str] = "%d/%m/%Y") -> datetime.date: diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index b2556c4..dc5ce2a 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -1,6 +1,7 @@ import datetime import pytest +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users from database.schemas.relational import ( @@ -22,7 +23,7 @@ def test_user_id() -> int: @pytest.fixture -async def user(postgres, test_user_id) -> User: +async def user(postgres: AsyncSession, test_user_id: int) -> User: """Fixture to create a user""" _user = await users.get_or_add(postgres, test_user_id) await postgres.refresh(_user) @@ -30,7 +31,7 @@ async def user(postgres, test_user_id) -> User: @pytest.fixture -async def bank(postgres, user: User) -> Bank: +async def bank(postgres: AsyncSession, user: User) -> Bank: """Fixture to fetch the test user's bank""" _bank = user.bank await postgres.refresh(_bank) @@ -38,7 +39,7 @@ async def bank(postgres, user: User) -> Bank: @pytest.fixture -async def ufora_course(postgres) -> UforaCourse: +async def ufora_course(postgres: AsyncSession) -> UforaCourse: """Fixture to create a course""" course = UforaCourse(name="test", code="code", year=1, log_announcements=True) postgres.add(course) @@ -47,7 +48,7 @@ async def ufora_course(postgres) -> UforaCourse: @pytest.fixture -async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaCourse: +async def ufora_course_with_alias(postgres: AsyncSession, ufora_course: UforaCourse) -> UforaCourse: """Fixture to create a course with an alias""" alias = UforaCourseAlias(course_id=ufora_course.course_id, alias="alias") postgres.add(alias) @@ -57,7 +58,7 @@ async def ufora_course_with_alias(postgres, ufora_course: UforaCourse) -> UforaC @pytest.fixture -async def ufora_announcement(ufora_course: UforaCourse, postgres) -> UforaAnnouncement: +async def ufora_announcement(postgres: AsyncSession, ufora_course: UforaCourse) -> UforaAnnouncement: """Fixture to create an announcement""" announcement = UforaAnnouncement(course_id=ufora_course.course_id, publication_date=datetime.datetime.now()) postgres.add(announcement) diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 21639b1..e7f2242 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -1,13 +1,14 @@ from datetime import datetime, timedelta from freezegun import freeze_time +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import birthdays as crud from database.crud import users from database.schemas.relational import User -async def test_add_birthday_not_present(postgres, user: User): +async def test_add_birthday_not_present(postgres: AsyncSession, user: User): """Test setting a user's birthday when it doesn't exist yet""" assert user.birthday is None @@ -18,7 +19,7 @@ async def test_add_birthday_not_present(postgres, user: User): assert user.birthday.birthday == bd_date -async def test_add_birthday_overwrite(postgres, user: User): +async def test_add_birthday_overwrite(postgres: AsyncSession, user: User): """Test that setting a user's birthday when it already exists overwrites it""" bd_date = datetime.today().date() await crud.add_birthday(postgres, user.user_id, bd_date) @@ -31,7 +32,7 @@ async def test_add_birthday_overwrite(postgres, user: User): assert user.birthday.birthday == new_bd_date -async def test_get_birthday_exists(postgres, user: User): +async def test_get_birthday_exists(postgres: AsyncSession, user: User): """Test getting a user's birthday when it exists""" bd_date = datetime.today().date() await crud.add_birthday(postgres, user.user_id, bd_date) @@ -42,14 +43,14 @@ async def test_get_birthday_exists(postgres, user: User): assert bd.birthday == bd_date -async def test_get_birthday_not_exists(postgres, user: User): +async def test_get_birthday_not_exists(postgres: AsyncSession, user: User): """Test getting a user's birthday when it doesn't exist""" bd = await crud.get_birthday_for_user(postgres, user.user_id) assert bd is None @freeze_time("2022/07/23") -async def test_get_birthdays_on_day(postgres, user: User): +async def test_get_birthdays_on_day(postgres: AsyncSession, user: User): """Test getting all birthdays on a given day""" await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001)) @@ -61,7 +62,7 @@ async def test_get_birthdays_on_day(postgres, user: User): @freeze_time("2022/07/23") -async def test_get_birthdays_none_present(postgres): +async def test_get_birthdays_none_present(postgres: AsyncSession): """Test getting all birthdays when there are none""" birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 0 diff --git a/tests/test_database/test_crud/test_currency.py b/tests/test_database/test_crud/test_currency.py index e5cdc0c..8bd7e8f 100644 --- a/tests/test_database/test_crud/test_currency.py +++ b/tests/test_database/test_crud/test_currency.py @@ -2,13 +2,14 @@ import datetime import pytest from freezegun import freeze_time +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import currency as crud from database.exceptions import currency as exceptions from database.schemas.relational import Bank -async def test_add_dinks(postgres, bank: Bank): +async def test_add_dinks(postgres: AsyncSession, bank: Bank): """Test adding dinks to an account""" assert bank.dinks == 0 await crud.add_dinks(postgres, bank.user_id, 10) @@ -17,7 +18,7 @@ async def test_add_dinks(postgres, bank: Bank): @freeze_time("2022/07/23") -async def test_claim_nightly_available(postgres, bank: Bank): +async def test_claim_nightly_available(postgres: AsyncSession, bank: Bank): """Test claiming nightlies when it hasn't been done yet""" await crud.claim_nightly(postgres, bank.user_id) await postgres.refresh(bank) @@ -28,7 +29,7 @@ async def test_claim_nightly_available(postgres, bank: Bank): @freeze_time("2022/07/23") -async def test_claim_nightly_unavailable(postgres, bank: Bank): +async def test_claim_nightly_unavailable(postgres: AsyncSession, bank: Bank): """Test claiming nightlies twice in a day""" await crud.claim_nightly(postgres, bank.user_id) @@ -39,7 +40,7 @@ async def test_claim_nightly_unavailable(postgres, bank: Bank): assert bank.dinks == crud.NIGHTLY_AMOUNT -async def test_invest(postgres, bank: Bank): +async def test_invest(postgres: AsyncSession, bank: Bank): """Test investing some Dinks""" bank.dinks = 100 postgres.add(bank) @@ -52,7 +53,7 @@ async def test_invest(postgres, bank: Bank): assert bank.invested == 20 -async def test_invest_all(postgres, bank: Bank): +async def test_invest_all(postgres: AsyncSession, bank: Bank): """Test investing all dinks""" bank.dinks = 100 postgres.add(bank) @@ -65,7 +66,7 @@ async def test_invest_all(postgres, bank: Bank): assert bank.invested == 100 -async def test_invest_more_than_owned(postgres, bank: Bank): +async def test_invest_more_than_owned(postgres: AsyncSession, bank: Bank): """Test investing more Dinks than you own""" bank.dinks = 100 postgres.add(bank) diff --git a/tests/test_database/test_crud/test_custom_commands.py b/tests/test_database/test_crud/test_custom_commands.py index 88810d4..6f141bc 100644 --- a/tests/test_database/test_crud/test_custom_commands.py +++ b/tests/test_database/test_crud/test_custom_commands.py @@ -1,5 +1,6 @@ import pytest from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import custom_commands as crud from database.exceptions.constraints import DuplicateInsertException @@ -7,7 +8,7 @@ from database.exceptions.not_found import NoResultFoundException from database.schemas.relational import CustomCommand -async def test_create_command_non_existing(postgres): +async def test_create_command_non_existing(postgres: AsyncSession): """Test creating a new command when it doesn't exist yet""" await crud.create_command(postgres, "name", "response") @@ -16,7 +17,7 @@ async def test_create_command_non_existing(postgres): assert commands[0].name == "name" -async def test_create_command_duplicate_name(postgres): +async def test_create_command_duplicate_name(postgres: AsyncSession): """Test creating a command when the name already exists""" await crud.create_command(postgres, "name", "response") @@ -24,7 +25,7 @@ async def test_create_command_duplicate_name(postgres): await crud.create_command(postgres, "name", "other response") -async def test_create_command_name_is_alias(postgres): +async def test_create_command_name_is_alias(postgres: AsyncSession): """Test creating a command when the name is taken by an alias""" await crud.create_command(postgres, "name", "response") await crud.create_alias(postgres, "name", "n") @@ -33,7 +34,7 @@ async def test_create_command_name_is_alias(postgres): await crud.create_command(postgres, "n", "other response") -async def test_create_alias(postgres): +async def test_create_alias(postgres: AsyncSession): """Test creating an alias when the name is still free""" command = await crud.create_command(postgres, "name", "response") await crud.create_alias(postgres, command.name, "n") @@ -43,13 +44,13 @@ async def test_create_alias(postgres): assert command.aliases[0].alias == "n" -async def test_create_alias_non_existing(postgres): +async def test_create_alias_non_existing(postgres: AsyncSession): """Test creating an alias when the command doesn't exist""" with pytest.raises(NoResultFoundException): await crud.create_alias(postgres, "name", "alias") -async def test_create_alias_duplicate(postgres): +async def test_create_alias_duplicate(postgres: AsyncSession): """Test creating an alias when another alias already has this name""" command = await crud.create_command(postgres, "name", "response") await crud.create_alias(postgres, command.name, "n") @@ -58,7 +59,7 @@ async def test_create_alias_duplicate(postgres): await crud.create_alias(postgres, command.name, "n") -async def test_create_alias_is_command(postgres): +async def test_create_alias_is_command(postgres: AsyncSession): """Test creating an alias when the name is taken by a command""" await crud.create_command(postgres, "n", "response") command = await crud.create_command(postgres, "name", "response") @@ -67,7 +68,7 @@ async def test_create_alias_is_command(postgres): await crud.create_alias(postgres, command.name, "n") -async def test_create_alias_match_by_alias(postgres): +async def test_create_alias_match_by_alias(postgres: AsyncSession): """Test creating an alias for a command when matching the name to another alias""" command = await crud.create_command(postgres, "name", "response") await crud.create_alias(postgres, command.name, "a1") @@ -75,21 +76,21 @@ async def test_create_alias_match_by_alias(postgres): assert alias.command == command -async def test_get_command_by_name_exists(postgres): +async def test_get_command_by_name_exists(postgres: AsyncSession): """Test getting a command by name""" await crud.create_command(postgres, "name", "response") command = await crud.get_command(postgres, "name") assert command is not None -async def test_get_command_by_cleaned_name(postgres): +async def test_get_command_by_cleaned_name(postgres: AsyncSession): """Test getting a command by the cleaned version of the name""" command = await crud.create_command(postgres, "CAPITALIZED NAME WITH SPACES", "response") found = await crud.get_command(postgres, "capitalizednamewithspaces") assert command == found -async def test_get_command_by_alias(postgres): +async def test_get_command_by_alias(postgres: AsyncSession): """Test getting a command by an alias""" command = await crud.create_command(postgres, "name", "response") await crud.create_alias(postgres, command.name, "a1") @@ -99,12 +100,12 @@ async def test_get_command_by_alias(postgres): assert command == found -async def test_get_command_non_existing(postgres): +async def test_get_command_non_existing(postgres: AsyncSession): """Test getting a command when it doesn't exist""" assert await crud.get_command(postgres, "name") is None -async def test_edit_command(postgres): +async def test_edit_command(postgres: AsyncSession): """Test editing an existing command""" command = await crud.create_command(postgres, "name", "response") await crud.edit_command(postgres, command.name, "new name", "new response") @@ -112,7 +113,7 @@ async def test_edit_command(postgres): assert command.response == "new response" -async def test_edit_command_non_existing(postgres): +async def test_edit_command_non_existing(postgres: AsyncSession): """Test editing a command that doesn't exist""" with pytest.raises(NoResultFoundException): await crud.edit_command(postgres, "name", "n", "r") diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py index 22c28c2..f34d0fa 100644 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ b/tests/test_database/test_crud/test_dad_jokes.py @@ -1,10 +1,11 @@ from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import dad_jokes as crud from database.schemas.relational import DadJoke -async def test_add_dad_joke(postgres): +async def test_add_dad_joke(postgres: AsyncSession): """Test creating a new joke""" statement = select(DadJoke) result = (await postgres.execute(statement)).scalars().all() diff --git a/tests/test_database/test_crud/test_tasks.py b/tests/test_database/test_crud/test_tasks.py index c4c7ba0..b13b221 100644 --- a/tests/test_database/test_crud/test_tasks.py +++ b/tests/test_database/test_crud/test_tasks.py @@ -3,6 +3,7 @@ import datetime import pytest from freezegun import freeze_time from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import tasks as crud from database.enums import TaskType @@ -16,7 +17,7 @@ def task_type() -> TaskType: @pytest.fixture -async def task(postgres, task_type: TaskType) -> Task: +async def task(postgres: AsyncSession, task_type: TaskType) -> Task: """Fixture to create a task""" task = Task(task=task_type) postgres.add(task) @@ -24,21 +25,21 @@ async def task(postgres, task_type: TaskType) -> Task: return task -async def test_get_task_by_enum_present(postgres, task: Task, task_type: TaskType): +async def test_get_task_by_enum_present(postgres: AsyncSession, task: Task, task_type: TaskType): """Test getting a task by its enum type when it exists""" result = await crud.get_task_by_enum(postgres, task_type) assert result is not None assert result == task -async def test_get_task_by_enum_not_present(postgres, task_type: TaskType): +async def test_get_task_by_enum_not_present(postgres: AsyncSession, task_type: TaskType): """Test getting a task by its enum type when it doesn't exist""" result = await crud.get_task_by_enum(postgres, task_type) assert result is None @freeze_time("2022/07/24") -async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskType): +async def test_set_execution_time_exists(postgres: AsyncSession, task: Task, task_type: TaskType): """Test setting the execution time of an existing task""" await postgres.refresh(task) assert task.previous_run is None @@ -49,7 +50,7 @@ async def test_set_execution_time_exists(postgres, task: Task, task_type: TaskTy @freeze_time("2022/07/24") -async def test_set_execution_time_doesnt_exist(postgres, task_type: TaskType): +async def test_set_execution_time_doesnt_exist(postgres: AsyncSession, task_type: TaskType): """Test setting the execution time of a non-existing task""" statement = select(Task).where(Task.task == task_type) results = list((await postgres.execute(statement)).scalars().all()) diff --git a/tests/test_database/test_crud/test_ufora_announcements.py b/tests/test_database/test_crud/test_ufora_announcements.py index 1aa45ee..34f4222 100644 --- a/tests/test_database/test_crud/test_ufora_announcements.py +++ b/tests/test_database/test_crud/test_ufora_announcements.py @@ -1,16 +1,18 @@ import datetime +from sqlalchemy.ext.asyncio import AsyncSession + from database.crud import ufora_announcements as crud from database.schemas.relational import UforaAnnouncement, UforaCourse -async def test_get_courses_with_announcements_none(postgres): +async def test_get_courses_with_announcements_none(postgres: AsyncSession): """Test getting all courses with announcements when there are none""" results = await crud.get_courses_with_announcements(postgres) assert len(results) == 0 -async def test_get_courses_with_announcements(postgres): +async def test_get_courses_with_announcements(postgres: AsyncSession): """Test getting all courses with announcements""" course_1 = UforaCourse(name="test", code="code", year=1, log_announcements=True) course_2 = UforaCourse(name="test2", code="code2", year=1, log_announcements=False) @@ -22,14 +24,14 @@ async def test_get_courses_with_announcements(postgres): assert results[0] == course_1 -async def test_create_new_announcement(ufora_course: UforaCourse, postgres): +async def test_create_new_announcement(postgres: AsyncSession, ufora_course: UforaCourse): """Test creating a new announcement""" await crud.create_new_announcement(postgres, 1, course=ufora_course, publication_date=datetime.datetime.now()) await postgres.refresh(ufora_course) assert len(ufora_course.announcements) == 1 -async def test_remove_old_announcements(ufora_announcement: UforaAnnouncement, postgres): +async def test_remove_old_announcements(postgres: AsyncSession, ufora_announcement: UforaAnnouncement): """Test removing all stale announcements""" course = ufora_announcement.course ufora_announcement.publication_date -= datetime.timedelta(weeks=2) diff --git a/tests/test_database/test_crud/test_ufora_courses.py b/tests/test_database/test_crud/test_ufora_courses.py index 34748c0..140bc4a 100644 --- a/tests/test_database/test_crud/test_ufora_courses.py +++ b/tests/test_database/test_crud/test_ufora_courses.py @@ -1,20 +1,22 @@ +from sqlalchemy.ext.asyncio import AsyncSession + from database.crud import ufora_courses as crud from database.schemas.relational import UforaCourse -async def test_get_course_by_name_exact(postgres, ufora_course: UforaCourse): +async def test_get_course_by_name_exact(postgres: AsyncSession, ufora_course: UforaCourse): """Test getting a course by its name when the query is an exact match""" match = await crud.get_course_by_name(postgres, "Test") assert match == ufora_course -async def test_get_course_by_name_substring(postgres, ufora_course: UforaCourse): +async def test_get_course_by_name_substring(postgres: AsyncSession, ufora_course: UforaCourse): """Test getting a course by its name when the query is a substring""" match = await crud.get_course_by_name(postgres, "es") assert match == ufora_course -async def test_get_course_by_name_alias(postgres, ufora_course_with_alias: UforaCourse): +async def test_get_course_by_name_alias(postgres: AsyncSession, ufora_course_with_alias: UforaCourse): """Test getting a course by its name when the name doesn't match, but the alias does""" match = await crud.get_course_by_name(postgres, "ali") assert match == ufora_course_with_alias diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index e852298..96d3383 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -1,10 +1,11 @@ from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession from database.crud import users as crud from database.schemas.relational import User -async def test_get_or_add_non_existing(postgres): +async def test_get_or_add_non_existing(postgres: AsyncSession): """Test get_or_add for a user that doesn't exist""" await crud.get_or_add(postgres, 1) statement = select(User) @@ -15,7 +16,7 @@ async def test_get_or_add_non_existing(postgres): assert res[0].nightly_data is not None -async def test_get_or_add_existing(postgres): +async def test_get_or_add_existing(postgres: AsyncSession): """Test get_or_add for a user that does exist""" user = await crud.get_or_add(postgres, 1) bank = user.bank diff --git a/tests/test_database/test_crud/test_wordle.py b/tests/test_database/test_crud/test_wordle.py index 382bb62..a1720de 100644 --- a/tests/test_database/test_crud/test_wordle.py +++ b/tests/test_database/test_crud/test_wordle.py @@ -19,24 +19,24 @@ async def wordle_game(wordle_collection: MongoCollection, test_user_id: int) -> yield game -async def test_start_new_game(wordle_collection: MongoCollection, test_user_id: int): +async def test_start_new_game(mongodb: MongoDatabase, wordle_collection: MongoCollection, test_user_id: int): """Test starting a new game""" result = await wordle_collection.find_one({"user_id": test_user_id}) assert result is None - await crud.start_new_wordle_game(wordle_collection, test_user_id) + await crud.start_new_wordle_game(mongodb, test_user_id) result = await wordle_collection.find_one({"user_id": test_user_id}) assert result is not None -async def test_get_active_wordle_game_none(wordle_collection: MongoCollection, test_user_id: int): +async def test_get_active_wordle_game_none(mongodb: MongoDatabase, test_user_id: int): """Test getting an active game when there is none""" - result = await crud.get_active_wordle_game(wordle_collection, test_user_id) + result = await crud.get_active_wordle_game(mongodb, test_user_id) assert result is None -async def test_get_active_wordle_game(wordle_collection: MongoCollection, wordle_game: WordleGame): +async def test_get_active_wordle_game(mongodb: MongoDatabase, wordle_game: WordleGame): """Test getting an active game when there is none""" - result = await crud.get_active_wordle_game(wordle_collection, wordle_game.user_id) - assert result == wordle_game.dict(by_alias=True) + result = await crud.get_active_wordle_game(mongodb, wordle_game.user_id) + assert result.dict(by_alias=True) == wordle_game.dict(by_alias=True) diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index 69a6ff2..b613737 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -1,8 +1,10 @@ +from sqlalchemy.ext.asyncio import AsyncSession + from database.schemas.relational import UforaCourse from database.utils.caches import UforaCourseCache -async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alias: UforaCourse): +async def test_ufora_course_cache_refresh_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's empty""" cache = UforaCourseCache() await cache.refresh(postgres) @@ -12,7 +14,7 @@ async def test_ufora_course_cache_refresh_empty(postgres, ufora_course_with_alia assert cache.aliases == {"alias": "test"} -async def test_ufora_course_cache_refresh_not_empty(postgres, ufora_course_with_alias: UforaCourse): +async def test_ufora_course_cache_refresh_not_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse): """Test loading the data for the Ufora Course cache when it's not empty anymore""" cache = UforaCourseCache() cache.data = ["Something"]