mirror of https://github.com/stijndcl/didier
commit
5bfd3a92a9
|
@ -3,12 +3,12 @@ default_language_version:
|
|||
|
||||
repos:
|
||||
- repo: https://github.com/ambv/black
|
||||
rev: 22.3.0
|
||||
rev: 23.3.0
|
||||
hooks:
|
||||
- id: black
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.3.0
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: check-json
|
||||
- id: end-of-file-fixer
|
||||
|
@ -21,7 +21,7 @@ repos:
|
|||
- id: isort
|
||||
|
||||
- repo: https://github.com/PyCQA/autoflake
|
||||
rev: v1.4
|
||||
rev: v2.2.0
|
||||
hooks:
|
||||
- id: autoflake
|
||||
name: autoflake (python)
|
||||
|
@ -31,7 +31,7 @@ repos:
|
|||
- "--ignore-init-module-imports"
|
||||
|
||||
- repo: https://github.com/PyCQA/flake8
|
||||
rev: 4.0.1
|
||||
rev: 6.0.0
|
||||
hooks:
|
||||
- id: flake8
|
||||
exclude: ^(alembic|.github)
|
||||
|
|
|
@ -0,0 +1,94 @@
|
|||
"""Migrate to 2.x
|
||||
|
||||
Revision ID: 09128b6e34dd
|
||||
Revises: 1e3e7f4192c4
|
||||
Create Date: 2023-07-07 16:23:15.990231
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "09128b6e34dd"
|
||||
down_revision = "1e3e7f4192c4"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("bank", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("birthdays", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("bookmarks", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("command_stats", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op:
|
||||
batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("deadlines", schema=None) as batch_op:
|
||||
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("github_links", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("nightly_data", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("reminders", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("ufora_announcements", schema=None) as batch_op:
|
||||
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
|
||||
batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=False)
|
||||
|
||||
with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op:
|
||||
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op:
|
||||
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("ufora_announcements", schema=None) as batch_op:
|
||||
batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=True)
|
||||
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("reminders", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("nightly_data", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("github_links", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("deadlines", schema=None) as batch_op:
|
||||
batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op:
|
||||
batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("command_stats", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("bookmarks", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("birthdays", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
with op.batch_alter_table("bank", schema=None) as batch_op:
|
||||
batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True)
|
||||
|
||||
# ### end Alembic commands ###
|
|
@ -0,0 +1,57 @@
|
|||
"""Remove wordle
|
||||
|
||||
Revision ID: 1e3e7f4192c4
|
||||
Revises: 954ad804f057
|
||||
Create Date: 2023-07-07 14:52:20.993687
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1e3e7f4192c4"
|
||||
down_revision = "954ad804f057"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("wordle_word")
|
||||
op.drop_table("wordle_stats")
|
||||
op.drop_table("wordle_guesses")
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"wordle_guesses",
|
||||
sa.Column("wordle_guess_id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BIGINT(), autoincrement=False, nullable=True),
|
||||
sa.Column("guess", sa.TEXT(), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], name="wordle_guesses_user_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("wordle_guess_id", name="wordle_guesses_pkey"),
|
||||
)
|
||||
op.create_table(
|
||||
"wordle_stats",
|
||||
sa.Column("wordle_stats_id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("user_id", sa.BIGINT(), autoincrement=False, nullable=True),
|
||||
sa.Column("last_win", sa.DATE(), autoincrement=False, nullable=True),
|
||||
sa.Column("games", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("wins", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("current_streak", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.Column("highest_streak", sa.INTEGER(), server_default=sa.text("0"), autoincrement=False, nullable=False),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.user_id"], name="wordle_stats_user_id_fkey"),
|
||||
sa.PrimaryKeyConstraint("wordle_stats_id", name="wordle_stats_pkey"),
|
||||
)
|
||||
op.create_table(
|
||||
"wordle_word",
|
||||
sa.Column("word_id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("word", sa.TEXT(), autoincrement=False, nullable=False),
|
||||
sa.Column("day", sa.DATE(), autoincrement=False, nullable=False),
|
||||
sa.PrimaryKeyConstraint("word_id", name="wordle_word_pkey"),
|
||||
sa.UniqueConstraint("day", name="wordle_word_day_key"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
|
@ -1,2 +0,0 @@
|
|||
WORDLE_GUESS_COUNT = 6
|
||||
WORDLE_WORD_LENGTH = 5
|
|
@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[
|
|||
if query is not None:
|
||||
statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%"))
|
||||
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]:
|
||||
|
|
|
@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
|
|||
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
|
||||
"""Get a list of all commands"""
|
||||
statement = select(CustomCommand)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
|
|
|
@ -38,4 +38,4 @@ async def get_deadlines(
|
|||
statement = statement.where(Deadline.course_id == course.course_id)
|
||||
|
||||
statement = statement.options(selectinload(Deadline.course))
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
|
|
@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"]
|
|||
async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]:
|
||||
"""Return a list of all easter eggs"""
|
||||
statement = select(EasterEgg)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
|
|
@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even
|
|||
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
|
||||
"""Get a list of all upcoming events"""
|
||||
statement = select(Event).where(Event.timestamp > now)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:
|
||||
|
|
|
@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]):
|
|||
async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]:
|
||||
"""Filter a list of game IDs down to the ones that aren't in the database yet"""
|
||||
statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids))
|
||||
matches: list[int] = (await session.execute(statement)).scalars().all()
|
||||
matches: list[int] = list((await session.execute(statement)).scalars().all())
|
||||
return list(set(game_ids).difference(matches))
|
||||
|
|
|
@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id:
|
|||
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
|
||||
"""Get a user's GitHub links"""
|
||||
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
|
|
@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
|
|||
async def get_all_links(session: AsyncSession) -> list[Link]:
|
||||
"""Get a list of all links"""
|
||||
statement = select(Link)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def add_link(session: AsyncSession, name: str, url: str) -> Link:
|
||||
|
|
|
@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou
|
|||
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
|
||||
"""Get a list of all memes"""
|
||||
statement = select(MemeTemplate)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:
|
||||
|
|
|
@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"]
|
|||
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
|
||||
"""Get a list of all Reminders for a given category"""
|
||||
statement = select(Reminder).where(Reminder.category == category)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:
|
||||
|
|
|
@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_
|
|||
async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]:
|
||||
"""Get all courses where announcements are enabled"""
|
||||
statement = select(UforaCourse).where(UforaCourse.log_announcements)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def create_new_announcement(
|
||||
|
|
|
@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor
|
|||
# Search case-insensitively
|
||||
query = query.lower()
|
||||
|
||||
statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
|
||||
result = (await session.execute(statement)).scalars().first()
|
||||
if result:
|
||||
return result
|
||||
course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
|
||||
course_result = (await session.execute(course_statement)).scalars().first()
|
||||
if course_result:
|
||||
return course_result
|
||||
|
||||
statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
|
||||
result = (await session.execute(statement)).scalars().first()
|
||||
return result.course if result else None
|
||||
alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
|
||||
alias_result = (await session.execute(alias_statement)).scalars().first()
|
||||
return alias_result.course if alias_result else None
|
||||
|
|
|
@ -1,85 +0,0 @@
|
|||
import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.schemas import WordleGuess, WordleWord
|
||||
|
||||
__all__ = [
|
||||
"get_active_wordle_game",
|
||||
"get_wordle_guesses",
|
||||
"make_wordle_guess",
|
||||
"set_daily_word",
|
||||
"reset_wordle_games",
|
||||
]
|
||||
|
||||
|
||||
async def get_active_wordle_game(session: AsyncSession, user_id: int) -> list[WordleGuess]:
|
||||
"""Find a player's active game"""
|
||||
await get_or_add_user(session, user_id)
|
||||
statement = select(WordleGuess).where(WordleGuess.user_id == user_id)
|
||||
guesses = (await session.execute(statement)).scalars().all()
|
||||
return guesses
|
||||
|
||||
|
||||
async def get_wordle_guesses(session: AsyncSession, user_id: int) -> list[str]:
|
||||
"""Get the strings of a player's guesses"""
|
||||
active_game = await get_active_wordle_game(session, user_id)
|
||||
return list(map(lambda g: g.guess.lower(), active_game))
|
||||
|
||||
|
||||
async def make_wordle_guess(session: AsyncSession, user_id: int, guess: str):
|
||||
"""Make a guess in your current game"""
|
||||
guess_instance = WordleGuess(user_id=user_id, guess=guess)
|
||||
session.add(guess_instance)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_daily_word(session: AsyncSession) -> Optional[WordleWord]:
|
||||
"""Get the word of today"""
|
||||
statement = select(WordleWord).where(WordleWord.day == datetime.date.today())
|
||||
row = (await session.execute(statement)).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
return None
|
||||
|
||||
return row
|
||||
|
||||
|
||||
async def set_daily_word(session: AsyncSession, word: str, *, forced: bool = False) -> str:
|
||||
"""Set the word of today
|
||||
|
||||
This does NOT overwrite the existing word if there is one, so that it can safely run
|
||||
on startup every time.
|
||||
|
||||
In order to always overwrite the current word, set the "forced"-kwarg to True.
|
||||
|
||||
Returns the word that was chosen. If one already existed, return that instead.
|
||||
"""
|
||||
current_word = await get_daily_word(session)
|
||||
|
||||
if current_word is None:
|
||||
current_word = WordleWord(word=word, day=datetime.date.today())
|
||||
session.add(current_word)
|
||||
await session.commit()
|
||||
|
||||
# Remove all active games
|
||||
await reset_wordle_games(session)
|
||||
elif forced:
|
||||
current_word.word = word
|
||||
session.add(current_word)
|
||||
await session.commit()
|
||||
|
||||
# Remove all active games
|
||||
await reset_wordle_games(session)
|
||||
|
||||
return current_word.word
|
||||
|
||||
|
||||
async def reset_wordle_games(session: AsyncSession):
|
||||
"""Reset all active games"""
|
||||
statement = delete(WordleGuess)
|
||||
await session.execute(statement)
|
||||
await session.commit()
|
|
@ -1,60 +0,0 @@
|
|||
from datetime import date
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.schemas import WordleStats
|
||||
|
||||
__all__ = ["get_wordle_stats", "complete_wordle_game"]
|
||||
|
||||
|
||||
async def get_wordle_stats(session: AsyncSession, user_id: int) -> WordleStats:
|
||||
"""Get a user's wordle stats
|
||||
|
||||
If no entry is found, it is first created
|
||||
"""
|
||||
await get_or_add_user(session, user_id)
|
||||
|
||||
statement = select(WordleStats).where(WordleStats.user_id == user_id)
|
||||
stats = (await session.execute(statement)).scalar_one_or_none()
|
||||
if stats is not None:
|
||||
return stats
|
||||
|
||||
stats = WordleStats(user_id=user_id)
|
||||
session.add(stats)
|
||||
await session.commit()
|
||||
await session.refresh(stats)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
async def complete_wordle_game(session: AsyncSession, user_id: int, win: bool):
|
||||
"""Update the user's Wordle stats"""
|
||||
stats = await get_wordle_stats(session, user_id)
|
||||
stats.games += 1
|
||||
|
||||
if win:
|
||||
stats.wins += 1
|
||||
|
||||
# Update streak
|
||||
today = date.today()
|
||||
last_win = stats.last_win
|
||||
stats.last_win = today
|
||||
|
||||
if last_win is None or (today - last_win).days > 1:
|
||||
# Never won a game before or streak is over
|
||||
stats.current_streak = 1
|
||||
else:
|
||||
# On a streak: increase counter
|
||||
stats.current_streak += 1
|
||||
|
||||
# Update max streak if necessary
|
||||
if stats.current_streak > stats.highest_streak:
|
||||
stats.highest_streak = stats.current_streak
|
||||
else:
|
||||
# Streak is over
|
||||
stats.current_streak = 0
|
||||
|
||||
session.add(stats)
|
||||
await session.commit()
|
|
@ -1,8 +1,7 @@
|
|||
from urllib.parse import quote_plus
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
|
||||
|
||||
import settings
|
||||
|
||||
|
@ -22,6 +21,4 @@ postgres_engine = create_async_engine(
|
|||
future=True,
|
||||
)
|
||||
|
||||
DBSession = sessionmaker(
|
||||
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
||||
)
|
||||
DBSession = async_sessionmaker(autocommit=False, autoflush=False, bind=postgres_engine, expire_on_commit=False)
|
||||
|
|
|
@ -1,27 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import (
|
||||
BigInteger,
|
||||
Boolean,
|
||||
Column,
|
||||
Date,
|
||||
DateTime,
|
||||
Enum,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import declarative_base, relationship
|
||||
from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship
|
||||
from sqlalchemy.types import DateTime
|
||||
|
||||
from database import enums
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Base",
|
||||
"Bank",
|
||||
|
@ -45,33 +32,37 @@ __all__ = [
|
|||
"UforaCourse",
|
||||
"UforaCourseAlias",
|
||||
"User",
|
||||
"WordleGuess",
|
||||
"WordleStats",
|
||||
"WordleWord",
|
||||
]
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""Required base class for all tables"""
|
||||
|
||||
# Make all DateTimes timezone-aware
|
||||
type_annotation_map = {datetime: DateTime(timezone=True)}
|
||||
|
||||
|
||||
class Bank(Base):
|
||||
"""A user's currency information"""
|
||||
|
||||
__tablename__ = "bank"
|
||||
|
||||
bank_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
bank_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
|
||||
dinks: int = Column(BigInteger, server_default="0", nullable=False)
|
||||
invested: int = Column(BigInteger, server_default="0", nullable=False)
|
||||
dinks: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False)
|
||||
invested: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False)
|
||||
|
||||
# Interest rate
|
||||
interest_level: int = Column(Integer, server_default="1", nullable=False)
|
||||
interest_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
|
||||
|
||||
# Maximum amount that can be stored in the bank
|
||||
capacity_level: int = Column(Integer, server_default="1", nullable=False)
|
||||
capacity_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
|
||||
|
||||
# Maximum amount that can be robbed
|
||||
rob_level: int = Column(Integer, server_default="1", nullable=False)
|
||||
rob_level: Mapped[int] = mapped_column(server_default="1", nullable=False)
|
||||
|
||||
user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin")
|
||||
user: Mapped[User] = relationship(uselist=False, back_populates="bank", lazy="selectin")
|
||||
|
||||
|
||||
class Birthday(Base):
|
||||
|
@ -79,11 +70,11 @@ class Birthday(Base):
|
|||
|
||||
__tablename__ = "birthdays"
|
||||
|
||||
birthday_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
birthday: date = Column(Date, nullable=False)
|
||||
birthday_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
birthday: Mapped[date] = mapped_column(nullable=False)
|
||||
|
||||
user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin")
|
||||
user: Mapped[User] = relationship(uselist=False, back_populates="birthday", lazy="selectin")
|
||||
|
||||
|
||||
class Bookmark(Base):
|
||||
|
@ -92,26 +83,26 @@ class Bookmark(Base):
|
|||
__tablename__ = "bookmarks"
|
||||
__table_args__ = (UniqueConstraint("user_id", "label"),)
|
||||
|
||||
bookmark_id: int = Column(Integer, primary_key=True)
|
||||
label: str = Column(Text, nullable=False)
|
||||
jump_url: str = Column(Text, nullable=False)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
bookmark_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
label: Mapped[str] = mapped_column(nullable=False)
|
||||
jump_url: Mapped[str] = mapped_column(nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
|
||||
user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin")
|
||||
user: Mapped[User] = relationship(back_populates="bookmarks", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class CommandStats(Base):
|
||||
"""Metrics on how often commands are used"""
|
||||
|
||||
__tablename__ = "command_stats"
|
||||
command_stats_id: int = Column(Integer, primary_key=True)
|
||||
command: str = Column(Text, nullable=False)
|
||||
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
slash: bool = Column(Boolean, nullable=False)
|
||||
context_menu: bool = Column(Boolean, nullable=False)
|
||||
command_stats_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
command: Mapped[str] = mapped_column(nullable=False)
|
||||
timestamp: Mapped[datetime] = mapped_column(nullable=False)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
slash: Mapped[bool] = mapped_column(nullable=False)
|
||||
context_menu: Mapped[bool] = mapped_column(nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin")
|
||||
user: Mapped[User] = relationship(back_populates="command_stats", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class CustomCommand(Base):
|
||||
|
@ -119,13 +110,13 @@ class CustomCommand(Base):
|
|||
|
||||
__tablename__ = "custom_commands"
|
||||
|
||||
command_id: int = Column(Integer, primary_key=True)
|
||||
name: str = Column(Text, nullable=False, unique=True)
|
||||
indexed_name: str = Column(Text, nullable=False, index=True)
|
||||
response: str = Column(Text, nullable=False)
|
||||
command_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
indexed_name: Mapped[str] = mapped_column(nullable=False, index=True)
|
||||
response: Mapped[str] = mapped_column(nullable=False)
|
||||
|
||||
aliases: list[CustomCommandAlias] = relationship(
|
||||
"CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
|
||||
aliases: Mapped[List[CustomCommandAlias]] = relationship(
|
||||
back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
|
@ -134,12 +125,12 @@ class CustomCommandAlias(Base):
|
|||
|
||||
__tablename__ = "custom_command_aliases"
|
||||
|
||||
alias_id: int = Column(Integer, primary_key=True)
|
||||
alias: str = Column(Text, nullable=False, unique=True)
|
||||
indexed_alias: str = Column(Text, nullable=False, index=True)
|
||||
command_id: int = Column(Integer, ForeignKey("custom_commands.command_id"))
|
||||
alias_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
alias: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
indexed_alias: Mapped[str] = mapped_column(nullable=False, index=True)
|
||||
command_id: Mapped[int] = mapped_column(ForeignKey("custom_commands.command_id"))
|
||||
|
||||
command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin")
|
||||
command: Mapped[CustomCommand] = relationship(back_populates="aliases", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class DadJoke(Base):
|
||||
|
@ -147,8 +138,8 @@ class DadJoke(Base):
|
|||
|
||||
__tablename__ = "dad_jokes"
|
||||
|
||||
dad_joke_id: int = Column(Integer, primary_key=True)
|
||||
joke: str = Column(Text, nullable=False)
|
||||
dad_joke_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
joke: Mapped[str] = mapped_column(nullable=False)
|
||||
|
||||
|
||||
class Deadline(Base):
|
||||
|
@ -156,12 +147,12 @@ class Deadline(Base):
|
|||
|
||||
__tablename__ = "deadlines"
|
||||
|
||||
deadline_id: int = Column(Integer, primary_key=True)
|
||||
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
|
||||
name: str = Column(Text, nullable=False)
|
||||
deadline: datetime = Column(DateTime(timezone=True), nullable=False)
|
||||
deadline_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
deadline: Mapped[datetime] = mapped_column(nullable=False)
|
||||
|
||||
course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin")
|
||||
course: Mapped[UforaCourse] = relationship(back_populates="deadlines", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class EasterEgg(Base):
|
||||
|
@ -169,11 +160,11 @@ class EasterEgg(Base):
|
|||
|
||||
__tablename__ = "easter_eggs"
|
||||
|
||||
easter_egg_id: int = Column(Integer, primary_key=True)
|
||||
match: str = Column(Text, nullable=False)
|
||||
response: str = Column(Text, nullable=False)
|
||||
exact: bool = Column(Boolean, nullable=False, server_default="1")
|
||||
startswith: bool = Column(Boolean, nullable=False, server_default="1")
|
||||
easter_egg_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
match: Mapped[str] = mapped_column(nullable=False)
|
||||
response: Mapped[str] = mapped_column(nullable=False)
|
||||
exact: Mapped[bool] = mapped_column(nullable=False, server_default="1")
|
||||
startswith: Mapped[bool] = mapped_column(nullable=False, server_default="1")
|
||||
|
||||
|
||||
class Event(Base):
|
||||
|
@ -181,11 +172,11 @@ class Event(Base):
|
|||
|
||||
__tablename__ = "events"
|
||||
|
||||
event_id: int = Column(Integer, primary_key=True)
|
||||
name: str = Column(Text, nullable=False)
|
||||
description: Optional[str] = Column(Text, nullable=True)
|
||||
notification_channel: int = Column(BigInteger, nullable=False)
|
||||
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
|
||||
event_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(nullable=True)
|
||||
notification_channel: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||
timestamp: Mapped[datetime] = mapped_column(nullable=False)
|
||||
|
||||
|
||||
class FreeGame(Base):
|
||||
|
@ -193,7 +184,7 @@ class FreeGame(Base):
|
|||
|
||||
__tablename__ = "free_games"
|
||||
|
||||
free_game_id: int = Column(Integer, primary_key=True)
|
||||
free_game_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
|
||||
|
||||
class GitHubLink(Base):
|
||||
|
@ -201,11 +192,11 @@ class GitHubLink(Base):
|
|||
|
||||
__tablename__ = "github_links"
|
||||
|
||||
github_link_id: int = Column(Integer, primary_key=True)
|
||||
url: str = Column(Text, nullable=False, unique=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
github_link_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
url: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
|
||||
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin")
|
||||
user: Mapped[User] = relationship(back_populates="github_links", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Link(Base):
|
||||
|
@ -213,9 +204,9 @@ class Link(Base):
|
|||
|
||||
__tablename__ = "links"
|
||||
|
||||
link_id: int = Column(Integer, primary_key=True)
|
||||
name: str = Column(Text, nullable=False, unique=True)
|
||||
url: str = Column(Text, nullable=False)
|
||||
link_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
url: Mapped[str] = mapped_column(nullable=False)
|
||||
|
||||
|
||||
class MemeTemplate(Base):
|
||||
|
@ -223,10 +214,10 @@ class MemeTemplate(Base):
|
|||
|
||||
__tablename__ = "meme"
|
||||
|
||||
meme_id: int = Column(Integer, primary_key=True)
|
||||
name: str = Column(Text, nullable=False, unique=True)
|
||||
template_id: int = Column(Integer, nullable=False, unique=True)
|
||||
field_count: int = Column(Integer, nullable=False)
|
||||
meme_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
template_id: Mapped[int] = mapped_column(nullable=False, unique=True)
|
||||
field_count: Mapped[int] = mapped_column(nullable=False)
|
||||
|
||||
|
||||
class NightlyData(Base):
|
||||
|
@ -234,12 +225,12 @@ class NightlyData(Base):
|
|||
|
||||
__tablename__ = "nightly_data"
|
||||
|
||||
nightly_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
last_nightly: Optional[date] = Column(Date, nullable=True)
|
||||
count: int = Column(Integer, server_default="0", nullable=False)
|
||||
nightly_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
last_nightly: Mapped[Optional[date]] = mapped_column(nullable=True)
|
||||
count: Mapped[int] = mapped_column(server_default="0", nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin")
|
||||
user: Mapped[User] = relationship(back_populates="nightly_data", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Reminder(Base):
|
||||
|
@ -247,11 +238,11 @@ class Reminder(Base):
|
|||
|
||||
__tablename__ = "reminders"
|
||||
|
||||
reminder_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False)
|
||||
reminder_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id"))
|
||||
category: Mapped[enums.ReminderCategory] = mapped_column(nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin")
|
||||
user: Mapped[User] = relationship(back_populates="reminders", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Task(Base):
|
||||
|
@ -259,9 +250,9 @@ class Task(Base):
|
|||
|
||||
__tablename__ = "tasks"
|
||||
|
||||
task_id: int = Column(Integer, primary_key=True)
|
||||
task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True)
|
||||
previous_run: datetime = Column(DateTime(timezone=True), nullable=True)
|
||||
task_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
task: Mapped[enums.TaskType] = mapped_column(nullable=False, unique=True)
|
||||
previous_run: Mapped[datetime] = mapped_column(nullable=True)
|
||||
|
||||
|
||||
class UforaCourse(Base):
|
||||
|
@ -269,25 +260,25 @@ class UforaCourse(Base):
|
|||
|
||||
__tablename__ = "ufora_courses"
|
||||
|
||||
course_id: int = Column(Integer, primary_key=True)
|
||||
name: str = Column(Text, nullable=False, unique=True)
|
||||
code: str = Column(Text, nullable=False, unique=True)
|
||||
year: int = Column(Integer, nullable=False)
|
||||
compulsory: bool = Column(Boolean, server_default="1", nullable=False)
|
||||
role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False)
|
||||
overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False)
|
||||
course_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
name: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
code: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
year: Mapped[int] = mapped_column(nullable=False)
|
||||
compulsory: Mapped[bool] = mapped_column(server_default="1", nullable=False)
|
||||
role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
|
||||
overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
|
||||
# This is not the greatest fix, but there can only ever be two, so it will do the job
|
||||
alternative_overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False)
|
||||
log_announcements: bool = Column(Boolean, server_default="0", nullable=False)
|
||||
alternative_overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False)
|
||||
log_announcements: Mapped[bool] = mapped_column(server_default="0", nullable=False)
|
||||
|
||||
announcements: list[UforaAnnouncement] = relationship(
|
||||
"UforaAnnouncement", back_populates="course", cascade="all, delete-orphan", lazy="selectin"
|
||||
announcements: Mapped[List[UforaAnnouncement]] = relationship(
|
||||
back_populates="course", cascade="all, delete-orphan", lazy="selectin"
|
||||
)
|
||||
aliases: list[UforaCourseAlias] = relationship(
|
||||
"UforaCourseAlias", back_populates="course", cascade="all, delete-orphan", lazy="selectin"
|
||||
aliases: Mapped[List[UforaCourseAlias]] = relationship(
|
||||
back_populates="course", cascade="all, delete-orphan", lazy="selectin"
|
||||
)
|
||||
deadlines: list[Deadline] = relationship(
|
||||
"Deadline", back_populates="course", cascade="all, delete-orphan", lazy="selectin"
|
||||
deadlines: Mapped[List[Deadline]] = relationship(
|
||||
back_populates="course", cascade="all, delete-orphan", lazy="selectin"
|
||||
)
|
||||
|
||||
|
||||
|
@ -296,11 +287,11 @@ class UforaCourseAlias(Base):
|
|||
|
||||
__tablename__ = "ufora_course_aliases"
|
||||
|
||||
alias_id: int = Column(Integer, primary_key=True)
|
||||
alias: str = Column(Text, nullable=False, unique=True)
|
||||
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
|
||||
alias_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
alias: Mapped[str] = mapped_column(nullable=False, unique=True)
|
||||
course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
|
||||
|
||||
course: UforaCourse = relationship("UforaCourse", back_populates="aliases", uselist=False, lazy="selectin")
|
||||
course: Mapped[UforaCourse] = relationship(back_populates="aliases", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class UforaAnnouncement(Base):
|
||||
|
@ -308,11 +299,11 @@ class UforaAnnouncement(Base):
|
|||
|
||||
__tablename__ = "ufora_announcements"
|
||||
|
||||
announcement_id: int = Column(Integer, primary_key=True)
|
||||
course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id"))
|
||||
publication_date: date = Column(Date)
|
||||
announcement_id: Mapped[int] = mapped_column(primary_key=True)
|
||||
course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id"))
|
||||
publication_date: Mapped[date] = mapped_column()
|
||||
|
||||
course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin")
|
||||
course: Mapped[UforaCourse] = relationship(back_populates="announcements", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class User(Base):
|
||||
|
@ -320,70 +311,26 @@ class User(Base):
|
|||
|
||||
__tablename__ = "users"
|
||||
|
||||
user_id: int = Column(BigInteger, primary_key=True)
|
||||
user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True)
|
||||
|
||||
bank: Bank = relationship(
|
||||
"Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
bank: Mapped[Bank] = relationship(
|
||||
back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
birthday: Optional[Birthday] = relationship(
|
||||
"Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
birthday: Mapped[Optional[Birthday]] = relationship(
|
||||
back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
bookmarks: list[Bookmark] = relationship(
|
||||
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
bookmarks: Mapped[List[Bookmark]] = relationship(
|
||||
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
command_stats: list[CommandStats] = relationship(
|
||||
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
command_stats: Mapped[List[CommandStats]] = relationship(
|
||||
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
github_links: list[GitHubLink] = relationship(
|
||||
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
github_links: Mapped[List[GitHubLink]] = relationship(
|
||||
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
nightly_data: NightlyData = relationship(
|
||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
nightly_data: Mapped[NightlyData] = relationship(
|
||||
back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
reminders: list[Reminder] = relationship(
|
||||
"Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
reminders: Mapped[List[Reminder]] = relationship(
|
||||
back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_guesses: list[WordleGuess] = relationship(
|
||||
"WordleGuess", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
wordle_stats: WordleStats = relationship(
|
||||
"WordleStats", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
class WordleGuess(Base):
|
||||
"""A user's Wordle guesses for today"""
|
||||
|
||||
__tablename__ = "wordle_guesses"
|
||||
|
||||
wordle_guess_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
guess: str = Column(Text, nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="wordle_guesses", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class WordleStats(Base):
|
||||
"""Stats about a user's wordle performance"""
|
||||
|
||||
__tablename__ = "wordle_stats"
|
||||
|
||||
wordle_stats_id: int = Column(Integer, primary_key=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
last_win: Optional[date] = Column(Date, nullable=True)
|
||||
games: int = Column(Integer, server_default="0", nullable=False)
|
||||
wins: int = Column(Integer, server_default="0", nullable=False)
|
||||
current_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||
highest_streak: int = Column(Integer, server_default="0", nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="wordle_stats", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class WordleWord(Base):
|
||||
"""The current Wordle word"""
|
||||
|
||||
__tablename__ = "wordle_word"
|
||||
|
||||
word_id: int = Column(Integer, primary_key=True)
|
||||
word: str = Column(Text, nullable=False)
|
||||
day: date = Column(Date, nullable=False, unique=True)
|
||||
|
|
|
@ -0,0 +1,171 @@
|
|||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from database.engine import DBSession
|
||||
from database.schemas import UforaCourse, UforaCourseAlias
|
||||
|
||||
__all__ = ["main"]
|
||||
|
||||
|
||||
async def purge_aliases(session: AsyncSession):
|
||||
"""Delete old aliases for courses that will get a new id"""
|
||||
codes = ["C004074", "C004073", "C004075", "C002309"]
|
||||
for course_code in codes:
|
||||
select_stmt = (
|
||||
select(UforaCourse).where(UforaCourse.code == course_code).options(selectinload(UforaCourse.aliases))
|
||||
)
|
||||
course: UforaCourse = (await session.execute(select_stmt)).scalar_one()
|
||||
|
||||
for alias in list(course.aliases):
|
||||
await session.delete(alias)
|
||||
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def main():
|
||||
"""Add the Ufora courses for the 2023-2024 academic year"""
|
||||
session: AsyncSession
|
||||
async with DBSession() as session:
|
||||
# # Remove Advanced Databases (which no longer exists)
|
||||
delete_stmt = delete(UforaCourse).where(UforaCourse.code == "E018441")
|
||||
await session.execute(delete_stmt)
|
||||
await session.commit()
|
||||
|
||||
# Delete aliases of courses with new IDs
|
||||
await purge_aliases(session)
|
||||
|
||||
# Fix IDs of compulsory courses and enable announcements
|
||||
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004074")
|
||||
bds: UforaCourse = (await session.execute(select_stmt)).scalar_one()
|
||||
bds.course_id = 828305
|
||||
bds.log_announcements = True
|
||||
session.add(bds)
|
||||
|
||||
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004073")
|
||||
cg: UforaCourse = (await session.execute(select_stmt)).scalar_one()
|
||||
cg.course_id = 828293
|
||||
cg.log_announcements = True
|
||||
session.add(cg)
|
||||
|
||||
select_stmt = select(UforaCourse).where(UforaCourse.code == "C004075")
|
||||
stage: UforaCourse = (await session.execute(select_stmt)).scalar_one()
|
||||
stage.course_id = 857878
|
||||
stage.log_announcements = True
|
||||
session.add(stage)
|
||||
|
||||
select_stmt = select(UforaCourse).where(UforaCourse.code == "C002309")
|
||||
thesis: UforaCourse = (await session.execute(select_stmt)).scalar_one()
|
||||
thesis.course_id = 828446
|
||||
thesis.log_announcements = True
|
||||
session.add(thesis)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Add new aliases for these courses
|
||||
cg_alias = UforaCourseAlias(course_id=cg.course_id, alias="Computer Graphics")
|
||||
stage_alias = UforaCourseAlias(course_id=stage.course_id, alias="Stage")
|
||||
thesis_alias = UforaCourseAlias(course_id=thesis.course_id, alias="Thesis")
|
||||
|
||||
session.add_all([cg_alias, stage_alias, thesis_alias])
|
||||
await session.commit()
|
||||
|
||||
# New elective courses
|
||||
bed_eco = UforaCourse(
|
||||
code="H001535", name="Bedrijfseconomie", year=6, compulsory=False, role_id=1155496199000952922
|
||||
)
|
||||
|
||||
beg_eco = UforaCourse(
|
||||
code="D012144", name="Beginselen van economie", year=6, compulsory=False, role_id=1155495997024247948
|
||||
)
|
||||
|
||||
big_data_tech = UforaCourse(
|
||||
code="E018240", name="Big Data Technology", year=6, compulsory=False, role_id=1155490114282201148
|
||||
)
|
||||
|
||||
cloud_storage = UforaCourse(
|
||||
code="E017310", name="Cloud Storage and Computing", year=6, compulsory=False, role_id=1155490706849271841
|
||||
)
|
||||
|
||||
criminologie = UforaCourse(
|
||||
code="B001623", name="Inleiding tot Criminologie", year=6, compulsory=False, role_id=1155486768477515936
|
||||
)
|
||||
|
||||
data_quality = UforaCourse(
|
||||
code="E018700", name="Data Quality", year=6, compulsory=False, role_id=1155491028707586180
|
||||
)
|
||||
|
||||
data_vis_ai = UforaCourse(
|
||||
code="E061370",
|
||||
name="Data Visualization for and with AI",
|
||||
year=6,
|
||||
compulsory=False,
|
||||
role_id=1155491854687686746,
|
||||
)
|
||||
|
||||
db_design = UforaCourse(
|
||||
code="E018610", name="Database Design", year=6, compulsory=False, role_id=1155489846345875489
|
||||
)
|
||||
|
||||
finance_markets = UforaCourse(
|
||||
code="F000093",
|
||||
name="Financiële Markten en Instellingen",
|
||||
year=6,
|
||||
compulsory=False,
|
||||
role_id=1155492815615299634,
|
||||
)
|
||||
|
||||
game_theory = UforaCourse(
|
||||
code="E003710",
|
||||
name="Game Theory and Multiagent Systems",
|
||||
year=6,
|
||||
compulsory=False,
|
||||
role_id=1155488481666154506,
|
||||
)
|
||||
|
||||
knowledge_graphs = UforaCourse(
|
||||
code="E018160", name="Knowledge Graphs", year=6, compulsory=False, role_id=1155491648323735592
|
||||
)
|
||||
|
||||
natural_language_processing = UforaCourse(
|
||||
code="E061341", name="Natural Language Processing", year=6, compulsory=False, role_id=1155487540992823348
|
||||
)
|
||||
|
||||
nosql = UforaCourse(
|
||||
code="E018130", name="NoSQL Databases", year=6, compulsory=False, role_id=1155491405955878973
|
||||
)
|
||||
|
||||
secure_ss = UforaCourse(
|
||||
code="E017950", name="Secure Software and Systems", year=6, compulsory=False, role_id=1155492095281340467
|
||||
)
|
||||
|
||||
session.add_all(
|
||||
[
|
||||
bed_eco,
|
||||
beg_eco,
|
||||
big_data_tech,
|
||||
cloud_storage,
|
||||
criminologie,
|
||||
data_quality,
|
||||
data_vis_ai,
|
||||
db_design,
|
||||
finance_markets,
|
||||
game_theory,
|
||||
knowledge_graphs,
|
||||
natural_language_processing,
|
||||
nosql,
|
||||
secure_ss,
|
||||
]
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Aliases for new elective courses
|
||||
datakwaliteit = UforaCourseAlias(course_id=data_quality.course_id, alias="Datakwaliteit")
|
||||
devops = UforaCourseAlias(course_id=cloud_storage.course_id, alias="DevOps")
|
||||
nlp = UforaCourseAlias(course_id=natural_language_processing.course_id, alias="NLP")
|
||||
nlp_nl = UforaCourseAlias(course_id=natural_language_processing.course_id, alias="Natuurlijke Taalverwerking")
|
||||
|
||||
session.add_all([datakwaliteit, devops, nlp, nlp_nl])
|
||||
|
||||
await session.commit()
|
|
@ -4,8 +4,8 @@ from discord import app_commands
|
|||
from overrides import overrides
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import easter_eggs, links, memes, ufora_courses, wordle
|
||||
from database.schemas import EasterEgg, WordleWord
|
||||
from database.crud import easter_eggs, links, memes, ufora_courses
|
||||
from database.schemas import EasterEgg
|
||||
|
||||
__all__ = ["CacheManager", "EasterEggCache", "LinkCache", "UforaCourseCache"]
|
||||
|
||||
|
@ -69,7 +69,7 @@ class LinkCache(DatabaseCache):
|
|||
self.clear()
|
||||
|
||||
all_links = await links.get_all_links(database_session)
|
||||
self.data = list(map(lambda l: l.name, all_links))
|
||||
self.data = list(map(lambda link: link.name, all_links))
|
||||
self.data.sort()
|
||||
self.data_transformed = list(map(str.lower, self.data))
|
||||
|
||||
|
@ -132,17 +132,6 @@ class UforaCourseCache(DatabaseCache):
|
|||
return [app_commands.Choice(name=suggestion, value=suggestion.lower()) for suggestion in suggestions]
|
||||
|
||||
|
||||
class WordleCache(DatabaseCache):
|
||||
"""Cache to store the current daily Wordle word"""
|
||||
|
||||
word: WordleWord
|
||||
|
||||
async def invalidate(self, database_session: AsyncSession):
|
||||
word = await wordle.get_daily_word(database_session)
|
||||
if word is not None:
|
||||
self.word = word
|
||||
|
||||
|
||||
class CacheManager:
|
||||
"""Class that keeps track of all caches"""
|
||||
|
||||
|
@ -150,14 +139,12 @@ class CacheManager:
|
|||
links: LinkCache
|
||||
memes: MemeCache
|
||||
ufora_courses: UforaCourseCache
|
||||
wordle_word: WordleCache
|
||||
|
||||
def __init__(self):
|
||||
self.easter_eggs = EasterEggCache()
|
||||
self.links = LinkCache()
|
||||
self.memes = MemeCache()
|
||||
self.ufora_courses = UforaCourseCache()
|
||||
self.wordle_word = WordleCache()
|
||||
|
||||
async def initialize_caches(self, postgres_session: AsyncSession):
|
||||
"""Initialize the contents of all caches"""
|
||||
|
@ -165,4 +152,3 @@ class CacheManager:
|
|||
await self.links.invalidate(postgres_session)
|
||||
await self.memes.invalidate(postgres_session)
|
||||
await self.ufora_courses.invalidate(postgres_session)
|
||||
await self.wordle_word.invalidate(postgres_session)
|
||||
|
|
|
@ -25,7 +25,7 @@ class Currency(commands.Cog):
|
|||
super().__init__()
|
||||
self.client = client
|
||||
|
||||
@commands.command(name="award")
|
||||
@commands.command(name="award") # type: ignore[arg-type]
|
||||
@commands.check(is_owner)
|
||||
async def award(
|
||||
self,
|
||||
|
@ -49,7 +49,9 @@ class Currency(commands.Cog):
|
|||
bank = await crud.get_bank(session, ctx.author.id)
|
||||
|
||||
embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue())
|
||||
embed.set_thumbnail(url=ctx.author.avatar.url)
|
||||
|
||||
if ctx.author.avatar is not None:
|
||||
embed.set_thumbnail(url=ctx.author.avatar.url)
|
||||
|
||||
embed.add_field(name="Interest level", value=bank.interest_level)
|
||||
embed.add_field(name="Capacity level", value=bank.capacity_level)
|
||||
|
@ -57,7 +59,9 @@ class Currency(commands.Cog):
|
|||
|
||||
await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@bank.group(name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True)
|
||||
@bank.group( # type: ignore[arg-type]
|
||||
name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True
|
||||
)
|
||||
async def bank_upgrades(self, ctx: commands.Context):
|
||||
"""List the upgrades you can buy & their prices."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -77,7 +81,7 @@ class Currency(commands.Cog):
|
|||
|
||||
await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@bank_upgrades.command(name="capacity", aliases=["c"])
|
||||
@bank_upgrades.command(name="capacity", aliases=["c"]) # type: ignore[arg-type]
|
||||
async def bank_upgrade_capacity(self, ctx: commands.Context):
|
||||
"""Upgrade the capacity level of your bank."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -88,7 +92,7 @@ class Currency(commands.Cog):
|
|||
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@bank_upgrades.command(name="interest", aliases=["i"])
|
||||
@bank_upgrades.command(name="interest", aliases=["i"]) # type: ignore[arg-type]
|
||||
async def bank_upgrade_interest(self, ctx: commands.Context):
|
||||
"""Upgrade the interest level of your bank."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -99,7 +103,7 @@ class Currency(commands.Cog):
|
|||
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@bank_upgrades.command(name="rob", aliases=["r"])
|
||||
@bank_upgrades.command(name="rob", aliases=["r"]) # type: ignore[arg-type]
|
||||
async def bank_upgrade_rob(self, ctx: commands.Context):
|
||||
"""Upgrade the rob level of your bank."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -110,7 +114,7 @@ class Currency(commands.Cog):
|
|||
await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False)
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@commands.hybrid_command(name="dinks")
|
||||
@commands.hybrid_command(name="dinks") # type: ignore[arg-type]
|
||||
async def dinks(self, ctx: commands.Context):
|
||||
"""Check your Didier Dinks."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -118,7 +122,7 @@ class Currency(commands.Cog):
|
|||
plural = pluralize("Didier Dink", bank.dinks)
|
||||
await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False)
|
||||
|
||||
@commands.command(name="invest", aliases=["deposit", "dep"])
|
||||
@commands.command(name="invest", aliases=["deposit", "dep"]) # type: ignore[arg-type]
|
||||
async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]):
|
||||
"""Invest `amount` Didier Dinks into your bank.
|
||||
|
||||
|
@ -144,7 +148,7 @@ class Currency(commands.Cog):
|
|||
f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False
|
||||
)
|
||||
|
||||
@commands.hybrid_command(name="nightly")
|
||||
@commands.hybrid_command(name="nightly") # type: ignore[arg-type]
|
||||
async def nightly(self, ctx: commands.Context):
|
||||
"""Claim nightly Didier Dinks."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
|
|
@ -13,7 +13,7 @@ class DebugCog(commands.Cog):
|
|||
self.client = client
|
||||
|
||||
@overrides
|
||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
async def cog_check(self, ctx: commands.Context) -> bool: # type:ignore[override]
|
||||
return await self.client.is_owner(ctx.author)
|
||||
|
||||
@commands.command(aliases=["Dev"])
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import Optional, Union, cast
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
|
@ -17,6 +17,7 @@ from didier.exceptions import expect
|
|||
from didier.menus.bookmarks import BookmarkSource
|
||||
from didier.utils.discord import colours
|
||||
from didier.utils.discord.assets import get_author_avatar, get_user_avatar
|
||||
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
|
||||
from didier.utils.discord.constants import Limits
|
||||
from didier.utils.timer import Timer
|
||||
from didier.utils.types.datetime import localize, str_to_date, tz_aware_now
|
||||
|
@ -60,9 +61,19 @@ class Discord(commands.Cog):
|
|||
event = await events.get_event_by_id(session, event_id)
|
||||
|
||||
if event is None:
|
||||
return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True)
|
||||
return await self.client.log_error(f"Unable to find event with id {event_id}.", log_to_discord=True)
|
||||
|
||||
channel = self.client.get_channel(event.notification_channel)
|
||||
if channel is None:
|
||||
return await self.client.log_error(
|
||||
f"Unable to fetch channel for event `#{event_id}` (id `{event.notification_channel}`)."
|
||||
)
|
||||
|
||||
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
return await self.client.log_error(
|
||||
f"Channel for event `#{event_id}` (id `{event.notification_channel}`) is not messageable."
|
||||
)
|
||||
|
||||
human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M")
|
||||
|
||||
embed = discord.Embed(title=event.name, colour=discord.Colour.blue())
|
||||
|
@ -81,7 +92,7 @@ class Discord(commands.Cog):
|
|||
self.client.loop.create_task(self.timer.update())
|
||||
|
||||
@commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True)
|
||||
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
||||
async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None):
|
||||
"""Command to check the birthday of `user`.
|
||||
|
||||
Not passing an argument for `user` will show yours instead.
|
||||
|
@ -98,8 +109,10 @@ class Discord(commands.Cog):
|
|||
day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month))
|
||||
return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False)
|
||||
|
||||
@birthday.command(name="set", aliases=["config"])
|
||||
async def birthday_set(self, ctx: commands.Context, day: str, user: Optional[discord.User] = None):
|
||||
@birthday.command(name="set", aliases=["config"]) # type: ignore[arg-type]
|
||||
async def birthday_set(
|
||||
self, ctx: commands.Context, day: str, user: Optional[Union[discord.User, discord.Member]] = None
|
||||
):
|
||||
"""Set your birthday to `day`.
|
||||
|
||||
Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`.
|
||||
|
@ -113,6 +126,9 @@ class Discord(commands.Cog):
|
|||
if user is None:
|
||||
user = ctx.author
|
||||
|
||||
# Please Mypy
|
||||
user = cast(Union[discord.User, discord.Member], user)
|
||||
|
||||
try:
|
||||
default_year = 2001
|
||||
date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"])
|
||||
|
@ -141,7 +157,7 @@ class Discord(commands.Cog):
|
|||
"""
|
||||
# No label: shortcut to display bookmarks
|
||||
if label is None:
|
||||
return await self.bookmark_search(ctx, query=None)
|
||||
return await self.bookmark_search(ctx, query=None) # type: ignore[arg-type]
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
result = expect(
|
||||
|
@ -151,7 +167,7 @@ class Discord(commands.Cog):
|
|||
)
|
||||
await ctx.reply(result.jump_url, mention_author=False)
|
||||
|
||||
@bookmark.command(name="create", aliases=["new"])
|
||||
@bookmark.command(name="create", aliases=["new"]) # type: ignore[arg-type]
|
||||
async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]):
|
||||
"""Create a new bookmark for message `message` with label `label`.
|
||||
|
||||
|
@ -182,7 +198,7 @@ class Discord(commands.Cog):
|
|||
# Label isn't allowed
|
||||
return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False)
|
||||
|
||||
@bookmark.command(name="delete", aliases=["rm"])
|
||||
@bookmark.command(name="delete", aliases=["rm"]) # type: ignore[arg-type]
|
||||
async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str):
|
||||
"""Delete the bookmark with id `bookmark_id`.
|
||||
|
||||
|
@ -207,7 +223,7 @@ class Discord(commands.Cog):
|
|||
|
||||
return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False)
|
||||
|
||||
@bookmark.command(name="search", aliases=["list", "ls"])
|
||||
@bookmark.command(name="search", aliases=["list", "ls"]) # type: ignore[arg-type]
|
||||
async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None):
|
||||
"""Search through the list of bookmarks.
|
||||
|
||||
|
@ -236,7 +252,7 @@ class Discord(commands.Cog):
|
|||
modal = CreateBookmark(self.client, message.jump_url)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@commands.hybrid_command(name="events")
|
||||
@commands.hybrid_command(name="events") # type: ignore[arg-type]
|
||||
@app_commands.rename(event_id="id")
|
||||
@app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.")
|
||||
async def events(self, ctx: commands.Context, event_id: Optional[int] = None):
|
||||
|
@ -276,16 +292,16 @@ class Discord(commands.Cog):
|
|||
embed.add_field(
|
||||
name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True
|
||||
)
|
||||
embed.add_field(
|
||||
name="Channel",
|
||||
value=self.client.get_channel(result_event.notification_channel).mention,
|
||||
inline=False,
|
||||
)
|
||||
|
||||
channel = self.client.get_channel(result_event.notification_channel)
|
||||
if channel is not None and not isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
embed.add_field(name="Channel", value=channel.mention, inline=False)
|
||||
|
||||
embed.description = result_event.description
|
||||
return await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
|
||||
async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None):
|
||||
async def github_group(self, ctx: commands.Context, user: Optional[Union[discord.User, discord.Member]] = None):
|
||||
"""Show a user's GitHub links.
|
||||
|
||||
If no user is provided, this shows your links instead.
|
||||
|
@ -293,6 +309,9 @@ class Discord(commands.Cog):
|
|||
# Default to author
|
||||
user = user or ctx.author
|
||||
|
||||
# Please Mypy
|
||||
user = cast(Union[discord.User, discord.Member], user)
|
||||
|
||||
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
|
||||
embed.set_author(name=user.display_name, icon_url=get_user_avatar(user))
|
||||
|
||||
|
@ -324,7 +343,7 @@ class Discord(commands.Cog):
|
|||
|
||||
return await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@github_group.command(name="add", aliases=["create", "insert"])
|
||||
@github_group.command(name="add", aliases=["create", "insert"]) # type: ignore[arg-type]
|
||||
async def github_add(self, ctx: commands.Context, link: str):
|
||||
"""Add a new link into the database."""
|
||||
# Remove wrapping <brackets> which can be used to escape Discord embeds
|
||||
|
@ -339,7 +358,7 @@ class Discord(commands.Cog):
|
|||
await self.client.confirm_message(ctx.message)
|
||||
return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False)
|
||||
|
||||
@github_group.command(name="delete", aliases=["del", "remove", "rm"])
|
||||
@github_group.command(name="delete", aliases=["del", "remove", "rm"]) # type: ignore[arg-type]
|
||||
async def github_delete(self, ctx: commands.Context, link_id: str):
|
||||
"""Delete the link with it `link_id` from the database.
|
||||
|
||||
|
@ -411,7 +430,7 @@ class Discord(commands.Cog):
|
|||
await message.add_reaction("📌")
|
||||
return await interaction.response.send_message("📌", ephemeral=True)
|
||||
|
||||
@commands.hybrid_command(name="snipe")
|
||||
@commands.hybrid_command(name="snipe") # type: ignore[arg-type]
|
||||
async def snipe(self, ctx: commands.Context):
|
||||
"""Publicly shame people when they edit or delete one of their messages.
|
||||
|
||||
|
@ -420,7 +439,7 @@ class Discord(commands.Cog):
|
|||
if ctx.guild is None:
|
||||
return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True)
|
||||
|
||||
sniped_data = self.client.sniped.get(ctx.channel.id, None)
|
||||
sniped_data = self.client.sniped.get(ctx.channel.id)
|
||||
if sniped_data is None:
|
||||
return await ctx.reply(
|
||||
"There's no one to make fun of in this channel.", mention_author=False, ephemeral=True
|
||||
|
|
|
@ -28,7 +28,7 @@ class Fun(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.hybrid_command(name="clap")
|
||||
@commands.hybrid_command(name="clap") # type: ignore[arg-type]
|
||||
async def clap(self, ctx: commands.Context, *, text: str):
|
||||
"""Clap a message with emojis for extra dramatic effect"""
|
||||
chars = list(filter(lambda c: c in constants.EMOJI_MAP, text))
|
||||
|
@ -50,10 +50,7 @@ class Fun(commands.Cog):
|
|||
meme = await generate_meme(self.client.http_session, result, fields)
|
||||
return meme
|
||||
|
||||
@commands.hybrid_command(
|
||||
name="dadjoke",
|
||||
aliases=["dad", "dj"],
|
||||
)
|
||||
@commands.hybrid_command(name="dadjoke", aliases=["dad", "dj"]) # type: ignore[arg-type]
|
||||
async def dad_joke(self, ctx: commands.Context):
|
||||
"""Why does Yoda's code always crash? Because there is no try."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -83,13 +80,13 @@ class Fun(commands.Cog):
|
|||
return await self.memegen_ls_msg(ctx)
|
||||
|
||||
if fields is None:
|
||||
return await self.memegen_preview_msg(ctx, template)
|
||||
return await self.memegen_preview_msg(ctx, template) # type: ignore[arg-type]
|
||||
|
||||
async with ctx.typing():
|
||||
meme = await self._do_generate_meme(template, shlex.split(fields))
|
||||
return await ctx.reply(meme, mention_author=False)
|
||||
|
||||
@memegen_msg.command(name="list", aliases=["ls"])
|
||||
@memegen_msg.command(name="list", aliases=["ls"]) # type: ignore[arg-type]
|
||||
async def memegen_ls_msg(self, ctx: commands.Context):
|
||||
"""Get a list of all available meme templates.
|
||||
|
||||
|
@ -100,14 +97,14 @@ class Fun(commands.Cog):
|
|||
|
||||
await MemeSource(ctx, results).start()
|
||||
|
||||
@memegen_msg.command(name="preview", aliases=["p"])
|
||||
@memegen_msg.command(name="preview", aliases=["p"]) # type: ignore[arg-type]
|
||||
async def memegen_preview_msg(self, ctx: commands.Context, template: str):
|
||||
"""Generate a preview for the meme template `template`, to see how the fields are structured."""
|
||||
async with ctx.typing():
|
||||
meme = await self._do_generate_meme(template, [])
|
||||
return await ctx.reply(meme, mention_author=False)
|
||||
|
||||
@memes_slash.command(name="generate")
|
||||
@memes_slash.command(name="generate") # type: ignore[arg-type]
|
||||
async def memegen_slash(self, interaction: discord.Interaction, template: str):
|
||||
"""Generate a meme."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -116,7 +113,7 @@ class Fun(commands.Cog):
|
|||
modal = GenerateMeme(self.client, result)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@memes_slash.command(name="preview")
|
||||
@memes_slash.command(name="preview") # type: ignore[arg-type]
|
||||
@app_commands.describe(template="The meme template to use in the preview.")
|
||||
async def memegen_preview_slash(self, interaction: discord.Interaction, template: str):
|
||||
"""Generate a preview for a meme, to see how the fields are structured."""
|
||||
|
@ -134,7 +131,7 @@ class Fun(commands.Cog):
|
|||
"""Autocompletion for the 'template'-parameter"""
|
||||
return self.client.database_caches.memes.get_autocomplete_suggestions(current)
|
||||
|
||||
@app_commands.command()
|
||||
@app_commands.command() # type: ignore[arg-type]
|
||||
@app_commands.describe(message="The text to convert.")
|
||||
async def mock(self, interaction: discord.Interaction, message: str):
|
||||
"""Mock a message.
|
||||
|
@ -158,7 +155,7 @@ class Fun(commands.Cog):
|
|||
|
||||
return await interaction.followup.send(mock(message))
|
||||
|
||||
@commands.hybrid_command(name="xkcd")
|
||||
@commands.hybrid_command(name="xkcd") # type: ignore[arg-type]
|
||||
@app_commands.rename(comic_id="id")
|
||||
async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None):
|
||||
"""Fetch comic `#id` from xkcd.
|
||||
|
|
|
@ -1,14 +1,6 @@
|
|||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from database.constants import WORDLE_WORD_LENGTH
|
||||
from database.crud.wordle import get_wordle_guesses, make_wordle_guess
|
||||
from database.crud.wordle_stats import complete_wordle_game
|
||||
from didier import Didier
|
||||
from didier.data.embeds.wordle import WordleEmbed, WordleErrorEmbed, is_wordle_game_over
|
||||
|
||||
|
||||
class Games(commands.Cog):
|
||||
|
@ -19,53 +11,6 @@ class Games(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@app_commands.command(name="wordle", description="Play Wordle!")
|
||||
async def wordle(self, interaction: discord.Interaction, guess: Optional[str] = None):
|
||||
"""View your active Wordle game
|
||||
|
||||
If an argument is provided, make a guess instead
|
||||
"""
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
# Guess is wrong length
|
||||
if guess is not None and len(guess) != 0 and len(guess) != WORDLE_WORD_LENGTH:
|
||||
embed = WordleErrorEmbed(message=f"Guess must be 5 characters, but `{guess}` is {len(guess)}.").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
word_instance = self.client.database_caches.wordle_word.word
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
guesses = await get_wordle_guesses(session, interaction.user.id)
|
||||
|
||||
# Trying to guess with a complete game
|
||||
if is_wordle_game_over(guesses, word_instance.word):
|
||||
embed = WordleErrorEmbed(
|
||||
message="You've already completed today's Wordle.\nTry again tomorrow!"
|
||||
).to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
# Make a guess
|
||||
if guess:
|
||||
# The guess is not a real word
|
||||
if guess.lower() not in self.client.wordle_words:
|
||||
embed = WordleErrorEmbed(message=f"`{guess}` is not a valid word.").to_embed()
|
||||
return await interaction.followup.send(embed=embed)
|
||||
|
||||
guess = guess.lower()
|
||||
await make_wordle_guess(session, interaction.user.id, guess)
|
||||
|
||||
# Don't re-request the game, we already have it
|
||||
# just append locally
|
||||
guesses.append(guess)
|
||||
|
||||
embed = WordleEmbed(guesses=guesses, word=word_instance).to_embed()
|
||||
await interaction.followup.send(embed=embed)
|
||||
|
||||
# After responding to the interaction: update stats in the background
|
||||
game_over = is_wordle_game_over(guesses, word_instance.word)
|
||||
if game_over:
|
||||
await complete_wordle_game(session, interaction.user.id, word_instance.word in guesses)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
|
|
|
@ -159,6 +159,9 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
|
|||
Code in codeblocks is ignored, as it is used to create examples.
|
||||
"""
|
||||
description = command.help
|
||||
if description is None:
|
||||
return ""
|
||||
|
||||
codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL)
|
||||
|
||||
# Regex borrowed from https://stackoverflow.com/a/59843498/13568999
|
||||
|
@ -198,13 +201,10 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
|
|||
|
||||
return None
|
||||
|
||||
async def _filter_cogs(self, cogs: list[commands.Cog]) -> list[commands.Cog]:
|
||||
async def _filter_cogs(self, cogs: list[Optional[commands.Cog]]) -> list[commands.Cog]:
|
||||
"""Filter the list of cogs down to all those that the user can see"""
|
||||
|
||||
async def _predicate(cog: Optional[commands.Cog]) -> bool:
|
||||
if cog is None:
|
||||
return False
|
||||
|
||||
async def _predicate(cog: commands.Cog) -> bool:
|
||||
# Remove cogs that we never want to see in the help page because they
|
||||
# don't contain commands, or shouldn't be visible at all
|
||||
if not cog.get_commands():
|
||||
|
@ -220,12 +220,12 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
|
|||
return True
|
||||
|
||||
# Filter list of cogs down
|
||||
filtered_cogs = [cog for cog in cogs if await _predicate(cog)]
|
||||
filtered_cogs = [cog for cog in cogs if cog is not None and await _predicate(cog)]
|
||||
return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name))
|
||||
|
||||
def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]:
|
||||
"""Check if a command has flags"""
|
||||
flag_param = command.params.get("flags", None)
|
||||
flag_param = command.params.get("flags")
|
||||
if flag_param is None:
|
||||
return None
|
||||
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import inspect
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
from discord.ext import commands
|
||||
|
||||
|
@ -76,18 +76,24 @@ class Meta(commands.Cog):
|
|||
if command_name is None:
|
||||
return await ctx.reply(repo_home, mention_author=False)
|
||||
|
||||
command: Optional[Union[commands.HelpCommand, commands.Command]]
|
||||
src: Any
|
||||
|
||||
if command_name == "help":
|
||||
command = self.client.help_command
|
||||
if command is None:
|
||||
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
|
||||
|
||||
src = type(self.client.help_command)
|
||||
filename = inspect.getsourcefile(src)
|
||||
else:
|
||||
command = self.client.get_command(command_name)
|
||||
if command is None:
|
||||
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
|
||||
|
||||
src = command.callback.__code__
|
||||
filename = src.co_filename
|
||||
|
||||
if command is None:
|
||||
return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False)
|
||||
|
||||
lines, first_line = inspect.getsourcelines(src)
|
||||
|
||||
if filename is None:
|
||||
|
|
|
@ -22,7 +22,7 @@ class Other(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.hybrid_command(name="corona", aliases=["covid", "rona"])
|
||||
@commands.hybrid_command(name="corona", aliases=["covid", "rona"]) # type: ignore[arg-type]
|
||||
async def covid(self, ctx: commands.Context, country: str = "Belgium"):
|
||||
"""Show Covid-19 info for a specific country.
|
||||
|
||||
|
@ -43,7 +43,7 @@ class Other(commands.Cog):
|
|||
"""Autocompletion for the 'country'-parameter"""
|
||||
return autocomplete_country(value)[:25]
|
||||
|
||||
@commands.hybrid_command(
|
||||
@commands.hybrid_command( # type: ignore[arg-type]
|
||||
name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary"
|
||||
)
|
||||
async def define(self, ctx: commands.Context, *, query: str):
|
||||
|
@ -55,7 +55,7 @@ class Other(commands.Cog):
|
|||
mention_author=False,
|
||||
)
|
||||
|
||||
@commands.hybrid_command(name="google", description="Google search")
|
||||
@commands.hybrid_command(name="google", description="Google search") # type: ignore[arg-type]
|
||||
@app_commands.describe(query="Search query")
|
||||
async def google(self, ctx: commands.Context, *, query: str):
|
||||
"""Show the Google search results for `query`.
|
||||
|
@ -71,7 +71,7 @@ class Other(commands.Cog):
|
|||
embed = GoogleSearch(results).to_embed()
|
||||
await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.")
|
||||
@commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") # type: ignore[arg-type]
|
||||
async def inspire(self, ctx: commands.Context):
|
||||
"""Generate an [InspiroBot](https://inspirobot.me/) quote."""
|
||||
async with ctx.typing():
|
||||
|
@ -82,7 +82,7 @@ class Other(commands.Cog):
|
|||
async with self.client.postgres_session as session:
|
||||
return await get_link_by_name(session, name.lower())
|
||||
|
||||
@commands.command(name="Link", aliases=["Links"])
|
||||
@commands.command(name="Link", aliases=["Links"]) # type: ignore[arg-type]
|
||||
async def link_msg(self, ctx: commands.Context, name: str):
|
||||
"""Get the link to the resource named `name`."""
|
||||
link = await self._get_link(name)
|
||||
|
@ -92,7 +92,7 @@ class Other(commands.Cog):
|
|||
target_message = await self.client.get_reply_target(ctx)
|
||||
await target_message.reply(link.url, mention_author=False)
|
||||
|
||||
@app_commands.command(name="link")
|
||||
@app_commands.command(name="link") # type: ignore[arg-type]
|
||||
@app_commands.describe(name="The name of the resource")
|
||||
async def link_slash(self, interaction: discord.Interaction, name: str):
|
||||
"""Get the link to something."""
|
||||
|
|
|
@ -42,7 +42,7 @@ class Owner(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override]
|
||||
"""Global check for every command in this cog
|
||||
|
||||
This means that we don't have to add is_owner() to every single command separately
|
||||
|
@ -102,7 +102,7 @@ class Owner(commands.Cog):
|
|||
async def add_msg(self, ctx: commands.Context):
|
||||
"""Command group for [add X] message commands"""
|
||||
|
||||
@add_msg.command(name="Alias")
|
||||
@add_msg.command(name="Alias") # type: ignore[arg-type]
|
||||
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
|
||||
"""Add a new alias for a custom command"""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -116,7 +116,7 @@ class Owner(commands.Cog):
|
|||
await ctx.reply("There is already a command with this name.")
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@add_msg.command(name="Custom")
|
||||
@add_msg.command(name="Custom") # type: ignore[arg-type]
|
||||
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
|
||||
"""Add a new custom command"""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -127,7 +127,7 @@ class Owner(commands.Cog):
|
|||
await ctx.reply("There is already a command with this name.")
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@add_msg.command(name="Link")
|
||||
@add_msg.command(name="Link") # type: ignore[arg-type]
|
||||
async def add_link_msg(self, ctx: commands.Context, name: str, url: str):
|
||||
"""Add a new link"""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -136,7 +136,7 @@ class Owner(commands.Cog):
|
|||
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
||||
@add_slash.command(name="custom", description="Add a custom command")
|
||||
@add_slash.command(name="custom", description="Add a custom command") # type: ignore[arg-type]
|
||||
async def add_custom_slash(self, interaction: discord.Interaction):
|
||||
"""Slash command to add a custom command"""
|
||||
if not await self.client.is_owner(interaction.user):
|
||||
|
@ -145,7 +145,7 @@ class Owner(commands.Cog):
|
|||
modal = CreateCustomCommand(self.client)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@add_slash.command(name="dadjoke", description="Add a dad joke")
|
||||
@add_slash.command(name="dadjoke", description="Add a dad joke") # type: ignore[arg-type]
|
||||
async def add_dad_joke_slash(self, interaction: discord.Interaction):
|
||||
"""Slash command to add a dad joke"""
|
||||
if not await self.client.is_owner(interaction.user):
|
||||
|
@ -154,7 +154,7 @@ class Owner(commands.Cog):
|
|||
modal = AddDadJoke(self.client)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@add_slash.command(name="deadline", description="Add a deadline")
|
||||
@add_slash.command(name="deadline", description="Add a deadline") # type: ignore[arg-type]
|
||||
@app_commands.describe(course="The name of the course to add a deadline for (aliases work too)")
|
||||
async def add_deadline_slash(self, interaction: discord.Interaction, course: str):
|
||||
"""Slash command to add a deadline"""
|
||||
|
@ -174,7 +174,7 @@ class Owner(commands.Cog):
|
|||
"""Autocompletion for the 'course'-parameter"""
|
||||
return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current)
|
||||
|
||||
@add_slash.command(name="event", description="Add a new event")
|
||||
@add_slash.command(name="event", description="Add a new event") # type: ignore[arg-type]
|
||||
async def add_event_slash(self, interaction: discord.Interaction):
|
||||
"""Slash command to add new events"""
|
||||
if not await self.client.is_owner(interaction.user):
|
||||
|
@ -183,7 +183,7 @@ class Owner(commands.Cog):
|
|||
modal = AddEvent(self.client)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@add_slash.command(name="link", description="Add a new link")
|
||||
@add_slash.command(name="link", description="Add a new link") # type: ignore[arg-type]
|
||||
async def add_link_slash(self, interaction: discord.Interaction):
|
||||
"""Slash command to add new links"""
|
||||
if not await self.client.is_owner(interaction.user):
|
||||
|
@ -192,7 +192,7 @@ class Owner(commands.Cog):
|
|||
modal = AddLink(self.client)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@add_slash.command(name="meme", description="Add a new meme")
|
||||
@add_slash.command(name="meme", description="Add a new meme") # type: ignore[arg-type]
|
||||
async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int):
|
||||
"""Slash command to add new memes"""
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
@ -205,11 +205,11 @@ class Owner(commands.Cog):
|
|||
await interaction.followup.send(f"Added meme `{meme.meme_id}`.")
|
||||
await self.client.database_caches.memes.invalidate(session)
|
||||
|
||||
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
|
||||
@commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) # type: ignore[arg-type]
|
||||
async def edit_msg(self, ctx: commands.Context):
|
||||
"""Command group for [edit X] commands"""
|
||||
|
||||
@edit_msg.command(name="Custom")
|
||||
@edit_msg.command(name="Custom") # type: ignore[arg-type]
|
||||
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
|
||||
"""Edit an existing custom command"""
|
||||
async with self.client.postgres_session as session:
|
||||
|
@ -220,7 +220,7 @@ class Owner(commands.Cog):
|
|||
await ctx.reply(f"No command found matching `{command}`.")
|
||||
return await self.client.reject_message(ctx.message)
|
||||
|
||||
@edit_slash.command(name="custom", description="Edit a custom command")
|
||||
@edit_slash.command(name="custom", description="Edit a custom command") # type: ignore[arg-type]
|
||||
@app_commands.describe(command="The name of the command to edit")
|
||||
async def edit_custom_slash(self, interaction: discord.Interaction, command: str):
|
||||
"""Slash command to edit a custom command"""
|
||||
|
|
|
@ -27,7 +27,7 @@ class School(commands.Cog):
|
|||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
@commands.hybrid_command(name="deadlines")
|
||||
@commands.hybrid_command(name="deadlines") # type: ignore[arg-type]
|
||||
async def deadlines(self, ctx: commands.Context):
|
||||
"""Show upcoming deadlines."""
|
||||
async with ctx.typing():
|
||||
|
@ -40,7 +40,7 @@ class School(commands.Cog):
|
|||
embed = Deadlines(deadlines).to_embed()
|
||||
await ctx.reply(embed=embed, mention_author=False, ephemeral=False)
|
||||
|
||||
@commands.hybrid_command(name="les", aliases=["sched", "schedule"])
|
||||
@commands.hybrid_command(name="les", aliases=["sched", "schedule"]) # type: ignore[arg-type]
|
||||
@app_commands.rename(day_dt="date")
|
||||
async def les(
|
||||
self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None
|
||||
|
@ -72,10 +72,7 @@ class School(commands.Cog):
|
|||
except NotInMainGuildException:
|
||||
return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False)
|
||||
|
||||
@commands.hybrid_command(
|
||||
name="menu",
|
||||
aliases=["eten", "food"],
|
||||
)
|
||||
@commands.hybrid_command(name="menu", aliases=["eten", "food"]) # type: ignore[arg-type]
|
||||
@app_commands.rename(day_dt="date")
|
||||
async def menu(
|
||||
self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None
|
||||
|
@ -96,7 +93,7 @@ class School(commands.Cog):
|
|||
embed = no_menu_found(day_dt)
|
||||
await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@commands.hybrid_command(
|
||||
@commands.hybrid_command( # type: ignore[arg-type]
|
||||
name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"]
|
||||
)
|
||||
@app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)")
|
||||
|
@ -124,7 +121,7 @@ class School(commands.Cog):
|
|||
mention_author=False,
|
||||
)
|
||||
|
||||
@commands.hybrid_command(name="ufora")
|
||||
@commands.hybrid_command(name="ufora") # type: ignore[arg-type]
|
||||
async def ufora(self, ctx: commands.Context, course: str):
|
||||
"""Link the Ufora page for a course."""
|
||||
async with self.client.postgres_session as session:
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import logging
|
||||
import random
|
||||
|
||||
import discord
|
||||
|
@ -10,7 +12,6 @@ from database import enums
|
|||
from database.crud.birthdays import get_birthdays_on_day
|
||||
from database.crud.reminders import get_all_reminders_for_category
|
||||
from database.crud.ufora_announcements import remove_old_announcements
|
||||
from database.crud.wordle import set_daily_word
|
||||
from database.schemas import Reminder
|
||||
from didier import Didier
|
||||
from didier.data.embeds.schedules import (
|
||||
|
@ -21,9 +22,12 @@ from didier.data.embeds.schedules import (
|
|||
from didier.data.rss_feeds.free_games import fetch_free_games
|
||||
from didier.data.rss_feeds.ufora import fetch_ufora_announcements
|
||||
from didier.decorators.tasks import timed_task
|
||||
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
|
||||
from didier.utils.discord.checks import is_owner
|
||||
from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# datetime.time()-instances for when every task should run
|
||||
DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||
|
@ -54,11 +58,10 @@ class Tasks(commands.Cog):
|
|||
"reminders": self.reminders,
|
||||
"ufora": self.pull_ufora_announcements,
|
||||
"remove_ufora": self.remove_old_ufora_announcements,
|
||||
"wordle": self.reset_wordle_word,
|
||||
}
|
||||
|
||||
@overrides
|
||||
def cog_load(self) -> None:
|
||||
async def cog_load(self) -> None:
|
||||
# Only check birthdays if there's a channel to send it to
|
||||
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None:
|
||||
self.check_birthdays.start()
|
||||
|
@ -74,10 +77,10 @@ class Tasks(commands.Cog):
|
|||
|
||||
# Start other tasks
|
||||
self.reminders.start()
|
||||
self.reset_wordle_word.start()
|
||||
asyncio.create_task(self.get_error_channel())
|
||||
|
||||
@overrides
|
||||
def cog_unload(self) -> None:
|
||||
async def cog_unload(self) -> None:
|
||||
# Cancel all pending tasks
|
||||
for task in self._tasks.values():
|
||||
if task.is_running():
|
||||
|
@ -99,7 +102,7 @@ class Tasks(commands.Cog):
|
|||
|
||||
await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]")
|
||||
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") # type: ignore[arg-type]
|
||||
async def force_task(self, ctx: commands.Context, name: str):
|
||||
"""Command to force-run a task without waiting for the specified run time"""
|
||||
name = name.lower()
|
||||
|
@ -110,23 +113,53 @@ class Tasks(commands.Cog):
|
|||
await task(forced=True)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
|
||||
async def get_error_channel(self):
|
||||
"""Get the configured channel from the cache"""
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
# Configure channel to send errors to
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
channel = self.client.get_channel(settings.ERRORS_CHANNEL)
|
||||
|
||||
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
logger.error(f"Configured error channel (id `{settings.ERRORS_CHANNEL}`) is not messageable.")
|
||||
else:
|
||||
self.client.error_channel = channel
|
||||
elif self.client.owner_id is not None:
|
||||
self.client.error_channel = self.client.get_user(self.client.owner_id)
|
||||
|
||||
@tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME)
|
||||
@timed_task(enums.TaskType.BIRTHDAYS)
|
||||
async def check_birthdays(self, **kwargs):
|
||||
"""Check if it's currently anyone's birthday"""
|
||||
_ = kwargs
|
||||
|
||||
# Can't happen (task isn't started if this is None), but Mypy doesn't know
|
||||
if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is None:
|
||||
return
|
||||
|
||||
now = tz_aware_now().date()
|
||||
async with self.client.postgres_session as session:
|
||||
birthdays = await get_birthdays_on_day(session, now)
|
||||
|
||||
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
||||
if channel is None:
|
||||
return await self.client.log_error("Unable to find channel for birthday announcements")
|
||||
return await self.client.log_error("Unable to fetch channel for birthday announcements.")
|
||||
|
||||
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
return await self.client.log_error(
|
||||
f"Birthday announcement channel (id `{settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL}`) is not messageable."
|
||||
)
|
||||
|
||||
for birthday in birthdays:
|
||||
user = self.client.get_user(birthday.user_id)
|
||||
|
||||
if user is None:
|
||||
await self.client.log_error(
|
||||
f"Unable to fetch user with id `{birthday.user_id}` for birthday announcement"
|
||||
)
|
||||
continue
|
||||
|
||||
await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention))
|
||||
|
||||
@check_birthdays.before_loop
|
||||
|
@ -146,6 +179,14 @@ class Tasks(commands.Cog):
|
|||
games = await fetch_free_games(self.client.http_session, session)
|
||||
channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL)
|
||||
|
||||
if channel is None:
|
||||
return await self.client.log_error("Unable to fetch channel for free games announcements.")
|
||||
|
||||
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
return await self.client.log_error(
|
||||
f"Free games channel (id `{settings.FREE_GAMES_CHANNEL}`) is not messageable."
|
||||
)
|
||||
|
||||
for game in games:
|
||||
await channel.send(embed=game.to_embed())
|
||||
|
||||
|
@ -207,6 +248,17 @@ class Tasks(commands.Cog):
|
|||
|
||||
async with self.client.postgres_session as db_session:
|
||||
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
|
||||
|
||||
if announcements_channel is None:
|
||||
return await self.client.log_error(
|
||||
f"Unable to fetch channel for ufora announcements (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`)."
|
||||
)
|
||||
|
||||
if isinstance(announcements_channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
return await self.client.log_error(
|
||||
f"Ufora announcements channel (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`) is not messageable."
|
||||
)
|
||||
|
||||
announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
|
||||
|
||||
for announcement in announcements:
|
||||
|
@ -266,34 +318,16 @@ class Tasks(commands.Cog):
|
|||
async with self.client.postgres_session as session:
|
||||
await remove_old_announcements(session)
|
||||
|
||||
@tasks.loop(time=DAILY_RESET_TIME)
|
||||
async def reset_wordle_word(self, forced: bool = False):
|
||||
"""Reset the daily Wordle word"""
|
||||
async with self.client.postgres_session as session:
|
||||
await set_daily_word(session, random.choice(tuple(self.client.wordle_words)), forced=forced)
|
||||
await self.client.database_caches.wordle_word.invalidate(session)
|
||||
|
||||
@reset_wordle_word.before_loop
|
||||
async def _before_reset_wordle_word(self):
|
||||
await self.client.wait_until_ready()
|
||||
|
||||
@check_birthdays.error
|
||||
@pull_schedules.error
|
||||
@pull_ufora_announcements.error
|
||||
@reminders.error
|
||||
@remove_old_ufora_announcements.error
|
||||
@reset_wordle_word.error
|
||||
async def _on_tasks_error(self, error: BaseException):
|
||||
"""Error handler for all tasks"""
|
||||
self.client.dispatch("task_error", error)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog
|
||||
|
||||
Initially fetch the wordle word from the database, or reset it
|
||||
if there hasn't been a reset yet today
|
||||
"""
|
||||
cog = Tasks(client)
|
||||
await client.add_cog(cog)
|
||||
await cog.reset_wordle_word()
|
||||
"""Load the cog"""
|
||||
await client.add_cog(Tasks(client))
|
||||
|
|
|
@ -19,7 +19,7 @@ async def get_country_info(http_session: ClientSession, country: str) -> CovidDa
|
|||
yesterday = response
|
||||
|
||||
data = {"today": today, "yesterday": yesterday}
|
||||
return CovidData.parse_obj(data)
|
||||
return CovidData.model_validate(data)
|
||||
|
||||
|
||||
async def get_global_info(http_session: ClientSession) -> CovidData:
|
||||
|
@ -35,4 +35,4 @@ async def get_global_info(http_session: ClientSession) -> CovidData:
|
|||
yesterday = response
|
||||
|
||||
data = {"today": today, "yesterday": yesterday}
|
||||
return CovidData.parse_obj(data)
|
||||
return CovidData.model_validate(data)
|
||||
|
|
|
@ -12,4 +12,4 @@ async def fetch_menu(http_session: ClientSession, day_dt: date) -> Menu:
|
|||
"""Fetch the menu for a given day"""
|
||||
endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json"
|
||||
async with ensure_get(http_session, endpoint, log_exceptions=False) as response:
|
||||
return Menu.parse_obj(response)
|
||||
return Menu.model_validate(response)
|
||||
|
|
|
@ -14,4 +14,4 @@ async def lookup(http_session: ClientSession, query: str) -> list[Definition]:
|
|||
url = "https://api.urbandictionary.com/v0/define"
|
||||
|
||||
async with ensure_get(http_session, url, params={"term": query}) as response:
|
||||
return list(map(Definition.parse_obj, response["list"]))
|
||||
return list(map(Definition.model_validate, response["list"]))
|
||||
|
|
|
@ -13,4 +13,4 @@ async def fetch_xkcd_post(http_session: ClientSession, *, num: Optional[int] = N
|
|||
url = "https://xkcd.com" + (f"/{num}" if num is not None else "") + "/info.0.json"
|
||||
|
||||
async with ensure_get(http_session, url) as response:
|
||||
return XKCDPost.parse_obj(response)
|
||||
return XKCDPost.model_validate(response)
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import discord
|
||||
from overrides import overrides
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from didier.data.embeds.base import EmbedPydantic
|
||||
|
||||
|
@ -24,7 +24,7 @@ class _CovidNumbers(BaseModel):
|
|||
active: int
|
||||
tests: int
|
||||
|
||||
@validator("updated")
|
||||
@field_validator("updated")
|
||||
def updated_to_seconds(cls, value: int) -> int:
|
||||
"""Turn the updated field into seconds instead of milliseconds"""
|
||||
return int(value) // 1000
|
||||
|
|
|
@ -38,10 +38,10 @@ def create_error_embed(ctx: Optional[commands.Context], exception: Exception) ->
|
|||
embed = discord.Embed(title="Error", colour=discord.Colour.red())
|
||||
|
||||
if ctx is not None:
|
||||
if ctx.guild is None:
|
||||
if ctx.guild is None or isinstance(ctx.channel, discord.DMChannel):
|
||||
origin = "DM"
|
||||
else:
|
||||
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
|
||||
origin = f"<#{ctx.channel.id}> ({ctx.guild.name})"
|
||||
|
||||
invocation = f"{ctx.author.display_name} in {origin}"
|
||||
|
||||
|
|
|
@ -4,18 +4,17 @@ from typing import Optional
|
|||
import discord
|
||||
from aiohttp import ClientSession
|
||||
from overrides import overrides
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from didier.data.embeds.base import EmbedPydantic
|
||||
from didier.data.scrapers.common import GameStorePage
|
||||
from didier.data.scrapers.steam import get_steam_webpage_info
|
||||
from didier.utils.discord import colours
|
||||
|
||||
__all__ = ["SEPARATOR", "FreeGameEmbed"]
|
||||
|
||||
from didier.utils.discord.constants import Limits
|
||||
from didier.utils.types.string import abbreviate
|
||||
|
||||
__all__ = ["SEPARATOR", "FreeGameEmbed"]
|
||||
|
||||
SEPARATOR = " • Free • "
|
||||
|
||||
|
||||
|
@ -58,7 +57,7 @@ class FreeGameEmbed(EmbedPydantic):
|
|||
|
||||
store_page: Optional[GameStorePage] = None
|
||||
|
||||
@validator("title")
|
||||
@field_validator("title")
|
||||
def _clean_title(cls, value: str) -> str:
|
||||
return html.unescape(value)
|
||||
|
||||
|
@ -107,7 +106,6 @@ class FreeGameEmbed(EmbedPydantic):
|
|||
embed.add_field(name="Open in browser", value=f"[{self.link}]({self.link})")
|
||||
|
||||
if self.store_page.xdg_open_url is not None:
|
||||
|
||||
embed.add_field(
|
||||
name="Open in app", value=f"[{self.store_page.xdg_open_url}]({self.store_page.xdg_open_url})"
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ __all__ = ["create_logging_embed"]
|
|||
def create_logging_embed(level: int, message: str) -> discord.Embed:
|
||||
"""Create an embed to send to the logging channel"""
|
||||
colours = {
|
||||
logging.DEBUG: discord.Colour.light_gray,
|
||||
logging.DEBUG: discord.Colour.light_grey(),
|
||||
logging.ERROR: discord.Colour.red(),
|
||||
logging.INFO: discord.Colour.blue(),
|
||||
logging.WARNING: discord.Colour.yellow(),
|
||||
|
|
|
@ -2,7 +2,7 @@ from datetime import datetime
|
|||
|
||||
import discord
|
||||
from overrides import overrides
|
||||
from pydantic import validator
|
||||
from pydantic import field_validator
|
||||
|
||||
from didier.data.embeds.base import EmbedPydantic
|
||||
from didier.utils.discord import colours
|
||||
|
@ -39,8 +39,8 @@ class Definition(EmbedPydantic):
|
|||
total_votes = self.thumbs_up + self.thumbs_down
|
||||
return round(100 * self.thumbs_up / total_votes, 2)
|
||||
|
||||
@validator("definition", "example")
|
||||
def modify_long_text(cls, field):
|
||||
@field_validator("definition", "example")
|
||||
def modify_long_text(cls, field: str):
|
||||
"""Remove brackets from fields & cut them off if they are too long"""
|
||||
field = field.replace("[", "").replace("]", "")
|
||||
return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH)
|
||||
|
|
|
@ -1,142 +0,0 @@
|
|||
import enum
|
||||
from dataclasses import dataclass
|
||||
|
||||
import discord
|
||||
from overrides import overrides
|
||||
|
||||
from database.constants import WORDLE_GUESS_COUNT, WORDLE_WORD_LENGTH
|
||||
from database.schemas import WordleWord
|
||||
from didier.data.embeds.base import EmbedBaseModel
|
||||
from didier.utils.types.datetime import int_to_weekday, tz_aware_now
|
||||
|
||||
__all__ = ["is_wordle_game_over", "WordleEmbed", "WordleErrorEmbed"]
|
||||
|
||||
|
||||
def is_wordle_game_over(guesses: list[str], word: str) -> bool:
|
||||
"""Check if the current game is over or not"""
|
||||
if not guesses:
|
||||
return False
|
||||
|
||||
if len(guesses) == WORDLE_GUESS_COUNT:
|
||||
return True
|
||||
|
||||
return word.lower() in guesses
|
||||
|
||||
|
||||
def footer() -> str:
|
||||
"""Create the footer to put on the embed"""
|
||||
today = tz_aware_now()
|
||||
return f"{int_to_weekday(today.weekday())} {today.strftime('%d/%m/%Y')}"
|
||||
|
||||
|
||||
class WordleColour(enum.IntEnum):
|
||||
"""Colours for the Wordle embed"""
|
||||
|
||||
EMPTY = 0
|
||||
WRONG_LETTER = 1
|
||||
WRONG_POSITION = 2
|
||||
CORRECT = 3
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordleEmbed(EmbedBaseModel):
|
||||
"""Embed for a Wordle game"""
|
||||
|
||||
guesses: list[str]
|
||||
word: WordleWord
|
||||
|
||||
def _letter_colour(self, guess: str, index: int) -> WordleColour:
|
||||
"""Get the colour for a guess at a given position"""
|
||||
if guess[index] == self.word.word[index]:
|
||||
return WordleColour.CORRECT
|
||||
|
||||
wrong_letter = 0
|
||||
wrong_position = 0
|
||||
|
||||
for i, letter in enumerate(self.word.word):
|
||||
if letter == guess[index] and guess[i] != guess[index]:
|
||||
wrong_letter += 1
|
||||
|
||||
if i <= index and guess[i] == guess[index] and letter != guess[index]:
|
||||
wrong_position += 1
|
||||
|
||||
if i >= index:
|
||||
if wrong_position == 0:
|
||||
break
|
||||
|
||||
if wrong_position <= wrong_letter:
|
||||
return WordleColour.WRONG_POSITION
|
||||
|
||||
return WordleColour.WRONG_LETTER
|
||||
|
||||
def _guess_colours(self, guess: str) -> list[WordleColour]:
|
||||
"""Create the colour codes for a specific guess"""
|
||||
return [self._letter_colour(guess, i) for i in range(WORDLE_WORD_LENGTH)]
|
||||
|
||||
def colour_code_game(self) -> list[list[WordleColour]]:
|
||||
"""Create the colour codes for an entire game"""
|
||||
colours = []
|
||||
|
||||
# Add all the guesses
|
||||
for guess in self.guesses:
|
||||
colours.append(self._guess_colours(guess))
|
||||
|
||||
# Fill the rest with empty spots
|
||||
for _ in range(WORDLE_GUESS_COUNT - len(colours)):
|
||||
colours.append([WordleColour.EMPTY] * WORDLE_WORD_LENGTH)
|
||||
|
||||
return colours
|
||||
|
||||
def _colours_to_emojis(self, colours: list[list[WordleColour]]) -> list[list[str]]:
|
||||
"""Turn the colours of the board into Discord emojis"""
|
||||
colour_map = {
|
||||
WordleColour.EMPTY: ":white_large_square:",
|
||||
WordleColour.WRONG_LETTER: ":black_large_square:",
|
||||
WordleColour.WRONG_POSITION: ":orange_square:",
|
||||
WordleColour.CORRECT: ":green_square:",
|
||||
}
|
||||
|
||||
emojis = []
|
||||
for row in colours:
|
||||
emojis.append(list(map(lambda char: colour_map[char], row)))
|
||||
|
||||
return emojis
|
||||
|
||||
@overrides
|
||||
def to_embed(self, **kwargs) -> discord.Embed:
|
||||
only_colours = kwargs.get("only_colours", False)
|
||||
|
||||
colours = self.colour_code_game()
|
||||
|
||||
embed = discord.Embed(colour=discord.Colour.blue(), title=f"Wordle #{self.word.word_id + 1}")
|
||||
emojis = self._colours_to_emojis(colours)
|
||||
|
||||
rows = [" ".join(row) for row in emojis]
|
||||
|
||||
# Don't reveal anything if we only want to show the colours
|
||||
if not only_colours and self.guesses:
|
||||
for i, guess in enumerate(self.guesses):
|
||||
rows[i] += f" ||{guess.upper()}||"
|
||||
|
||||
# If the game is over, reveal the word
|
||||
if is_wordle_game_over(self.guesses, self.word.word):
|
||||
rows.append(f"\n\nThe word was **{self.word.word.upper()}**!")
|
||||
|
||||
embed.description = "\n\n".join(rows)
|
||||
embed.set_footer(text=footer())
|
||||
|
||||
return embed
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordleErrorEmbed(EmbedBaseModel):
|
||||
"""Embed to send error messages to the user"""
|
||||
|
||||
message: str
|
||||
|
||||
@overrides
|
||||
def to_embed(self, **kwargs) -> discord.Embed:
|
||||
embed = discord.Embed(colour=discord.Colour.red(), title="Wordle")
|
||||
embed.description = self.message
|
||||
embed.set_footer(text=footer())
|
||||
return embed
|
|
@ -32,7 +32,7 @@ async def fetch_free_games(http_session: ClientSession, database_session: AsyncS
|
|||
if SEPARATOR not in entry["title"]:
|
||||
continue
|
||||
|
||||
game = FreeGameEmbed.parse_obj(entry)
|
||||
game = FreeGameEmbed.model_validate(entry)
|
||||
games.append(game)
|
||||
game_ids.append(game.dc_identifier)
|
||||
|
||||
|
|
|
@ -72,12 +72,12 @@ def get_search_results(bs: BeautifulSoup) -> list[str]:
|
|||
return list(dict.fromkeys(results))
|
||||
|
||||
|
||||
async def google_search(http_client: ClientSession, query: str):
|
||||
async def google_search(http_session: ClientSession, query: str):
|
||||
"""Get the first 10 Google search results"""
|
||||
query = urlencode({"q": query})
|
||||
|
||||
# Request 20 results in case of duplicates, bad matches, ...
|
||||
async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response:
|
||||
async with http_session.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response:
|
||||
# Something went wrong
|
||||
if response.status != http.HTTPStatus.OK:
|
||||
return SearchData(query, response.status)
|
||||
|
|
|
@ -17,7 +17,7 @@ from database.utils.caches import CacheManager
|
|||
from didier.data.embeds.error_embed import create_error_embed
|
||||
from didier.data.embeds.logging_embed import create_logging_embed
|
||||
from didier.data.embeds.schedules import Schedule, parse_schedule
|
||||
from didier.exceptions import HTTPException, NoMatch
|
||||
from didier.exceptions import GetNoneException, HTTPException, NoMatch
|
||||
from didier.utils.discord.prefix import get_prefix
|
||||
from didier.utils.discord.snipe import should_snipe
|
||||
from didier.utils.easter_eggs import detect_easter_egg
|
||||
|
@ -33,12 +33,11 @@ class Didier(commands.Bot):
|
|||
"""DIDIER <3"""
|
||||
|
||||
database_caches: CacheManager
|
||||
error_channel: discord.abc.Messageable
|
||||
error_channel: Optional[discord.abc.Messageable] = None
|
||||
initial_extensions: tuple[str, ...] = ()
|
||||
http_session: ClientSession
|
||||
schedules: dict[settings.ScheduleType, Schedule] = {}
|
||||
sniped: dict[int, tuple[discord.Message, Optional[discord.Message]]] = {}
|
||||
wordle_words: set[str] = set()
|
||||
|
||||
def __init__(self):
|
||||
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
|
||||
|
@ -57,12 +56,17 @@ class Didier(commands.Bot):
|
|||
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
|
||||
)
|
||||
|
||||
self.tree.on_error = self.on_app_command_error
|
||||
# I'm not creating a custom tree, this is the way to do it
|
||||
self.tree.on_error = self.on_app_command_error # type: ignore[method-assign]
|
||||
|
||||
@cached_property
|
||||
def main_guild(self) -> discord.Guild:
|
||||
"""Obtain a reference to the main guild"""
|
||||
return self.get_guild(settings.DISCORD_MAIN_GUILD)
|
||||
guild = self.get_guild(settings.DISCORD_MAIN_GUILD)
|
||||
if guild is None:
|
||||
raise GetNoneException("Main guild could not be found in the bot's cache")
|
||||
|
||||
return guild
|
||||
|
||||
@property
|
||||
def postgres_session(self) -> AsyncSession:
|
||||
|
@ -77,9 +81,6 @@ class Didier(commands.Bot):
|
|||
# Create directories that are ignored on GitHub
|
||||
self._create_ignored_directories()
|
||||
|
||||
# Load the Wordle dictionary
|
||||
self._load_wordle_words()
|
||||
|
||||
# Initialize caches
|
||||
self.database_caches = CacheManager()
|
||||
async with self.postgres_session as session:
|
||||
|
@ -97,12 +98,6 @@ class Didier(commands.Bot):
|
|||
await self._load_initial_extensions()
|
||||
await self._load_directory_extensions("didier/cogs")
|
||||
|
||||
# Configure channel to send errors to
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
self.error_channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
else:
|
||||
self.error_channel = self.get_user(self.owner_id)
|
||||
|
||||
def _create_ignored_directories(self):
|
||||
"""Create directories that store ignored data"""
|
||||
ignored = ["files/schedules"]
|
||||
|
@ -137,12 +132,6 @@ class Didier(commands.Bot):
|
|||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||
await self._load_directory_extensions(new_path)
|
||||
|
||||
def _load_wordle_words(self):
|
||||
"""Load the dictionary of Wordle words"""
|
||||
with open("files/dictionaries/words-english-wordle.txt", "r") as fp:
|
||||
for line in fp:
|
||||
self.wordle_words.add(line.strip())
|
||||
|
||||
async def load_schedules(self):
|
||||
"""Parse & load all schedules into memory"""
|
||||
self.schedules = {}
|
||||
|
@ -162,18 +151,27 @@ class Didier(commands.Bot):
|
|||
original message instead
|
||||
"""
|
||||
if ctx.message.reference is not None:
|
||||
return await self.resolve_message(ctx.message.reference)
|
||||
return await self.resolve_message(ctx.message.reference) or ctx.message
|
||||
|
||||
return ctx.message
|
||||
|
||||
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
||||
async def resolve_message(self, reference: discord.MessageReference) -> Optional[discord.Message]:
|
||||
"""Fetch a message from a reference"""
|
||||
# Message is in the cache, return it
|
||||
if reference.cached_message is not None:
|
||||
return reference.cached_message
|
||||
|
||||
if reference.message_id is None:
|
||||
return None
|
||||
|
||||
# For older messages: fetch them from the API
|
||||
channel = self.get_channel(reference.channel_id)
|
||||
if channel is None or isinstance(
|
||||
channel,
|
||||
(discord.CategoryChannel, discord.ForumChannel, discord.abc.PrivateChannel),
|
||||
): # Logically this can't happen, but we have to please Mypy
|
||||
return None
|
||||
|
||||
return await channel.fetch_message(reference.message_id)
|
||||
|
||||
async def confirm_message(self, message: discord.Message):
|
||||
|
@ -194,7 +192,7 @@ class Didier(commands.Bot):
|
|||
}
|
||||
|
||||
methods.get(level, logger.error)(message)
|
||||
if log_to_discord:
|
||||
if log_to_discord and self.error_channel is not None:
|
||||
embed = create_logging_embed(level, message)
|
||||
await self.error_channel.send(embed=embed)
|
||||
|
||||
|
@ -263,10 +261,9 @@ class Didier(commands.Bot):
|
|||
|
||||
await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True)
|
||||
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
if self.error_channel is not None:
|
||||
embed = create_error_embed(await commands.Context.from_interaction(interaction), exception)
|
||||
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
await channel.send(embed=embed)
|
||||
await self.error_channel.send(embed=embed)
|
||||
|
||||
async def on_command_completion(self, ctx: commands.Context):
|
||||
"""Event triggered when a message command completes successfully"""
|
||||
|
@ -291,7 +288,7 @@ class Didier(commands.Bot):
|
|||
|
||||
# Hybrid command errors are wrapped in an additional error, so wrap it back out
|
||||
if isinstance(exception, commands.HybridCommandError):
|
||||
exception = exception.original
|
||||
exception = exception.original # type: ignore[assignment]
|
||||
|
||||
# Ignore exceptions that aren't important
|
||||
if isinstance(
|
||||
|
@ -342,10 +339,9 @@ class Didier(commands.Bot):
|
|||
# Print everything that we care about to the logs/stderr
|
||||
await super().on_command_error(ctx, exception)
|
||||
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
if self.error_channel is not None:
|
||||
embed = create_error_embed(ctx, exception)
|
||||
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
await channel.send(embed=embed)
|
||||
await self.error_channel.send(embed=embed)
|
||||
|
||||
async def on_message(self, message: discord.Message, /) -> None:
|
||||
"""Event triggered when a message is sent"""
|
||||
|
@ -354,7 +350,7 @@ class Didier(commands.Bot):
|
|||
return
|
||||
|
||||
# Boos react to people that say Dider
|
||||
if "dider" in message.content.lower() and message.author.id != self.user.id:
|
||||
if "dider" in message.content.lower() and self.user is not None and message.author.id != self.user.id:
|
||||
await message.add_reaction(settings.DISCORD_BOOS_REACT)
|
||||
|
||||
# Potential custom command
|
||||
|
@ -384,7 +380,7 @@ class Didier(commands.Bot):
|
|||
|
||||
# If the edited message is currently present in the snipe cache,
|
||||
# don't update the <before>, but instead change the <after>
|
||||
existing = self.sniped.get(before.channel.id, None)
|
||||
existing = self.sniped.get(before.channel.id)
|
||||
if existing is not None and existing[0].id == before.id:
|
||||
before = existing[0]
|
||||
|
||||
|
@ -399,10 +395,9 @@ class Didier(commands.Bot):
|
|||
|
||||
async def on_task_error(self, exception: Exception):
|
||||
"""Event triggered when a task raises an exception"""
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
if self.error_channel:
|
||||
embed = create_error_embed(None, exception)
|
||||
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
await channel.send(embed=embed)
|
||||
await self.error_channel.send(embed=embed)
|
||||
|
||||
async def on_thread_create(self, thread: discord.Thread):
|
||||
"""Event triggered when a new thread is created"""
|
||||
|
|
|
@ -1,6 +1,14 @@
|
|||
from .get_none_exception import GetNoneException
|
||||
from .http_exception import HTTPException
|
||||
from .missing_env import MissingEnvironmentVariable
|
||||
from .no_match import NoMatch, expect
|
||||
from .not_in_main_guild_exception import NotInMainGuildException
|
||||
|
||||
__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect", "NotInMainGuildException"]
|
||||
__all__ = [
|
||||
"GetNoneException",
|
||||
"HTTPException",
|
||||
"MissingEnvironmentVariable",
|
||||
"NoMatch",
|
||||
"expect",
|
||||
"NotInMainGuildException",
|
||||
]
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
__all__ = ["GetNoneException"]
|
||||
|
||||
|
||||
class GetNoneException(RuntimeError):
|
||||
"""Exception raised when a Bot.get()-method returned None"""
|
|
@ -12,6 +12,6 @@ class NotInMainGuildException(ValueError):
|
|||
|
||||
def __init__(self, user: Union[discord.User, discord.Member]):
|
||||
super().__init__(
|
||||
f"User {user.display_name} (id {user.id}) "
|
||||
f"is not a member of the configured main guild (id {settings.DISCORD_MAIN_GUILD})."
|
||||
f"User {user.display_name} (id `{user.id}`) "
|
||||
f"is not a member of the configured main guild (id `{settings.DISCORD_MAIN_GUILD}`)."
|
||||
)
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
import discord
|
||||
|
||||
__all__ = ["NON_MESSAGEABLE_CHANNEL_TYPES"]
|
||||
|
||||
NON_MESSAGEABLE_CHANNEL_TYPES = (discord.ForumChannel, discord.CategoryChannel, discord.abc.PrivateChannel)
|
|
@ -15,11 +15,14 @@ def match_prefix(client: commands.Bot, message: Message) -> Optional[str]:
|
|||
This is done dynamically through regexes to allow case-insensitivity
|
||||
and variable amounts of whitespace among other things.
|
||||
"""
|
||||
mention = f"<@!?{client.user.id}>"
|
||||
mention = f"<@!?{client.user.id}>" if client.user else None
|
||||
regex = r"^({})\s*"
|
||||
|
||||
# Check which prefix was used
|
||||
for prefix in [*constants.PREFIXES, mention]:
|
||||
if prefix is None:
|
||||
continue
|
||||
|
||||
match = re.match(regex.format(prefix), message.content, flags=re.I)
|
||||
|
||||
if match is not None:
|
||||
|
|
|
@ -25,24 +25,24 @@ class CreateBookmark(discord.ui.Modal, title="Create Bookmark"):
|
|||
|
||||
@overrides
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
label = self.name.value.strip()
|
||||
|
||||
try:
|
||||
async with self.client.postgres_session as session:
|
||||
bm = await create_bookmark(session, interaction.user.id, label, self.jump_url)
|
||||
return await interaction.response.send_message(
|
||||
f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True
|
||||
return await interaction.followup.send(
|
||||
f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`)."
|
||||
)
|
||||
except DuplicateInsertException:
|
||||
# Label is already in use
|
||||
return await interaction.response.send_message(
|
||||
f"You already have a bookmark named `{label}`.", ephemeral=True
|
||||
)
|
||||
return await interaction.followup.send(f"You already have a bookmark named `{label}`.")
|
||||
except ForbiddenNameException:
|
||||
# Label isn't allowed
|
||||
return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True)
|
||||
return await interaction.followup.send(f"Bookmarks cannot be named `{label}`.")
|
||||
|
||||
@overrides
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||
await interaction.response.send_message("Something went wrong.", ephemeral=True)
|
||||
await interaction.followup.send("Something went wrong.", ephemeral=True)
|
||||
traceback.print_tb(error.__traceback__)
|
||||
|
|
|
@ -26,12 +26,14 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
|
|||
|
||||
@overrides
|
||||
async def on_submit(self, interaction: discord.Interaction):
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
joke = await add_dad_joke(session, str(self.joke.value))
|
||||
|
||||
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)
|
||||
await interaction.followup.send(f"Successfully added joke #{joke.dad_joke_id}")
|
||||
|
||||
@overrides
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||
await interaction.response.send_message("Something went wrong.", ephemeral=True)
|
||||
await interaction.followup.send("Something went wrong.", ephemeral=True)
|
||||
traceback.print_tb(error.__traceback__)
|
||||
|
|
|
@ -10,6 +10,8 @@ from didier import Didier
|
|||
|
||||
__all__ = ["AddEvent"]
|
||||
|
||||
from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES
|
||||
|
||||
|
||||
class AddEvent(discord.ui.Modal, title="Add Event"):
|
||||
"""Modal to add a new event"""
|
||||
|
@ -33,15 +35,20 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
|
|||
|
||||
@overrides
|
||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
|
||||
try:
|
||||
parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels"))
|
||||
except ParserError:
|
||||
return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True)
|
||||
return await interaction.followup.send("Unable to parse date argument.")
|
||||
|
||||
if self.client.get_channel(int(self.channel.value)) is None:
|
||||
return await interaction.response.send_message(
|
||||
f"Unable to find channel `{self.channel.value}`", ephemeral=True
|
||||
)
|
||||
channel = self.client.get_channel(int(self.channel.value))
|
||||
|
||||
if channel is None:
|
||||
return await interaction.followup.send(f"Unable to find channel with id `{self.channel.value}`")
|
||||
|
||||
if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES):
|
||||
return await interaction.followup.send(f"Channel with id `{self.channel.value}` is not messageable.")
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
event = await add_event(
|
||||
|
@ -52,10 +59,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"):
|
|||
channel_id=int(self.channel.value),
|
||||
)
|
||||
|
||||
await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True)
|
||||
await interaction.followup.send(f"Successfully added event `{event.event_id}`.")
|
||||
self.client.dispatch("event_create", event)
|
||||
|
||||
@overrides
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore
|
||||
await interaction.response.send_message("Something went wrong.", ephemeral=True)
|
||||
await interaction.followup.send("Something went wrong.", ephemeral=True)
|
||||
traceback.print_tb(error.__traceback__)
|
||||
|
|
File diff suppressed because it is too large
Load Diff
1
main.py
1
main.py
|
@ -36,6 +36,7 @@ def setup_logging():
|
|||
|
||||
# Configure discord handler
|
||||
discord_log = logging.getLogger("discord")
|
||||
discord_handler: logging.StreamHandler
|
||||
|
||||
# Make dev print to stderr instead, so you don't have to watch the file
|
||||
if settings.SANDBOX:
|
||||
|
|
|
@ -28,6 +28,7 @@ omit = [
|
|||
profile = "black"
|
||||
|
||||
[tool.mypy]
|
||||
check_untyped_defs = true
|
||||
files = [
|
||||
"database/**/*.py",
|
||||
"didier/**/*.py",
|
||||
|
@ -35,7 +36,6 @@ files = [
|
|||
]
|
||||
plugins = [
|
||||
"pydantic.mypy",
|
||||
"sqlalchemy.ext.mypy.plugin"
|
||||
]
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"]
|
||||
|
|
|
@ -1,22 +1,21 @@
|
|||
black==22.3.0
|
||||
coverage[toml]==6.4.1
|
||||
freezegun==1.2.1
|
||||
black==23.3.0
|
||||
coverage[toml]==7.2.7
|
||||
freezegun==1.2.2
|
||||
isort==5.12.0
|
||||
mypy==0.961
|
||||
pre-commit==2.20.0
|
||||
pytest==7.1.2
|
||||
pytest-asyncio==0.18.3
|
||||
pytest-env==0.6.2
|
||||
sqlalchemy2-stubs==0.0.2a23
|
||||
types-beautifulsoup4==4.11.3
|
||||
types-python-dateutil==2.8.19
|
||||
mypy==1.4.1
|
||||
pre-commit==3.3.3
|
||||
pytest==7.4.0
|
||||
pytest-asyncio==0.21.0
|
||||
pytest-env==0.8.2
|
||||
types-beautifulsoup4==4.12.0.5
|
||||
types-python-dateutil==2.8.19.13
|
||||
|
||||
# Flake8 + plugins
|
||||
flake8==4.0.1
|
||||
flake8-bandit==3.0.0
|
||||
flake8-bugbear==22.7.1
|
||||
flake8-docstrings==1.6.0
|
||||
flake8-dunder-all==0.2.1
|
||||
flake8-eradicate==1.2.1
|
||||
flake8-isort==4.1.1
|
||||
flake8-simplify==0.19.2
|
||||
flake8==6.0.0
|
||||
flake8-bandit==4.1.1
|
||||
flake8-bugbear==23.6.5
|
||||
flake8-docstrings==1.7.0
|
||||
flake8-dunder-all==0.3.0
|
||||
flake8-eradicate==1.5.0
|
||||
flake8-isort==6.0.0
|
||||
flake8-simplify==0.20.0
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
aiohttp==3.8.1
|
||||
alembic==1.8.0
|
||||
asyncpg==0.25.0
|
||||
beautifulsoup4==4.11.1
|
||||
discord.py==2.0.1
|
||||
aiohttp==3.8.4
|
||||
alembic==1.11.1
|
||||
asyncpg==0.28.0
|
||||
beautifulsoup4==4.12.2
|
||||
discord.py==2.3.1
|
||||
environs==9.5.0
|
||||
feedparser==6.0.10
|
||||
ics==0.7.2
|
||||
markdownify==0.11.2
|
||||
overrides==6.1.0
|
||||
pydantic==1.9.1
|
||||
markdownify==0.11.6
|
||||
overrides==7.3.1
|
||||
pydantic==2.0.2
|
||||
python-dateutil==2.8.2
|
||||
sqlalchemy[asyncio]==1.4.37
|
||||
sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18
|
||||
|
|
|
@ -111,7 +111,7 @@ class ScheduleInfo:
|
|||
|
||||
role_id: Optional[int]
|
||||
schedule_url: Optional[str]
|
||||
name: Optional[str] = None
|
||||
name: ScheduleType
|
||||
|
||||
|
||||
SCHEDULE_DATA = [
|
||||
|
|
|
@ -1,138 +0,0 @@
|
|||
from datetime import date, timedelta
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import wordle as crud
|
||||
from database.schemas import User, WordleGuess, WordleWord
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def wordle_guesses(postgres: AsyncSession, user: User) -> list[WordleGuess]:
|
||||
"""Fixture to generate some guesses"""
|
||||
guesses = []
|
||||
|
||||
for guess in ["TEST", "WORDLE", "WORDS"]:
|
||||
guess = WordleGuess(user_id=user.user_id, guess=guess)
|
||||
postgres.add(guess)
|
||||
await postgres.commit()
|
||||
|
||||
guesses.append(guess)
|
||||
|
||||
return guesses
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_active_wordle_game_none(postgres: AsyncSession, user: User):
|
||||
"""Test getting an active game when there is none"""
|
||||
result = await crud.get_active_wordle_game(postgres, user.user_id)
|
||||
assert not result
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_active_wordle_game(postgres: AsyncSession, wordle_guesses: list[WordleGuess]):
|
||||
"""Test getting an active game when there is one"""
|
||||
result = await crud.get_active_wordle_game(postgres, wordle_guesses[0].user_id)
|
||||
assert result == wordle_guesses
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_daily_word_none(postgres: AsyncSession):
|
||||
"""Test getting the daily word when the database is empty"""
|
||||
result = await crud.get_daily_word(postgres)
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_get_daily_word_not_today(postgres: AsyncSession):
|
||||
"""Test getting the daily word when there is an entry, but not for today"""
|
||||
day = date.today() - timedelta(days=1)
|
||||
|
||||
word = "testword"
|
||||
word_instance = WordleWord(word=word, day=day)
|
||||
postgres.add(word_instance)
|
||||
await postgres.commit()
|
||||
|
||||
assert await crud.get_daily_word(postgres) is None
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_get_daily_word_present(postgres: AsyncSession):
|
||||
"""Test getting the daily word when there is one for today"""
|
||||
day = date.today()
|
||||
|
||||
word = "testword"
|
||||
word_instance = WordleWord(word=word, day=day)
|
||||
postgres.add(word_instance)
|
||||
await postgres.commit()
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_none_present(postgres: AsyncSession):
|
||||
"""Test setting the daily word when there is none"""
|
||||
assert await crud.get_daily_word(postgres) is None
|
||||
word = "testword"
|
||||
await crud.set_daily_word(postgres, word)
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_present(postgres: AsyncSession):
|
||||
"""Test setting the daily word when there already is one"""
|
||||
word = "testword"
|
||||
await crud.set_daily_word(postgres, word)
|
||||
await crud.set_daily_word(postgres, "another word")
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_set_daily_word_force_overwrite(postgres: AsyncSession):
|
||||
"""Test setting the daily word when there already is one, but "forced" is set to True"""
|
||||
word = "testword"
|
||||
await crud.set_daily_word(postgres, word)
|
||||
word = "anotherword"
|
||||
await crud.set_daily_word(postgres, word, forced=True)
|
||||
|
||||
daily_word = await crud.get_daily_word(postgres)
|
||||
assert daily_word is not None
|
||||
assert daily_word.word == word
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_make_wordle_guess(postgres: AsyncSession, user: User):
|
||||
"""Test making a guess in your current game"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
guess = "guess"
|
||||
await crud.make_wordle_guess(postgres, test_user_id, guess)
|
||||
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess]
|
||||
|
||||
other_guess = "otherguess"
|
||||
await crud.make_wordle_guess(postgres, test_user_id, other_guess)
|
||||
assert await crud.get_wordle_guesses(postgres, test_user_id) == [guess, other_guess]
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_reset_wordle_games(postgres: AsyncSession, wordle_guesses: list[WordleGuess], user: User):
|
||||
"""Test dropping the collection of active games"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
assert await crud.get_active_wordle_game(postgres, test_user_id)
|
||||
await crud.reset_wordle_games(postgres)
|
||||
assert not await crud.get_active_wordle_game(postgres, test_user_id)
|
|
@ -1,72 +0,0 @@
|
|||
import datetime
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud import wordle_stats as crud
|
||||
from database.schemas import User, WordleStats
|
||||
|
||||
|
||||
async def insert_game_stats(postgres: AsyncSession, stats: WordleStats):
|
||||
"""Helper function to insert some stats"""
|
||||
postgres.add(stats)
|
||||
await postgres.commit()
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_stats_non_existent_creates(postgres: AsyncSession, user: User):
|
||||
"""Test getting a user's stats when the db is empty"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
statement = select(WordleStats).where(WordleStats.user_id == test_user_id)
|
||||
assert (await postgres.execute(statement)).scalar_one_or_none() is None
|
||||
|
||||
await crud.get_wordle_stats(postgres, test_user_id)
|
||||
assert (await postgres.execute(statement)).scalar_one_or_none() is not None
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
async def test_get_stats_existing_returns(postgres: AsyncSession, user: User):
|
||||
"""Test getting a user's stats when there's already an entry present"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
stats = WordleStats(user_id=test_user_id)
|
||||
stats.games = 20
|
||||
await insert_game_stats(postgres, stats)
|
||||
found_stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||
assert found_stats.games == 20
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_won(postgres: AsyncSession, user: User):
|
||||
"""Test completing a wordle game when you win"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
await crud.complete_wordle_game(postgres, test_user_id, win=True)
|
||||
stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||
assert stats.games == 1
|
||||
assert stats.wins == 1
|
||||
assert stats.current_streak == 1
|
||||
assert stats.highest_streak == 1
|
||||
assert stats.last_win == datetime.date.today()
|
||||
|
||||
|
||||
@pytest.mark.postgres
|
||||
@freeze_time("2022-07-30")
|
||||
async def test_complete_wordle_game_lost(postgres: AsyncSession, user: User):
|
||||
"""Test completing a wordle game when you lose"""
|
||||
test_user_id = user.user_id
|
||||
|
||||
stats = WordleStats(user_id=test_user_id)
|
||||
stats.current_streak = 10
|
||||
await insert_game_stats(postgres, stats)
|
||||
|
||||
await crud.complete_wordle_game(postgres, test_user_id, win=False)
|
||||
stats = await crud.get_wordle_stats(postgres, test_user_id)
|
||||
|
||||
# Check that streak was broken
|
||||
assert stats.current_streak == 0
|
||||
assert stats.games == 1
|
Loading…
Reference in New Issue