Merge pull request #177 from stijndcl/didier-3.7.0

Didier v3.7.0
pull/180/head v3.7.0
Stijn De Clercq 2023-09-24 16:28:02 +02:00 committed by GitHub
commit 5bfd3a92a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 768 additions and 22942 deletions

View File

@ -3,12 +3,12 @@ default_language_version:
repos: repos:
- repo: https://github.com/ambv/black - repo: https://github.com/ambv/black
rev: 22.3.0 rev: 23.3.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0 rev: v4.4.0
hooks: hooks:
- id: check-json - id: check-json
- id: end-of-file-fixer - id: end-of-file-fixer
@ -21,7 +21,7 @@ repos:
- id: isort - id: isort
- repo: https://github.com/PyCQA/autoflake - repo: https://github.com/PyCQA/autoflake
rev: v1.4 rev: v2.2.0
hooks: hooks:
- id: autoflake - id: autoflake
name: autoflake (python) name: autoflake (python)
@ -31,7 +31,7 @@ repos:
- "--ignore-init-module-imports" - "--ignore-init-module-imports"
- repo: https://github.com/PyCQA/flake8 - repo: https://github.com/PyCQA/flake8
rev: 4.0.1 rev: 6.0.0
hooks: hooks:
- id: flake8 - id: flake8
exclude: ^(alembic|.github) exclude: ^(alembic|.github)

View File

@ -0,0 +1,94 @@
"""Migrate to 2.x
Revision ID: 09128b6e34dd
Revises: 1e3e7f4192c4
Create Date: 2023-07-07 16:23:15.990231
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "09128b6e34dd"
down_revision = "1e3e7f4192c4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("bank", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("birthdays", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("bookmarks", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("command_stats", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op:
batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=False)
with op.batch_alter_table("deadlines", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
with op.batch_alter_table("github_links", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("nightly_data", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("reminders", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("ufora_announcements", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=False)
with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("ufora_announcements", schema=None) as batch_op:
batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=True)
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("reminders", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("nightly_data", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("github_links", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("deadlines", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op:
batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("command_stats", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("bookmarks", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("birthdays", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("bank", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
# ### end Alembic commands ###

View File

@ -0,0 +1,57 @@
"""Remove wordle
Revision ID: 1e3e7f4192c4
Revises: 954ad804f057
Create Date: 2023-07-07 14:52:20.993687
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "1e3e7f4192c4"
down_revision = "954ad804f057"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("wordle_word")
op.drop_table("wordle_stats")
op.drop_table("wordle_guesses")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"wordle_guesses",
sa.Column("wordle_guess_id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BIGINT(), autoincrement=False, nullable=True),
sa.Column("guess", sa.TEXT(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], name="wordle_guesses_user_id_fkey"),
sa.PrimaryKeyConstraint("wordle_guess_id", name="wordle_guesses_pkey"),
)
op.create_table(
"wordle_stats",
sa.Column("wordle_stats_id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BIGINT(), autoincrement=False, nullable=True),
sa.Column("last_win", sa.DATE(), autoincrement=False, nullable=True),
sa.Column("games", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("wins", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("current_streak", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("highest_streak", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], name="wordle_stats_user_id_fkey"),
sa.PrimaryKeyConstraint("wordle_stats_id", name="wordle_stats_pkey"),
)
op.create_table(
"wordle_word",
sa.Column("word_id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("word", sa.TEXT(), autoincrement=False, nullable=False),
sa.Column("day", sa.DATE(), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint("word_id", name="wordle_word_pkey"),
sa.UniqueConstraint("day", name="wordle_word_day_key"),
)
# ### end Alembic commands ###

View File

@ -1,2 +0,0 @@
WORDLE_GUESS_COUNT = 6
WORDLE_WORD_LENGTH = 5

View File

@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[
if query is not None: if query is not None:
statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%")) statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%"))
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]: async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]:

View File

@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]: async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
"""Get a list of all commands""" """Get a list of all commands"""
statement = select(CustomCommand) statement = select(CustomCommand)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:

View File

@ -38,4 +38,4 @@ async def get_deadlines(
statement = statement.where(Deadline.course_id == course.course_id) statement = statement.where(Deadline.course_id == course.course_id)
statement = statement.options(selectinload(Deadline.course)) statement = statement.options(selectinload(Deadline.course))
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())

View File

@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"]
async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]:
"""Return a list of all easter eggs""" """Return a list of all easter eggs"""
statement = select(EasterEgg) statement = select(EasterEgg)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())

View File

@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
"""Get a list of all upcoming events""" """Get a list of all upcoming events"""
statement = select(Event).where(Event.timestamp > now) statement = select(Event).where(Event.timestamp > now)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:

View File

@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]):
async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]: async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]:
"""Filter a list of game IDs down to the ones that aren't in the database yet""" """Filter a list of game IDs down to the ones that aren't in the database yet"""
statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids)) statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids))
matches: list[int] = (await session.execute(statement)).scalars().all() matches: list[int] = list((await session.execute(statement)).scalars().all())
return list(set(game_ids).difference(matches)) return list(set(game_ids).difference(matches))

View File

@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id:
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]: async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
"""Get a user's GitHub links""" """Get a user's GitHub links"""
statement = select(GitHubLink).where(GitHubLink.user_id == user_id) statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())

View File

@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
async def get_all_links(session: AsyncSession) -> list[Link]: async def get_all_links(session: AsyncSession) -> list[Link]:
"""Get a list of all links""" """Get a list of all links"""
statement = select(Link) statement = select(Link)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def add_link(session: AsyncSession, name: str, url: str) -> Link: async def add_link(session: AsyncSession, name: str, url: str) -> Link:

View File

@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
"""Get a list of all memes""" """Get a list of all memes"""
statement = select(MemeTemplate) statement = select(MemeTemplate)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:

View File

@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"]
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]: async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
"""Get a list of all Reminders for a given category""" """Get a list of all Reminders for a given category"""
statement = select(Reminder).where(Reminder.category == category) statement = select(Reminder).where(Reminder.category == category)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool: async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:

View File

