Merge pull request #177 from stijndcl/didier-3.7.0

Didier v3.7.0
pull/180/head v3.7.0
Stijn De Clercq 2023-09-24 16:28:02 +02:00 committed by GitHub
commit 5bfd3a92a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
62 changed files with 768 additions and 22942 deletions

View File

@ -3,12 +3,12 @@ default_language_version:
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
rev: 23.3.0
hooks:
- id: black
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.4.0
hooks:
- id: check-json
- id: end-of-file-fixer
@ -21,7 +21,7 @@ repos:
- id: isort
- repo: https://github.com/PyCQA/autoflake
rev: v1.4
rev: v2.2.0
hooks:
- id: autoflake
name: autoflake (python)
@ -31,7 +31,7 @@ repos:
- "--ignore-init-module-imports"
- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 6.0.0
hooks:
- id: flake8
exclude: ^(alembic|.github)

View File

@ -0,0 +1,94 @@
"""Migrate to 2.x
Revision ID: 09128b6e34dd
Revises: 1e3e7f4192c4
Create Date: 2023-07-07 16:23:15.990231
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "09128b6e34dd"
down_revision = "1e3e7f4192c4"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("bank", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("birthdays", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("bookmarks", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("command_stats", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op:
batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=False)
with op.batch_alter_table("deadlines", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
with op.batch_alter_table("github_links", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("nightly_data", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("reminders", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
with op.batch_alter_table("ufora_announcements", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=False)
with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("ufora_announcements", schema=None) as batch_op:
batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=True)
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("reminders", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("nightly_data", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("github_links", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("deadlines", schema=None) as batch_op:
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op:
batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=True)
with op.batch_alter_table("command_stats", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("bookmarks", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("birthdays", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
with op.batch_alter_table("bank", schema=None) as batch_op:
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
# ### end Alembic commands ###

View File

@ -0,0 +1,57 @@
"""Remove wordle
Revision ID: 1e3e7f4192c4
Revises: 954ad804f057
Create Date: 2023-07-07 14:52:20.993687
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "1e3e7f4192c4"
down_revision = "954ad804f057"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("wordle_word")
op.drop_table("wordle_stats")
op.drop_table("wordle_guesses")
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"wordle_guesses",
sa.Column("wordle_guess_id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BIGINT(), autoincrement=False, nullable=True),
sa.Column("guess", sa.TEXT(), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], name="wordle_guesses_user_id_fkey"),
sa.PrimaryKeyConstraint("wordle_guess_id", name="wordle_guesses_pkey"),
)
op.create_table(
"wordle_stats",
sa.Column("wordle_stats_id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("user_id", sa.BIGINT(), autoincrement=False, nullable=True),
sa.Column("last_win", sa.DATE(), autoincrement=False, nullable=True),
sa.Column("games", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("wins", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("current_streak", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.Column("highest_streak", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], name="wordle_stats_user_id_fkey"),
sa.PrimaryKeyConstraint("wordle_stats_id", name="wordle_stats_pkey"),
)
op.create_table(
"wordle_word",
sa.Column("word_id", sa.INTEGER(), autoincrement=True, nullable=False),
sa.Column("word", sa.TEXT(), autoincrement=False, nullable=False),
sa.Column("day", sa.DATE(), autoincrement=False, nullable=False),
sa.PrimaryKeyConstraint("word_id", name="wordle_word_pkey"),
sa.UniqueConstraint("day", name="wordle_word_day_key"),
)
# ### end Alembic commands ###

View File

@ -1,2 +0,0 @@
WORDLE_GUESS_COUNT = 6
WORDLE_WORD_LENGTH = 5

View File

@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[
if query is not None:
statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%"))
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]:

View File

@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
"""Get a list of all commands"""
statement = select(CustomCommand)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:

View File

@ -38,4 +38,4 @@ async def get_deadlines(
statement = statement.where(Deadline.course_id == course.course_id)
statement = statement.options(selectinload(Deadline.course))
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())

View File

@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"]
async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]:
"""Return a list of all easter eggs"""
statement = select(EasterEgg)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())

View File

@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
"""Get a list of all upcoming events"""
statement = select(Event).where(Event.timestamp > now)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:

View File

@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]):
async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]:
"""Filter a list of game IDs down to the ones that aren't in the database yet"""
statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids))
matches: list[int] = (await session.execute(statement)).scalars().all()
matches: list[int] = list((await session.execute(statement)).scalars().all())
return list(set(game_ids).difference(matches))

View File

@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id:
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
"""Get a user's GitHub links"""
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())

View File

@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
async def get_all_links(session: AsyncSession) -> list[Link]:
"""Get a list of all links"""
statement = select(Link)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def add_link(session: AsyncSession, name: str, url: str) -> Link:

View File

@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
"""Get a list of all memes"""
statement = select(MemeTemplate)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:

View File

@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"]
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
"""Get a list of all Reminders for a given category"""
statement = select(Reminder).where(Reminder.category == category)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:

View File

