diff --git a/database/crud/birthdays.py b/database/crud/birthdays.py index 229ef89..7ffa308 100644 --- a/database/crud/birthdays.py +++ b/database/crud/birthdays.py @@ -17,7 +17,7 @@ async def add_birthday(session: AsyncSession, user_id: int, birthday: date): If already present, overwrites the existing one """ - user = await users.get_or_add(session, user_id, options=[selectinload(User.birthday)]) + user = await users.get_or_add_user(session, user_id, options=[selectinload(User.birthday)]) if user.birthday is not None: bd = user.birthday diff --git a/database/crud/currency.py b/database/crud/currency.py index f720c69..da0ff84 100644 --- a/database/crud/currency.py +++ b/database/crud/currency.py @@ -29,13 +29,13 @@ NIGHTLY_AMOUNT = 420 async def get_bank(session: AsyncSession, user_id: int) -> Bank: """Get a user's bank info""" - user = await users.get_or_add(session, user_id) + user = await users.get_or_add_user(session, user_id) return user.bank async def get_nightly_data(session: AsyncSession, user_id: int) -> NightlyData: """Get a user's nightly info""" - user = await users.get_or_add(session, user_id) + user = await users.get_or_add_user(session, user_id) return user.nightly_data diff --git a/database/crud/users.py b/database/crud/users.py index 8f885b6..bd4f2ad 100644 --- a/database/crud/users.py +++ b/database/crud/users.py @@ -6,11 +6,11 @@ from sqlalchemy.ext.asyncio import AsyncSession from database.schemas import Bank, NightlyData, User __all__ = [ - "get_or_add", + "get_or_add_user", ] -async def get_or_add(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: +async def get_or_add_user(session: AsyncSession, user_id: int, *, options: Optional[list] = None) -> User: """Get a user's profile If it doesn't exist yet, create it (along with all linked datastructures) diff --git a/database/crud/wordle.py b/database/crud/wordle.py index ea8b892..a918256 100644 --- a/database/crud/wordle.py +++ b/database/crud/wordle.py @@ -4,6 +4,7 @@ from typing import Optional from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession +from database.crud.users import get_or_add_user from database.schemas import WordleGuess, WordleWord __all__ = [ @@ -16,6 +17,7 @@ __all__ = [ async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]: """Find a player's active game""" + await get_or_add_user(session, user_id) statement = select(WordleGuess).where(WordleGuess.user_id == user_id) guesses = (await session.execute(statement)).scalars().all() return guesses diff --git a/database/utils/caches.py b/database/utils/caches.py index 3911a0e..35165a5 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -121,7 +121,7 @@ class WordleCache(DatabaseCache): async def invalidate(self, database_session: AsyncSession): word = await wordle.get_daily_word(database_session) if word is not None: - self.data = [word] + self.data = [word.word] class CacheManager: diff --git a/didier/cogs/games.py b/didier/cogs/games.py index 1a4453a..a98fe24 100644 --- a/didier/cogs/games.py +++ b/didier/cogs/games.py @@ -5,11 +5,7 @@ from discord import app_commands from discord.ext import commands from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH -from database.crud.wordle import ( - get_active_wordle_game, - make_wordle_guess, - start_new_wordle_game, -) +from database.crud.wordle import get_active_wordle_game, make_wordle_guess from didier import Didier from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed @@ -35,31 +31,35 @@ class Games(commands.Cog): embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed() return await interaction.followup.send(embed=embed) - active_game = await get_active_wordle_game(self.client.mongo_db, interaction.user.id) - if active_game is None: - active_game = await start_new_wordle_game(self.client.mongo_db, interaction.user.id) + word = self.client.database_caches.wordle_word.data[0].lower() - # Trying to guess with a complete game - if len(active_game.guesses) == WORDLE_GUESS_COUNT and guess: - embed = WordleErrorEmbed(message="You've already completed today's Wordle.\nTry again tomorrow!").to_embed() - return await interaction.followup.send(embed=embed) + async with self.client.postgres_session as session: + guesses_instances = await get_active_wordle_game(session, interaction.user.id) + guesses = list(map(lambda g: g.guess, guesses_instances)) - # Make a guess - if guess: - # The guess is not a real word - if guess.lower() not in self.client.wordle_words: - embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed() + # Trying to guess with a complete game + if (len(guesses) == WORDLE_GUESS_COUNT and guess) or word in guesses: + embed = WordleErrorEmbed( + message="You've already completed today's Wordle.\nTry again tomorrow!" + ).to_embed() return await interaction.followup.send(embed=embed) - guess = guess.lower() - await make_wordle_guess(self.client.mongo_db, interaction.user.id, guess) + # Make a guess + if guess: + # The guess is not a real word + if guess.lower() not in self.client.wordle_words: + embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed() + return await interaction.followup.send(embed=embed) - # Don't re-request the game, we already have it - # just append locally - active_game.guesses.append(guess) + guess = guess.lower() + await make_wordle_guess(session, interaction.user.id, guess) - embed = WordleEmbed(game=active_game, word=self.client.database_caches.wordle_word.data[0]).to_embed() - await interaction.followup.send(embed=embed) + # Don't re-request the game, we already have it + # just append locally + guesses.append(guess) + + embed = WordleEmbed(guesses=guesses, word=word).to_embed() + await interaction.followup.send(embed=embed) async def setup(client: Didier): diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 901aac6..2e62044 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -140,9 +140,9 @@ class Tasks(commands.Cog): @tasks.loop(time=DAILY_RESET_TIME) async def reset_wordle_word(self, forced: bool = False): """Reset the daily Wordle word""" - db = self.client.mongo_db - word = await set_daily_word(db, random.choice(tuple(self.client.wordle_words)), forced=forced) - self.client.database_caches.wordle_word.data = [word] + async with self.client.postgres_session as session: + word = await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced) + self.client.database_caches.wordle_word.data = [word] @reset_wordle_word.before_loop async def _before_reset_wordle_word(self): diff --git a/didier/data/embeds/wordle.py b/didier/data/embeds/wordle.py index d29a29f..1cf70fb 100644 --- a/didier/data/embeds/wordle.py +++ b/didier/data/embeds/wordle.py @@ -1,12 +1,11 @@ import enum from dataclasses import dataclass -from typing import Optional import discord from overrides import overrides from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH -from database.schemas.mongo.wordle import WordleGame +from database.schemas import WordleGuess from didier.data.embeds.base import EmbedBaseModel from didier.utils.types.datetime import int_to_weekday, tz_aware_now @@ -32,7 +31,7 @@ class WordleColour(enum.IntEnum): class WordleEmbed(EmbedBaseModel): """Embed for a Wordle game""" - game: Optional[WordleGame] + guesses: list[WordleGuess] word: str def _letter_colour(self, guess: str, index: int) -> WordleColour: @@ -68,9 +67,8 @@ class WordleEmbed(EmbedBaseModel): colours = [] # Add all the guesses - if self.game is not None: - for guess in self.game.guesses: - colours.append(self._guess_colours(guess)) + for guess in self.guesses: + colours.append(self._guess_colours(guess)) # Fill the rest with empty spots for _ in range(WORDLE_GUESS_COUNT - len(colours)): @@ -93,6 +91,16 @@ class WordleEmbed(EmbedBaseModel): return emojis + def _is_game_over(self) -> bool: + """Check if the current game is over or not""" + if not self.guesses: + return False + + if len(self.guesses) == WORDLE_GUESS_COUNT: + return True + + return self.word.lower() in self.guesses + @overrides def to_embed(self, **kwargs) -> discord.Embed: only_colours = kwargs.get("only_colours", False) @@ -105,12 +113,12 @@ class WordleEmbed(EmbedBaseModel): rows = [" ".join(row) for row in emojis] # Don't reveal anything if we only want to show the colours - if not only_colours and self.game is not None: - for i, guess in enumerate(self.game.guesses): + if not only_colours and self.guesses: + for i, guess in enumerate(self.guesses): rows[i] += f" ||{guess.upper()}||" # If the game is over, reveal the word - if self.game.is_game_over(self.word): + if self._is_game_over(): rows.append(f"\n\nThe word was **{self.word.upper()}**!") embed.description = "\n\n".join(rows) diff --git a/tests/test_database/conftest.py b/tests/test_database/conftest.py index 8383b05..a675a35 100644 --- a/tests/test_database/conftest.py +++ b/tests/test_database/conftest.py @@ -25,7 +25,7 @@ def test_user_id() -> int: @pytest.fixture async def user(postgres: AsyncSession, test_user_id: int) -> User: """Fixture to create a user""" - _user = await users.get_or_add(postgres, test_user_id) + _user = await users.get_or_add_user(postgres, test_user_id) await postgres.refresh(_user) return _user diff --git a/tests/test_database/test_crud/test_birthdays.py b/tests/test_database/test_crud/test_birthdays.py index 86740d5..29a3791 100644 --- a/tests/test_database/test_crud/test_birthdays.py +++ b/tests/test_database/test_crud/test_birthdays.py @@ -54,7 +54,7 @@ async def test_get_birthdays_on_day(postgres: AsyncSession, user: User): """Test getting all birthdays on a given day""" await crud.add_birthday(postgres, user.user_id, datetime.today().replace(year=2001)) - user_2 = await users.get_or_add(postgres, user.user_id + 1) + user_2 = await users.get_or_add_user(postgres, user.user_id + 1) await crud.add_birthday(postgres, user_2.user_id, datetime.today() + timedelta(weeks=1)) birthdays = await crud.get_birthdays_on_day(postgres, datetime.today()) assert len(birthdays) == 1 diff --git a/tests/test_database/test_crud/test_users.py b/tests/test_database/test_crud/test_users.py index b726fab..b5c2fc6 100644 --- a/tests/test_database/test_crud/test_users.py +++ b/tests/test_database/test_crud/test_users.py @@ -7,7 +7,7 @@ from database.schemas import User async def test_get_or_add_non_existing(postgres: AsyncSession): """Test get_or_add for a user that doesn't exist""" - await crud.get_or_add(postgres, 1) + await crud.get_or_add_user(postgres, 1) statement = select(User) res = (await postgres.execute(statement)).scalars().all() @@ -18,8 +18,8 @@ async def test_get_or_add_non_existing(postgres: AsyncSession): async def test_get_or_add_existing(postgres: AsyncSession): """Test get_or_add for a user that does exist""" - user = await crud.get_or_add(postgres, 1) + user = await crud.get_or_add_user(postgres, 1) bank = user.bank - assert await crud.get_or_add(postgres, 1) == user - assert (await crud.get_or_add(postgres, 1)).bank == bank + assert await crud.get_or_add_user(postgres, 1) == user + assert (await crud.get_or_add_user(postgres, 1)).bank == bank