@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_
async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]: async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]:
"""Get all courses where announcements are enabled""" """Get all courses where announcements are enabled"""
statement = select(UforaCourse).where(UforaCourse.log_announcements) statement = select(UforaCourse).where(UforaCourse.log_announcements)
return (await session.execute(statement)).scalars().all() return list((await session.execute(statement)).scalars().all())
async def create_new_announcement( async def create_new_announcement(

View File

@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor
# Search case-insensitively # Search case-insensitively
query = query.lower() query = query.lower()
statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
result = (await session.execute(statement)).scalars().first() course_result = (await session.execute(course_statement)).scalars().first()
if result: if course_result:
return result return course_result
statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
result = (await session.execute(statement)).scalars().first() alias_result = (await session.execute(alias_statement)).scalars().first()
return result.course if result else None return alias_result.course if alias_result else None

View File

@ -1,85 +0,0 @@
import datetime
from typing import Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import WordleGuess, WordleWord
__all__ = [
"get_active_wordle_game",
"get_wordle_guesses",
"make_wordle_guess",
"set_daily_word",
"reset_wordle_games",
]
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
"""Find a player's active game"""
await get_or_add_user(session, user_id)
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
guesses = (await session.execute(statement)).scalars().all()
return guesses
async def get_wordle_guesses(session: AsyncSession, user_id: int) -> list[str]:
"""Get the strings of a player's guesses"""
active_game = await get_active_wordle_game(session, user_id)
return list(map(lambda g: g.guess.lower(), active_game))
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
"""Make a guess in your current game"""
guess_instance = WordleGuess(user_id=user_id, guess=guess)
session.add(guess_instance)
await session.commit()
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
"""Get the word of today"""
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
row = (await session.execute(statement)).scalar_one_or_none()
if row is None:
return None
return row
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
"""Set the word of today
This does NOT overwrite the existing word if there is one, so that it can safely run
on startup every time.
In order to always overwrite the current word, set the "forced"-kwarg to True.
Returns the word that was chosen. If one already existed, return that instead.
"""
current_word = await get_daily_word(session)
if current_word is None:
current_word = WordleWord(word=word, day=datetime.date.today())
session.add(current_word)
await session.commit()
# Remove all active games
await reset_wordle_games(session)
elif forced:
current_word.word = word
session.add(current_word)
await session.commit()
# Remove all active games
await reset_wordle_games(session)
return current_word.word
async def reset_wordle_games(session: AsyncSession):
"""Reset all active games"""
statement = delete(WordleGuess)
await session.execute(statement)
await session.commit()

View File

@ -1,60 +0,0 @@
from datetime import date
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import WordleStats
__all__ = ["get_wordle_stats", "complete_wordle_game"]
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
"""Get a user's wordle stats
If no entry is found, it is first created
"""
await get_or_add_user(session, user_id)
statement = select(WordleStats).where(WordleStats.user_id == user_id)
stats = (await session.execute(statement)).scalar_one_or_none()
if stats is not None:
return stats
stats = WordleStats(user_id=user_id)
session.add(stats)
await session.commit()
await session.refresh(stats)
return stats
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
"""Update the user's Wordle stats"""
stats = await get_wordle_stats(session, user_id)
stats.games += 1
if win:
stats.wins += 1
# Update streak
today = date.today()
last_win = stats.last_win
stats.last_win = today
if last_win is None or (today - last_win).days > 1:
# Never won a game before or streak is over
stats.current_streak = 1
else:
# On a streak: increase counter
stats.current_streak += 1
# Update max streak if necessary
if stats.current_streak > stats.highest_streak:
stats.highest_streak = stats.current_streak
else:
# Streak is over
stats.current_streak = 0
session.add(stats)
await session.commit()

View File

@ -1,8 +1,7 @@
from urllib.parse import quote_plus from urllib.parse import quote_plus
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from sqlalchemy.orm import sessionmaker
import settings import settings
@ -22,6 +21,4 @@ postgres_engine = create_async_engine(
future=True, future=True,
) )
DBSession = sessionmaker( DBSession = async_sessionmaker(autocommit=False, autoflush=False, bind=postgres_engine, expire_on_commit=False)
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
)

View File

@ -1,27 +1,14 @@
from __future__ import annotations from __future__ import annotations
from datetime import date, datetime from datetime import date, datetime
from typing import Optional from typing import List, Optional
from sqlalchemy import ( from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint
BigInteger, from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
Boolean, from sqlalchemy.types import DateTime
Column,
Date,
DateTime,
Enum,
ForeignKey,
Integer,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import declarative_base, relationship
from database import enums from database import enums
Base = declarative_base()
__all__ = [ __all__ = [
"Base", "Base",
"Bank", "Bank",
@ -45,33 +32,37 @@ __all__ = [
"UforaCourse", "UforaCourse",
"UforaCourseAlias", "UforaCourseAlias",
"User", "User",
"WordleGuess",
"WordleStats",
"WordleWord",
] ]
class Base(DeclarativeBase):
"""Required base class for all tables"""
# Make all DateTimes timezone-aware
type_annotation_map = {datetime: DateTime(timezone=True)}
class Bank(Base): class Bank(Base):
"""A user's currency information""" """A user's currency information"""
__tablename__ = "bank" __tablename__ = "bank"
bank_id: int = Column(Integer, primary_key=True) bank_id: Mapped[int] = mapped_column(primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
dinks: int = Column(BigInteger, server_default="0", nullable=False) dinks: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False)
invested: int = Column(BigInteger, server_default="0", nullable=False) invested: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False)
# Interest rate # Interest rate
interest_level: int = Column(Integer, server_default="1", nullable=False) interest_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
# Maximum amount that can be stored in the bank # Maximum amount that can be stored in the bank
capacity_level: int = Column(Integer, server_default="1", nullable=False) capacity_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
# Maximum amount that can be robbed # Maximum amount that can be robbed
rob_level: int = Column(Integer, server_default="1", nullable=False) rob_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") user: Mapped[User] = relationship(uselist=False, back_populates="bank", lazy="selectin")
class Birthday(Base): class Birthday(Base):
@ -79,11 +70,11 @@ class Birthday(Base):
__tablename__ = "birthdays" __tablename__ = "birthdays"
birthday_id: int = Column(Integer, primary_key=True) birthday_id: Mapped[int] = mapped_column(primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
birthday: date = Column(Date, nullable=False) birthday: Mapped[date] = mapped_column(nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") user: Mapped[User] = relationship(uselist=False, back_populates="birthday", lazy="selectin")
class Bookmark(Base): class Bookmark(Base):
@ -92,26 +83,26 @@ class Bookmark(Base):
__tablename__ = "bookmarks" __tablename__ = "bookmarks"
__table_args__ = (UniqueConstraint("user_id", "label"),) __table_args__ = (UniqueConstraint("user_id", "label"),)
bookmark_id: int = Column(Integer, primary_key=True) bookmark_id: Mapped[int] = mapped_column(primary_key=True)
label: str = Column(Text, nullable=False) label: Mapped[str] = mapped_column(nullable=False)
jump_url: str = Column(Text, nullable=False) jump_url: Mapped[str] = mapped_column(nullable=False)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin") user: Mapped[User] = relationship(back_populates="bookmarks", uselist=False, lazy="selectin")
class CommandStats(Base): class CommandStats(Base):
"""Metrics on how often commands are used""" """Metrics on how often commands are used"""
__tablename__ = "command_stats" __tablename__ = "command_stats"
command_stats_id: int = Column(Integer, primary_key=True) command_stats_id: Mapped[int] = mapped_column(primary_key=True)
command: str = Column(Text, nullable=False) command: Mapped[str] = mapped_column(nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False) timestamp: Mapped[datetime] = mapped_column(nullable=False)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
slash: bool = Column(Boolean, nullable=False) slash: Mapped[bool] = mapped_column(nullable=False)
context_menu: bool = Column(Boolean, nullable=False) context_menu: Mapped[bool] = mapped_column(nullable=False)
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin") user: Mapped[User] = relationship(back_populates="command_stats", uselist=False, lazy="selectin")
class CustomCommand(Base): class CustomCommand(Base):
@ -119,13 +110,13 @@ class CustomCommand(Base):
__tablename__ = "custom_commands" __tablename__ = "custom_commands"
command_id: int = Column(Integer, primary_key=True) command_id: Mapped[int] = mapped_column(primary_key=True)
name: str = Column(Text, nullable=False, unique=True) name: Mapped[str] = mapped_column(nullable=False, unique=True)
indexed_name: str = Column(Text, nullable=False, index=True) indexed_name: Mapped[str] = mapped_column(nullable=False, index=True)
response: str = Column(Text, nullable=False) response: Mapped[str] = mapped_column(nullable=False)
aliases: list[CustomCommandAlias] = relationship( aliases: Mapped[List[CustomCommandAlias]] = relationship(
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
) )
@ -134,12 +125,12 @@ class CustomCommandAlias(Base):
__tablename__ = "custom_command_aliases" __tablename__ = "custom_command_aliases"
alias_id: int = Column(Integer, primary_key=True) alias_id: Mapped[int] = mapped_column(primary_key=True)
alias: str = Column(Text, nullable=False, unique=True) alias: Mapped[str] = mapped_column(nullable=False, unique=True)
indexed_alias: str = Column(Text, nullable=False, index=True) indexed_alias: Mapped[str] = mapped_column(nullable=False, index=True)
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id")) command_id: Mapped[int] = mapped_column(ForeignKey("custom_commands.command_id"))
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") command: Mapped[CustomCommand] = relationship(back_populates="aliases", uselist=False, lazy="selectin")
class DadJoke(Base): class DadJoke(Base):
@ -147,8 +138,8 @@ class DadJoke(Base):
__tablename__ = "dad_jokes" __tablename__ = "dad_jokes"
dad_joke_id: int = Column(Integer, primary_key=True) dad_joke_id: Mapped[int] = mapped_column(primary_key=True)
joke: str = Column(Text, nullable=False) joke: Mapped[str] = mapped_column(nullable=False)
class Deadline(Base): class Deadline(Base):
@ -156,12 +147,12 @@ class Deadline(Base):
__tablename__ = "deadlines" __tablename__ = "deadlines"
deadline_id: int = Column(Integer, primary_key=True) deadline_id: Mapped[int] = mapped_column(primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
name: str = Column(Text, nullable=False) name: Mapped[str] = mapped_column(nullable=False)
deadline: datetime = Column(DateTime(timezone=True), nullable=False) deadline: Mapped[datetime] = mapped_column(nullable=False)
course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin") course: Mapped[UforaCourse] = relationship(back_populates="deadlines", uselist=False, lazy="selectin")
class EasterEgg(Base): class EasterEgg(Base):
@ -169,11 +160,11 @@ class EasterEgg(Base):
__tablename__ = "easter_eggs" __tablename__ = "easter_eggs"
easter_egg_id: int = Column(Integer, primary_key=True) easter_egg_id: Mapped[int] = mapped_column(primary_key=True)
match: str = Column(Text, nullable=False) match: Mapped[str] = mapped_column(nullable=False)
response: str = Column(Text, nullable=False) response: Mapped[str] = mapped_column(nullable=False)
exact: bool = Column(Boolean, nullable=False, server_default="1") exact: Mapped[bool] = mapped_column(nullable=False, server_default="1")
startswith: bool = Column(Boolean, nullable=False, server_default="1") startswith: Mapped[bool] = mapped_column(nullable=False, server_default="1")
class Event(Base): class Event(Base):
@ -181,11 +172,11 @@ class Event(Base):
__tablename__ = "events" __tablename__ = "events"
event_id: int = Column(Integer, primary_key=True) event_id: Mapped[int] = mapped_column(primary_key=True)
name: str = Column(Text, nullable=False) name: Mapped[str] = mapped_column(nullable=False)
description: Optional[str] = Column(Text, nullable=True) description: Mapped[Optional[str]] = mapped_column(nullable=True)
notification_channel: int = Column(BigInteger, nullable=False) notification_channel: Mapped[int] = mapped_column(BigInteger, nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False) timestamp: Mapped[datetime] = mapped_column(nullable=False)
class FreeGame(Base): class FreeGame(Base):
@ -193,7 +184,7 @@ class FreeGame(Base):
__tablename__ = "free_games" __tablename__ = "free_games"
free_game_id: int = Column(Integer, primary_key=True) free_game_id: Mapped[int] = mapped_column(primary_key=True)
class GitHubLink(Base): class GitHubLink(Base):
@ -201,11 +192,11 @@ class GitHubLink(Base):
__tablename__ = "github_links" __tablename__ = "github_links"
github_link_id: int = Column(Integer, primary_key=True) github_link_id: Mapped[int] = mapped_column(primary_key=True)
url: str = Column(Text, nullable=False, unique=True) url: Mapped[str] = mapped_column(nullable=False, unique=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin") user: Mapped[User] = relationship(back_populates="github_links", uselist=False, lazy="selectin")
class Link(Base): class Link(Base):
@ -213,9 +204,9 @@ class Link(Base):
__tablename__ = "links" __tablename__ = "links"
link_id: int = Column(Integer, primary_key=True) link_id: Mapped[int] = mapped_column(primary_key=True)
name: str = Column(Text, nullable=False, unique=True) name: Mapped[str] = mapped_column(nullable=False, unique=True)
url: str = Column(Text, nullable=False) url: Mapped[str] = mapped_column(nullable=False)
class MemeTemplate(Base): class MemeTemplate(Base):
@ -223,10 +214,10 @@ class MemeTemplate(Base):
__tablename__ = "meme" __tablename__ = "meme"
meme_id: int = Column(Integer, primary_key=True) meme_id: Mapped[int] = mapped_column(primary_key=True)
name: str = Column(Text, nullable=False, unique=True) name: Mapped[str] = mapped_column(nullable=False, unique=True)
template_id: int = Column(Integer, nullable=False, unique=True) template_id: Mapped[int] = mapped_column(nullable=False, unique=True)
field_count: int = Column(Integer, nullable=False) field_count: Mapped[int] = mapped_column(nullable=False)
class NightlyData(Base): class NightlyData(Base):
@ -234,12 +225,12 @@ class NightlyData(Base):
__tablename__ = "nightly_data" __tablename__ = "nightly_data"
nightly_id: int = Column(Integer, primary_key=True) nightly_id: Mapped[int] = mapped_column(primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[date] = Column(Date, nullable=True) last_nightly: Mapped[Optional[date]] = mapped_column(nullable=True)
count: int = Column(Integer, server_default="0", nullable=False) count: Mapped[int] = mapped_column(server_default="0", nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") user: Mapped[User] = relationship(back_populates="nightly_data", uselist=False, lazy="selectin")
class Reminder(Base): class Reminder(Base):
@ -247,11 +238,11 @@ class Reminder(Base):
__tablename__ = "reminders" __tablename__ = "reminders"
reminder_id: int = Column(Integer, primary_key=True) reminder_id: Mapped[int] = mapped_column(primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id")) user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False) category: Mapped[enums.ReminderCategory] = mapped_column(nullable=False)
user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin") user: Mapped[User] = relationship(back_populates="reminders", uselist=False, lazy="selectin")
class Task(Base): class Task(Base):
@ -259,9 +250,9 @@ class Task(Base):
__tablename__ = "tasks" __tablename__ = "tasks"
task_id: int = Column(Integer, primary_key=True) task_id: Mapped[int] = mapped_column(primary_key=True)
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True) task: Mapped[enums.TaskType] = mapped_column(nullable=False, unique=True)
previous_run: datetime = Column(DateTime(timezone=True), nullable=True) previous_run: Mapped[datetime] = mapped_column(nullable=True)
class UforaCourse(Base): class UforaCourse(Base):
@ -269,25 +260,25 @@ class UforaCourse(Base):
__tablename__ = "ufora_courses" __tablename__ = "ufora_courses"
course_id: int = Column(Integer, primary_key=True) course_id: Mapped[int] = mapped_column(primary_key=True)
name: str = Column(Text, nullable=False, unique=True) name: Mapped[str] = mapped_column(nullable=False, unique=True)
code: str = Column(Text, nullable=False, unique=True) code: Mapped[str] = mapped_column(nullable=False, unique=True)
year: int = Column(Integer, nullable=False) year: Mapped[int] = mapped_column(nullable=False)
compulsory: bool = Column(Boolean, server_default="1", nullable=False) compulsory: Mapped[bool] = mapped_column(server_default="1", nullable=False)
role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
# This is not the greatest fix, but there can only ever be two, so it will do the job # This is not the greatest fix, but there can only ever be two, so it will do the job
alternative_overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) alternative_overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
log_announcements: bool = Column(Boolean, server_default="0", nullable=False) log_announcements: Mapped[bool] = mapped_column(server_default="0", nullable=False)
announcements: list[UforaAnnouncement] = relationship( announcements: Mapped[List[UforaAnnouncement]] = relationship(
"UforaAnnouncement", back_populates="course", cascade="all, delete-orphan", lazy="selectin" back_populates="course", cascade="all, delete-orphan", lazy="selectin"
) )
aliases: list[UforaCourseAlias] = relationship( aliases: Mapped[List[UforaCourseAlias]] = relationship(
"UforaCourseAlias", back_populates="course", cascade="all, delete-orphan", lazy="selectin" back_populates="course", cascade="all, delete-orphan", lazy="selectin"
) )
deadlines: list[Deadline] = relationship( deadlines: Mapped[List[Deadline]] = relationship(
"Deadline", back_populates="course", cascade="all, delete-orphan", lazy="selectin" back_populates="course", cascade="all, delete-orphan", lazy="selectin"
) )
@ -296,11 +287,11 @@ class UforaCourseAlias(Base):
__tablename__ = "ufora_course_aliases" __tablename__ = "ufora_course_aliases"
alias_id: int = Column(Integer, primary_key=True) alias_id: Mapped[int] = mapped_column(primary_key=True)
alias: str = Column(Text, nullable=False, unique=True) alias: Mapped[str] = mapped_column(nullable=False, unique=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
course: UforaCourse = relationship("UforaCourse", back_populates="aliases", uselist=False, lazy="selectin") course: Mapped[UforaCourse] = relationship(back_populates="aliases", uselist=False, lazy="selectin")
class UforaAnnouncement(Base): class UforaAnnouncement(Base):
@ -308,11 +299,11 @@ class UforaAnnouncement(Base):
__tablename__ = "ufora_announcements" __tablename__ = "ufora_announcements"
announcement_id: int = Column(Integer, primary_key=True) announcement_id: Mapped[int] = mapped_column(primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
publication_date: date = Column(Date) publication_date: Mapped[date] = mapped_column()
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") course: Mapped[UforaCourse] = relationship(back_populates="announcements", uselist=False, lazy="selectin")
class User(Base): class User(Base):
@ -320,70 +311,26 @@ class User(Base):
__tablename__ = "users" __tablename__ = "users"
user_id: int = Column(BigInteger, primary_key=True) user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
bank: Bank = relationship( bank: Mapped[Bank] = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
birthday: Optional[Birthday] = relationship( birthday: Mapped[Optional[Birthday]] = relationship(
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
bookmarks: list[Bookmark] = relationship( bookmarks: Mapped[List[Bookmark]] = relationship(
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )
command_stats: list[CommandStats] = relationship( command_stats: Mapped[List[CommandStats]] = relationship(
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )
github_links: list[GitHubLink] = relationship( github_links: Mapped[List[GitHubLink]] = relationship(
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )
nightly_data: NightlyData = relationship( nightly_data: Mapped[NightlyData] = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )
reminders: list[Reminder] = relationship( reminders: Mapped[List[Reminder]] = relationship(
"Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )
wordle_guesses: list[WordleGuess] = relationship(
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_stats: WordleStats = relationship(
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
class WordleGuess(Base):
"""A user's Wordle guesses for today"""
__tablename__ = "wordle_guesses"
wordle_guess_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
guess: str = Column(Text, nullable=False)
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
class WordleStats(Base):
"""Stats about a user's wordle performance"""
__tablename__ = "wordle_stats"
wordle_stats_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_win: Optional[date] = Column(Date, nullable=True)
games: int = Column(Integer, server_default="0", nullable=False)
wins: int = Column(Integer, server_default="0", nullable=False)
current_streak: int = Column(Integer, server_default="0", nullable=False)
highest_streak: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
class WordleWord(Base):
"""The current Wordle word"""
__tablename__ = "wordle_word"
word_id: int = Column(Integer, primary_key=True)
word: str = Column(Text, nullable=False)
day: date = Column(Date, nullable=False, unique=True)

View File

@ -0,0 +1,171 @@
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from database.engine import DBSession
from database.schemas import UforaCourse, UforaCourseAlias
__all__ = ["main"]
async def purge_aliases(session: AsyncSession):
"""Delete old aliases for courses that will get a new id"""
codes = ["C004074", "C004073", "C004075", "C002309"]
for course_code in codes:
select_stmt = (
select(UforaCourse).where(UforaCourse.code == course_code).options(selectinload(UforaCourse.aliases))
)
course: UforaCourse = (await session.execute(select_stmt)).scalar_one()
for alias in list(course.aliases):
await session.delete(alias)
await session.commit()
async def main():
"""Add the Ufora courses for the 2023-2024 academic year"""
session: AsyncSession
async with DBSession() as session:
# # Remove Advanced Databases (which no longer exists)
delete_stmt = delete(UforaCourse).where(UforaCourse.code == "E018441")
await session.execute(delete_stmt)
await session.commit()
# Delete aliases of courses with new IDs
await purge_aliases(session)
# Fix IDs of compulsory courses and enable announcements
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004074")
bds: UforaCourse = (await session.execute(select_stmt)).scalar_one()
bds.course_id = 828305
bds.log_announcements = True
session.add(bds)
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004073")
cg: UforaCourse = (await session.execute(select_stmt)).scalar_one()
cg.course_id = 828293
cg.log_announcements = True
session.add(cg)
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004075")
stage: UforaCourse = (await session.execute(select_stmt)).scalar_one()
stage.course_id = 857878
stage.log_announcements = True
session.add(stage)
select_stmt = select(UforaCourse).where(UforaCourse.code == "C002309")
thesis: UforaCourse = (await session.execute(select_stmt)).scalar_one()
thesis.course_id = 828446
thesis.log_announcements = True
session.add(thesis)
await session.commit()
# Add new aliases for these courses
cg_alias = UforaCourseAlias(course_id=cg.course_id, alias="Computer Graphics")
stage_alias = UforaCourseAlias(course_id=stage.course_id, alias="Stage")
thesis_alias = UforaCourseAlias(course_id=thesis.course_id, alias="Thesis")
session.add_all([cg_alias, stage_alias, thesis_alias])
await session.commit()
# New elective courses
bed_eco = UforaCourse(
code="H001535", name="Bedrijfseconomie", year=6, compulsory=False, role_id=1155496199000952922
)
beg_eco = UforaCourse(
code="D012144", name="Beginselen van economie", year=6, compulsory=False, role_id=1155495997024247948
)
big_data_tech = UforaCourse(
code="E018240", name="Big Data Technology", year=6, compulsory=False, role_id=1155490114282201148
)
cloud_storage = UforaCourse(
code="E017310", name="Cloud Storage and Computing", year=6, compulsory=False, role_id=1155490706849271841
)
criminologie = UforaCourse(
code="B001623", name="Inleiding tot Criminologie", year=6, compulsory=False, role_id=1155486768477515936
)
data_quality = UforaCourse(
code="E018700", name="Data Quality", year=6, compulsory=False, role_id=1155491028707586180
)
data_vis_ai = UforaCourse(
code="E061370",
name="Data Visualization for and with AI",
year=6,
compulsory=False,
role_id=1155491854687686746,
)
db_design = UforaCourse(
code="E018610", name="Database Design", year=6, compulsory=False, role_id=1155489846345875489
)
finance_markets = UforaCourse(
code="F000093",
name="Financiële Markten en Instellingen",
year=6,
compulsory=False,
role_id=1155492815615299634,
)
game_theory = UforaCourse(
code="E003710",
name="Game Theory and Multiagent Systems",
year=6,
compulsory=False,
role_id=1155488481666154506,
)
knowledge_graphs = UforaCourse(
code="E018160", name="Knowledge Graphs", year=6, compulsory=False, role_id=1155491648323735592
)
natural_language_processing = UforaCourse(
code="E061341", name="Natural Language Processing", year=6, compulsory=False, role_id=1155487540992823348
)
nosql = UforaCourse(
code="E018130", name="NoSQL Databases", year=6, compulsory=False, role_id=1155491405955878973
)
secure_ss = UforaCourse(
code="E017950", name="Secure Software and Systems", year=6, compulsory=False, role_id=1155492095281340467
)
session.add_all(
[
bed_eco,
beg_eco,
big_data_tech,
cloud_storage,
criminologie,
data_quality,
data_vis_ai,
db_design,
finance_markets,
game_theory,
knowledge_graphs,
natural_language_processing,
nosql,
secure_ss,
]
)
await session.commit()
# Aliases for new elective courses
datakwaliteit = UforaCourseAlias(course_id=data_quality.course_id, alias="Datakwaliteit")
devops = UforaCourseAlias(course_id=cloud_storage.course_id, alias="DevOps")
nlp = UforaCourseAlias(course_id=natural_language_processing.course_id, alias="NLP")
nlp_nl = UforaCourseAlias(course_id=natural_language_processing.course_id, alias="Natuurlijke Taalverwerking")
session.add_all([datakwaliteit, devops, nlp, nlp_nl])
await session.commit()

View File

@ -4,8 +4,8 @@ from discord import app_commands
from overrides import overrides from overrides import overrides
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import easter_eggs, links, memes, ufora_courses, wordle from database.crud import easter_eggs, links, memes, ufora_courses
from database.schemas import EasterEgg, WordleWord from database.schemas import EasterEgg
__all__ = ["CacheManager", "EasterEggCache", "LinkCache", "UforaCourseCache"] __all__ = ["CacheManager", "EasterEggCache", "LinkCache", "UforaCourseCache"]
@ -69,7 +69,7 @@ class LinkCache(DatabaseCache):
self.clear() self.clear()
all_links = await links.get_all_links(database_session) all_links = await links.get_all_links(database_session)
self.data = list(map(lambda l: l.name, all_links)) self.data = list(map(lambda link: link.name, all_links))
self.data.sort() self.data.sort()
self.data_transformed = list(map(str.lower, self.data)) self.data_transformed = list(map(str.lower, self.data))
@ -132,17 +132,6 @@ class UforaCourseCache(DatabaseCache):
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions] return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
class WordleCache(DatabaseCache):
"""Cache to store the current daily Wordle word"""
word: WordleWord
async def invalidate(self, database_session: AsyncSession):
word = await wordle.get_daily_word(database_session)
if word is not None:
self.word = word
class CacheManager: class CacheManager:
"""Class that keeps track of all caches""" """Class that keeps track of all caches"""
@ -150,14 +139,12 @@ class CacheManager:
links: LinkCache links: LinkCache
memes: MemeCache memes: MemeCache
ufora_courses: UforaCourseCache ufora_courses: UforaCourseCache
wordle_word: WordleCache
def __init__(self): def __init__(self):
self.easter_eggs = EasterEggCache() self.easter_eggs = EasterEggCache()
self.links = LinkCache() self.links = LinkCache()
self.memes = MemeCache() self.memes = MemeCache()
self.ufora_courses = UforaCourseCache() self.ufora_courses = UforaCourseCache()
self.wordle_word = WordleCache()
async def initialize_caches(self, postgres_session: AsyncSession): async def initialize_caches(self, postgres_session: AsyncSession):
"""Initialize the contents of all caches""" """Initialize the contents of all caches"""
@ -165,4 +152,3 @@ class CacheManager:
await self.links.invalidate(postgres_session) await self.links.invalidate(postgres_session)
await self.memes.invalidate(postgres_session) await self.memes.invalidate(postgres_session)
await self.ufora_courses.invalidate(postgres_session) await self.ufora_courses.invalidate(postgres_session)
await self.wordle_word.invalidate(postgres_session)

View File

@ -25,7 +25,7 @@ class Currency(commands.Cog):
super().__init__() super().__init__()
self.client = client self.client = client
@commands.command(name="award") @commands.command(name="award") # type: ignore[arg-type]
@commands.check(is_owner) @commands.check(is_owner)
async def award( async def award(
self, self,
@ -49,7 +49,9 @@ class Currency(commands.Cog):
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue()) embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue())
embed.set_thumbnail(url=ctx.author.avatar.url)
if ctx.author.avatar is not None:
embed.set_thumbnail(url=ctx.author.avatar.url)
embed.add_field(name="Interest level", value=bank.interest_level) embed.add_field(name="Interest level", value=bank.interest_level)
embed.add_field(name="Capacity level", value=bank.capacity_level) embed.add_field(name="Capacity level", value=bank.capacity_level)
@ -57,7 +59,9 @@ class Currency(commands.Cog):
await ctx.reply(embed=embed, mention_author=False) await ctx.reply(embed=embed, mention_author=False)
@bank.group(name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True) @bank.group( # type: ignore[arg-type]
name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True
)
async def bank_upgrades(self, ctx: commands.Context): async def bank_upgrades(self, ctx: commands.Context):
"""List the upgrades you can buy & their prices.""" """List the upgrades you can buy & their prices."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -77,7 +81,7 @@ class Currency(commands.Cog):
await ctx.reply(embed=embed, mention_author=False) await ctx.reply(embed=embed, mention_author=False)
@bank_upgrades.command(name="capacity", aliases=["c"]) @bank_upgrades.command(name="capacity", aliases=["c"]) # type: ignore[arg-type]
async def bank_upgrade_capacity(self, ctx: commands.Context): async def bank_upgrade_capacity(self, ctx: commands.Context):
"""Upgrade the capacity level of your bank.""" """Upgrade the capacity level of your bank."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -88,7 +92,7 @@ class Currency(commands.Cog):
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@bank_upgrades.command(name="interest", aliases=["i"]) @bank_upgrades.command(name="interest", aliases=["i"]) # type: ignore[arg-type]
async def bank_upgrade_interest(self, ctx: commands.Context): async def bank_upgrade_interest(self, ctx: commands.Context):
"""Upgrade the interest level of your bank.""" """Upgrade the interest level of your bank."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -99,7 +103,7 @@ class Currency(commands.Cog):
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@bank_upgrades.command(name="rob", aliases=["r"]) @bank_upgrades.command(name="rob", aliases=["r"]) # type: ignore[arg-type]
async def bank_upgrade_rob(self, ctx: commands.Context): async def bank_upgrade_rob(self, ctx: commands.Context):
"""Upgrade the rob level of your bank.""" """Upgrade the rob level of your bank."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -110,7 +114,7 @@ class Currency(commands.Cog):
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@commands.hybrid_command(name="dinks") @commands.hybrid_command(name="dinks") # type: ignore[arg-type]
async def dinks(self, ctx: commands.Context): async def dinks(self, ctx: commands.Context):
"""Check your Didier Dinks.""" """Check your Didier Dinks."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -118,7 +122,7 @@ class Currency(commands.Cog):
plural = pluralize("Didier Dink", bank.dinks) plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False)
@commands.command(name="invest", aliases=["deposit", "dep"]) @commands.command(name="invest", aliases=["deposit", "dep"]) # type: ignore[arg-type]
async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]): async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]):
"""Invest `amount` Didier Dinks into your bank. """Invest `amount` Didier Dinks into your bank.
@ -144,7 +148,7 @@ class Currency(commands.Cog):
f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False
) )
@commands.hybrid_command(name="nightly") @commands.hybrid_command(name="nightly") # type: ignore[arg-type]
async def nightly(self, ctx: commands.Context): async def nightly(self, ctx: commands.Context):
"""Claim nightly Didier Dinks.""" """Claim nightly Didier Dinks."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:

View File

@ -13,7 +13,7 @@ class DebugCog(commands.Cog):
self.client = client self.client = client
@overrides @overrides
async def cog_check(self, ctx: commands.Context) -> bool: async def cog_check(self, ctx: commands.Context) -> bool: # type:ignore[override]
return await self.client.is_owner(ctx.author) return await self.client.is_owner(ctx.author)
@commands.command(aliases=["Dev"]) @commands.command(aliases=["Dev"])

View File

@ -1,4 +1,4 @@
from typing import Optional from typing import Optional, Union, cast
import discord import discord
from discord import app_commands from discord import app_commands
@ -17,6 +17,7 @@ from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource from didier.menus.bookmarks import BookmarkSource
from didier.utils.discord import colours from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
from didier.utils.discord.constants import Limits from didier.utils.discord.constants import Limits
from didier.utils.timer import Timer from didier.utils.timer import Timer
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
@ -60,9 +61,19 @@ class Discord(commands.Cog):
event = await events.get_event_by_id(session, event_id) event = await events.get_event_by_id(session, event_id)
if event is None: if event is None:
return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True) return await self.client.log_error(f"Unable to find event with id {event_id}.", log_to_discord=True)
channel = self.client.get_channel(event.notification_channel) channel = self.client.get_channel(event.notification_channel)
if channel is None:
return await self.client.log_error(
f"Unable to fetch channel for event `#{event_id}` (id `{event.notification_channel}`)."
)
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Channel for event `#{event_id}` (id `{event.notification_channel}`) is not messageable."
)
human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M") human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M")
embed = discord.Embed(title=event.name, colour=discord.Colour.blue()) embed = discord.Embed(title=event.name, colour=discord.Colour.blue())
@ -81,7 +92,7 @@ class Discord(commands.Cog):
self.client.loop.create_task(self.timer.update()) self.client.loop.create_task(self.timer.update())
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: discord.User = None): async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):
"""Command to check the birthday of `user`. """Command to check the birthday of `user`.
Not passing an argument for `user` will show yours instead. Not passing an argument for `user` will show yours instead.
@ -98,8 +109,10 @@ class Discord(commands.Cog):
day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month))
return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False) return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False)
@birthday.command(name="set", aliases=["config"]) @birthday.command(name="set", aliases=["config"]) # type: ignore[arg-type]
async def birthday_set(self, ctx: commands.Context, day: str, user: Optional[discord.User] = None): async def birthday_set(
self, ctx: commands.Context, day: str, user: Optional[Union[discord.User, discord.Member]] = None
):
"""Set your birthday to `day`. """Set your birthday to `day`.
Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`. Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`.
@ -113,6 +126,9 @@ class Discord(commands.Cog):
if user is None: if user is None:
user = ctx.author user = ctx.author
# Please Mypy
user = cast(Union[discord.User, discord.Member], user)
try: try:
default_year = 2001 default_year = 2001
date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"]) date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
@ -141,7 +157,7 @@ class Discord(commands.Cog):
""" """
# No label: shortcut to display bookmarks # No label: shortcut to display bookmarks
if label is None: if label is None:
return await self.bookmark_search(ctx, query=None) return await self.bookmark_search(ctx, query=None) # type: ignore[arg-type]
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
result = expect( result = expect(
@ -151,7 +167,7 @@ class Discord(commands.Cog):
) )
await ctx.reply(result.jump_url, mention_author=False) await ctx.reply(result.jump_url, mention_author=False)
@bookmark.command(name="create", aliases=["new"]) @bookmark.command(name="create", aliases=["new"]) # type: ignore[arg-type]
async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]): async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]):
"""Create a new bookmark for message `message` with label `label`. """Create a new bookmark for message `message` with label `label`.
@ -182,7 +198,7 @@ class Discord(commands.Cog):
# Label isn't allowed # Label isn't allowed
return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False)
@bookmark.command(name="delete", aliases=["rm"]) @bookmark.command(name="delete", aliases=["rm"]) # type: ignore[arg-type]
async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str):
"""Delete the bookmark with id `bookmark_id`. """Delete the bookmark with id `bookmark_id`.
@ -207,7 +223,7 @@ class Discord(commands.Cog):
return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False)
@bookmark.command(name="search", aliases=["list", "ls"]) @bookmark.command(name="search", aliases=["list", "ls"]) # type: ignore[arg-type]
async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None):
"""Search through the list of bookmarks. """Search through the list of bookmarks.
@ -236,7 +252,7 @@ class Discord(commands.Cog):
modal = CreateBookmark(self.client, message.jump_url) modal = CreateBookmark(self.client, message.jump_url)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@commands.hybrid_command(name="events") @commands.hybrid_command(name="events") # type: ignore[arg-type]
@app_commands.rename(event_id="id") @app_commands.rename(event_id="id")
@app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.") @app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.")
async def events(self, ctx: commands.Context, event_id: Optional[int] = None): async def events(self, ctx: commands.Context, event_id: Optional[int] = None):
@ -276,16 +292,16 @@ class Discord(commands.Cog):
embed.add_field( embed.add_field(
name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True
) )
embed.add_field(
name="Channel", channel = self.client.get_channel(result_event.notification_channel)
value=self.client.get_channel(result_event.notification_channel).mention, if channel is not None and not isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
inline=False, embed.add_field(name="Channel", value=channel.mention, inline=False)
)
embed.description = result_event.description embed.description = result_event.description
return await ctx.reply(embed=embed, mention_author=False) return await ctx.reply(embed=embed, mention_author=False)
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): async def github_group(self, ctx: commands.Context, user: Optional[Union[discord.User, discord.Member]] = None):
"""Show a user's GitHub links. """Show a user's GitHub links.
If no user is provided, this shows your links instead. If no user is provided, this shows your links instead.
@ -293,6 +309,9 @@ class Discord(commands.Cog):
# Default to author # Default to author
user = user or ctx.author user = user or ctx.author
# Please Mypy
user = cast(Union[discord.User, discord.Member], user)
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
embed.set_author(name=user.display_name, icon_url=get_user_avatar(user)) embed.set_author(name=user.display_name, icon_url=get_user_avatar(user))
@ -324,7 +343,7 @@ class Discord(commands.Cog):
return await ctx.reply(embed=embed, mention_author=False) return await ctx.reply(embed=embed, mention_author=False)
@github_group.command(name="add", aliases=["create", "insert"]) @github_group.command(name="add", aliases=["create", "insert"]) # type: ignore[arg-type]
async def github_add(self, ctx: commands.Context, link: str): async def github_add(self, ctx: commands.Context, link: str):
"""Add a new link into the database.""" """Add a new link into the database."""
# Remove wrapping <brackets> which can be used to escape Discord embeds # Remove wrapping <brackets> which can be used to escape Discord embeds
@ -339,7 +358,7 @@ class Discord(commands.Cog):
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False) return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False)
@github_group.command(name="delete", aliases=["del", "remove", "rm"]) @github_group.command(name="delete", aliases=["del", "remove", "rm"]) # type: ignore[arg-type]
async def github_delete(self, ctx: commands.Context, link_id: str): async def github_delete(self, ctx: commands.Context, link_id: str):
"""Delete the link with it `link_id` from the database. """Delete the link with it `link_id` from the database.
@ -411,7 +430,7 @@ class Discord(commands.Cog):
await message.add_reaction("📌") await message.add_reaction("📌")
return await interaction.response.send_message("📌", ephemeral=True) return await interaction.response.send_message("📌", ephemeral=True)
@commands.hybrid_command(name="snipe") @commands.hybrid_command(name="snipe") # type: ignore[arg-type]
async def snipe(self, ctx: commands.Context): async def snipe(self, ctx: commands.Context):
"""Publicly shame people when they edit or delete one of their messages. """Publicly shame people when they edit or delete one of their messages.
@ -420,7 +439,7 @@ class Discord(commands.Cog):
if ctx.guild is None: if ctx.guild is None:
return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True) return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True)
sniped_data = self.client.sniped.get(ctx.channel.id, None) sniped_data = self.client.sniped.get(ctx.channel.id)
if sniped_data is None: if sniped_data is None:
return await ctx.reply( return await ctx.reply(
"There's no one to make fun of in this channel.", mention_author=False, ephemeral=True "There's no one to make fun of in this channel.", mention_author=False, ephemeral=True

View File

@ -28,7 +28,7 @@ class Fun(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.hybrid_command(name="clap") @commands.hybrid_command(name="clap") # type: ignore[arg-type]
async def clap(self, ctx: commands.Context, *, text: str): async def clap(self, ctx: commands.Context, *, text: str):
"""Clap a message with emojis for extra dramatic effect""" """Clap a message with emojis for extra dramatic effect"""
chars = list(filter(lambda c: c in constants.EMOJI_MAP, text)) chars = list(filter(lambda c: c in constants.EMOJI_MAP, text))
@ -50,10 +50,7 @@ class Fun(commands.Cog):
meme = await generate_meme(self.client.http_session, result, fields) meme = await generate_meme(self.client.http_session, result, fields)
return meme return meme
@commands.hybrid_command( @commands.hybrid_command(name="dadjoke", aliases=["dad", "dj"]) # type: ignore[arg-type]
name="dadjoke",
aliases=["dad", "dj"],
)
async def dad_joke(self, ctx: commands.Context): async def dad_joke(self, ctx: commands.Context):
"""Why does Yoda's code always crash? Because there is no try.""" """Why does Yoda's code always crash? Because there is no try."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -83,13 +80,13 @@ class Fun(commands.Cog):
return await self.memegen_ls_msg(ctx) return await self.memegen_ls_msg(ctx)
if fields is None: if fields is None:
return await self.memegen_preview_msg(ctx, template) return await self.memegen_preview_msg(ctx, template) # type: ignore[arg-type]
async with ctx.typing(): async with ctx.typing():
meme = await self._do_generate_meme(template, shlex.split(fields)) meme = await self._do_generate_meme(template, shlex.split(fields))
return await ctx.reply(meme, mention_author=False) return await ctx.reply(meme, mention_author=False)
@memegen_msg.command(name="list", aliases=["ls"]) @memegen_msg.command(name="list", aliases=["ls"]) # type: ignore[arg-type]
async def memegen_ls_msg(self, ctx: commands.Context): async def memegen_ls_msg(self, ctx: commands.Context):
"""Get a list of all available meme templates. """Get a list of all available meme templates.
@ -100,14 +97,14 @@ class Fun(commands.Cog):
await MemeSource(ctx, results).start() await MemeSource(ctx, results).start()
@memegen_msg.command(name="preview", aliases=["p"]) @memegen_msg.command(name="preview", aliases=["p"]) # type: ignore[arg-type]
async def memegen_preview_msg(self, ctx: commands.Context, template: str): async def memegen_preview_msg(self, ctx: commands.Context, template: str):
"""Generate a preview for the meme template `template`, to see how the fields are structured.""" """Generate a preview for the meme template `template`, to see how the fields are structured."""
async with ctx.typing(): async with ctx.typing():
meme = await self._do_generate_meme(template, []) meme = await self._do_generate_meme(template, [])
return await ctx.reply(meme, mention_author=False) return await ctx.reply(meme, mention_author=False)
@memes_slash.command(name="generate") @memes_slash.command(name="generate") # type: ignore[arg-type]
async def memegen_slash(self, interaction: discord.Interaction, template: str): async def memegen_slash(self, interaction: discord.Interaction, template: str):
"""Generate a meme.""" """Generate a meme."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -116,7 +113,7 @@ class Fun(commands.Cog):
modal = GenerateMeme(self.client, result) modal = GenerateMeme(self.client, result)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@memes_slash.command(name="preview") @memes_slash.command(name="preview") # type: ignore[arg-type]
@app_commands.describe(template="The meme template to use in the preview.") @app_commands.describe(template="The meme template to use in the preview.")
async def memegen_preview_slash(self, interaction: discord.Interaction, template: str): async def memegen_preview_slash(self, interaction: discord.Interaction, template: str):
"""Generate a preview for a meme, to see how the fields are structured.""" """Generate a preview for a meme, to see how the fields are structured."""
@ -134,7 +131,7 @@ class Fun(commands.Cog):
"""Autocompletion for the 'template'-parameter""" """Autocompletion for the 'template'-parameter"""
return self.client.database_caches.memes.get_autocomplete_suggestions(current) return self.client.database_caches.memes.get_autocomplete_suggestions(current)
@app_commands.command() @app_commands.command() # type: ignore[arg-type]
@app_commands.describe(message="The text to convert.") @app_commands.describe(message="The text to convert.")
async def mock(self, interaction: discord.Interaction, message: str): async def mock(self, interaction: discord.Interaction, message: str):
"""Mock a message. """Mock a message.
@ -158,7 +155,7 @@ class Fun(commands.Cog):
return await interaction.followup.send(mock(message)) return await interaction.followup.send(mock(message))
@commands.hybrid_command(name="xkcd") @commands.hybrid_command(name="xkcd") # type: ignore[arg-type]
@app_commands.rename(comic_id="id") @app_commands.rename(comic_id="id")
async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None): async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None):
"""Fetch comic `#id` from xkcd. """Fetch comic `#id` from xkcd.

View File

@ -1,14 +1,6 @@
from typing import Optional
import discord
from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.constants import WORDLE_WORD_LENGTH
from database.crud.wordle import get_wordle_guesses, make_wordle_guess
from database.crud.wordle_stats import complete_wordle_game
from didier import Didier from didier import Didier
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed, is_wordle_game_over
class Games(commands.Cog): class Games(commands.Cog):
@ -19,53 +11,6 @@ class Games(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@app_commands.command(name="wordle", description="Play Wordle!")
async def wordle(self, interaction: discord.Interaction, guess: Optional[str] = None):
"""View your active Wordle game
If an argument is provided, make a guess instead
"""
await interaction.response.defer(ephemeral=True)
# Guess is wrong length
if guess is not None and len(guess) != 0 and len(guess) != WORDLE_WORD_LENGTH:
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
return await interaction.followup.send(embed=embed)
word_instance = self.client.database_caches.wordle_word.word
async with self.client.postgres_session as session:
guesses = await get_wordle_guesses(session, interaction.user.id)
# Trying to guess with a complete game
if is_wordle_game_over(guesses, word_instance.word):
embed = WordleErrorEmbed(
message="You've already completed today's Wordle.\nTry again tomorrow!"
).to_embed()
return await interaction.followup.send(embed=embed)
# Make a guess
if guess:
# The guess is not a real word
if guess.lower() not in self.client.wordle_words:
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
return await interaction.followup.send(embed=embed)
guess = guess.lower()
await make_wordle_guess(session, interaction.user.id, guess)
# Don't re-request the game, we already have it
# just append locally
guesses.append(guess)
embed = WordleEmbed(guesses=guesses, word=word_instance).to_embed()
await interaction.followup.send(embed=embed)
# After responding to the interaction: update stats in the background
game_over = is_wordle_game_over(guesses, word_instance.word)
if game_over:
await complete_wordle_game(session, interaction.user.id, word_instance.word in guesses)
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog""" """Load the cog"""

View File

@ -159,6 +159,9 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
Code in codeblocks is ignored, as it is used to create examples. Code in codeblocks is ignored, as it is used to create examples.
""" """
description = command.help description = command.help
if description is None:
return ""
codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL) codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL)
# Regex borrowed from https://stackoverflow.com/a/59843498/13568999 # Regex borrowed from https://stackoverflow.com/a/59843498/13568999
@ -198,13 +201,10 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
return None return None
async def _filter_cogs(self, cogs: list[commands.Cog]) -> list[commands.Cog]: async def _filter_cogs(self, cogs: list[Optional[commands.Cog]]) -> list[commands.Cog]:
"""Filter the list of cogs down to all those that the user can see""" """Filter the list of cogs down to all those that the user can see"""
async def _predicate(cog: Optional[commands.Cog]) -> bool: async def _predicate(cog: commands.Cog) -> bool:
if cog is None:
return False
# Remove cogs that we never want to see in the help page because they # Remove cogs that we never want to see in the help page because they
# don't contain commands, or shouldn't be visible at all # don't contain commands, or shouldn't be visible at all
if not cog.get_commands(): if not cog.get_commands():
@ -220,12 +220,12 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
return True return True
# Filter list of cogs down # Filter list of cogs down
filtered_cogs = [cog for cog in cogs if await _predicate(cog)] filtered_cogs = [cog for cog in cogs if cog is not None and await _predicate(cog)]
return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name)) return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name))
def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]: def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]:
"""Check if a command has flags""" """Check if a command has flags"""
flag_param = command.params.get("flags", None) flag_param = command.params.get("flags")
if flag_param is None: if flag_param is None:
return None return None

View File

@ -1,6 +1,6 @@
import inspect import inspect
import os import os
from typing import Optional from typing import Any, Optional, Union
from discord.ext import commands from discord.ext import commands
@ -76,18 +76,24 @@ class Meta(commands.Cog):
if command_name is None: if command_name is None:
return await ctx.reply(repo_home, mention_author=False) return await ctx.reply(repo_home, mention_author=False)
command: Optional[Union[commands.HelpCommand, commands.Command]]
src: Any
if command_name == "help": if command_name == "help":
command = self.client.help_command command = self.client.help_command
if command is None:
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
src = type(self.client.help_command) src = type(self.client.help_command)
filename = inspect.getsourcefile(src) filename = inspect.getsourcefile(src)
else: else:
command = self.client.get_command(command_name) command = self.client.get_command(command_name)
if command is None:
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
src = command.callback.__code__ src = command.callback.__code__
filename = src.co_filename filename = src.co_filename
if command is None:
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
lines, first_line = inspect.getsourcelines(src) lines, first_line = inspect.getsourcelines(src)
if filename is None: if filename is None:

View File

@ -22,7 +22,7 @@ class Other(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.hybrid_command(name="corona", aliases=["covid", "rona"]) @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) # type: ignore[arg-type]
async def covid(self, ctx: commands.Context, country: str = "Belgium"): async def covid(self, ctx: commands.Context, country: str = "Belgium"):
"""Show Covid-19 info for a specific country. """Show Covid-19 info for a specific country.
@ -43,7 +43,7 @@ class Other(commands.Cog):
"""Autocompletion for the 'country'-parameter""" """Autocompletion for the 'country'-parameter"""
return autocomplete_country(value)[:25] return autocomplete_country(value)[:25]
@commands.hybrid_command( @commands.hybrid_command( # type: ignore[arg-type]
name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary" name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary"
) )
async def define(self, ctx: commands.Context, *, query: str): async def define(self, ctx: commands.Context, *, query: str):
@ -55,7 +55,7 @@ class Other(commands.Cog):
mention_author=False, mention_author=False,
) )
@commands.hybrid_command(name="google", description="Google search") @commands.hybrid_command(name="google", description="Google search") # type: ignore[arg-type]
@app_commands.describe(query="Search query") @app_commands.describe(query="Search query")
async def google(self, ctx: commands.Context, *, query: str): async def google(self, ctx: commands.Context, *, query: str):
"""Show the Google search results for `query`. """Show the Google search results for `query`.
@ -71,7 +71,7 @@ class Other(commands.Cog):
embed = GoogleSearch(results).to_embed() embed = GoogleSearch(results).to_embed()
await ctx.reply(embed=embed, mention_author=False) await ctx.reply(embed=embed, mention_author=False)
@commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") # type: ignore[arg-type]
async def inspire(self, ctx: commands.Context): async def inspire(self, ctx: commands.Context):
"""Generate an [InspiroBot](https://inspirobot.me/) quote.""" """Generate an [InspiroBot](https://inspirobot.me/) quote."""
async with ctx.typing(): async with ctx.typing():
@ -82,7 +82,7 @@ class Other(commands.Cog):
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
return await get_link_by_name(session, name.lower()) return await get_link_by_name(session, name.lower())
@commands.command(name="Link", aliases=["Links"]) @commands.command(name="Link", aliases=["Links"]) # type: ignore[arg-type]
async def link_msg(self, ctx: commands.Context, name: str): async def link_msg(self, ctx: commands.Context, name: str):
"""Get the link to the resource named `name`.""" """Get the link to the resource named `name`."""
link = await self._get_link(name) link = await self._get_link(name)
@ -92,7 +92,7 @@ class Other(commands.Cog):
target_message = await self.client.get_reply_target(ctx) target_message = await self.client.get_reply_target(ctx)
await target_message.reply(link.url, mention_author=False) await target_message.reply(link.url, mention_author=False)
@app_commands.command(name="link") @app_commands.command(name="link") # type: ignore[arg-type]
@app_commands.describe(name="The name of the resource") @app_commands.describe(name="The name of the resource")
async def link_slash(self, interaction: discord.Interaction, name: str): async def link_slash(self, interaction: discord.Interaction, name: str):
"""Get the link to something.""" """Get the link to something."""

View File

@ -42,7 +42,7 @@ class Owner(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
async def cog_check(self, ctx: commands.Context) -> bool: async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override]
"""Global check for every command in this cog """Global check for every command in this cog
This means that we don't have to add is_owner() to every single command separately This means that we don't have to add is_owner() to every single command separately
@ -102,7 +102,7 @@ class Owner(commands.Cog):
async def add_msg(self, ctx: commands.Context): async def add_msg(self, ctx: commands.Context):
"""Command group for [add X] message commands""" """Command group for [add X] message commands"""
@add_msg.command(name="Alias") @add_msg.command(name="Alias") # type: ignore[arg-type]
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command""" """Add a new alias for a custom command"""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -116,7 +116,7 @@ class Owner(commands.Cog):
await ctx.reply("There is already a command with this name.") await ctx.reply("There is already a command with this name.")
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@add_msg.command(name="Custom") @add_msg.command(name="Custom") # type: ignore[arg-type]
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command""" """Add a new custom command"""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -127,7 +127,7 @@ class Owner(commands.Cog):
await ctx.reply("There is already a command with this name.") await ctx.reply("There is already a command with this name.")
await self.client.reject_message(ctx.message) await self.client.reject_message(ctx.message)
@add_msg.command(name="Link") @add_msg.command(name="Link") # type: ignore[arg-type]
async def add_link_msg(self, ctx: commands.Context, name: str, url: str): async def add_link_msg(self, ctx: commands.Context, name: str, url: str):
"""Add a new link""" """Add a new link"""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -136,7 +136,7 @@ class Owner(commands.Cog):
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
@add_slash.command(name="custom", description="Add a custom command") @add_slash.command(name="custom", description="Add a custom command") # type: ignore[arg-type]
async def add_custom_slash(self, interaction: discord.Interaction): async def add_custom_slash(self, interaction: discord.Interaction):
"""Slash command to add a custom command""" """Slash command to add a custom command"""
if not await self.client.is_owner(interaction.user): if not await self.client.is_owner(interaction.user):
@ -145,7 +145,7 @@ class Owner(commands.Cog):
modal = CreateCustomCommand(self.client) modal = CreateCustomCommand(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@add_slash.command(name="dadjoke", description="Add a dad joke") @add_slash.command(name="dadjoke", description="Add a dad joke") # type: ignore[arg-type]
async def add_dad_joke_slash(self, interaction: discord.Interaction): async def add_dad_joke_slash(self, interaction: discord.Interaction):
"""Slash command to add a dad joke""" """Slash command to add a dad joke"""
if not await self.client.is_owner(interaction.user): if not await self.client.is_owner(interaction.user):
@ -154,7 +154,7 @@ class Owner(commands.Cog):
modal = AddDadJoke(self.client) modal = AddDadJoke(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@add_slash.command(name="deadline", description="Add a deadline") @add_slash.command(name="deadline", description="Add a deadline") # type: ignore[arg-type]
@app_commands.describe(course="The name of the course to add a deadline for (aliases work too)") @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)")
async def add_deadline_slash(self, interaction: discord.Interaction, course: str): async def add_deadline_slash(self, interaction: discord.Interaction, course: str):
"""Slash command to add a deadline""" """Slash command to add a deadline"""
@ -174,7 +174,7 @@ class Owner(commands.Cog):
"""Autocompletion for the 'course'-parameter""" """Autocompletion for the 'course'-parameter"""
return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
@add_slash.command(name="event", description="Add a new event") @add_slash.command(name="event", description="Add a new event") # type: ignore[arg-type]
async def add_event_slash(self, interaction: discord.Interaction): async def add_event_slash(self, interaction: discord.Interaction):
"""Slash command to add new events""" """Slash command to add new events"""
if not await self.client.is_owner(interaction.user): if not await self.client.is_owner(interaction.user):
@ -183,7 +183,7 @@ class Owner(commands.Cog):
modal = AddEvent(self.client) modal = AddEvent(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@add_slash.command(name="link", description="Add a new link") @add_slash.command(name="link", description="Add a new link") # type: ignore[arg-type]
async def add_link_slash(self, interaction: discord.Interaction): async def add_link_slash(self, interaction: discord.Interaction):
"""Slash command to add new links""" """Slash command to add new links"""
if not await self.client.is_owner(interaction.user): if not await self.client.is_owner(interaction.user):
@ -192,7 +192,7 @@ class Owner(commands.Cog):
modal = AddLink(self.client) modal = AddLink(self.client)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@add_slash.command(name="meme", description="Add a new meme") @add_slash.command(name="meme", description="Add a new meme") # type: ignore[arg-type]
async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int):
"""Slash command to add new memes""" """Slash command to add new memes"""
await interaction.response.defer(ephemeral=True) await interaction.response.defer(ephemeral=True)
@ -205,11 +205,11 @@ class Owner(commands.Cog):
await interaction.followup.send(f"Added meme `{meme.meme_id}`.") await interaction.followup.send(f"Added meme `{meme.meme_id}`.")
await self.client.database_caches.memes.invalidate(session) await self.client.database_caches.memes.invalidate(session)
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) # type: ignore[arg-type]
async def edit_msg(self, ctx: commands.Context): async def edit_msg(self, ctx: commands.Context):
"""Command group for [edit X] commands""" """Command group for [edit X] commands"""
@edit_msg.command(name="Custom") @edit_msg.command(name="Custom") # type: ignore[arg-type]
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
"""Edit an existing custom command""" """Edit an existing custom command"""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
@ -220,7 +220,7 @@ class Owner(commands.Cog):
await ctx.reply(f"No command found matching `{command}`.") await ctx.reply(f"No command found matching `{command}`.")
return await self.client.reject_message(ctx.message) return await self.client.reject_message(ctx.message)
@edit_slash.command(name="custom", description="Edit a custom command") @edit_slash.command(name="custom", description="Edit a custom command") # type: ignore[arg-type]
@app_commands.describe(command="The name of the command to edit") @app_commands.describe(command="The name of the command to edit")
async def edit_custom_slash(self, interaction: discord.Interaction, command: str): async def edit_custom_slash(self, interaction: discord.Interaction, command: str):
"""Slash command to edit a custom command""" """Slash command to edit a custom command"""

View File

@ -27,7 +27,7 @@ class School(commands.Cog):
def __init__(self, client: Didier): def __init__(self, client: Didier):
self.client = client self.client = client
@commands.hybrid_command(name="deadlines") @commands.hybrid_command(name="deadlines") # type: ignore[arg-type]
async def deadlines(self, ctx: commands.Context): async def deadlines(self, ctx: commands.Context):
"""Show upcoming deadlines.""" """Show upcoming deadlines."""
async with ctx.typing(): async with ctx.typing():
@ -40,7 +40,7 @@ class School(commands.Cog):
embed = Deadlines(deadlines).to_embed() embed = Deadlines(deadlines).to_embed()
await ctx.reply(embed=embed, mention_author=False, ephemeral=False) await ctx.reply(embed=embed, mention_author=False, ephemeral=False)
@commands.hybrid_command(name="les", aliases=["sched", "schedule"]) @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) # type: ignore[arg-type]
@app_commands.rename(day_dt="date") @app_commands.rename(day_dt="date")
async def les( async def les(
self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None
@ -72,10 +72,7 @@ class School(commands.Cog):
except NotInMainGuildException: except NotInMainGuildException:
return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False) return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False)
@commands.hybrid_command( @commands.hybrid_command(name="menu", aliases=["eten", "food"]) # type: ignore[arg-type]
name="menu",
aliases=["eten", "food"],
)
@app_commands.rename(day_dt="date") @app_commands.rename(day_dt="date")
async def menu( async def menu(
self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None
@ -96,7 +93,7 @@ class School(commands.Cog):
embed = no_menu_found(day_dt) embed = no_menu_found(day_dt)
await ctx.reply(embed=embed, mention_author=False) await ctx.reply(embed=embed, mention_author=False)
@commands.hybrid_command( @commands.hybrid_command( # type: ignore[arg-type]
name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"] name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"]
) )
@app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)") @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)")
@ -124,7 +121,7 @@ class School(commands.Cog):
mention_author=False, mention_author=False,
) )
@commands.hybrid_command(name="ufora") @commands.hybrid_command(name="ufora") # type: ignore[arg-type]
async def ufora(self, ctx: commands.Context, course: str): async def ufora(self, ctx: commands.Context, course: str):
"""Link the Ufora page for a course.""" """Link the Ufora page for a course."""
async with self.client.postgres_session as session: async with self.client.postgres_session as session:

View File

@ -1,4 +1,6 @@
import asyncio
import datetime import datetime
import logging
import random import random
import discord import discord
@ -10,7 +12,6 @@ from database import enums
from database.crud.birthdays import get_birthdays_on_day from database.crud.birthdays import get_birthdays_on_day
from database.crud.reminders import get_all_reminders_for_category from database.crud.reminders import get_all_reminders_for_category
from database.crud.ufora_announcements import remove_old_announcements from database.crud.ufora_announcements import remove_old_announcements
from database.crud.wordle import set_daily_word
from database.schemas import Reminder from database.schemas import Reminder
from didier import Didier from didier import Didier
from didier.data.embeds.schedules import ( from didier.data.embeds.schedules import (
@ -21,9 +22,12 @@ from didier.data.embeds.schedules import (
from didier.data.rss_feeds.free_games import fetch_free_games from didier.data.rss_feeds.free_games import fetch_free_games
from didier.data.rss_feeds.ufora import fetch_ufora_announcements from didier.data.rss_feeds.ufora import fetch_ufora_announcements
from didier.decorators.tasks import timed_task from didier.decorators.tasks import timed_task
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
from didier.utils.discord.checks import is_owner from didier.utils.discord.checks import is_owner
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
logger = logging.getLogger(__name__)
# datetime.time()-instances for when every task should run # datetime.time()-instances for when every task should run
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
@ -54,11 +58,10 @@ class Tasks(commands.Cog):
"reminders": self.reminders, "reminders": self.reminders,
"ufora": self.pull_ufora_announcements, "ufora": self.pull_ufora_announcements,
"remove_ufora": self.remove_old_ufora_announcements, "remove_ufora": self.remove_old_ufora_announcements,
"wordle": self.reset_wordle_word,
} }
@overrides @overrides
def cog_load(self) -> None: async def cog_load(self) -> None:
# Only check birthdays if there's a channel to send it to # Only check birthdays if there's a channel to send it to
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None: if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
self.check_birthdays.start() self.check_birthdays.start()
@ -74,10 +77,10 @@ class Tasks(commands.Cog):
# Start other tasks # Start other tasks
self.reminders.start() self.reminders.start()
self.reset_wordle_word.start() asyncio.create_task(self.get_error_channel())
@overrides @overrides
def cog_unload(self) -> None: async def cog_unload(self) -> None:
# Cancel all pending tasks # Cancel all pending tasks
for task in self._tasks.values(): for task in self._tasks.values():
if task.is_running(): if task.is_running():
@ -99,7 +102,7 @@ class Tasks(commands.Cog):
await ctx.reply(embed=embed, mention_author=False) await ctx.reply(embed=embed, mention_author=False)
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") # type: ignore[arg-type]
async def force_task(self, ctx: commands.Context, name: str): async def force_task(self, ctx: commands.Context, name: str):
"""Command to force-run a task without waiting for the specified run time""" """Command to force-run a task without waiting for the specified run time"""
name = name.lower() name = name.lower()
@ -110,23 +113,53 @@ class Tasks(commands.Cog):
await task(forced=True) await task(forced=True)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
async def get_error_channel(self):
"""Get the configured channel from the cache"""
await self.client.wait_until_ready()
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
channel = self.client.get_channel(settings.ERRORS_CHANNEL)
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
logger.error(f"Configured error channel (id `{settings.ERRORS_CHANNEL}`) is not messageable.")
else:
self.client.error_channel = channel
elif self.client.owner_id is not None:
self.client.error_channel = self.client.get_user(self.client.owner_id)
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
@timed_task(enums.TaskType.BIRTHDAYS) @timed_task(enums.TaskType.BIRTHDAYS)
async def check_birthdays(self, **kwargs): async def check_birthdays(self, **kwargs):
"""Check if it's currently anyone's birthday""" """Check if it's currently anyone's birthday"""
_ = kwargs _ = kwargs
# Can't happen (task isn't started if this is None), but Mypy doesn't know
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is None:
return
now = tz_aware_now().date() now = tz_aware_now().date()
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
birthdays = await get_birthdays_on_day(session, now) birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
if channel is None: if channel is None:
return await self.client.log_error("Unable to find channel for birthday announcements") return await self.client.log_error("Unable to fetch channel for birthday announcements.")
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Birthday announcement channel (id `{settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL}`) is not messageable."
)
for birthday in birthdays: for birthday in birthdays:
user = self.client.get_user(birthday.user_id) user = self.client.get_user(birthday.user_id)
if user is None:
await self.client.log_error(
f"Unable to fetch user with id `{birthday.user_id}` for birthday announcement"
)
continue
await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention)) await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention))
@check_birthdays.before_loop @check_birthdays.before_loop
@ -146,6 +179,14 @@ class Tasks(commands.Cog):
games = await fetch_free_games(self.client.http_session, session) games = await fetch_free_games(self.client.http_session, session)
channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL) channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL)
if channel is None:
return await self.client.log_error("Unable to fetch channel for free games announcements.")
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Free games channel (id `{settings.FREE_GAMES_CHANNEL}`) is not messageable."
)
for game in games: for game in games:
await channel.send(embed=game.to_embed()) await channel.send(embed=game.to_embed())
@ -207,6 +248,17 @@ class Tasks(commands.Cog):
async with self.client.postgres_session as db_session: async with self.client.postgres_session as db_session:
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
if announcements_channel is None:
return await self.client.log_error(
f"Unable to fetch channel for ufora announcements (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`)."
)
if isinstance(announcements_channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Ufora announcements channel (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`) is not messageable."
)
announcements = await fetch_ufora_announcements(self.client.http_session, db_session) announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
for announcement in announcements: for announcement in announcements:
@ -266,34 +318,16 @@ class Tasks(commands.Cog):
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
await remove_old_announcements(session) await remove_old_announcements(session)
@tasks.loop(time=DAILY_RESET_TIME)
async def reset_wordle_word(self, forced: bool = False):
"""Reset the daily Wordle word"""
async with self.client.postgres_session as session:
await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced)
await self.client.database_caches.wordle_word.invalidate(session)
@reset_wordle_word.before_loop
async def _before_reset_wordle_word(self):
await self.client.wait_until_ready()
@check_birthdays.error @check_birthdays.error
@pull_schedules.error @pull_schedules.error
@pull_ufora_announcements.error @pull_ufora_announcements.error
@reminders.error @reminders.error
@remove_old_ufora_announcements.error @remove_old_ufora_announcements.error
@reset_wordle_word.error
async def _on_tasks_error(self, error: BaseException): async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks""" """Error handler for all tasks"""
self.client.dispatch("task_error", error) self.client.dispatch("task_error", error)
async def setup(client: Didier): async def setup(client: Didier):
"""Load the cog """Load the cog"""
await client.add_cog(Tasks(client))
Initially fetch the wordle word from the database, or reset it
if there hasn't been a reset yet today
"""
cog = Tasks(client)
await client.add_cog(cog)
await cog.reset_wordle_word()

View File

@ -19,7 +19,7 @@ async def get_country_info(http_session: ClientSession, country: str) -> CovidDa
yesterday = response yesterday = response
data = {"today": today, "yesterday": yesterday} data = {"today": today, "yesterday": yesterday}
return CovidData.parse_obj(data) return CovidData.model_validate(data)
async def get_global_info(http_session: ClientSession) -> CovidData: async def get_global_info(http_session: ClientSession) -> CovidData:
@ -35,4 +35,4 @@ async def get_global_info(http_session: ClientSession) -> CovidData:
yesterday = response yesterday = response
data = {"today": today, "yesterday": yesterday} data = {"today": today, "yesterday": yesterday}
return CovidData.parse_obj(data) return CovidData.model_validate(data)

View File

@ -12,4 +12,4 @@ async def fetch_menu(http_session: ClientSession, day_dt: date) -> Menu:
"""Fetch the menu for a given day""" """Fetch the menu for a given day"""
endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json" endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json"
async with ensure_get(http_session, endpoint, log_exceptions=False) as response: async with ensure_get(http_session, endpoint, log_exceptions=False) as response:
return Menu.parse_obj(response) return Menu.model_validate(response)

View File

@ -14,4 +14,4 @@ async def lookup(http_session: ClientSession, query: str) -> list[Definition]:
url = "https://api.urbandictionary.com/v0/define" url = "https://api.urbandictionary.com/v0/define"
async with ensure_get(http_session, url, params={"term": query}) as response: async with ensure_get(http_session, url, params={"term": query}) as response:
return list(map(Definition.parse_obj, response["list"])) return list(map(Definition.model_validate, response["list"]))

View File

@ -13,4 +13,4 @@ async def fetch_xkcd_post(http_session: ClientSession, *, num: Optional[int] = N
url = "https://xkcd.com" + (f"/{num}" if num is not None else "") + "/info.0.json" url = "https://xkcd.com" + (f"/{num}" if num is not None else "") + "/info.0.json"
async with ensure_get(http_session, url) as response: async with ensure_get(http_session, url) as response:
return XKCDPost.parse_obj(response) return XKCDPost.model_validate(response)

View File

@ -1,6 +1,6 @@
import discord import discord
from overrides import overrides from overrides import overrides
from pydantic import BaseModel, Field, validator from pydantic import BaseModel, Field, field_validator
from didier.data.embeds.base import EmbedPydantic from didier.data.embeds.base import EmbedPydantic
@ -24,7 +24,7 @@ class _CovidNumbers(BaseModel):
active: int active: int
tests: int tests: int
@validator("updated") @field_validator("updated")
def updated_to_seconds(cls, value: int) -> int: def updated_to_seconds(cls, value: int) -> int:
"""Turn the updated field into seconds instead of milliseconds""" """Turn the updated field into seconds instead of milliseconds"""
return int(value) // 1000 return int(value) // 1000

View File

@ -38,10 +38,10 @@ def create_error_embed(ctx: Optional[commands.Context], exception: Exception) ->
embed = discord.Embed(title="Error", colour=discord.Colour.red()) embed = discord.Embed(title="Error", colour=discord.Colour.red())
if ctx is not None: if ctx is not None:
if ctx.guild is None: if ctx.guild is None or isinstance(ctx.channel, discord.DMChannel):
origin = "DM" origin = "DM"
else: else:
origin = f"{ctx.channel.mention} ({ctx.guild.name})" origin = f"<#{ctx.channel.id}> ({ctx.guild.name})"
invocation = f"{ctx.author.display_name} in {origin}" invocation = f"{ctx.author.display_name} in {origin}"

View File

@ -4,18 +4,17 @@ from typing import Optional
import discord import discord
from aiohttp import ClientSession from aiohttp import ClientSession
from overrides import overrides from overrides import overrides
from pydantic import validator from pydantic import field_validator
from didier.data.embeds.base import EmbedPydantic from didier.data.embeds.base import EmbedPydantic
from didier.data.scrapers.common import GameStorePage from didier.data.scrapers.common import GameStorePage
from didier.data.scrapers.steam import get_steam_webpage_info from didier.data.scrapers.steam import get_steam_webpage_info
from didier.utils.discord import colours from didier.utils.discord import colours
__all__ = ["SEPARATOR", "FreeGameEmbed"]
from didier.utils.discord.constants import Limits from didier.utils.discord.constants import Limits
from didier.utils.types.string import abbreviate from didier.utils.types.string import abbreviate
__all__ = ["SEPARATOR", "FreeGameEmbed"]
SEPARATOR = " • Free • " SEPARATOR = " • Free • "
@ -58,7 +57,7 @@ class FreeGameEmbed(EmbedPydantic):
store_page: Optional[GameStorePage] = None store_page: Optional[GameStorePage] = None
@validator("title") @field_validator("title")
def _clean_title(cls, value: str) -> str: def _clean_title(cls, value: str) -> str:
return html.unescape(value) return html.unescape(value)
@ -107,7 +106,6 @@ class FreeGameEmbed(EmbedPydantic):
embed.add_field(name="Open in browser", value=f"[{self.link}]({self.link})") embed.add_field(name="Open in browser", value=f"[{self.link}]({self.link})")
if self.store_page.xdg_open_url is not None: if self.store_page.xdg_open_url is not None:
embed.add_field( embed.add_field(
name="Open in app", value=f"[{self.store_page.xdg_open_url}]({self.store_page.xdg_open_url})" name="Open in app", value=f"[{self.store_page.xdg_open_url}]({self.store_page.xdg_open_url})"
) )

View File

@ -11,7 +11,7 @@ __all__ = ["create_logging_embed"]
def create_logging_embed(level: int, message: str) -> discord.Embed: def create_logging_embed(level: int, message: str) -> discord.Embed:
"""Create an embed to send to the logging channel""" """Create an embed to send to the logging channel"""
colours = { colours = {
logging.DEBUG: discord.Colour.light_gray, logging.DEBUG: discord.Colour.light_grey(),
logging.ERROR: discord.Colour.red(), logging.ERROR: discord.Colour.red(),
logging.INFO: discord.Colour.blue(), logging.INFO: discord.Colour.blue(),
logging.WARNING: discord.Colour.yellow(), logging.WARNING: discord.Colour.yellow(),

View File

@ -2,7 +2,7 @@ from datetime import datetime
import discord import discord
from overrides import overrides from overrides import overrides
from pydantic import validator from pydantic import field_validator
from didier.data.embeds.base import EmbedPydantic from didier.data.embeds.base import EmbedPydantic
from didier.utils.discord import colours from didier.utils.discord import colours
@ -39,8 +39,8 @@ class Definition(EmbedPydantic):
total_votes = self.thumbs_up + self.thumbs_down total_votes = self.thumbs_up + self.thumbs_down
return round(100 * self.thumbs_up / total_votes, 2) return round(100 * self.thumbs_up / total_votes, 2)
@validator("definition", "example") @field_validator("definition", "example")
def modify_long_text(cls, field): def modify_long_text(cls, field: str):
"""Remove brackets from fields & cut them off if they are too long""" """Remove brackets from fields & cut them off if they are too long"""
field = field.replace("[", "").replace("]", "") field = field.replace("[", "").replace("]", "")
return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH) return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH)

View File

@ -1,142 +0,0 @@
import enum
from dataclasses import dataclass
import discord
from overrides import overrides
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
from database.schemas import WordleWord
from didier.data.embeds.base import EmbedBaseModel
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
__all__ = ["is_wordle_game_over", "WordleEmbed", "WordleErrorEmbed"]
def is_wordle_game_over(guesses: list[str], word: str) -> bool:
"""Check if the current game is over or not"""
if not guesses:
return False
if len(guesses) == WORDLE_GUESS_COUNT:
return True
return word.lower() in guesses
def footer() -> str:
"""Create the footer to put on the embed"""
today = tz_aware_now()
return f"{int_to_weekday(today.weekday())} {today.strftime('%d/%m/%Y')}"
class WordleColour(enum.IntEnum):
"""Colours for the Wordle embed"""
EMPTY = 0
WRONG_LETTER = 1
WRONG_POSITION = 2
CORRECT = 3
@dataclass
class WordleEmbed(EmbedBaseModel):
"""Embed for a Wordle game"""
guesses: list[str]
word: WordleWord
def _letter_colour(self, guess: str, index: int) -> WordleColour:
"""Get the colour for a guess at a given position"""
if guess[index] == self.word.word[index]:
return WordleColour.CORRECT
wrong_letter = 0
wrong_position = 0
for i, letter in enumerate(self.word.word):
if letter == guess[index] and guess[i] != guess[index]:
wrong_letter += 1
if i <= index and guess[i] == guess[index] and letter != guess[index]:
wrong_position += 1
if i >= index:
if wrong_position == 0:
break
if wrong_position <= wrong_letter:
return WordleColour.WRONG_POSITION
return WordleColour.WRONG_LETTER
def _guess_colours(self, guess: str) -> list[WordleColour]:
"""Create the colour codes for a specific guess"""
return [self._letter_colour(guess, i) for i in range(WORDLE_WORD_LENGTH)]
def colour_code_game(self) -> list[list[WordleColour]]:
"""Create the colour codes for an entire game"""
colours = []
# Add all the guesses
for guess in self.guesses:
colours.append(self._guess_colours(guess))
# Fill the rest with empty spots
for _ in range(WORDLE_GUESS_COUNT - len(colours)):
colours.append([WordleColour.EMPTY] * WORDLE_WORD_LENGTH)
return colours
def _colours_to_emojis(self, colours: list[list[WordleColour]]) -> list[list[str]]:
"""Turn the colours of the board into Discord emojis"""
colour_map = {
WordleColour.EMPTY: ":white_large_square:",
WordleColour.WRONG_LETTER: ":black_large_square:",
WordleColour.WRONG_POSITION: ":orange_square:",
WordleColour.CORRECT: ":green_square:",
}
emojis = []
for row in colours:
emojis.append(list(map(lambda char: colour_map[char], row)))
return emojis
@overrides
def to_embed(self, **kwargs) -> discord.Embed:
only_colours = kwargs.get("only_colours", False)
colours = self.colour_code_game()
embed = discord.Embed(colour=discord.Colour.blue(), title=f"Wordle #{self.word.word_id + 1}")
emojis = self._colours_to_emojis(colours)
rows = [" ".join(row) for row in emojis]
# Don't reveal anything if we only want to show the colours
if not only_colours and self.guesses:
for i, guess in enumerate(self.guesses):
rows[i] += f" ||{guess.upper()}||"
# If the game is over, reveal the word
if is_wordle_game_over(self.guesses, self.word.word):
rows.append(f"\n\nThe word was **{self.word.word.upper()}**!")
embed.description = "\n\n".join(rows)
embed.set_footer(text=footer())
return embed
@dataclass
class WordleErrorEmbed(EmbedBaseModel):
"""Embed to send error messages to the user"""
message: str
@overrides
def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(colour=discord.Colour.red(), title="Wordle")
embed.description = self.message
embed.set_footer(text=footer())
return embed

View File

@ -32,7 +32,7 @@ async def fetch_free_games(http_session: ClientSession, database_session: AsyncS
if SEPARATOR not in entry["title"]: if SEPARATOR not in entry["title"]:
continue continue
game = FreeGameEmbed.parse_obj(entry) game = FreeGameEmbed.model_validate(entry)
games.append(game) games.append(game)
game_ids.append(game.dc_identifier) game_ids.append(game.dc_identifier)

View File

@ -72,12 +72,12 @@ def get_search_results(bs: BeautifulSoup) -> list[str]:
return list(dict.fromkeys(results)) return list(dict.fromkeys(results))
async def google_search(http_client: ClientSession, query: str): async def google_search(http_session: ClientSession, query: str):
"""Get the first 10 Google search results""" """Get the first 10 Google search results"""
query = urlencode({"q": query}) query = urlencode({"q": query})
# Request 20 results in case of duplicates, bad matches, ... # Request 20 results in case of duplicates, bad matches, ...
async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: async with http_session.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response:
# Something went wrong # Something went wrong
if response.status != http.HTTPStatus.OK: if response.status != http.HTTPStatus.OK:
return SearchData(query, response.status) return SearchData(query, response.status)

View File

@ -17,7 +17,7 @@ from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.error_embed import create_error_embed
from didier.data.embeds.logging_embed import create_logging_embed from didier.data.embeds.logging_embed import create_logging_embed
from didier.data.embeds.schedules import Schedule, parse_schedule from didier.data.embeds.schedules import Schedule, parse_schedule
from didier.exceptions import HTTPException, NoMatch from didier.exceptions import GetNoneException, HTTPException, NoMatch
from didier.utils.discord.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
from didier.utils.discord.snipe import should_snipe from didier.utils.discord.snipe import should_snipe
from didier.utils.easter_eggs import detect_easter_egg from didier.utils.easter_eggs import detect_easter_egg
@ -33,12 +33,11 @@ class Didier(commands.Bot):
"""DIDIER <3""" """DIDIER <3"""
database_caches: CacheManager database_caches: CacheManager
error_channel: discord.abc.Messageable error_channel: Optional[discord.abc.Messageable] = None
initial_extensions: tuple[str, ...] = () initial_extensions: tuple[str, ...] = ()
http_session: ClientSession http_session: ClientSession
schedules: dict[settings.ScheduleType, Schedule] = {} schedules: dict[settings.ScheduleType, Schedule] = {}
sniped: dict[int, tuple[discord.Message, Optional[discord.Message]]] = {} sniped: dict[int, tuple[discord.Message, Optional[discord.Message]]] = {}
wordle_words: set[str] = set()
def __init__(self): def __init__(self):
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE) activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
@ -57,12 +56,17 @@ class Didier(commands.Bot):
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
) )
self.tree.on_error = self.on_app_command_error # I'm not creating a custom tree, this is the way to do it
self.tree.on_error = self.on_app_command_error # type: ignore[method-assign]
@cached_property @cached_property
def main_guild(self) -> discord.Guild: def main_guild(self) -> discord.Guild:
"""Obtain a reference to the main guild""" """Obtain a reference to the main guild"""
return self.get_guild(settings.DISCORD_MAIN_GUILD) guild = self.get_guild(settings.DISCORD_MAIN_GUILD)
if guild is None:
raise GetNoneException("Main guild could not be found in the bot's cache")
return guild
@property @property
def postgres_session(self) -> AsyncSession: def postgres_session(self) -> AsyncSession:
@ -77,9 +81,6 @@ class Didier(commands.Bot):
# Create directories that are ignored on GitHub # Create directories that are ignored on GitHub
self._create_ignored_directories() self._create_ignored_directories()
# Load the Wordle dictionary
self._load_wordle_words()
# Initialize caches # Initialize caches
self.database_caches = CacheManager() self.database_caches = CacheManager()
async with self.postgres_session as session: async with self.postgres_session as session:
@ -97,12 +98,6 @@ class Didier(commands.Bot):
await self._load_initial_extensions() await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs") await self._load_directory_extensions("didier/cogs")
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
else:
self.error_channel = self.get_user(self.owner_id)
def _create_ignored_directories(self): def _create_ignored_directories(self):
"""Create directories that store ignored data""" """Create directories that store ignored data"""
ignored = ["files/schedules"] ignored = ["files/schedules"]
@ -137,12 +132,6 @@ class Didier(commands.Bot):
elif os.path.isdir(new_path := f"{path}/{file}"): elif os.path.isdir(new_path := f"{path}/{file}"):
await self._load_directory_extensions(new_path) await self._load_directory_extensions(new_path)
def _load_wordle_words(self):
"""Load the dictionary of Wordle words"""
with open("files/dictionaries/words-english-wordle.txt", "r") as fp:
for line in fp:
self.wordle_words.add(line.strip())
async def load_schedules(self): async def load_schedules(self):
"""Parse & load all schedules into memory""" """Parse & load all schedules into memory"""
self.schedules = {} self.schedules = {}
@ -162,18 +151,27 @@ class Didier(commands.Bot):
original message instead original message instead
""" """
if ctx.message.reference is not None: if ctx.message.reference is not None:
return await self.resolve_message(ctx.message.reference) return await self.resolve_message(ctx.message.reference) or ctx.message
return ctx.message return ctx.message
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: async def resolve_message(self, reference: discord.MessageReference) -> Optional[discord.Message]:
"""Fetch a message from a reference""" """Fetch a message from a reference"""
# Message is in the cache, return it # Message is in the cache, return it
if reference.cached_message is not None: if reference.cached_message is not None:
return reference.cached_message return reference.cached_message
if reference.message_id is None:
return None
# For older messages: fetch them from the API # For older messages: fetch them from the API
channel = self.get_channel(reference.channel_id) channel = self.get_channel(reference.channel_id)
if channel is None or isinstance(
channel,
(discord.CategoryChannel, discord.ForumChannel, discord.abc.PrivateChannel),
): # Logically this can't happen, but we have to please Mypy
return None
return await channel.fetch_message(reference.message_id) return await channel.fetch_message(reference.message_id)
async def confirm_message(self, message: discord.Message): async def confirm_message(self, message: discord.Message):
@ -194,7 +192,7 @@ class Didier(commands.Bot):
} }
methods.get(level, logger.error)(message) methods.get(level, logger.error)(message)
if log_to_discord: if log_to_discord and self.error_channel is not None:
embed = create_logging_embed(level, message) embed = create_logging_embed(level, message)
await self.error_channel.send(embed=embed) await self.error_channel.send(embed=embed)
@ -263,10 +261,9 @@ class Didier(commands.Bot):
await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True) await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True)
if settings.ERRORS_CHANNEL is not None: if self.error_channel is not None:
embed = create_error_embed(await commands.Context.from_interaction(interaction), exception) embed = create_error_embed(await commands.Context.from_interaction(interaction), exception)
channel = self.get_channel(settings.ERRORS_CHANNEL) await self.error_channel.send(embed=embed)
await channel.send(embed=embed)
async def on_command_completion(self, ctx: commands.Context): async def on_command_completion(self, ctx: commands.Context):
"""Event triggered when a message command completes successfully""" """Event triggered when a message command completes successfully"""
@ -291,7 +288,7 @@ class Didier(commands.Bot):
# Hybrid command errors are wrapped in an additional error, so wrap it back out # Hybrid command errors are wrapped in an additional error, so wrap it back out
if isinstance(exception, commands.HybridCommandError): if isinstance(exception, commands.HybridCommandError):
exception = exception.original exception = exception.original # type: ignore[assignment]
# Ignore exceptions that aren't important # Ignore exceptions that aren't important
if isinstance( if isinstance(
@ -342,10 +339,9 @@ class Didier(commands.Bot):
# Print everything that we care about to the logs/stderr # Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception) await super().on_command_error(ctx, exception)
if settings.ERRORS_CHANNEL is not None: if self.error_channel is not None:
embed = create_error_embed(ctx, exception) embed = create_error_embed(ctx, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL) await self.error_channel.send(embed=embed)
await channel.send(embed=embed)
async def on_message(self, message: discord.Message, /) -> None: async def on_message(self, message: discord.Message, /) -> None:
"""Event triggered when a message is sent""" """Event triggered when a message is sent"""
@ -354,7 +350,7 @@ class Didier(commands.Bot):
return return
# Boos react to people that say Dider # Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id: if "dider" in message.content.lower() and self.user is not None and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT) await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command # Potential custom command
@ -384,7 +380,7 @@ class Didier(commands.Bot):
# If the edited message is currently present in the snipe cache, # If the edited message is currently present in the snipe cache,
# don't update the <before>, but instead change the <after> # don't update the <before>, but instead change the <after>
existing = self.sniped.get(before.channel.id, None) existing = self.sniped.get(before.channel.id)
if existing is not None and existing[0].id == before.id: if existing is not None and existing[0].id == before.id:
before = existing[0] before = existing[0]
@ -399,10 +395,9 @@ class Didier(commands.Bot):
async def on_task_error(self, exception: Exception): async def on_task_error(self, exception: Exception):
"""Event triggered when a task raises an exception""" """Event triggered when a task raises an exception"""
if settings.ERRORS_CHANNEL is not None: if self.error_channel:
embed = create_error_embed(None, exception) embed = create_error_embed(None, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL) await self.error_channel.send(embed=embed)
await channel.send(embed=embed)
async def on_thread_create(self, thread: discord.Thread): async def on_thread_create(self, thread: discord.Thread):
"""Event triggered when a new thread is created""" """Event triggered when a new thread is created"""

View File

@ -1,6 +1,14 @@
from .get_none_exception import GetNoneException
from .http_exception import HTTPException from .http_exception import HTTPException
from .missing_env import MissingEnvironmentVariable from .missing_env import MissingEnvironmentVariable
from .no_match import NoMatch, expect from .no_match import NoMatch, expect
from .not_in_main_guild_exception import NotInMainGuildException from .not_in_main_guild_exception import NotInMainGuildException
__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect", "NotInMainGuildException"] __all__ = [
"GetNoneException",
"HTTPException",
"MissingEnvironmentVariable",
"NoMatch",
"expect",
"NotInMainGuildException",
]

View File

@ -0,0 +1,5 @@
__all__ = ["GetNoneException"]
class GetNoneException(RuntimeError):
"""Exception raised when a Bot.get()-method returned None"""

View File

@ -12,6 +12,6 @@ class NotInMainGuildException(ValueError):
def __init__(self, user: Union[discord.User, discord.Member]): def __init__(self, user: Union[discord.User, discord.Member]):
super().__init__( super().__init__(
f"User {user.display_name} (id {user.id}) " f"User {user.display_name} (id `{user.id}`) "
f"is not a member of the configured main guild (id {settings.DISCORD_MAIN_GUILD})." f"is not a member of the configured main guild (id `{settings.DISCORD_MAIN_GUILD}`)."
) )

View File

@ -0,0 +1,5 @@
import discord
__all__ = ["NON_MESSAGEABLE_CHANNEL_TYPES"]
NON_MESSAGEABLE_CHANNEL_TYPES = (discord.ForumChannel, discord.CategoryChannel, discord.abc.PrivateChannel)

View File

@ -15,11 +15,14 @@ def match_prefix(client: commands.Bot, message: Message) -> Optional[str]:
This is done dynamically through regexes to allow case-insensitivity This is done dynamically through regexes to allow case-insensitivity
and variable amounts of whitespace among other things. and variable amounts of whitespace among other things.
""" """
mention = f"<@!?{client.user.id}>" mention = f"<@!?{client.user.id}>" if client.user else None
regex = r"^({})\s*" regex = r"^({})\s*"
# Check which prefix was used # Check which prefix was used
for prefix in [*constants.PREFIXES, mention]: for prefix in [*constants.PREFIXES, mention]:
if prefix is None:
continue
match = re.match(regex.format(prefix), message.content, flags=re.I) match = re.match(regex.format(prefix), message.content, flags=re.I)
if match is not None: if match is not None:

View File

@ -25,24 +25,24 @@ class CreateBookmark(discord.ui.Modal, title="Create Bookmark"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction): async def on_submit(self, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True)
label = self.name.value.strip() label = self.name.value.strip()
try: try:
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
bm = await create_bookmark(session, interaction.user.id, label, self.jump_url) bm = await create_bookmark(session, interaction.user.id, label, self.jump_url)
return await interaction.response.send_message( return await interaction.followup.send(
f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`)."
) )
except DuplicateInsertException: except DuplicateInsertException:
# Label is already in use # Label is already in use
return await interaction.response.send_message( return await interaction.followup.send(f"You already have a bookmark named `{label}`.")
f"You already have a bookmark named `{label}`.", ephemeral=True
)
except ForbiddenNameException: except ForbiddenNameException:
# Label isn't allowed # Label isn't allowed
return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True) return await interaction.followup.send(f"Bookmarks cannot be named `{label}`.")
@overrides @overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True) await interaction.followup.send("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__) traceback.print_tb(error.__traceback__)

View File

@ -26,12 +26,14 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction): async def on_submit(self, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True)
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
joke = await add_dad_joke(session, str(self.joke.value)) joke = await add_dad_joke(session, str(self.joke.value))
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) await interaction.followup.send(f"Successfully added joke #{joke.dad_joke_id}")
@overrides @overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True) await interaction.followup.send("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__) traceback.print_tb(error.__traceback__)

View File

@ -10,6 +10,8 @@ from didier import Didier
__all__ = ["AddEvent"] __all__ = ["AddEvent"]
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
class AddEvent(discord.ui.Modal, title="Add Event"): class AddEvent(discord.ui.Modal, title="Add Event"):
"""Modal to add a new event""" """Modal to add a new event"""
@ -33,15 +35,20 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction) -> None: async def on_submit(self, interaction: discord.Interaction) -> None:
await interaction.response.defer(ephemeral=True)
try: try:
parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels"))
except ParserError: except ParserError:
return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True) return await interaction.followup.send("Unable to parse date argument.")
if self.client.get_channel(int(self.channel.value)) is None: channel = self.client.get_channel(int(self.channel.value))
return await interaction.response.send_message(
f"Unable to find channel `{self.channel.value}`", ephemeral=True if channel is None:
) return await interaction.followup.send(f"Unable to find channel with id `{self.channel.value}`")
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await interaction.followup.send(f"Channel with id `{self.channel.value}` is not messageable.")
async with self.client.postgres_session as session: async with self.client.postgres_session as session:
event = await add_event( event = await add_event(
@ -52,10 +59,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
channel_id=int(self.channel.value), channel_id=int(self.channel.value),
) )
await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) await interaction.followup.send(f"Successfully added event `{event.event_id}`.")
self.client.dispatch("event_create", event) self.client.dispatch("event_create", event)
@overrides @overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True) await interaction.followup.send("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__) traceback.print_tb(error.__traceback__)

File diff suppressed because it is too large Load Diff

View File

@ -36,6 +36,7 @@ def setup_logging():
# Configure discord handler # Configure discord handler
discord_log = logging.getLogger("discord") discord_log = logging.getLogger("discord")
discord_handler: logging.StreamHandler
# Make dev print to stderr instead, so you don't have to watch the file # Make dev print to stderr instead, so you don't have to watch the file
if settings.SANDBOX: if settings.SANDBOX:

View File

@ -28,6 +28,7 @@ omit = [
profile = "black" profile = "black"
[tool.mypy] [tool.mypy]
check_untyped_defs = true
files = [ files = [
"database/**/*.py", "database/**/*.py",
"didier/**/*.py", "didier/**/*.py",
@ -35,7 +36,6 @@ files = [
] ]
plugins = [ plugins = [
"pydantic.mypy", "pydantic.mypy",
"sqlalchemy.ext.mypy.plugin"
] ]
[[tool.mypy.overrides]] [[tool.mypy.overrides]]
module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"] module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"]

View File

@ -1,22 +1,21 @@
black==22.3.0 black==23.3.0
coverage[toml]==6.4.1 coverage[toml]==7.2.7
freezegun==1.2.1 freezegun==1.2.2
isort==5.12.0 isort==5.12.0
mypy==0.961 mypy==1.4.1
pre-commit==2.20.0 pre-commit==3.3.3
pytest==7.1.2 pytest==7.4.0
pytest-asyncio==0.18.3 pytest-asyncio==0.21.0
pytest-env==0.6.2 pytest-env==0.8.2
sqlalchemy2-stubs==0.0.2a23 types-beautifulsoup4==4.12.0.5
types-beautifulsoup4==4.11.3 types-python-dateutil==2.8.19.13
types-python-dateutil==2.8.19
# Flake8 + plugins # Flake8 + plugins
flake8==4.0.1 flake8==6.0.0
flake8-bandit==3.0.0 flake8-bandit==4.1.1
flake8-bugbear==22.7.1 flake8-bugbear==23.6.5
flake8-docstrings==1.6.0 flake8-docstrings==1.7.0
flake8-dunder-all==0.2.1 flake8-dunder-all==0.3.0
flake8-eradicate==1.2.1 flake8-eradicate==1.5.0
flake8-isort==4.1.1 flake8-isort==6.0.0
flake8-simplify==0.19.2 flake8-simplify==0.20.0

View File

@ -1,13 +1,13 @@
aiohttp==3.8.1 aiohttp==3.8.4
alembic==1.8.0 alembic==1.11.1
asyncpg==0.25.0 asyncpg==0.28.0
beautifulsoup4==4.11.1 beautifulsoup4==4.12.2
discord.py==2.0.1 discord.py==2.3.1
environs==9.5.0 environs==9.5.0
feedparser==6.0.10 feedparser==6.0.10
ics==0.7.2 ics==0.7.2
markdownify==0.11.2 markdownify==0.11.6
overrides==6.1.0 overrides==7.3.1
pydantic==1.9.1 pydantic==2.0.2
python-dateutil==2.8.2 python-dateutil==2.8.2
sqlalchemy[asyncio]==1.4.37 sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18

View File

@ -111,7 +111,7 @@ class ScheduleInfo:
role_id: Optional[int] role_id: Optional[int]
schedule_url: Optional[str] schedule_url: Optional[str]
name: Optional[str] = None name: ScheduleType
SCHEDULE_DATA = [ SCHEDULE_DATA = [

View File

@ -1,138 +0,0 @@
from datetime import date, timedelta
import pytest
from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import wordle as crud
from database.schemas import User, WordleGuess, WordleWord
@pytest.fixture
async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]:
"""Fixture to generate some guesses"""
guesses = []
for guess in ["TEST", "WORDLE", "WORDS"]:
guess = WordleGuess(user_id=user.user_id, guess=guess)
postgres.add(guess)
await postgres.commit()
guesses.append(guess)
return guesses
@pytest.mark.postgres
async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User):
"""Test getting an active game when there is none"""
result = await crud.get_active_wordle_game(postgres, user.user_id)
assert not result
@pytest.mark.postgres
async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]):
"""Test getting an active game when there is one"""
result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id)
assert result == wordle_guesses
@pytest.mark.postgres
async def test_get_daily_word_none(postgres: AsyncSession):
"""Test getting the daily word when the database is empty"""
result = await crud.get_daily_word(postgres)
assert result is None
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_get_daily_word_not_today(postgres: AsyncSession):
"""Test getting the daily word when there is an entry, but not for today"""
day = date.today() - timedelta(days=1)
word = "testword"
word_instance = WordleWord(word=word, day=day)
postgres.add(word_instance)
await postgres.commit()
assert await crud.get_daily_word(postgres) is None
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_get_daily_word_present(postgres: AsyncSession):
"""Test getting the daily word when there is one for today"""
day = date.today()
word = "testword"
word_instance = WordleWord(word=word, day=day)
postgres.add(word_instance)
await postgres.commit()
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_set_daily_word_none_present(postgres: AsyncSession):
"""Test setting the daily word when there is none"""
assert await crud.get_daily_word(postgres) is None
word = "testword"
await crud.set_daily_word(postgres, word)
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_set_daily_word_present(postgres: AsyncSession):
"""Test setting the daily word when there already is one"""
word = "testword"
await crud.set_daily_word(postgres, word)
await crud.set_daily_word(postgres, "another word")
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_set_daily_word_force_overwrite(postgres: AsyncSession):
"""Test setting the daily word when there already is one, but "forced" is set to True"""
word = "testword"
await crud.set_daily_word(postgres, word)
word = "anotherword"
await crud.set_daily_word(postgres, word, forced=True)
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
async def test_make_wordle_guess(postgres: AsyncSession, user: User):
"""Test making a guess in your current game"""
test_user_id = user.user_id
guess = "guess"
await crud.make_wordle_guess(postgres, test_user_id, guess)
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess]
other_guess = "otherguess"
await crud.make_wordle_guess(postgres, test_user_id, other_guess)
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess, other_guess]
@pytest.mark.postgres
async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User):
"""Test dropping the collection of active games"""
test_user_id = user.user_id
assert await crud.get_active_wordle_game(postgres, test_user_id)
await crud.reset_wordle_games(postgres)
assert not await crud.get_active_wordle_game(postgres, test_user_id)