@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_
async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]:
"""Get all courses where announcements are enabled"""
statement = select(UforaCourse).where(UforaCourse.log_announcements)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def create_new_announcement(

View File

@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor
# Search case-insensitively
query = query.lower()
statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
result = (await session.execute(statement)).scalars().first()
if result:
return result
course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
course_result = (await session.execute(course_statement)).scalars().first()
if course_result:
return course_result
statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
result = (await session.execute(statement)).scalars().first()
return result.course if result else None
alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
alias_result = (await session.execute(alias_statement)).scalars().first()
return alias_result.course if alias_result else None

View File

@ -1,85 +0,0 @@
import datetime
from typing import Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import WordleGuess, WordleWord
__all__ = [
"get_active_wordle_game",
"get_wordle_guesses",
"make_wordle_guess",
"set_daily_word",
"reset_wordle_games",
]
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
"""Find a player's active game"""
await get_or_add_user(session, user_id)
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
guesses = (await session.execute(statement)).scalars().all()
return guesses
async def get_wordle_guesses(session: AsyncSession, user_id: int) -> list[str]:
"""Get the strings of a player's guesses"""
active_game = await get_active_wordle_game(session, user_id)
return list(map(lambda g: g.guess.lower(), active_game))
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
"""Make a guess in your current game"""
guess_instance = WordleGuess(user_id=user_id, guess=guess)
session.add(guess_instance)
await session.commit()
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
"""Get the word of today"""
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
row = (await session.execute(statement)).scalar_one_or_none()
if row is None:
return None
return row
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
"""Set the word of today
This does NOT overwrite the existing word if there is one, so that it can safely run
on startup every time.
In order to always overwrite the current word, set the "forced"-kwarg to True.
Returns the word that was chosen. If one already existed, return that instead.
"""
current_word = await get_daily_word(session)
if current_word is None:
current_word = WordleWord(word=word, day=datetime.date.today())
session.add(current_word)
await session.commit()
# Remove all active games
await reset_wordle_games(session)
elif forced:
current_word.word = word
session.add(current_word)
await session.commit()
# Remove all active games
await reset_wordle_games(session)
return current_word.word
async def reset_wordle_games(session: AsyncSession):
"""Reset all active games"""
statement = delete(WordleGuess)
await session.execute(statement)
await session.commit()

View File

@ -1,60 +0,0 @@
from datetime import date
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import WordleStats
__all__ = ["get_wordle_stats", "complete_wordle_game"]
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
"""Get a user's wordle stats
If no entry is found, it is first created
"""
await get_or_add_user(session, user_id)
statement = select(WordleStats).where(WordleStats.user_id == user_id)
stats = (await session.execute(statement)).scalar_one_or_none()
if stats is not None:
return stats
stats = WordleStats(user_id=user_id)
session.add(stats)
await session.commit()
await session.refresh(stats)
return stats
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
"""Update the user's Wordle stats"""
stats = await get_wordle_stats(session, user_id)
stats.games += 1
if win:
stats.wins += 1
# Update streak
today = date.today()
last_win = stats.last_win
stats.last_win = today
if last_win is None or (today - last_win).days > 1:
# Never won a game before or streak is over
stats.current_streak = 1
else:
# On a streak: increase counter
stats.current_streak += 1
# Update max streak if necessary
if stats.current_streak > stats.highest_streak:
stats.highest_streak = stats.current_streak
else:
# Streak is over
stats.current_streak = 0
session.add(stats)
await session.commit()

View File

@ -1,8 +1,7 @@
from urllib.parse import quote_plus
from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
import settings
@ -22,6 +21,4 @@ postgres_engine = create_async_engine(
future=True,
)
DBSession = sessionmaker(
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
)
DBSession = async_sessionmaker(autocommit=False, autoflush=False, bind=postgres_engine, expire_on_commit=False)

View File

@ -1,27 +1,14 @@
from __future__ import annotations
from datetime import date, datetime
from typing import Optional
from typing import List, Optional
from sqlalchemy import (
BigInteger,
Boolean,
Column,
Date,
DateTime,
Enum,
ForeignKey,
Integer,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import declarative_base, relationship
from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
from sqlalchemy.types import DateTime
from database import enums
Base = declarative_base()
__all__ = [
"Base",
"Bank",
@ -45,33 +32,37 @@ __all__ = [
"UforaCourse",
"UforaCourseAlias",
"User",
"WordleGuess",
"WordleStats",
"WordleWord",
]
class Base(DeclarativeBase):
"""Required base class for all tables"""
# Make all DateTimes timezone-aware
type_annotation_map = {datetime: DateTime(timezone=True)}
class Bank(Base):
"""A user's currency information"""
__tablename__ = "bank"
bank_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
bank_id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
dinks: int = Column(BigInteger, server_default="0", nullable=False)
invested: int = Column(BigInteger, server_default="0", nullable=False)
dinks: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False)
invested: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False)
# Interest rate
interest_level: int = Column(Integer, server_default="1", nullable=False)
interest_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
# Maximum amount that can be stored in the bank
capacity_level: int = Column(Integer, server_default="1", nullable=False)
capacity_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
# Maximum amount that can be robbed
rob_level: int = Column(Integer, server_default="1", nullable=False)
rob_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
user: Mapped[User] = relationship(uselist=False, back_populates="bank", lazy="selectin")
class Birthday(Base):
@ -79,11 +70,11 @@ class Birthday(Base):
__tablename__ = "birthdays"
birthday_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
birthday: date = Column(Date, nullable=False)
birthday_id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
birthday: Mapped[date] = mapped_column(nullable=False)
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
user: Mapped[User] = relationship(uselist=False, back_populates="birthday", lazy="selectin")
class Bookmark(Base):
@ -92,26 +83,26 @@ class Bookmark(Base):
__tablename__ = "bookmarks"
__table_args__ = (UniqueConstraint("user_id", "label"),)
bookmark_id: int = Column(Integer, primary_key=True)
label: str = Column(Text, nullable=False)
jump_url: str = Column(Text, nullable=False)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
bookmark_id: Mapped[int] = mapped_column(primary_key=True)
label: Mapped[str] = mapped_column(nullable=False)
jump_url: Mapped[str] = mapped_column(nullable=False)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin")
user: Mapped[User] = relationship(back_populates="bookmarks", uselist=False, lazy="selectin")
class CommandStats(Base):
"""Metrics on how often commands are used"""
__tablename__ = "command_stats"
command_stats_id: int = Column(Integer, primary_key=True)
command: str = Column(Text, nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
slash: bool = Column(Boolean, nullable=False)
context_menu: bool = Column(Boolean, nullable=False)
command_stats_id: Mapped[int] = mapped_column(primary_key=True)
command: Mapped[str] = mapped_column(nullable=False)
timestamp: Mapped[datetime] = mapped_column(nullable=False)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
slash: Mapped[bool] = mapped_column(nullable=False)
context_menu: Mapped[bool] = mapped_column(nullable=False)
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin")
user: Mapped[User] = relationship(back_populates="command_stats", uselist=False, lazy="selectin")
class CustomCommand(Base):
@ -119,13 +110,13 @@ class CustomCommand(Base):
__tablename__ = "custom_commands"
command_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
indexed_name: str = Column(Text, nullable=False, index=True)
response: str = Column(Text, nullable=False)
command_id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, unique=True)
indexed_name: Mapped[str] = mapped_column(nullable=False, index=True)
response: Mapped[str] = mapped_column(nullable=False)
aliases: list[CustomCommandAlias] = relationship(
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
aliases: Mapped[List[CustomCommandAlias]] = relationship(
back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
)
@ -134,12 +125,12 @@ class CustomCommandAlias(Base):
__tablename__ = "custom_command_aliases"
alias_id: int = Column(Integer, primary_key=True)
alias: str = Column(Text, nullable=False, unique=True)
indexed_alias: str = Column(Text, nullable=False, index=True)
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
alias_id: Mapped[int] = mapped_column(primary_key=True)
alias: Mapped[str] = mapped_column(nullable=False, unique=True)
indexed_alias: Mapped[str] = mapped_column(nullable=False, index=True)
command_id: Mapped[int] = mapped_column(ForeignKey("custom_commands.command_id"))
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
command: Mapped[CustomCommand] = relationship(back_populates="aliases", uselist=False, lazy="selectin")
class DadJoke(Base):
@ -147,8 +138,8 @@ class DadJoke(Base):
__tablename__ = "dad_jokes"
dad_joke_id: int = Column(Integer, primary_key=True)
joke: str = Column(Text, nullable=False)
dad_joke_id: Mapped[int] = mapped_column(primary_key=True)
joke: Mapped[str] = mapped_column(nullable=False)
class Deadline(Base):
@ -156,12 +147,12 @@ class Deadline(Base):
__tablename__ = "deadlines"
deadline_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
name: str = Column(Text, nullable=False)
deadline: datetime = Column(DateTime(timezone=True), nullable=False)
deadline_id: Mapped[int] = mapped_column(primary_key=True)
course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
name: Mapped[str] = mapped_column(nullable=False)
deadline: Mapped[datetime] = mapped_column(nullable=False)
course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin")
course: Mapped[UforaCourse] = relationship(back_populates="deadlines", uselist=False, lazy="selectin")
class EasterEgg(Base):
@ -169,11 +160,11 @@ class EasterEgg(Base):
__tablename__ = "easter_eggs"
easter_egg_id: int = Column(Integer, primary_key=True)
match: str = Column(Text, nullable=False)
response: str = Column(Text, nullable=False)
exact: bool = Column(Boolean, nullable=False, server_default="1")
startswith: bool = Column(Boolean, nullable=False, server_default="1")
easter_egg_id: Mapped[int] = mapped_column(primary_key=True)
match: Mapped[str] = mapped_column(nullable=False)
response: Mapped[str] = mapped_column(nullable=False)
exact: Mapped[bool] = mapped_column(nullable=False, server_default="1")
startswith: Mapped[bool] = mapped_column(nullable=False, server_default="1")
class Event(Base):
@ -181,11 +172,11 @@ class Event(Base):
__tablename__ = "events"
event_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False)
description: Optional[str] = Column(Text, nullable=True)
notification_channel: int = Column(BigInteger, nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
event_id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False)
description: Mapped[Optional[str]] = mapped_column(nullable=True)
notification_channel: Mapped[int] = mapped_column(BigInteger, nullable=False)
timestamp: Mapped[datetime] = mapped_column(nullable=False)
class FreeGame(Base):
@ -193,7 +184,7 @@ class FreeGame(Base):
__tablename__ = "free_games"
free_game_id: int = Column(Integer, primary_key=True)
free_game_id: Mapped[int] = mapped_column(primary_key=True)
class GitHubLink(Base):
@ -201,11 +192,11 @@ class GitHubLink(Base):
__tablename__ = "github_links"
github_link_id: int = Column(Integer, primary_key=True)
url: str = Column(Text, nullable=False, unique=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
github_link_id: Mapped[int] = mapped_column(primary_key=True)
url: Mapped[str] = mapped_column(nullable=False, unique=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin")
user: Mapped[User] = relationship(back_populates="github_links", uselist=False, lazy="selectin")
class Link(Base):
@ -213,9 +204,9 @@ class Link(Base):
__tablename__ = "links"
link_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
url: str = Column(Text, nullable=False)
link_id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, unique=True)
url: Mapped[str] = mapped_column(nullable=False)
class MemeTemplate(Base):
@ -223,10 +214,10 @@ class MemeTemplate(Base):
__tablename__ = "meme"
meme_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
template_id: int = Column(Integer, nullable=False, unique=True)
field_count: int = Column(Integer, nullable=False)
meme_id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, unique=True)
template_id: Mapped[int] = mapped_column(nullable=False, unique=True)
field_count: Mapped[int] = mapped_column(nullable=False)
class NightlyData(Base):
@ -234,12 +225,12 @@ class NightlyData(Base):
__tablename__ = "nightly_data"
nightly_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Optional[date] = Column(Date, nullable=True)
count: int = Column(Integer, server_default="0", nullable=False)
nightly_id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
last_nightly: Mapped[Optional[date]] = mapped_column(nullable=True)
count: Mapped[int] = mapped_column(server_default="0", nullable=False)
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
user: Mapped[User] = relationship(back_populates="nightly_data", uselist=False, lazy="selectin")
class Reminder(Base):
@ -247,11 +238,11 @@ class Reminder(Base):
__tablename__ = "reminders"
reminder_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False)
reminder_id: Mapped[int] = mapped_column(primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
category: Mapped[enums.ReminderCategory] = mapped_column(nullable=False)
user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin")
user: Mapped[User] = relationship(back_populates="reminders", uselist=False, lazy="selectin")
class Task(Base):
@ -259,9 +250,9 @@ class Task(Base):
__tablename__ = "tasks"
task_id: int = Column(Integer, primary_key=True)
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True)
previous_run: datetime = Column(DateTime(timezone=True), nullable=True)
task_id: Mapped[int] = mapped_column(primary_key=True)
task: Mapped[enums.TaskType] = mapped_column(nullable=False, unique=True)
previous_run: Mapped[datetime] = mapped_column(nullable=True)
class UforaCourse(Base):
@ -269,25 +260,25 @@ class UforaCourse(Base):
__tablename__ = "ufora_courses"
course_id: int = Column(Integer, primary_key=True)
name: str = Column(Text, nullable=False, unique=True)
code: str = Column(Text, nullable=False, unique=True)
year: int = Column(Integer, nullable=False)
compulsory: bool = Column(Boolean, server_default="1", nullable=False)
role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False)
overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False)
course_id: Mapped[int] = mapped_column(primary_key=True)
name: Mapped[str] = mapped_column(nullable=False, unique=True)
code: Mapped[str] = mapped_column(nullable=False, unique=True)
year: Mapped[int] = mapped_column(nullable=False)
compulsory: Mapped[bool] = mapped_column(server_default="1", nullable=False)
role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
# This is not the greatest fix, but there can only ever be two, so it will do the job
alternative_overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False)
log_announcements: bool = Column(Boolean, server_default="0", nullable=False)
alternative_overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
log_announcements: Mapped[bool] = mapped_column(server_default="0", nullable=False)
announcements: list[UforaAnnouncement] = relationship(
"UforaAnnouncement", back_populates="course", cascade="all, delete-orphan", lazy="selectin"
announcements: Mapped[List[UforaAnnouncement]] = relationship(
back_populates="course", cascade="all, delete-orphan", lazy="selectin"
)
aliases: list[UforaCourseAlias] = relationship(
"UforaCourseAlias", back_populates="course", cascade="all, delete-orphan", lazy="selectin"
aliases: Mapped[List[UforaCourseAlias]] = relationship(
back_populates="course", cascade="all, delete-orphan", lazy="selectin"
)
deadlines: list[Deadline] = relationship(
"Deadline", back_populates="course", cascade="all, delete-orphan", lazy="selectin"
deadlines: Mapped[List[Deadline]] = relationship(
back_populates="course", cascade="all, delete-orphan", lazy="selectin"
)
@ -296,11 +287,11 @@ class UforaCourseAlias(Base):
__tablename__ = "ufora_course_aliases"
alias_id: int = Column(Integer, primary_key=True)
alias: str = Column(Text, nullable=False, unique=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
alias_id: Mapped[int] = mapped_column(primary_key=True)
alias: Mapped[str] = mapped_column(nullable=False, unique=True)
course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
course: UforaCourse = relationship("UforaCourse", back_populates="aliases", uselist=False, lazy="selectin")
course: Mapped[UforaCourse] = relationship(back_populates="aliases", uselist=False, lazy="selectin")
class UforaAnnouncement(Base):
@ -308,11 +299,11 @@ class UforaAnnouncement(Base):
__tablename__ = "ufora_announcements"
announcement_id: int = Column(Integer, primary_key=True)
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
publication_date: date = Column(Date)
announcement_id: Mapped[int] = mapped_column(primary_key=True)
course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
publication_date: Mapped[date] = mapped_column()
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")
course: Mapped[UforaCourse] = relationship(back_populates="announcements", uselist=False, lazy="selectin")
class User(Base):
@ -320,70 +311,26 @@ class User(Base):
__tablename__ = "users"
user_id: int = Column(BigInteger, primary_key=True)
user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
bank: Bank = relationship(
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
bank: Mapped[Bank] = relationship(
back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
birthday: Optional[Birthday] = relationship(
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
birthday: Mapped[Optional[Birthday]] = relationship(
back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
bookmarks: list[Bookmark] = relationship(
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
bookmarks: Mapped[List[Bookmark]] = relationship(
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
command_stats: list[CommandStats] = relationship(
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
command_stats: Mapped[List[CommandStats]] = relationship(
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
github_links: list[GitHubLink] = relationship(
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
github_links: Mapped[List[GitHubLink]] = relationship(
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
nightly_data: Mapped[NightlyData] = relationship(
back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
reminders: list[Reminder] = relationship(
"Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
reminders: Mapped[List[Reminder]] = relationship(
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_guesses: list[WordleGuess] = relationship(
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
wordle_stats: WordleStats = relationship(
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
)
class WordleGuess(Base):
"""A user's Wordle guesses for today"""
__tablename__ = "wordle_guesses"
wordle_guess_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
guess: str = Column(Text, nullable=False)
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
class WordleStats(Base):
"""Stats about a user's wordle performance"""
__tablename__ = "wordle_stats"
wordle_stats_id: int = Column(Integer, primary_key=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
last_win: Optional[date] = Column(Date, nullable=True)
games: int = Column(Integer, server_default="0", nullable=False)
wins: int = Column(Integer, server_default="0", nullable=False)
current_streak: int = Column(Integer, server_default="0", nullable=False)
highest_streak: int = Column(Integer, server_default="0", nullable=False)
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
class WordleWord(Base):
"""The current Wordle word"""
__tablename__ = "wordle_word"
word_id: int = Column(Integer, primary_key=True)
word: str = Column(Text, nullable=False)
day: date = Column(Date, nullable=False, unique=True)

View File

@ -0,0 +1,171 @@
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from database.engine import DBSession
from database.schemas import UforaCourse, UforaCourseAlias
__all__ = ["main"]
async def purge_aliases(session: AsyncSession):
"""Delete old aliases for courses that will get a new id"""
codes = ["C004074", "C004073", "C004075", "C002309"]
for course_code in codes:
select_stmt = (
select(UforaCourse).where(UforaCourse.code == course_code).options(selectinload(UforaCourse.aliases))
)
course: UforaCourse = (await session.execute(select_stmt)).scalar_one()
for alias in list(course.aliases):
await session.delete(alias)
await session.commit()
async def main():
"""Add the Ufora courses for the 2023-2024 academic year"""
session: AsyncSession
async with DBSession() as session:
# # Remove Advanced Databases (which no longer exists)
delete_stmt = delete(UforaCourse).where(UforaCourse.code == "E018441")
await session.execute(delete_stmt)
await session.commit()
# Delete aliases of courses with new IDs
await purge_aliases(session)
# Fix IDs of compulsory courses and enable announcements
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004074")
bds: UforaCourse = (await session.execute(select_stmt)).scalar_one()
bds.course_id = 828305
bds.log_announcements = True
session.add(bds)
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004073")
cg: UforaCourse = (await session.execute(select_stmt)).scalar_one()
cg.course_id = 828293
cg.log_announcements = True
session.add(cg)
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004075")
stage: UforaCourse = (await session.execute(select_stmt)).scalar_one()
stage.course_id = 857878
stage.log_announcements = True
session.add(stage)
select_stmt = select(UforaCourse).where(UforaCourse.code == "C002309")
thesis: UforaCourse = (await session.execute(select_stmt)).scalar_one()
thesis.course_id = 828446
thesis.log_announcements = True
session.add(thesis)
await session.commit()
# Add new aliases for these courses
cg_alias = UforaCourseAlias(course_id=cg.course_id, alias="Computer Graphics")
stage_alias = UforaCourseAlias(course_id=stage.course_id, alias="Stage")
thesis_alias = UforaCourseAlias(course_id=thesis.course_id, alias="Thesis")
session.add_all([cg_alias, stage_alias, thesis_alias])
await session.commit()
# New elective courses
bed_eco = UforaCourse(
code="H001535", name="Bedrijfseconomie", year=6, compulsory=False, role_id=1155496199000952922
)
beg_eco = UforaCourse(
code="D012144", name="Beginselen van economie", year=6, compulsory=False, role_id=1155495997024247948
)
big_data_tech = UforaCourse(
code="E018240", name="Big Data Technology", year=6, compulsory=False, role_id=1155490114282201148
)
cloud_storage = UforaCourse(
code="E017310", name="Cloud Storage and Computing", year=6, compulsory=False, role_id=1155490706849271841
)
criminologie = UforaCourse(
code="B001623", name="Inleiding tot Criminologie", year=6, compulsory=False, role_id=1155486768477515936
)
data_quality = UforaCourse(
code="E018700", name="Data Quality", year=6, compulsory=False, role_id=1155491028707586180
)
data_vis_ai = UforaCourse(
code="E061370",
name="Data Visualization for and with AI",
year=6,
compulsory=False,
role_id=1155491854687686746,
)
db_design = UforaCourse(
code="E018610", name="Database Design", year=6, compulsory=False, role_id=1155489846345875489
)
finance_markets = UforaCourse(
code="F000093",
name="Financiële Markten en Instellingen",
year=6,
compulsory=False,
role_id=1155492815615299634,
)
game_theory = UforaCourse(
code="E003710",
name="Game Theory and Multiagent Systems",
year=6,
compulsory=False,
role_id=1155488481666154506,
)
knowledge_graphs = UforaCourse(
code="E018160", name="Knowledge Graphs", year=6, compulsory=False, role_id=1155491648323735592
)
natural_language_processing = UforaCourse(
code="E061341", name="Natural Language Processing", year=6, compulsory=False, role_id=1155487540992823348
)
nosql = UforaCourse(
code="E018130", name="NoSQL Databases", year=6, compulsory=False, role_id=1155491405955878973
)
secure_ss = UforaCourse(
code="E017950", name="Secure Software and Systems", year=6, compulsory=False, role_id=1155492095281340467
)
session.add_all(
[
bed_eco,
beg_eco,
big_data_tech,
cloud_storage,
criminologie,
data_quality,
data_vis_ai,
db_design,
finance_markets,
game_theory,
knowledge_graphs,
natural_language_processing,
nosql,
secure_ss,
]
)
await session.commit()
# Aliases for new elective courses
datakwaliteit = UforaCourseAlias(course_id=data_quality.course_id, alias="Datakwaliteit")
devops = UforaCourseAlias(course_id=cloud_storage.course_id, alias="DevOps")
nlp = UforaCourseAlias(course_id=natural_language_processing.course_id, alias="NLP")
nlp_nl = UforaCourseAlias(course_id=natural_language_processing.course_id, alias="Natuurlijke Taalverwerking")
session.add_all([datakwaliteit, devops, nlp, nlp_nl])
await session.commit()

View File

@ -4,8 +4,8 @@ from discord import app_commands
from overrides import overrides
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import easter_eggs, links, memes, ufora_courses, wordle
from database.schemas import EasterEgg, WordleWord
from database.crud import easter_eggs, links, memes, ufora_courses
from database.schemas import EasterEgg
__all__ = ["CacheManager", "EasterEggCache", "LinkCache", "UforaCourseCache"]
@ -69,7 +69,7 @@ class LinkCache(DatabaseCache):
self.clear()
all_links = await links.get_all_links(database_session)
self.data = list(map(lambda l: l.name, all_links))
self.data = list(map(lambda link: link.name, all_links))
self.data.sort()
self.data_transformed = list(map(str.lower, self.data))
@ -132,17 +132,6 @@ class UforaCourseCache(DatabaseCache):
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
class WordleCache(DatabaseCache):
"""Cache to store the current daily Wordle word"""
word: WordleWord
async def invalidate(self, database_session: AsyncSession):
word = await wordle.get_daily_word(database_session)
if word is not None:
self.word = word
class CacheManager:
"""Class that keeps track of all caches"""
@ -150,14 +139,12 @@ class CacheManager:
links: LinkCache
memes: MemeCache
ufora_courses: UforaCourseCache
wordle_word: WordleCache
def __init__(self):
self.easter_eggs = EasterEggCache()
self.links = LinkCache()
self.memes = MemeCache()
self.ufora_courses = UforaCourseCache()
self.wordle_word = WordleCache()
async def initialize_caches(self, postgres_session: AsyncSession):
"""Initialize the contents of all caches"""
@ -165,4 +152,3 @@ class CacheManager:
await self.links.invalidate(postgres_session)
await self.memes.invalidate(postgres_session)
await self.ufora_courses.invalidate(postgres_session)
await self.wordle_word.invalidate(postgres_session)

View File

@ -25,7 +25,7 @@ class Currency(commands.Cog):
super().__init__()
self.client = client
@commands.command(name="award")
@commands.command(name="award") # type: ignore[arg-type]
@commands.check(is_owner)
async def award(
self,
@ -49,7 +49,9 @@ class Currency(commands.Cog):
bank = await crud.get_bank(session, ctx.author.id)
embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue())
embed.set_thumbnail(url=ctx.author.avatar.url)
if ctx.author.avatar is not None:
embed.set_thumbnail(url=ctx.author.avatar.url)
embed.add_field(name="Interest level", value=bank.interest_level)
embed.add_field(name="Capacity level", value=bank.capacity_level)
@ -57,7 +59,9 @@ class Currency(commands.Cog):
await ctx.reply(embed=embed, mention_author=False)
@bank.group(name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True)
@bank.group( # type: ignore[arg-type]
name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True
)
async def bank_upgrades(self, ctx: commands.Context):
"""List the upgrades you can buy & their prices."""
async with self.client.postgres_session as session:
@ -77,7 +81,7 @@ class Currency(commands.Cog):
await ctx.reply(embed=embed, mention_author=False)
@bank_upgrades.command(name="capacity", aliases=["c"])
@bank_upgrades.command(name="capacity", aliases=["c"]) # type: ignore[arg-type]
async def bank_upgrade_capacity(self, ctx: commands.Context):
"""Upgrade the capacity level of your bank."""
async with self.client.postgres_session as session:
@ -88,7 +92,7 @@ class Currency(commands.Cog):
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
await self.client.reject_message(ctx.message)
@bank_upgrades.command(name="interest", aliases=["i"])
@bank_upgrades.command(name="interest", aliases=["i"]) # type: ignore[arg-type]
async def bank_upgrade_interest(self, ctx: commands.Context):
"""Upgrade the interest level of your bank."""
async with self.client.postgres_session as session:
@ -99,7 +103,7 @@ class Currency(commands.Cog):
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
await self.client.reject_message(ctx.message)
@bank_upgrades.command(name="rob", aliases=["r"])
@bank_upgrades.command(name="rob", aliases=["r"]) # type: ignore[arg-type]
async def bank_upgrade_rob(self, ctx: commands.Context):
"""Upgrade the rob level of your bank."""
async with self.client.postgres_session as session:
@ -110,7 +114,7 @@ class Currency(commands.Cog):
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
await self.client.reject_message(ctx.message)
@commands.hybrid_command(name="dinks")
@commands.hybrid_command(name="dinks") # type: ignore[arg-type]
async def dinks(self, ctx: commands.Context):
"""Check your Didier Dinks."""
async with self.client.postgres_session as session:
@ -118,7 +122,7 @@ class Currency(commands.Cog):
plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False)
@commands.command(name="invest", aliases=["deposit", "dep"])
@commands.command(name="invest", aliases=["deposit", "dep"]) # type: ignore[arg-type]
async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]):
"""Invest `amount` Didier Dinks into your bank.
@ -144,7 +148,7 @@ class Currency(commands.Cog):
f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False
)
@commands.hybrid_command(name="nightly")
@commands.hybrid_command(name="nightly") # type: ignore[arg-type]
async def nightly(self, ctx: commands.Context):
"""Claim nightly Didier Dinks."""
async with self.client.postgres_session as session:

View File

@ -13,7 +13,7 @@ class DebugCog(commands.Cog):
self.client = client
@overrides
async def cog_check(self, ctx: commands.Context) -> bool:
async def cog_check(self, ctx: commands.Context) -> bool: # type:ignore[override]
return await self.client.is_owner(ctx.author)
@commands.command(aliases=["Dev"])

View File

@ -1,4 +1,4 @@
from typing import Optional
from typing import Optional, Union, cast
import discord
from discord import app_commands
@ -17,6 +17,7 @@ from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource
from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
from didier.utils.discord.constants import Limits
from didier.utils.timer import Timer
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
@ -60,9 +61,19 @@ class Discord(commands.Cog):
event = await events.get_event_by_id(session, event_id)
if event is None:
return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True)
return await self.client.log_error(f"Unable to find event with id {event_id}.", log_to_discord=True)
channel = self.client.get_channel(event.notification_channel)
if channel is None:
return await self.client.log_error(
f"Unable to fetch channel for event `#{event_id}` (id `{event.notification_channel}`)."
)
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Channel for event `#{event_id}` (id `{event.notification_channel}`) is not messageable."
)
human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M")
embed = discord.Embed(title=event.name, colour=discord.Colour.blue())
@ -81,7 +92,7 @@ class Discord(commands.Cog):
self.client.loop.create_task(self.timer.update())
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
async def birthday(self, ctx: commands.Context, user: discord.User = None):
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):
"""Command to check the birthday of `user`.
Not passing an argument for `user` will show yours instead.
@ -98,8 +109,10 @@ class Discord(commands.Cog):
day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month))
return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False)
@birthday.command(name="set", aliases=["config"])
async def birthday_set(self, ctx: commands.Context, day: str, user: Optional[discord.User] = None):
@birthday.command(name="set", aliases=["config"]) # type: ignore[arg-type]
async def birthday_set(
self, ctx: commands.Context, day: str, user: Optional[Union[discord.User, discord.Member]] = None
):
"""Set your birthday to `day`.
Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`.
@ -113,6 +126,9 @@ class Discord(commands.Cog):
if user is None:
user = ctx.author
# Please Mypy
user = cast(Union[discord.User, discord.Member], user)
try:
default_year = 2001
date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
@ -141,7 +157,7 @@ class Discord(commands.Cog):
"""
# No label: shortcut to display bookmarks
if label is None:
return await self.bookmark_search(ctx, query=None)
return await self.bookmark_search(ctx, query=None) # type: ignore[arg-type]
async with self.client.postgres_session as session:
result = expect(
@ -151,7 +167,7 @@ class Discord(commands.Cog):
)
await ctx.reply(result.jump_url, mention_author=False)
@bookmark.command(name="create", aliases=["new"])
@bookmark.command(name="create", aliases=["new"]) # type: ignore[arg-type]
async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]):
"""Create a new bookmark for message `message` with label `label`.
@ -182,7 +198,7 @@ class Discord(commands.Cog):
# Label isn't allowed
return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False)
@bookmark.command(name="delete", aliases=["rm"])
@bookmark.command(name="delete", aliases=["rm"]) # type: ignore[arg-type]
async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str):
"""Delete the bookmark with id `bookmark_id`.
@ -207,7 +223,7 @@ class Discord(commands.Cog):
return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False)
@bookmark.command(name="search", aliases=["list", "ls"])
@bookmark.command(name="search", aliases=["list", "ls"]) # type: ignore[arg-type]
async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None):
"""Search through the list of bookmarks.
@ -236,7 +252,7 @@ class Discord(commands.Cog):
modal = CreateBookmark(self.client, message.jump_url)
await interaction.response.send_modal(modal)
@commands.hybrid_command(name="events")
@commands.hybrid_command(name="events") # type: ignore[arg-type]
@app_commands.rename(event_id="id")
@app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.")
async def events(self, ctx: commands.Context, event_id: Optional[int] = None):
@ -276,16 +292,16 @@ class Discord(commands.Cog):
embed.add_field(
name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True
)
embed.add_field(
name="Channel",
value=self.client.get_channel(result_event.notification_channel).mention,
inline=False,
)
channel = self.client.get_channel(result_event.notification_channel)
if channel is not None and not isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
embed.add_field(name="Channel", value=channel.mention, inline=False)
embed.description = result_event.description
return await ctx.reply(embed=embed, mention_author=False)
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None):
async def github_group(self, ctx: commands.Context, user: Optional[Union[discord.User, discord.Member]] = None):
"""Show a user's GitHub links.
If no user is provided, this shows your links instead.
@ -293,6 +309,9 @@ class Discord(commands.Cog):
# Default to author
user = user or ctx.author
# Please Mypy
user = cast(Union[discord.User, discord.Member], user)
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
embed.set_author(name=user.display_name, icon_url=get_user_avatar(user))
@ -324,7 +343,7 @@ class Discord(commands.Cog):
return await ctx.reply(embed=embed, mention_author=False)
@github_group.command(name="add", aliases=["create", "insert"])
@github_group.command(name="add", aliases=["create", "insert"]) # type: ignore[arg-type]
async def github_add(self, ctx: commands.Context, link: str):
"""Add a new link into the database."""
# Remove wrapping <brackets> which can be used to escape Discord embeds
@ -339,7 +358,7 @@ class Discord(commands.Cog):
await self.client.confirm_message(ctx.message)
return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False)
@github_group.command(name="delete", aliases=["del", "remove", "rm"])
@github_group.command(name="delete", aliases=["del", "remove", "rm"]) # type: ignore[arg-type]
async def github_delete(self, ctx: commands.Context, link_id: str):
"""Delete the link with it `link_id` from the database.
@ -411,7 +430,7 @@ class Discord(commands.Cog):
await message.add_reaction("📌")
return await interaction.response.send_message("📌", ephemeral=True)
@commands.hybrid_command(name="snipe")
@commands.hybrid_command(name="snipe") # type: ignore[arg-type]
async def snipe(self, ctx: commands.Context):
"""Publicly shame people when they edit or delete one of their messages.
@ -420,7 +439,7 @@ class Discord(commands.Cog):
if ctx.guild is None:
return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True)
sniped_data = self.client.sniped.get(ctx.channel.id, None)
sniped_data = self.client.sniped.get(ctx.channel.id)
if sniped_data is None:
return await ctx.reply(
"There's no one to make fun of in this channel.", mention_author=False, ephemeral=True

View File

@ -28,7 +28,7 @@ class Fun(commands.Cog):
def __init__(self, client: Didier):
self.client = client
@commands.hybrid_command(name="clap")
@commands.hybrid_command(name="clap") # type: ignore[arg-type]
async def clap(self, ctx: commands.Context, *, text: str):
"""Clap a message with emojis for extra dramatic effect"""
chars = list(filter(lambda c: c in constants.EMOJI_MAP, text))
@ -50,10 +50,7 @@ class Fun(commands.Cog):
meme = await generate_meme(self.client.http_session, result, fields)
return meme
@commands.hybrid_command(
name="dadjoke",
aliases=["dad", "dj"],
)
@commands.hybrid_command(name="dadjoke", aliases=["dad", "dj"]) # type: ignore[arg-type]
async def dad_joke(self, ctx: commands.Context):
"""Why does Yoda's code always crash? Because there is no try."""
async with self.client.postgres_session as session:
@ -83,13 +80,13 @@ class Fun(commands.Cog):
return await self.memegen_ls_msg(ctx)
if fields is None:
return await self.memegen_preview_msg(ctx, template)
return await self.memegen_preview_msg(ctx, template) # type: ignore[arg-type]
async with ctx.typing():
meme = await self._do_generate_meme(template, shlex.split(fields))
return await ctx.reply(meme, mention_author=False)
@memegen_msg.command(name="list", aliases=["ls"])
@memegen_msg.command(name="list", aliases=["ls"]) # type: ignore[arg-type]
async def memegen_ls_msg(self, ctx: commands.Context):
"""Get a list of all available meme templates.
@ -100,14 +97,14 @@ class Fun(commands.Cog):
await MemeSource(ctx, results).start()
@memegen_msg.command(name="preview", aliases=["p"])
@memegen_msg.command(name="preview", aliases=["p"]) # type: ignore[arg-type]
async def memegen_preview_msg(self, ctx: commands.Context, template: str):
"""Generate a preview for the meme template `template`, to see how the fields are structured."""
async with ctx.typing():
meme = await self._do_generate_meme(template, [])
return await ctx.reply(meme, mention_author=False)
@memes_slash.command(name="generate")
@memes_slash.command(name="generate") # type: ignore[arg-type]
async def memegen_slash(self, interaction: discord.Interaction, template: str):
"""Generate a meme."""
async with self.client.postgres_session as session:
@ -116,7 +113,7 @@ class Fun(commands.Cog):
modal = GenerateMeme(self.client, result)
await interaction.response.send_modal(modal)
@memes_slash.command(name="preview")
@memes_slash.command(name="preview") # type: ignore[arg-type]
@app_commands.describe(template="The meme template to use in the preview.")
async def memegen_preview_slash(self, interaction: discord.Interaction, template: str):
"""Generate a preview for a meme, to see how the fields are structured."""
@ -134,7 +131,7 @@ class Fun(commands.Cog):
"""Autocompletion for the 'template'-parameter"""
return self.client.database_caches.memes.get_autocomplete_suggestions(current)
@app_commands.command()
@app_commands.command() # type: ignore[arg-type]
@app_commands.describe(message="The text to convert.")
async def mock(self, interaction: discord.Interaction, message: str):
"""Mock a message.
@ -158,7 +155,7 @@ class Fun(commands.Cog):
return await interaction.followup.send(mock(message))
@commands.hybrid_command(name="xkcd")
@commands.hybrid_command(name="xkcd") # type: ignore[arg-type]
@app_commands.rename(comic_id="id")
async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None):
"""Fetch comic `#id` from xkcd.

View File

@ -1,14 +1,6 @@
from typing import Optional
import discord
from discord import app_commands
from discord.ext import commands
from database.constants import WORDLE_WORD_LENGTH
from database.crud.wordle import get_wordle_guesses, make_wordle_guess
from database.crud.wordle_stats import complete_wordle_game
from didier import Didier
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed, is_wordle_game_over
class Games(commands.Cog):
@ -19,53 +11,6 @@ class Games(commands.Cog):
def __init__(self, client: Didier):
self.client = client
@app_commands.command(name="wordle", description="Play Wordle!")
async def wordle(self, interaction: discord.Interaction, guess: Optional[str] = None):
"""View your active Wordle game
If an argument is provided, make a guess instead
"""
await interaction.response.defer(ephemeral=True)
# Guess is wrong length
if guess is not None and len(guess) != 0 and len(guess) != WORDLE_WORD_LENGTH:
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
return await interaction.followup.send(embed=embed)
word_instance = self.client.database_caches.wordle_word.word
async with self.client.postgres_session as session:
guesses = await get_wordle_guesses(session, interaction.user.id)
# Trying to guess with a complete game
if is_wordle_game_over(guesses, word_instance.word):
embed = WordleErrorEmbed(
message="You've already completed today's Wordle.\nTry again tomorrow!"
).to_embed()
return await interaction.followup.send(embed=embed)
# Make a guess
if guess:
# The guess is not a real word
if guess.lower() not in self.client.wordle_words:
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
return await interaction.followup.send(embed=embed)
guess = guess.lower()
await make_wordle_guess(session, interaction.user.id, guess)
# Don't re-request the game, we already have it
# just append locally
guesses.append(guess)
embed = WordleEmbed(guesses=guesses, word=word_instance).to_embed()
await interaction.followup.send(embed=embed)
# After responding to the interaction: update stats in the background
game_over = is_wordle_game_over(guesses, word_instance.word)
if game_over:
await complete_wordle_game(session, interaction.user.id, word_instance.word in guesses)
async def setup(client: Didier):
"""Load the cog"""

View File

@ -159,6 +159,9 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
Code in codeblocks is ignored, as it is used to create examples.
"""
description = command.help
if description is None:
return ""
codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL)
# Regex borrowed from https://stackoverflow.com/a/59843498/13568999
@ -198,13 +201,10 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
return None
async def _filter_cogs(self, cogs: list[commands.Cog]) -> list[commands.Cog]:
async def _filter_cogs(self, cogs: list[Optional[commands.Cog]]) -> list[commands.Cog]:
"""Filter the list of cogs down to all those that the user can see"""
async def _predicate(cog: Optional[commands.Cog]) -> bool:
if cog is None:
return False
async def _predicate(cog: commands.Cog) -> bool:
# Remove cogs that we never want to see in the help page because they
# don't contain commands, or shouldn't be visible at all
if not cog.get_commands():
@ -220,12 +220,12 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
return True
# Filter list of cogs down
filtered_cogs = [cog for cog in cogs if await _predicate(cog)]
filtered_cogs = [cog for cog in cogs if cog is not None and await _predicate(cog)]
return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name))
def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]:
"""Check if a command has flags"""
flag_param = command.params.get("flags", None)
flag_param = command.params.get("flags")
if flag_param is None:
return None

View File

@ -1,6 +1,6 @@
import inspect
import os
from typing import Optional
from typing import Any, Optional, Union
from discord.ext import commands
@ -76,18 +76,24 @@ class Meta(commands.Cog):
if command_name is None:
return await ctx.reply(repo_home, mention_author=False)
command: Optional[Union[commands.HelpCommand, commands.Command]]
src: Any
if command_name == "help":
command = self.client.help_command
if command is None:
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
src = type(self.client.help_command)
filename = inspect.getsourcefile(src)
else:
command = self.client.get_command(command_name)
if command is None:
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
src = command.callback.__code__
filename = src.co_filename
if command is None:
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
lines, first_line = inspect.getsourcelines(src)
if filename is None:

View File

@ -22,7 +22,7 @@ class Other(commands.Cog):
def __init__(self, client: Didier):
self.client = client
@commands.hybrid_command(name="corona", aliases=["covid", "rona"])
@commands.hybrid_command(name="corona", aliases=["covid", "rona"]) # type: ignore[arg-type]
async def covid(self, ctx: commands.Context, country: str = "Belgium"):
"""Show Covid-19 info for a specific country.
@ -43,7 +43,7 @@ class Other(commands.Cog):
"""Autocompletion for the 'country'-parameter"""
return autocomplete_country(value)[:25]
@commands.hybrid_command(
@commands.hybrid_command( # type: ignore[arg-type]
name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary"
)
async def define(self, ctx: commands.Context, *, query: str):
@ -55,7 +55,7 @@ class Other(commands.Cog):
mention_author=False,
)
@commands.hybrid_command(name="google", description="Google search")
@commands.hybrid_command(name="google", description="Google search") # type: ignore[arg-type]
@app_commands.describe(query="Search query")
async def google(self, ctx: commands.Context, *, query: str):
"""Show the Google search results for `query`.
@ -71,7 +71,7 @@ class Other(commands.Cog):
embed = GoogleSearch(results).to_embed()
await ctx.reply(embed=embed, mention_author=False)
@commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.")
@commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") # type: ignore[arg-type]
async def inspire(self, ctx: commands.Context):
"""Generate an [InspiroBot](https://inspirobot.me/) quote."""
async with ctx.typing():
@ -82,7 +82,7 @@ class Other(commands.Cog):
async with self.client.postgres_session as session:
return await get_link_by_name(session, name.lower())
@commands.command(name="Link", aliases=["Links"])
@commands.command(name="Link", aliases=["Links"]) # type: ignore[arg-type]
async def link_msg(self, ctx: commands.Context, name: str):
"""Get the link to the resource named `name`."""
link = await self._get_link(name)
@ -92,7 +92,7 @@ class Other(commands.Cog):
target_message = await self.client.get_reply_target(ctx)
await target_message.reply(link.url, mention_author=False)
@app_commands.command(name="link")
@app_commands.command(name="link") # type: ignore[arg-type]
@app_commands.describe(name="The name of the resource")
async def link_slash(self, interaction: discord.Interaction, name: str):
"""Get the link to something."""

View File

@ -42,7 +42,7 @@ class Owner(commands.Cog):
def __init__(self, client: Didier):
self.client = client
async def cog_check(self, ctx: commands.Context) -> bool:
async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override]
"""Global check for every command in this cog
This means that we don't have to add is_owner() to every single command separately
@ -102,7 +102,7 @@ class Owner(commands.Cog):
async def add_msg(self, ctx: commands.Context):
"""Command group for [add X] message commands"""
@add_msg.command(name="Alias")
@add_msg.command(name="Alias") # type: ignore[arg-type]
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command"""
async with self.client.postgres_session as session:
@ -116,7 +116,7 @@ class Owner(commands.Cog):
await ctx.reply("There is already a command with this name.")
await self.client.reject_message(ctx.message)
@add_msg.command(name="Custom")
@add_msg.command(name="Custom") # type: ignore[arg-type]
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command"""
async with self.client.postgres_session as session:
@ -127,7 +127,7 @@ class Owner(commands.Cog):
await ctx.reply("There is already a command with this name.")
await self.client.reject_message(ctx.message)
@add_msg.command(name="Link")
@add_msg.command(name="Link") # type: ignore[arg-type]
async def add_link_msg(self, ctx: commands.Context, name: str, url: str):
"""Add a new link"""
async with self.client.postgres_session as session:
@ -136,7 +136,7 @@ class Owner(commands.Cog):
await self.client.confirm_message(ctx.message)
@add_slash.command(name="custom", description="Add a custom command")
@add_slash.command(name="custom", description="Add a custom command") # type: ignore[arg-type]
async def add_custom_slash(self, interaction: discord.Interaction):
"""Slash command to add a custom command"""
if not await self.client.is_owner(interaction.user):
@ -145,7 +145,7 @@ class Owner(commands.Cog):
modal = CreateCustomCommand(self.client)
await interaction.response.send_modal(modal)
@add_slash.command(name="dadjoke", description="Add a dad joke")
@add_slash.command(name="dadjoke", description="Add a dad joke") # type: ignore[arg-type]
async def add_dad_joke_slash(self, interaction: discord.Interaction):
"""Slash command to add a dad joke"""
if not await self.client.is_owner(interaction.user):
@ -154,7 +154,7 @@ class Owner(commands.Cog):
modal = AddDadJoke(self.client)
await interaction.response.send_modal(modal)
@add_slash.command(name="deadline", description="Add a deadline")
@add_slash.command(name="deadline", description="Add a deadline") # type: ignore[arg-type]
@app_commands.describe(course="The name of the course to add a deadline for (aliases work too)")
async def add_deadline_slash(self, interaction: discord.Interaction, course: str):
"""Slash command to add a deadline"""
@ -174,7 +174,7 @@ class Owner(commands.Cog):
"""Autocompletion for the 'course'-parameter"""
return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
@add_slash.command(name="event", description="Add a new event")
@add_slash.command(name="event", description="Add a new event") # type: ignore[arg-type]
async def add_event_slash(self, interaction: discord.Interaction):
"""Slash command to add new events"""
if not await self.client.is_owner(interaction.user):
@ -183,7 +183,7 @@ class Owner(commands.Cog):
modal = AddEvent(self.client)
await interaction.response.send_modal(modal)
@add_slash.command(name="link", description="Add a new link")
@add_slash.command(name="link", description="Add a new link") # type: ignore[arg-type]
async def add_link_slash(self, interaction: discord.Interaction):
"""Slash command to add new links"""
if not await self.client.is_owner(interaction.user):
@ -192,7 +192,7 @@ class Owner(commands.Cog):
modal = AddLink(self.client)
await interaction.response.send_modal(modal)
@add_slash.command(name="meme", description="Add a new meme")
@add_slash.command(name="meme", description="Add a new meme") # type: ignore[arg-type]
async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int):
"""Slash command to add new memes"""
await interaction.response.defer(ephemeral=True)
@ -205,11 +205,11 @@ class Owner(commands.Cog):
await interaction.followup.send(f"Added meme `{meme.meme_id}`.")
await self.client.database_caches.memes.invalidate(session)
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) # type: ignore[arg-type]
async def edit_msg(self, ctx: commands.Context):
"""Command group for [edit X] commands"""
@edit_msg.command(name="Custom")
@edit_msg.command(name="Custom") # type: ignore[arg-type]
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
"""Edit an existing custom command"""
async with self.client.postgres_session as session:
@ -220,7 +220,7 @@ class Owner(commands.Cog):
await ctx.reply(f"No command found matching `{command}`.")
return await self.client.reject_message(ctx.message)
@edit_slash.command(name="custom", description="Edit a custom command")
@edit_slash.command(name="custom", description="Edit a custom command") # type: ignore[arg-type]
@app_commands.describe(command="The name of the command to edit")
async def edit_custom_slash(self, interaction: discord.Interaction, command: str):
"""Slash command to edit a custom command"""

View File

@ -27,7 +27,7 @@ class School(commands.Cog):
def __init__(self, client: Didier):
self.client = client
@commands.hybrid_command(name="deadlines")
@commands.hybrid_command(name="deadlines") # type: ignore[arg-type]
async def deadlines(self, ctx: commands.Context):
"""Show upcoming deadlines."""
async with ctx.typing():
@ -40,7 +40,7 @@ class School(commands.Cog):
embed = Deadlines(deadlines).to_embed()
await ctx.reply(embed=embed, mention_author=False, ephemeral=False)
@commands.hybrid_command(name="les", aliases=["sched", "schedule"])
@commands.hybrid_command(name="les", aliases=["sched", "schedule"]) # type: ignore[arg-type]
@app_commands.rename(day_dt="date")
async def les(
self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None
@ -72,10 +72,7 @@ class School(commands.Cog):
except NotInMainGuildException:
return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False)
@commands.hybrid_command(
name="menu",
aliases=["eten", "food"],
)
@commands.hybrid_command(name="menu", aliases=["eten", "food"]) # type: ignore[arg-type]
@app_commands.rename(day_dt="date")
async def menu(
self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None
@ -96,7 +93,7 @@ class School(commands.Cog):
embed = no_menu_found(day_dt)
await ctx.reply(embed=embed, mention_author=False)
@commands.hybrid_command(
@commands.hybrid_command( # type: ignore[arg-type]
name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"]
)
@app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)")
@ -124,7 +121,7 @@ class School(commands.Cog):
mention_author=False,
)
@commands.hybrid_command(name="ufora")
@commands.hybrid_command(name="ufora") # type: ignore[arg-type]
async def ufora(self, ctx: commands.Context, course: str):
"""Link the Ufora page for a course."""
async with self.client.postgres_session as session:

View File

@ -1,4 +1,6 @@
import asyncio
import datetime
import logging
import random
import discord
@ -10,7 +12,6 @@ from database import enums
from database.crud.birthdays import get_birthdays_on_day
from database.crud.reminders import get_all_reminders_for_category
from database.crud.ufora_announcements import remove_old_announcements
from database.crud.wordle import set_daily_word
from database.schemas import Reminder
from didier import Didier
from didier.data.embeds.schedules import (
@ -21,9 +22,12 @@ from didier.data.embeds.schedules import (
from didier.data.rss_feeds.free_games import fetch_free_games
from didier.data.rss_feeds.ufora import fetch_ufora_announcements
from didier.decorators.tasks import timed_task
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
from didier.utils.discord.checks import is_owner
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
logger = logging.getLogger(__name__)
# datetime.time()-instances for when every task should run
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
@ -54,11 +58,10 @@ class Tasks(commands.Cog):
"reminders": self.reminders,
"ufora": self.pull_ufora_announcements,
"remove_ufora": self.remove_old_ufora_announcements,
"wordle": self.reset_wordle_word,
}
@overrides
def cog_load(self) -> None:
async def cog_load(self) -> None:
# Only check birthdays if there's a channel to send it to
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
self.check_birthdays.start()
@ -74,10 +77,10 @@ class Tasks(commands.Cog):
# Start other tasks
self.reminders.start()
self.reset_wordle_word.start()
asyncio.create_task(self.get_error_channel())
@overrides
def cog_unload(self) -> None:
async def cog_unload(self) -> None:
# Cancel all pending tasks
for task in self._tasks.values():
if task.is_running():
@ -99,7 +102,7 @@ class Tasks(commands.Cog):
await ctx.reply(embed=embed, mention_author=False)
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]")
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") # type: ignore[arg-type]
async def force_task(self, ctx: commands.Context, name: str):
"""Command to force-run a task without waiting for the specified run time"""
name = name.lower()
@ -110,23 +113,53 @@ class Tasks(commands.Cog):
await task(forced=True)
await self.client.confirm_message(ctx.message)
async def get_error_channel(self):
"""Get the configured channel from the cache"""
await self.client.wait_until_ready()
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
channel = self.client.get_channel(settings.ERRORS_CHANNEL)
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
logger.error(f"Configured error channel (id `{settings.ERRORS_CHANNEL}`) is not messageable.")
else:
self.client.error_channel = channel
elif self.client.owner_id is not None:
self.client.error_channel = self.client.get_user(self.client.owner_id)
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
@timed_task(enums.TaskType.BIRTHDAYS)
async def check_birthdays(self, **kwargs):
"""Check if it's currently anyone's birthday"""
_ = kwargs
# Can't happen (task isn't started if this is None), but Mypy doesn't know
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is None:
return
now = tz_aware_now().date()
async with self.client.postgres_session as session:
birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
if channel is None:
return await self.client.log_error("Unable to find channel for birthday announcements")
return await self.client.log_error("Unable to fetch channel for birthday announcements.")
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Birthday announcement channel (id `{settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL}`) is not messageable."
)
for birthday in birthdays:
user = self.client.get_user(birthday.user_id)
if user is None:
await self.client.log_error(
f"Unable to fetch user with id `{birthday.user_id}` for birthday announcement"
)
continue
await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention))
@check_birthdays.before_loop
@ -146,6 +179,14 @@ class Tasks(commands.Cog):
games = await fetch_free_games(self.client.http_session, session)
channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL)
if channel is None:
return await self.client.log_error("Unable to fetch channel for free games announcements.")
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Free games channel (id `{settings.FREE_GAMES_CHANNEL}`) is not messageable."
)
for game in games:
await channel.send(embed=game.to_embed())
@ -207,6 +248,17 @@ class Tasks(commands.Cog):
async with self.client.postgres_session as db_session:
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
if announcements_channel is None:
return await self.client.log_error(
f"Unable to fetch channel for ufora announcements (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`)."
)
if isinstance(announcements_channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await self.client.log_error(
f"Ufora announcements channel (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`) is not messageable."
)
announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
for announcement in announcements:
@ -266,34 +318,16 @@ class Tasks(commands.Cog):
async with self.client.postgres_session as session:
await remove_old_announcements(session)
@tasks.loop(time=DAILY_RESET_TIME)
async def reset_wordle_word(self, forced: bool = False):
"""Reset the daily Wordle word"""
async with self.client.postgres_session as session:
await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced)
await self.client.database_caches.wordle_word.invalidate(session)
@reset_wordle_word.before_loop
async def _before_reset_wordle_word(self):
await self.client.wait_until_ready()
@check_birthdays.error
@pull_schedules.error
@pull_ufora_announcements.error
@reminders.error
@remove_old_ufora_announcements.error
@reset_wordle_word.error
async def _on_tasks_error(self, error: BaseException):
"""Error handler for all tasks"""
self.client.dispatch("task_error", error)
async def setup(client: Didier):
"""Load the cog
Initially fetch the wordle word from the database, or reset it
if there hasn't been a reset yet today
"""
cog = Tasks(client)
await client.add_cog(cog)
await cog.reset_wordle_word()
"""Load the cog"""
await client.add_cog(Tasks(client))

View File

@ -19,7 +19,7 @@ async def get_country_info(http_session: ClientSession, country: str) -> CovidDa
yesterday = response
data = {"today": today, "yesterday": yesterday}
return CovidData.parse_obj(data)
return CovidData.model_validate(data)
async def get_global_info(http_session: ClientSession) -> CovidData:
@ -35,4 +35,4 @@ async def get_global_info(http_session: ClientSession) -> CovidData:
yesterday = response
data = {"today": today, "yesterday": yesterday}
return CovidData.parse_obj(data)
return CovidData.model_validate(data)

View File

@ -12,4 +12,4 @@ async def fetch_menu(http_session: ClientSession, day_dt: date) -> Menu:
"""Fetch the menu for a given day"""
endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json"
async with ensure_get(http_session, endpoint, log_exceptions=False) as response:
return Menu.parse_obj(response)
return Menu.model_validate(response)

View File

@ -14,4 +14,4 @@ async def lookup(http_session: ClientSession, query: str) -> list[Definition]:
url = "https://api.urbandictionary.com/v0/define"
async with ensure_get(http_session, url, params={"term": query}) as response:
return list(map(Definition.parse_obj, response["list"]))
return list(map(Definition.model_validate, response["list"]))

View File

@ -13,4 +13,4 @@ async def fetch_xkcd_post(http_session: ClientSession, *, num: Optional[int] = N
url = "https://xkcd.com" + (f"/{num}" if num is not None else "") + "/info.0.json"
async with ensure_get(http_session, url) as response:
return XKCDPost.parse_obj(response)
return XKCDPost.model_validate(response)

View File

@ -1,6 +1,6 @@
import discord
from overrides import overrides
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
from didier.data.embeds.base import EmbedPydantic
@ -24,7 +24,7 @@ class _CovidNumbers(BaseModel):
active: int
tests: int
@validator("updated")
@field_validator("updated")
def updated_to_seconds(cls, value: int) -> int:
"""Turn the updated field into seconds instead of milliseconds"""
return int(value) // 1000

View File

@ -38,10 +38,10 @@ def create_error_embed(ctx: Optional[commands.Context], exception: Exception) ->
embed = discord.Embed(title="Error", colour=discord.Colour.red())
if ctx is not None:
if ctx.guild is None:
if ctx.guild is None or isinstance(ctx.channel, discord.DMChannel):
origin = "DM"
else:
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
origin = f"<#{ctx.channel.id}> ({ctx.guild.name})"
invocation = f"{ctx.author.display_name} in {origin}"

View File

@ -4,18 +4,17 @@ from typing import Optional
import discord
from aiohttp import ClientSession
from overrides import overrides
from pydantic import validator
from pydantic import field_validator
from didier.data.embeds.base import EmbedPydantic
from didier.data.scrapers.common import GameStorePage
from didier.data.scrapers.steam import get_steam_webpage_info
from didier.utils.discord import colours
__all__ = ["SEPARATOR", "FreeGameEmbed"]
from didier.utils.discord.constants import Limits
from didier.utils.types.string import abbreviate
__all__ = ["SEPARATOR", "FreeGameEmbed"]
SEPARATOR = " • Free • "
@ -58,7 +57,7 @@ class FreeGameEmbed(EmbedPydantic):
store_page: Optional[GameStorePage] = None
@validator("title")
@field_validator("title")
def _clean_title(cls, value: str) -> str:
return html.unescape(value)
@ -107,7 +106,6 @@ class FreeGameEmbed(EmbedPydantic):
embed.add_field(name="Open in browser", value=f"[{self.link}]({self.link})")
if self.store_page.xdg_open_url is not None:
embed.add_field(
name="Open in app", value=f"[{self.store_page.xdg_open_url}]({self.store_page.xdg_open_url})"
)

View File

@ -11,7 +11,7 @@ __all__ = ["create_logging_embed"]
def create_logging_embed(level: int, message: str) -> discord.Embed:
"""Create an embed to send to the logging channel"""
colours = {
logging.DEBUG: discord.Colour.light_gray,
logging.DEBUG: discord.Colour.light_grey(),
logging.ERROR: discord.Colour.red(),
logging.INFO: discord.Colour.blue(),
logging.WARNING: discord.Colour.yellow(),

View File

@ -2,7 +2,7 @@ from datetime import datetime
import discord
from overrides import overrides
from pydantic import validator
from pydantic import field_validator
from didier.data.embeds.base import EmbedPydantic
from didier.utils.discord import colours
@ -39,8 +39,8 @@ class Definition(EmbedPydantic):
total_votes = self.thumbs_up + self.thumbs_down
return round(100 * self.thumbs_up / total_votes, 2)
@validator("definition", "example")
def modify_long_text(cls, field):
@field_validator("definition", "example")
def modify_long_text(cls, field: str):
"""Remove brackets from fields & cut them off if they are too long"""
field = field.replace("[", "").replace("]", "")
return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH)

View File

@ -1,142 +0,0 @@
import enum
from dataclasses import dataclass
import discord
from overrides import overrides
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
from database.schemas import WordleWord
from didier.data.embeds.base import EmbedBaseModel
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
__all__ = ["is_wordle_game_over", "WordleEmbed", "WordleErrorEmbed"]
def is_wordle_game_over(guesses: list[str], word: str) -> bool:
"""Check if the current game is over or not"""
if not guesses:
return False
if len(guesses) == WORDLE_GUESS_COUNT:
return True
return word.lower() in guesses
def footer() -> str:
"""Create the footer to put on the embed"""
today = tz_aware_now()
return f"{int_to_weekday(today.weekday())} {today.strftime('%d/%m/%Y')}"
class WordleColour(enum.IntEnum):
"""Colours for the Wordle embed"""
EMPTY = 0
WRONG_LETTER = 1
WRONG_POSITION = 2
CORRECT = 3
@dataclass
class WordleEmbed(EmbedBaseModel):
"""Embed for a Wordle game"""
guesses: list[str]
word: WordleWord
def _letter_colour(self, guess: str, index: int) -> WordleColour:
"""Get the colour for a guess at a given position"""
if guess[index] == self.word.word[index]:
return WordleColour.CORRECT
wrong_letter = 0
wrong_position = 0
for i, letter in enumerate(self.word.word):
if letter == guess[index] and guess[i] != guess[index]:
wrong_letter += 1
if i <= index and guess[i] == guess[index] and letter != guess[index]:
wrong_position += 1
if i >= index:
if wrong_position == 0:
break
if wrong_position <= wrong_letter:
return WordleColour.WRONG_POSITION
return WordleColour.WRONG_LETTER
def _guess_colours(self, guess: str) -> list[WordleColour]:
"""Create the colour codes for a specific guess"""
return [self._letter_colour(guess, i) for i in range(WORDLE_WORD_LENGTH)]
def colour_code_game(self) -> list[list[WordleColour]]:
"""Create the colour codes for an entire game"""
colours = []
# Add all the guesses
for guess in self.guesses:
colours.append(self._guess_colours(guess))
# Fill the rest with empty spots
for _ in range(WORDLE_GUESS_COUNT - len(colours)):
colours.append([WordleColour.EMPTY] * WORDLE_WORD_LENGTH)
return colours
def _colours_to_emojis(self, colours: list[list[WordleColour]]) -> list[list[str]]:
"""Turn the colours of the board into Discord emojis"""
colour_map = {
WordleColour.EMPTY: ":white_large_square:",
WordleColour.WRONG_LETTER: ":black_large_square:",
WordleColour.WRONG_POSITION: ":orange_square:",
WordleColour.CORRECT: ":green_square:",
}
emojis = []
for row in colours:
emojis.append(list(map(lambda char: colour_map[char], row)))
return emojis
@overrides
def to_embed(self, **kwargs) -> discord.Embed:
only_colours = kwargs.get("only_colours", False)
colours = self.colour_code_game()
embed = discord.Embed(colour=discord.Colour.blue(), title=f"Wordle #{self.word.word_id + 1}")
emojis = self._colours_to_emojis(colours)
rows = [" ".join(row) for row in emojis]
# Don't reveal anything if we only want to show the colours
if not only_colours and self.guesses:
for i, guess in enumerate(self.guesses):
rows[i] += f" ||{guess.upper()}||"
# If the game is over, reveal the word
if is_wordle_game_over(self.guesses, self.word.word):
rows.append(f"\n\nThe word was **{self.word.word.upper()}**!")
embed.description = "\n\n".join(rows)
embed.set_footer(text=footer())
return embed
@dataclass
class WordleErrorEmbed(EmbedBaseModel):
"""Embed to send error messages to the user"""
message: str
@overrides
def to_embed(self, **kwargs) -> discord.Embed:
embed = discord.Embed(colour=discord.Colour.red(), title="Wordle")
embed.description = self.message
embed.set_footer(text=footer())
return embed

View File

@ -32,7 +32,7 @@ async def fetch_free_games(http_session: ClientSession, database_session: AsyncS
if SEPARATOR not in entry["title"]:
continue
game = FreeGameEmbed.parse_obj(entry)
game = FreeGameEmbed.model_validate(entry)
games.append(game)
game_ids.append(game.dc_identifier)

View File

@ -72,12 +72,12 @@ def get_search_results(bs: BeautifulSoup) -> list[str]:
return list(dict.fromkeys(results))
async def google_search(http_client: ClientSession, query: str):
async def google_search(http_session: ClientSession, query: str):
"""Get the first 10 Google search results"""
query = urlencode({"q": query})
# Request 20 results in case of duplicates, bad matches, ...
async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response:
async with http_session.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response:
# Something went wrong
if response.status != http.HTTPStatus.OK:
return SearchData(query, response.status)

View File

@ -17,7 +17,7 @@ from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed
from didier.data.embeds.logging_embed import create_logging_embed
from didier.data.embeds.schedules import Schedule, parse_schedule
from didier.exceptions import HTTPException, NoMatch
from didier.exceptions import GetNoneException, HTTPException, NoMatch
from didier.utils.discord.prefix import get_prefix
from didier.utils.discord.snipe import should_snipe
from didier.utils.easter_eggs import detect_easter_egg
@ -33,12 +33,11 @@ class Didier(commands.Bot):
"""DIDIER <3"""
database_caches: CacheManager
error_channel: discord.abc.Messageable
error_channel: Optional[discord.abc.Messageable] = None
initial_extensions: tuple[str, ...] = ()
http_session: ClientSession
schedules: dict[settings.ScheduleType, Schedule] = {}
sniped: dict[int, tuple[discord.Message, Optional[discord.Message]]] = {}
wordle_words: set[str] = set()
def __init__(self):
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
@ -57,12 +56,17 @@ class Didier(commands.Bot):
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
)
self.tree.on_error = self.on_app_command_error
# I'm not creating a custom tree, this is the way to do it
self.tree.on_error = self.on_app_command_error # type: ignore[method-assign]
@cached_property
def main_guild(self) -> discord.Guild:
"""Obtain a reference to the main guild"""
return self.get_guild(settings.DISCORD_MAIN_GUILD)
guild = self.get_guild(settings.DISCORD_MAIN_GUILD)
if guild is None:
raise GetNoneException("Main guild could not be found in the bot's cache")
return guild
@property
def postgres_session(self) -> AsyncSession:
@ -77,9 +81,6 @@ class Didier(commands.Bot):
# Create directories that are ignored on GitHub
self._create_ignored_directories()
# Load the Wordle dictionary
self._load_wordle_words()
# Initialize caches
self.database_caches = CacheManager()
async with self.postgres_session as session:
@ -97,12 +98,6 @@ class Didier(commands.Bot):
await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs")
# Configure channel to send errors to
if settings.ERRORS_CHANNEL is not None:
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
else:
self.error_channel = self.get_user(self.owner_id)
def _create_ignored_directories(self):
"""Create directories that store ignored data"""
ignored = ["files/schedules"]
@ -137,12 +132,6 @@ class Didier(commands.Bot):
elif os.path.isdir(new_path := f"{path}/{file}"):
await self._load_directory_extensions(new_path)
def _load_wordle_words(self):
"""Load the dictionary of Wordle words"""
with open("files/dictionaries/words-english-wordle.txt", "r") as fp:
for line in fp:
self.wordle_words.add(line.strip())
async def load_schedules(self):
"""Parse & load all schedules into memory"""
self.schedules = {}
@ -162,18 +151,27 @@ class Didier(commands.Bot):
original message instead
"""
if ctx.message.reference is not None:
return await self.resolve_message(ctx.message.reference)
return await self.resolve_message(ctx.message.reference) or ctx.message
return ctx.message
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
async def resolve_message(self, reference: discord.MessageReference) -> Optional[discord.Message]:
"""Fetch a message from a reference"""
# Message is in the cache, return it
if reference.cached_message is not None:
return reference.cached_message
if reference.message_id is None:
return None
# For older messages: fetch them from the API
channel = self.get_channel(reference.channel_id)
if channel is None or isinstance(
channel,
(discord.CategoryChannel, discord.ForumChannel, discord.abc.PrivateChannel),
): # Logically this can't happen, but we have to please Mypy
return None
return await channel.fetch_message(reference.message_id)
async def confirm_message(self, message: discord.Message):
@ -194,7 +192,7 @@ class Didier(commands.Bot):
}
methods.get(level, logger.error)(message)
if log_to_discord:
if log_to_discord and self.error_channel is not None:
embed = create_logging_embed(level, message)
await self.error_channel.send(embed=embed)
@ -263,10 +261,9 @@ class Didier(commands.Bot):
await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True)
if settings.ERRORS_CHANNEL is not None:
if self.error_channel is not None:
embed = create_error_embed(await commands.Context.from_interaction(interaction), exception)
channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed)
await self.error_channel.send(embed=embed)
async def on_command_completion(self, ctx: commands.Context):
"""Event triggered when a message command completes successfully"""
@ -291,7 +288,7 @@ class Didier(commands.Bot):
# Hybrid command errors are wrapped in an additional error, so wrap it back out
if isinstance(exception, commands.HybridCommandError):
exception = exception.original
exception = exception.original # type: ignore[assignment]
# Ignore exceptions that aren't important
if isinstance(
@ -342,10 +339,9 @@ class Didier(commands.Bot):
# Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception)
if settings.ERRORS_CHANNEL is not None:
if self.error_channel is not None:
embed = create_error_embed(ctx, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed)
await self.error_channel.send(embed=embed)
async def on_message(self, message: discord.Message, /) -> None:
"""Event triggered when a message is sent"""
@ -354,7 +350,7 @@ class Didier(commands.Bot):
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
if "dider" in message.content.lower() and self.user is not None and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
@ -384,7 +380,7 @@ class Didier(commands.Bot):
# If the edited message is currently present in the snipe cache,
# don't update the <before>, but instead change the <after>
existing = self.sniped.get(before.channel.id, None)
existing = self.sniped.get(before.channel.id)
if existing is not None and existing[0].id == before.id:
before = existing[0]
@ -399,10 +395,9 @@ class Didier(commands.Bot):
async def on_task_error(self, exception: Exception):
"""Event triggered when a task raises an exception"""
if settings.ERRORS_CHANNEL is not None:
if self.error_channel:
embed = create_error_embed(None, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed)
await self.error_channel.send(embed=embed)
async def on_thread_create(self, thread: discord.Thread):
"""Event triggered when a new thread is created"""

View File

@ -1,6 +1,14 @@
from .get_none_exception import GetNoneException
from .http_exception import HTTPException
from .missing_env import MissingEnvironmentVariable
from .no_match import NoMatch, expect
from .not_in_main_guild_exception import NotInMainGuildException
__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect", "NotInMainGuildException"]
__all__ = [
"GetNoneException",
"HTTPException",
"MissingEnvironmentVariable",
"NoMatch",
"expect",
"NotInMainGuildException",
]

View File

@ -0,0 +1,5 @@
__all__ = ["GetNoneException"]
class GetNoneException(RuntimeError):
"""Exception raised when a Bot.get()-method returned None"""

View File

@ -12,6 +12,6 @@ class NotInMainGuildException(ValueError):
def __init__(self, user: Union[discord.User, discord.Member]):
super().__init__(
f"User {user.display_name} (id {user.id}) "
f"is not a member of the configured main guild (id {settings.DISCORD_MAIN_GUILD})."
f"User {user.display_name} (id `{user.id}`) "
f"is not a member of the configured main guild (id `{settings.DISCORD_MAIN_GUILD}`)."
)

View File

@ -0,0 +1,5 @@
import discord
__all__ = ["NON_MESSAGEABLE_CHANNEL_TYPES"]
NON_MESSAGEABLE_CHANNEL_TYPES = (discord.ForumChannel, discord.CategoryChannel, discord.abc.PrivateChannel)

View File

@ -15,11 +15,14 @@ def match_prefix(client: commands.Bot, message: Message) -> Optional[str]:
This is done dynamically through regexes to allow case-insensitivity
and variable amounts of whitespace among other things.
"""
mention = f"<@!?{client.user.id}>"
mention = f"<@!?{client.user.id}>" if client.user else None
regex = r"^({})\s*"
# Check which prefix was used
for prefix in [*constants.PREFIXES, mention]:
if prefix is None:
continue
match = re.match(regex.format(prefix), message.content, flags=re.I)
if match is not None:

View File

@ -25,24 +25,24 @@ class CreateBookmark(discord.ui.Modal, title="Create Bookmark"):
@overrides
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True)
label = self.name.value.strip()
try:
async with self.client.postgres_session as session:
bm = await create_bookmark(session, interaction.user.id, label, self.jump_url)
return await interaction.response.send_message(
f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True
return await interaction.followup.send(
f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`)."
)
except DuplicateInsertException:
# Label is already in use
return await interaction.response.send_message(
f"You already have a bookmark named `{label}`.", ephemeral=True
)
return await interaction.followup.send(f"You already have a bookmark named `{label}`.")
except ForbiddenNameException:
# Label isn't allowed
return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True)
return await interaction.followup.send(f"Bookmarks cannot be named `{label}`.")
@overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True)
await interaction.followup.send("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__)

View File

@ -26,12 +26,14 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
@overrides
async def on_submit(self, interaction: discord.Interaction):
await interaction.response.defer(ephemeral=True)
async with self.client.postgres_session as session:
joke = await add_dad_joke(session, str(self.joke.value))
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)
await interaction.followup.send(f"Successfully added joke #{joke.dad_joke_id}")
@overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True)
await interaction.followup.send("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__)

View File

@ -10,6 +10,8 @@ from didier import Didier
__all__ = ["AddEvent"]
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
class AddEvent(discord.ui.Modal, title="Add Event"):
"""Modal to add a new event"""
@ -33,15 +35,20 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
@overrides
async def on_submit(self, interaction: discord.Interaction) -> None:
await interaction.response.defer(ephemeral=True)
try:
parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels"))
except ParserError:
return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True)
return await interaction.followup.send("Unable to parse date argument.")
if self.client.get_channel(int(self.channel.value)) is None:
return await interaction.response.send_message(
f"Unable to find channel `{self.channel.value}`", ephemeral=True
)
channel = self.client.get_channel(int(self.channel.value))
if channel is None:
return await interaction.followup.send(f"Unable to find channel with id `{self.channel.value}`")
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
return await interaction.followup.send(f"Channel with id `{self.channel.value}` is not messageable.")
async with self.client.postgres_session as session:
event = await add_event(
@ -52,10 +59,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
channel_id=int(self.channel.value),
)
await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
await interaction.followup.send(f"Successfully added event `{event.event_id}`.")
self.client.dispatch("event_create", event)
@overrides
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
await interaction.response.send_message("Something went wrong.", ephemeral=True)
await interaction.followup.send("Something went wrong.", ephemeral=True)
traceback.print_tb(error.__traceback__)

File diff suppressed because it is too large Load Diff

View File

@ -36,6 +36,7 @@ def setup_logging():
# Configure discord handler
discord_log = logging.getLogger("discord")
discord_handler: logging.StreamHandler
# Make dev print to stderr instead, so you don't have to watch the file
if settings.SANDBOX:

View File

@ -28,6 +28,7 @@ omit = [
profile = "black"
[tool.mypy]
check_untyped_defs = true
files = [
"database/**/*.py",
"didier/**/*.py",
@ -35,7 +36,6 @@ files = [
]
plugins = [
"pydantic.mypy",
"sqlalchemy.ext.mypy.plugin"
]
[[tool.mypy.overrides]]
module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"]

View File

@ -1,22 +1,21 @@
black==22.3.0
coverage[toml]==6.4.1
freezegun==1.2.1
black==23.3.0
coverage[toml]==7.2.7
freezegun==1.2.2
isort==5.12.0
mypy==0.961
pre-commit==2.20.0
pytest==7.1.2
pytest-asyncio==0.18.3
pytest-env==0.6.2
sqlalchemy2-stubs==0.0.2a23
types-beautifulsoup4==4.11.3
types-python-dateutil==2.8.19
mypy==1.4.1
pre-commit==3.3.3
pytest==7.4.0
pytest-asyncio==0.21.0
pytest-env==0.8.2
types-beautifulsoup4==4.12.0.5
types-python-dateutil==2.8.19.13
# Flake8 + plugins
flake8==4.0.1
flake8-bandit==3.0.0
flake8-bugbear==22.7.1
flake8-docstrings==1.6.0
flake8-dunder-all==0.2.1
flake8-eradicate==1.2.1
flake8-isort==4.1.1
flake8-simplify==0.19.2
flake8==6.0.0
flake8-bandit==4.1.1
flake8-bugbear==23.6.5
flake8-docstrings==1.7.0
flake8-dunder-all==0.3.0
flake8-eradicate==1.5.0
flake8-isort==6.0.0
flake8-simplify==0.20.0

View File

@ -1,13 +1,13 @@
aiohttp==3.8.1
alembic==1.8.0
asyncpg==0.25.0
beautifulsoup4==4.11.1
discord.py==2.0.1
aiohttp==3.8.4
alembic==1.11.1
asyncpg==0.28.0
beautifulsoup4==4.12.2
discord.py==2.3.1
environs==9.5.0
feedparser==6.0.10
ics==0.7.2
markdownify==0.11.2
overrides==6.1.0
pydantic==1.9.1
markdownify==0.11.6
overrides==7.3.1
pydantic==2.0.2
python-dateutil==2.8.2
sqlalchemy[asyncio]==1.4.37
sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18

View File

@ -111,7 +111,7 @@ class ScheduleInfo:
role_id: Optional[int]
schedule_url: Optional[str]
name: Optional[str] = None
name: ScheduleType
SCHEDULE_DATA = [

View File

@ -1,138 +0,0 @@
from datetime import date, timedelta
import pytest
from freezegun import freeze_time
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import wordle as crud
from database.schemas import User, WordleGuess, WordleWord
@pytest.fixture
async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]:
"""Fixture to generate some guesses"""
guesses = []
for guess in ["TEST", "WORDLE", "WORDS"]:
guess = WordleGuess(user_id=user.user_id, guess=guess)
postgres.add(guess)
await postgres.commit()
guesses.append(guess)
return guesses
@pytest.mark.postgres
async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User):
"""Test getting an active game when there is none"""
result = await crud.get_active_wordle_game(postgres, user.user_id)
assert not result
@pytest.mark.postgres
async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]):
"""Test getting an active game when there is one"""
result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id)
assert result == wordle_guesses
@pytest.mark.postgres
async def test_get_daily_word_none(postgres: AsyncSession):
"""Test getting the daily word when the database is empty"""
result = await crud.get_daily_word(postgres)
assert result is None
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_get_daily_word_not_today(postgres: AsyncSession):
"""Test getting the daily word when there is an entry, but not for today"""
day = date.today() - timedelta(days=1)
word = "testword"
word_instance = WordleWord(word=word, day=day)
postgres.add(word_instance)
await postgres.commit()
assert await crud.get_daily_word(postgres) is None
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_get_daily_word_present(postgres: AsyncSession):
"""Test getting the daily word when there is one for today"""
day = date.today()
word = "testword"
word_instance = WordleWord(word=word, day=day)
postgres.add(word_instance)
await postgres.commit()
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_set_daily_word_none_present(postgres: AsyncSession):
"""Test setting the daily word when there is none"""
assert await crud.get_daily_word(postgres) is None
word = "testword"
await crud.set_daily_word(postgres, word)
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_set_daily_word_present(postgres: AsyncSession):
"""Test setting the daily word when there already is one"""
word = "testword"
await crud.set_daily_word(postgres, word)
await crud.set_daily_word(postgres, "another word")
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_set_daily_word_force_overwrite(postgres: AsyncSession):
"""Test setting the daily word when there already is one, but "forced" is set to True"""
word = "testword"
await crud.set_daily_word(postgres, word)
word = "anotherword"
await crud.set_daily_word(postgres, word, forced=True)
daily_word = await crud.get_daily_word(postgres)
assert daily_word is not None
assert daily_word.word == word
@pytest.mark.postgres
async def test_make_wordle_guess(postgres: AsyncSession, user: User):
"""Test making a guess in your current game"""
test_user_id = user.user_id
guess = "guess"
await crud.make_wordle_guess(postgres, test_user_id, guess)
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess]
other_guess = "otherguess"
await crud.make_wordle_guess(postgres, test_user_id, other_guess)
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess, other_guess]
@pytest.mark.postgres
async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User):
"""Test dropping the collection of active games"""
test_user_id = user.user_id
assert await crud.get_active_wordle_game(postgres, test_user_id)
await crud.reset_wordle_games(postgres)
assert not await crud.get_active_wordle_game(postgres, test_user_id)

View File

@ -1,72 +0,0 @@
import datetime
import pytest
from freezegun import freeze_time
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud import wordle_stats as crud
from database.schemas import User, WordleStats
async def insert_game_stats(postgres: AsyncSession, stats: WordleStats):
"""Helper function to insert some stats"""
postgres.add(stats)
await postgres.commit()
@pytest.mark.postgres
async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User):
"""Test getting a user's stats when the db is empty"""
test_user_id = user.user_id
statement = select(WordleStats).where(WordleStats.user_id == test_user_id)
assert (await postgres.execute(statement)).scalar_one_or_none() is None
await crud.get_wordle_stats(postgres, test_user_id)
assert (await postgres.execute(statement)).scalar_one_or_none() is not None
@pytest.mark.postgres
async def test_get_stats_existing_returns(postgres: AsyncSession, user: User):
"""Test getting a user's stats when there's already an entry present"""
test_user_id = user.user_id
stats = WordleStats(user_id=test_user_id)
stats.games = 20
await insert_game_stats(postgres, stats)
found_stats = await crud.get_wordle_stats(postgres, test_user_id)
assert found_stats.games == 20
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_complete_wordle_game_won(postgres: AsyncSession, user: User):
"""Test completing a wordle game when you win"""
test_user_id = user.user_id
await crud.complete_wordle_game(postgres, test_user_id, win=True)
stats = await crud.get_wordle_stats(postgres, test_user_id)
assert stats.games == 1
assert stats.wins == 1
assert stats.current_streak == 1
assert stats.highest_streak == 1
assert stats.last_win == datetime.date.today()
@pytest.mark.postgres
@freeze_time("2022-07-30")
async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User):
"""Test completing a wordle game when you lose"""
test_user_id = user.user_id
stats = WordleStats(user_id=test_user_id)
stats.current_streak = 10
await insert_game_stats(postgres, stats)
await crud.complete_wordle_game(postgres, test_user_id, win=False)
stats = await crud.get_wordle_stats(postgres, test_user_id)
# Check that streak was broken
assert stats.current_streak == 0
assert stats.games == 1