View File

@ -1,72 +0,0 @@
import datetime
import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import wordle_stats as crud
from database.schemas import User, WordleStats
async def insert_game_stats(postgres: AsyncSession, stats: WordleStats):
"""Helper function to insert some stats"""
postgres.add(stats)
await postgres.commit()
@pytest.mark.postgres
async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User):
"""Test getting a user's stats when the db is empty"""
test_user_id = user.user_id
statement = select(WordleStats).where(WordleStats.user_id == test_user_id)
assert (await postgres.execute(statement)).scalar_one_or_none() is None
await crud.get_wordle_stats(postgres, test_user_id)
assert (await postgres.execute(statement)).scalar_one_or_none() is not None
@pytest.mark.postgres
async def test_get_stats_existing_returns(postgres: AsyncSession, user: User):
"""Test getting a user's stats when there's already an entry present"""
test_user_id = user.user_id
stats = WordleStats(user_id=test_user_id)
stats.games = 20
await insert_game_stats(postgres, stats)
found_stats = await crud.get_wordle_stats(postgres, test_user_id)
assert found_stats.games == 20
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_complete_wordle_game_won(postgres: AsyncSession, user: User):
"""Test completing a wordle game when you win"""
test_user_id = user.user_id
await crud.complete_wordle_game(postgres, test_user_id, win=True)
stats = await crud.get_wordle_stats(postgres, test_user_id)
assert stats.games == 1
assert stats.wins == 1
assert stats.current_streak == 1
assert stats.highest_streak == 1
assert stats.last_win == datetime.date.today()
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User):
"""Test completing a wordle game when you lose"""
test_user_id = user.user_id
stats = WordleStats(user_id=test_user_id)
stats.current_streak = 10
await insert_game_stats(postgres, stats)
await crud.complete_wordle_game(postgres, test_user_id, win=False)
stats = await crud.get_wordle_stats(postgres, test_user_id)
# Check that streak was broken
assert stats.current_streak == 0
assert stats.games == 1