diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ba38355..acdc6a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,12 +3,12 @@ default_language_version: repos: - repo: https://github.com/ambv/black - rev: 22.3.0 + rev: 23.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.4.0 hooks: - id: check-json - id: end-of-file-fixer @@ -21,7 +21,7 @@ repos: - id: isort - repo: https://github.com/PyCQA/autoflake - rev: v1.4 + rev: v2.2.0 hooks: - id: autoflake name: autoflake (python) @@ -31,7 +31,7 @@ repos: - "--ignore-init-module-imports" - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 6.0.0 hooks: - id: flake8 exclude: ^(alembic|.github) diff --git a/alembic/versions/09128b6e34dd_migrate_to_2_x.py b/alembic/versions/09128b6e34dd_migrate_to_2_x.py new file mode 100644 index 0000000..ef4ee2c --- /dev/null +++ b/alembic/versions/09128b6e34dd_migrate_to_2_x.py @@ -0,0 +1,94 @@ +"""Migrate to 2.x + +Revision ID: 09128b6e34dd +Revises: 1e3e7f4192c4 +Create Date: 2023-07-07 16:23:15.990231 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "09128b6e34dd" +down_revision = "1e3e7f4192c4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("bank", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("birthdays", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("bookmarks", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("command_stats", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op: + batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=False) + + with op.batch_alter_table("deadlines", schema=None) as batch_op: + batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False) + + with op.batch_alter_table("github_links", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("nightly_data", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("reminders", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) + + with op.batch_alter_table("ufora_announcements", schema=None) as batch_op: + batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False) + batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=False) + + with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op: + batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False) + + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op: + batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True) + + with op.batch_alter_table("ufora_announcements", schema=None) as batch_op: + batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=True) + batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True) + + with op.batch_alter_table("reminders", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + with op.batch_alter_table("nightly_data", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + with op.batch_alter_table("github_links", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + with op.batch_alter_table("deadlines", schema=None) as batch_op: + batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True) + + with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op: + batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=True) + + with op.batch_alter_table("command_stats", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + with op.batch_alter_table("bookmarks", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + with op.batch_alter_table("birthdays", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + with op.batch_alter_table("bank", schema=None) as batch_op: + batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) + + # ### end Alembic commands ### diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index d696e50..925e645 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[ if query is not None: statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%")) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]: diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index efbc689..eac9c7e 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo async def get_all_commands(session: AsyncSession) -> list[CustomCommand]: """Get a list of all commands""" statement = select(CustomCommand) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index 78d623f..a539518 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -38,4 +38,4 @@ async def get_deadlines( statement = statement.where(Deadline.course_id == course.course_id) statement = statement.options(selectinload(Deadline.course)) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) diff --git a/database/crud/easter_eggs.py b/database/crud/easter_eggs.py index d4c25d9..f4a55ed 100644 --- a/database/crud/easter_eggs.py +++ b/database/crud/easter_eggs.py @@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"] async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: """Return a list of all easter eggs""" statement = select(EasterEgg) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) diff --git a/database/crud/events.py b/database/crud/events.py index 19887e6..e8dfac5 100644 --- a/database/crud/events.py +++ b/database/crud/events.py @@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: """Get a list of all upcoming events""" statement = select(Event).where(Event.timestamp > now) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: diff --git a/database/crud/free_games.py b/database/crud/free_games.py index b2d835d..e0dab5b 100644 --- a/database/crud/free_games.py +++ b/database/crud/free_games.py @@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]): async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]: """Filter a list of game IDs down to the ones that aren't in the database yet""" statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids)) - matches: list[int] = (await session.execute(statement)).scalars().all() + matches: list[int] = list((await session.execute(statement)).scalars().all()) return list(set(game_ids).difference(matches)) diff --git a/database/crud/github.py b/database/crud/github.py index 0d32377..a352bae 100644 --- a/database/crud/github.py +++ b/database/crud/github.py @@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]: """Get a user's GitHub links""" statement = select(GitHubLink).where(GitHubLink.user_id == user_id) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) diff --git a/database/crud/links.py b/database/crud/links.py index 495e0f3..20bcab3 100644 --- a/database/crud/links.py +++ b/database/crud/links.py @@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] async def get_all_links(session: AsyncSession) -> list[Link]: """Get a list of all links""" statement = select(Link) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def add_link(session: AsyncSession, name: str, url: str) -> Link: diff --git a/database/crud/memes.py b/database/crud/memes.py index ab288aa..b1ed1e0 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: """Get a list of all memes""" statement = select(MemeTemplate) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: diff --git a/database/crud/reminders.py b/database/crud/reminders.py index 007a779..78350e6 100644 --- a/database/crud/reminders.py +++ b/database/crud/reminders.py @@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"] async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]: """Get a list of all Reminders for a given category""" statement = select(Reminder).where(Reminder.category == category) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool: diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 688bcc7..06c2b58 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]: """Get all courses where announcements are enabled""" statement = select(UforaCourse).where(UforaCourse.log_announcements) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def create_new_announcement( diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index 5374c07..d4cf728 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor # Search case-insensitively query = query.lower() - statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) - result = (await session.execute(statement)).scalars().first() - if result: - return result + course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) + course_result = (await session.execute(course_statement)).scalars().first() + if course_result: + return course_result - statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) - result = (await session.execute(statement)).scalars().first() - return result.course if result else None + alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) + alias_result = (await session.execute(alias_statement)).scalars().first() + return alias_result.course if alias_result else None diff --git a/database/engine.py b/database/engine.py index 23e5b89..ec81bfb 100644 --- a/database/engine.py +++ b/database/engine.py @@ -1,8 +1,7 @@ from urllib.parse import quote_plus from sqlalchemy.engine import URL -from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine -from sqlalchemy.orm import sessionmaker +from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine import settings @@ -22,6 +21,4 @@ postgres_engine = create_async_engine( future=True, ) -DBSession = sessionmaker( - autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False -) +DBSession = async_sessionmaker(autocommit=False, autoflush=False, bind=postgres_engine, expire_on_commit=False) diff --git a/database/schemas.py b/database/schemas.py index 34efcbc..b92b2a2 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -1,27 +1,14 @@ from __future__ import annotations from datetime import date, datetime -from typing import Optional +from typing import List, Optional -from sqlalchemy import ( - BigInteger, - Boolean, - Column, - Date, - DateTime, - Enum, - ForeignKey, - Integer, - Text, - UniqueConstraint, -) -from sqlalchemy.orm import declarative_base, relationship +from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint +from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship +from sqlalchemy.types import DateTime from database import enums -Base = declarative_base() - - __all__ = [ "Base", "Bank", @@ -48,27 +35,34 @@ __all__ = [ ] +class Base(DeclarativeBase): + """Required base class for all tables""" + + # Make all DateTimes timezone-aware + type_annotation_map = {datetime: DateTime(timezone=True)} + + class Bank(Base): """A user's currency information""" __tablename__ = "bank" - bank_id: int = Column(Integer, primary_key=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + bank_id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - dinks: int = Column(BigInteger, server_default="0", nullable=False) - invested: int = Column(BigInteger, server_default="0", nullable=False) + dinks: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False) + invested: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False) # Interest rate - interest_level: int = Column(Integer, server_default="1", nullable=False) + interest_level: Mapped[int] = mapped_column(server_default="1", nullable=False) # Maximum amount that can be stored in the bank - capacity_level: int = Column(Integer, server_default="1", nullable=False) + capacity_level: Mapped[int] = mapped_column(server_default="1", nullable=False) # Maximum amount that can be robbed - rob_level: int = Column(Integer, server_default="1", nullable=False) + rob_level: Mapped[int] = mapped_column(server_default="1", nullable=False) - user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") + user: Mapped[User] = relationship(uselist=False, back_populates="bank", lazy="selectin") class Birthday(Base): @@ -76,11 +70,11 @@ class Birthday(Base): __tablename__ = "birthdays" - birthday_id: int = Column(Integer, primary_key=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - birthday: date = Column(Date, nullable=False) + birthday_id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + birthday: Mapped[date] = mapped_column(nullable=False) - user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") + user: Mapped[User] = relationship(uselist=False, back_populates="birthday", lazy="selectin") class Bookmark(Base): @@ -89,26 +83,26 @@ class Bookmark(Base): __tablename__ = "bookmarks" __table_args__ = (UniqueConstraint("user_id", "label"),) - bookmark_id: int = Column(Integer, primary_key=True) - label: str = Column(Text, nullable=False) - jump_url: str = Column(Text, nullable=False) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + bookmark_id: Mapped[int] = mapped_column(primary_key=True) + label: Mapped[str] = mapped_column(nullable=False) + jump_url: Mapped[str] = mapped_column(nullable=False) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin") + user: Mapped[User] = relationship(back_populates="bookmarks", uselist=False, lazy="selectin") class CommandStats(Base): """Metrics on how often commands are used""" __tablename__ = "command_stats" - command_stats_id: int = Column(Integer, primary_key=True) - command: str = Column(Text, nullable=False) - timestamp: datetime = Column(DateTime(timezone=True), nullable=False) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - slash: bool = Column(Boolean, nullable=False) - context_menu: bool = Column(Boolean, nullable=False) + command_stats_id: Mapped[int] = mapped_column(primary_key=True) + command: Mapped[str] = mapped_column(nullable=False) + timestamp: Mapped[datetime] = mapped_column(nullable=False) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + slash: Mapped[bool] = mapped_column(nullable=False) + context_menu: Mapped[bool] = mapped_column(nullable=False) - user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin") + user: Mapped[User] = relationship(back_populates="command_stats", uselist=False, lazy="selectin") class CustomCommand(Base): @@ -116,13 +110,13 @@ class CustomCommand(Base): __tablename__ = "custom_commands" - command_id: int = Column(Integer, primary_key=True) - name: str = Column(Text, nullable=False, unique=True) - indexed_name: str = Column(Text, nullable=False, index=True) - response: str = Column(Text, nullable=False) + command_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False, unique=True) + indexed_name: Mapped[str] = mapped_column(nullable=False, index=True) + response: Mapped[str] = mapped_column(nullable=False) - aliases: list[CustomCommandAlias] = relationship( - "CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" + aliases: Mapped[List[CustomCommandAlias]] = relationship( + back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" ) @@ -131,12 +125,12 @@ class CustomCommandAlias(Base): __tablename__ = "custom_command_aliases" - alias_id: int = Column(Integer, primary_key=True) - alias: str = Column(Text, nullable=False, unique=True) - indexed_alias: str = Column(Text, nullable=False, index=True) - command_id: int = Column(Integer, ForeignKey("custom_commands.command_id")) + alias_id: Mapped[int] = mapped_column(primary_key=True) + alias: Mapped[str] = mapped_column(nullable=False, unique=True) + indexed_alias: Mapped[str] = mapped_column(nullable=False, index=True) + command_id: Mapped[int] = mapped_column(ForeignKey("custom_commands.command_id")) - command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") + command: Mapped[CustomCommand] = relationship(back_populates="aliases", uselist=False, lazy="selectin") class DadJoke(Base): @@ -144,8 +138,8 @@ class DadJoke(Base): __tablename__ = "dad_jokes" - dad_joke_id: int = Column(Integer, primary_key=True) - joke: str = Column(Text, nullable=False) + dad_joke_id: Mapped[int] = mapped_column(primary_key=True) + joke: Mapped[str] = mapped_column(nullable=False) class Deadline(Base): @@ -153,12 +147,12 @@ class Deadline(Base): __tablename__ = "deadlines" - deadline_id: int = Column(Integer, primary_key=True) - course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) - name: str = Column(Text, nullable=False) - deadline: datetime = Column(DateTime(timezone=True), nullable=False) + deadline_id: Mapped[int] = mapped_column(primary_key=True) + course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id")) + name: Mapped[str] = mapped_column(nullable=False) + deadline: Mapped[datetime] = mapped_column(nullable=False) - course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin") + course: Mapped[UforaCourse] = relationship(back_populates="deadlines", uselist=False, lazy="selectin") class EasterEgg(Base): @@ -166,11 +160,11 @@ class EasterEgg(Base): __tablename__ = "easter_eggs" - easter_egg_id: int = Column(Integer, primary_key=True) - match: str = Column(Text, nullable=False) - response: str = Column(Text, nullable=False) - exact: bool = Column(Boolean, nullable=False, server_default="1") - startswith: bool = Column(Boolean, nullable=False, server_default="1") + easter_egg_id: Mapped[int] = mapped_column(primary_key=True) + match: Mapped[str] = mapped_column(nullable=False) + response: Mapped[str] = mapped_column(nullable=False) + exact: Mapped[bool] = mapped_column(nullable=False, server_default="1") + startswith: Mapped[bool] = mapped_column(nullable=False, server_default="1") class Event(Base): @@ -178,11 +172,11 @@ class Event(Base): __tablename__ = "events" - event_id: int = Column(Integer, primary_key=True) - name: str = Column(Text, nullable=False) - description: Optional[str] = Column(Text, nullable=True) - notification_channel: int = Column(BigInteger, nullable=False) - timestamp: datetime = Column(DateTime(timezone=True), nullable=False) + event_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False) + description: Mapped[Optional[str]] = mapped_column(nullable=True) + notification_channel: Mapped[int] = mapped_column(BigInteger, nullable=False) + timestamp: Mapped[datetime] = mapped_column(nullable=False) class FreeGame(Base): @@ -190,7 +184,7 @@ class FreeGame(Base): __tablename__ = "free_games" - free_game_id: int = Column(Integer, primary_key=True) + free_game_id: Mapped[int] = mapped_column(primary_key=True) class GitHubLink(Base): @@ -198,11 +192,11 @@ class GitHubLink(Base): __tablename__ = "github_links" - github_link_id: int = Column(Integer, primary_key=True) - url: str = Column(Text, nullable=False, unique=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + github_link_id: Mapped[int] = mapped_column(primary_key=True) + url: Mapped[str] = mapped_column(nullable=False, unique=True) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin") + user: Mapped[User] = relationship(back_populates="github_links", uselist=False, lazy="selectin") class Link(Base): @@ -210,9 +204,9 @@ class Link(Base): __tablename__ = "links" - link_id: int = Column(Integer, primary_key=True) - name: str = Column(Text, nullable=False, unique=True) - url: str = Column(Text, nullable=False) + link_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False, unique=True) + url: Mapped[str] = mapped_column(nullable=False) class MemeTemplate(Base): @@ -220,10 +214,10 @@ class MemeTemplate(Base): __tablename__ = "meme" - meme_id: int = Column(Integer, primary_key=True) - name: str = Column(Text, nullable=False, unique=True) - template_id: int = Column(Integer, nullable=False, unique=True) - field_count: int = Column(Integer, nullable=False) + meme_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False, unique=True) + template_id: Mapped[int] = mapped_column(nullable=False, unique=True) + field_count: Mapped[int] = mapped_column(nullable=False) class NightlyData(Base): @@ -231,12 +225,12 @@ class NightlyData(Base): __tablename__ = "nightly_data" - nightly_id: int = Column(Integer, primary_key=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - last_nightly: Optional[date] = Column(Date, nullable=True) - count: int = Column(Integer, server_default="0", nullable=False) + nightly_id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + last_nightly: Mapped[Optional[date]] = mapped_column(nullable=True) + count: Mapped[int] = mapped_column(server_default="0", nullable=False) - user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") + user: Mapped[User] = relationship(back_populates="nightly_data", uselist=False, lazy="selectin") class Reminder(Base): @@ -244,11 +238,11 @@ class Reminder(Base): __tablename__ = "reminders" - reminder_id: int = Column(Integer, primary_key=True) - user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False) + reminder_id: Mapped[int] = mapped_column(primary_key=True) + user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + category: Mapped[enums.ReminderCategory] = mapped_column(nullable=False) - user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin") + user: Mapped[User] = relationship(back_populates="reminders", uselist=False, lazy="selectin") class Task(Base): @@ -256,9 +250,9 @@ class Task(Base): __tablename__ = "tasks" - task_id: int = Column(Integer, primary_key=True) - task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True) - previous_run: datetime = Column(DateTime(timezone=True), nullable=True) + task_id: Mapped[int] = mapped_column(primary_key=True) + task: Mapped[enums.TaskType] = mapped_column(nullable=False, unique=True) + previous_run: Mapped[datetime] = mapped_column(nullable=True) class UforaCourse(Base): @@ -266,25 +260,25 @@ class UforaCourse(Base): __tablename__ = "ufora_courses" - course_id: int = Column(Integer, primary_key=True) - name: str = Column(Text, nullable=False, unique=True) - code: str = Column(Text, nullable=False, unique=True) - year: int = Column(Integer, nullable=False) - compulsory: bool = Column(Boolean, server_default="1", nullable=False) - role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) - overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) + course_id: Mapped[int] = mapped_column(primary_key=True) + name: Mapped[str] = mapped_column(nullable=False, unique=True) + code: Mapped[str] = mapped_column(nullable=False, unique=True) + year: Mapped[int] = mapped_column(nullable=False) + compulsory: Mapped[bool] = mapped_column(server_default="1", nullable=False) + role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False) + overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False) # This is not the greatest fix, but there can only ever be two, so it will do the job - alternative_overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) - log_announcements: bool = Column(Boolean, server_default="0", nullable=False) + alternative_overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False) + log_announcements: Mapped[bool] = mapped_column(server_default="0", nullable=False) - announcements: list[UforaAnnouncement] = relationship( - "UforaAnnouncement", back_populates="course", cascade="all, delete-orphan", lazy="selectin" + announcements: Mapped[List[UforaAnnouncement]] = relationship( + back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) - aliases: list[UforaCourseAlias] = relationship( - "UforaCourseAlias", back_populates="course", cascade="all, delete-orphan", lazy="selectin" + aliases: Mapped[List[UforaCourseAlias]] = relationship( + back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) - deadlines: list[Deadline] = relationship( - "Deadline", back_populates="course", cascade="all, delete-orphan", lazy="selectin" + deadlines: Mapped[List[Deadline]] = relationship( + back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) @@ -293,11 +287,11 @@ class UforaCourseAlias(Base): __tablename__ = "ufora_course_aliases" - alias_id: int = Column(Integer, primary_key=True) - alias: str = Column(Text, nullable=False, unique=True) - course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) + alias_id: Mapped[int] = mapped_column(primary_key=True) + alias: Mapped[str] = mapped_column(nullable=False, unique=True) + course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id")) - course: UforaCourse = relationship("UforaCourse", back_populates="aliases", uselist=False, lazy="selectin") + course: Mapped[UforaCourse] = relationship(back_populates="aliases", uselist=False, lazy="selectin") class UforaAnnouncement(Base): @@ -305,11 +299,11 @@ class UforaAnnouncement(Base): __tablename__ = "ufora_announcements" - announcement_id: int = Column(Integer, primary_key=True) - course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) - publication_date: date = Column(Date) + announcement_id: Mapped[int] = mapped_column(primary_key=True) + course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id")) + publication_date: Mapped[date] = mapped_column() - course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") + course: Mapped[UforaCourse] = relationship(back_populates="announcements", uselist=False, lazy="selectin") class User(Base): @@ -317,26 +311,26 @@ class User(Base): __tablename__ = "users" - user_id: int = Column(BigInteger, primary_key=True) + user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) - bank: Bank = relationship( - "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + bank: Mapped[Bank] = relationship( + back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - birthday: Optional[Birthday] = relationship( - "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + birthday: Mapped[Optional[Birthday]] = relationship( + back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - bookmarks: list[Bookmark] = relationship( - "Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + bookmarks: Mapped[List[Bookmark]] = relationship( + back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) - command_stats: list[CommandStats] = relationship( - "CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + command_stats: Mapped[List[CommandStats]] = relationship( + back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) - github_links: list[GitHubLink] = relationship( - "GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + github_links: Mapped[List[GitHubLink]] = relationship( + back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) - nightly_data: NightlyData = relationship( - "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + nightly_data: Mapped[NightlyData] = relationship( + back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - reminders: list[Reminder] = relationship( - "Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + reminders: Mapped[List[Reminder]] = relationship( + back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/database/utils/caches.py b/database/utils/caches.py index 248eb5f..2df9dac 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -69,7 +69,7 @@ class LinkCache(DatabaseCache): self.clear() all_links = await links.get_all_links(database_session) - self.data = list(map(lambda l: l.name, all_links)) + self.data = list(map(lambda link: link.name, all_links)) self.data.sort() self.data_transformed = list(map(str.lower, self.data)) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 709a461..4049654 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -25,7 +25,7 @@ class Currency(commands.Cog): super().__init__() self.client = client - @commands.command(name="award") + @commands.command(name="award") # type: ignore[arg-type] @commands.check(is_owner) async def award( self, @@ -49,7 +49,9 @@ class Currency(commands.Cog): bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue()) - embed.set_thumbnail(url=ctx.author.avatar.url) + + if ctx.author.avatar is not None: + embed.set_thumbnail(url=ctx.author.avatar.url) embed.add_field(name="Interest level", value=bank.interest_level) embed.add_field(name="Capacity level", value=bank.capacity_level) @@ -57,7 +59,9 @@ class Currency(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @bank.group(name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True) + @bank.group( # type: ignore[arg-type] + name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True + ) async def bank_upgrades(self, ctx: commands.Context): """List the upgrades you can buy & their prices.""" async with self.client.postgres_session as session: @@ -77,7 +81,7 @@ class Currency(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @bank_upgrades.command(name="capacity", aliases=["c"]) + @bank_upgrades.command(name="capacity", aliases=["c"]) # type: ignore[arg-type] async def bank_upgrade_capacity(self, ctx: commands.Context): """Upgrade the capacity level of your bank.""" async with self.client.postgres_session as session: @@ -88,7 +92,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @bank_upgrades.command(name="interest", aliases=["i"]) + @bank_upgrades.command(name="interest", aliases=["i"]) # type: ignore[arg-type] async def bank_upgrade_interest(self, ctx: commands.Context): """Upgrade the interest level of your bank.""" async with self.client.postgres_session as session: @@ -99,7 +103,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @bank_upgrades.command(name="rob", aliases=["r"]) + @bank_upgrades.command(name="rob", aliases=["r"]) # type: ignore[arg-type] async def bank_upgrade_rob(self, ctx: commands.Context): """Upgrade the rob level of your bank.""" async with self.client.postgres_session as session: @@ -110,7 +114,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @commands.hybrid_command(name="dinks") + @commands.hybrid_command(name="dinks") # type: ignore[arg-type] async def dinks(self, ctx: commands.Context): """Check your Didier Dinks.""" async with self.client.postgres_session as session: @@ -118,7 +122,7 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) - @commands.command(name="invest", aliases=["deposit", "dep"]) + @commands.command(name="invest", aliases=["deposit", "dep"]) # type: ignore[arg-type] async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]): """Invest `amount` Didier Dinks into your bank. @@ -144,7 +148,7 @@ class Currency(commands.Cog): f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False ) - @commands.hybrid_command(name="nightly") + @commands.hybrid_command(name="nightly") # type: ignore[arg-type] async def nightly(self, ctx: commands.Context): """Claim nightly Didier Dinks.""" async with self.client.postgres_session as session: diff --git a/didier/cogs/debug_cog.py b/didier/cogs/debug_cog.py index 2d03b9f..a0e4747 100644 --- a/didier/cogs/debug_cog.py +++ b/didier/cogs/debug_cog.py @@ -13,7 +13,7 @@ class DebugCog(commands.Cog): self.client = client @overrides - async def cog_check(self, ctx: commands.Context) -> bool: + async def cog_check(self, ctx: commands.Context) -> bool: # type:ignore[override] return await self.client.is_owner(ctx.author) @commands.command(aliases=["Dev"]) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index 4d9b423..fdfa05d 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -1,4 +1,4 @@ -from typing import Optional +from typing import Optional, Union, cast import discord from discord import app_commands @@ -17,6 +17,7 @@ from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar +from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.constants import Limits from didier.utils.timer import Timer from didier.utils.types.datetime import localize, str_to_date, tz_aware_now @@ -60,9 +61,19 @@ class Discord(commands.Cog): event = await events.get_event_by_id(session, event_id) if event is None: - return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True) + return await self.client.log_error(f"Unable to find event with id {event_id}.", log_to_discord=True) channel = self.client.get_channel(event.notification_channel) + if channel is None: + return await self.client.log_error( + f"Unable to fetch channel for event `#{event_id}` (id `{event.notification_channel}`)." + ) + + if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): + return await self.client.log_error( + f"Channel for event `#{event_id}` (id `{event.notification_channel}`) is not messageable." + ) + human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M") embed = discord.Embed(title=event.name, colour=discord.Colour.blue()) @@ -81,7 +92,7 @@ class Discord(commands.Cog): self.client.loop.create_task(self.timer.update()) @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) - async def birthday(self, ctx: commands.Context, user: discord.User = None): + async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None): """Command to check the birthday of `user`. Not passing an argument for `user` will show yours instead. @@ -98,8 +109,10 @@ class Discord(commands.Cog): day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False) - @birthday.command(name="set", aliases=["config"]) - async def birthday_set(self, ctx: commands.Context, day: str, user: Optional[discord.User] = None): + @birthday.command(name="set", aliases=["config"]) # type: ignore[arg-type] + async def birthday_set( + self, ctx: commands.Context, day: str, user: Optional[Union[discord.User, discord.Member]] = None + ): """Set your birthday to `day`. Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`. @@ -113,6 +126,9 @@ class Discord(commands.Cog): if user is None: user = ctx.author + # Please Mypy + user = cast(Union[discord.User, discord.Member], user) + try: default_year = 2001 date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"]) @@ -141,7 +157,7 @@ class Discord(commands.Cog): """ # No label: shortcut to display bookmarks if label is None: - return await self.bookmark_search(ctx, query=None) + return await self.bookmark_search(ctx, query=None) # type: ignore[arg-type] async with self.client.postgres_session as session: result = expect( @@ -151,7 +167,7 @@ class Discord(commands.Cog): ) await ctx.reply(result.jump_url, mention_author=False) - @bookmark.command(name="create", aliases=["new"]) + @bookmark.command(name="create", aliases=["new"]) # type: ignore[arg-type] async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]): """Create a new bookmark for message `message` with label `label`. @@ -182,7 +198,7 @@ class Discord(commands.Cog): # Label isn't allowed return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) - @bookmark.command(name="delete", aliases=["rm"]) + @bookmark.command(name="delete", aliases=["rm"]) # type: ignore[arg-type] async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): """Delete the bookmark with id `bookmark_id`. @@ -207,7 +223,7 @@ class Discord(commands.Cog): return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) - @bookmark.command(name="search", aliases=["list", "ls"]) + @bookmark.command(name="search", aliases=["list", "ls"]) # type: ignore[arg-type] async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): """Search through the list of bookmarks. @@ -236,7 +252,7 @@ class Discord(commands.Cog): modal = CreateBookmark(self.client, message.jump_url) await interaction.response.send_modal(modal) - @commands.hybrid_command(name="events") + @commands.hybrid_command(name="events") # type: ignore[arg-type] @app_commands.rename(event_id="id") @app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.") async def events(self, ctx: commands.Context, event_id: Optional[int] = None): @@ -276,16 +292,16 @@ class Discord(commands.Cog): embed.add_field( name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True ) - embed.add_field( - name="Channel", - value=self.client.get_channel(result_event.notification_channel).mention, - inline=False, - ) + + channel = self.client.get_channel(result_event.notification_channel) + if channel is not None and not isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): + embed.add_field(name="Channel", value=channel.mention, inline=False) + embed.description = result_event.description return await ctx.reply(embed=embed, mention_author=False) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) - async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): + async def github_group(self, ctx: commands.Context, user: Optional[Union[discord.User, discord.Member]] = None): """Show a user's GitHub links. If no user is provided, this shows your links instead. @@ -293,6 +309,9 @@ class Discord(commands.Cog): # Default to author user = user or ctx.author + # Please Mypy + user = cast(Union[discord.User, discord.Member], user) + embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") embed.set_author(name=user.display_name, icon_url=get_user_avatar(user)) @@ -324,7 +343,7 @@ class Discord(commands.Cog): return await ctx.reply(embed=embed, mention_author=False) - @github_group.command(name="add", aliases=["create", "insert"]) + @github_group.command(name="add", aliases=["create", "insert"]) # type: ignore[arg-type] async def github_add(self, ctx: commands.Context, link: str): """Add a new link into the database.""" # Remove wrapping which can be used to escape Discord embeds @@ -339,7 +358,7 @@ class Discord(commands.Cog): await self.client.confirm_message(ctx.message) return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False) - @github_group.command(name="delete", aliases=["del", "remove", "rm"]) + @github_group.command(name="delete", aliases=["del", "remove", "rm"]) # type: ignore[arg-type] async def github_delete(self, ctx: commands.Context, link_id: str): """Delete the link with it `link_id` from the database. @@ -411,7 +430,7 @@ class Discord(commands.Cog): await message.add_reaction("📌") return await interaction.response.send_message("📌", ephemeral=True) - @commands.hybrid_command(name="snipe") + @commands.hybrid_command(name="snipe") # type: ignore[arg-type] async def snipe(self, ctx: commands.Context): """Publicly shame people when they edit or delete one of their messages. @@ -420,7 +439,7 @@ class Discord(commands.Cog): if ctx.guild is None: return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True) - sniped_data = self.client.sniped.get(ctx.channel.id, None) + sniped_data = self.client.sniped.get(ctx.channel.id) if sniped_data is None: return await ctx.reply( "There's no one to make fun of in this channel.", mention_author=False, ephemeral=True diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index e824ab2..4ccfb2a 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -28,7 +28,7 @@ class Fun(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="clap") + @commands.hybrid_command(name="clap") # type: ignore[arg-type] async def clap(self, ctx: commands.Context, *, text: str): """Clap a message with emojis for extra dramatic effect""" chars = list(filter(lambda c: c in constants.EMOJI_MAP, text)) @@ -50,10 +50,7 @@ class Fun(commands.Cog): meme = await generate_meme(self.client.http_session, result, fields) return meme - @commands.hybrid_command( - name="dadjoke", - aliases=["dad", "dj"], - ) + @commands.hybrid_command(name="dadjoke", aliases=["dad", "dj"]) # type: ignore[arg-type] async def dad_joke(self, ctx: commands.Context): """Why does Yoda's code always crash? Because there is no try.""" async with self.client.postgres_session as session: @@ -83,13 +80,13 @@ class Fun(commands.Cog): return await self.memegen_ls_msg(ctx) if fields is None: - return await self.memegen_preview_msg(ctx, template) + return await self.memegen_preview_msg(ctx, template) # type: ignore[arg-type] async with ctx.typing(): meme = await self._do_generate_meme(template, shlex.split(fields)) return await ctx.reply(meme, mention_author=False) - @memegen_msg.command(name="list", aliases=["ls"]) + @memegen_msg.command(name="list", aliases=["ls"]) # type: ignore[arg-type] async def memegen_ls_msg(self, ctx: commands.Context): """Get a list of all available meme templates. @@ -100,14 +97,14 @@ class Fun(commands.Cog): await MemeSource(ctx, results).start() - @memegen_msg.command(name="preview", aliases=["p"]) + @memegen_msg.command(name="preview", aliases=["p"]) # type: ignore[arg-type] async def memegen_preview_msg(self, ctx: commands.Context, template: str): """Generate a preview for the meme template `template`, to see how the fields are structured.""" async with ctx.typing(): meme = await self._do_generate_meme(template, []) return await ctx.reply(meme, mention_author=False) - @memes_slash.command(name="generate") + @memes_slash.command(name="generate") # type: ignore[arg-type] async def memegen_slash(self, interaction: discord.Interaction, template: str): """Generate a meme.""" async with self.client.postgres_session as session: @@ -116,7 +113,7 @@ class Fun(commands.Cog): modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) - @memes_slash.command(name="preview") + @memes_slash.command(name="preview") # type: ignore[arg-type] @app_commands.describe(template="The meme template to use in the preview.") async def memegen_preview_slash(self, interaction: discord.Interaction, template: str): """Generate a preview for a meme, to see how the fields are structured.""" @@ -134,7 +131,7 @@ class Fun(commands.Cog): """Autocompletion for the 'template'-parameter""" return self.client.database_caches.memes.get_autocomplete_suggestions(current) - @app_commands.command() + @app_commands.command() # type: ignore[arg-type] @app_commands.describe(message="The text to convert.") async def mock(self, interaction: discord.Interaction, message: str): """Mock a message. @@ -158,7 +155,7 @@ class Fun(commands.Cog): return await interaction.followup.send(mock(message)) - @commands.hybrid_command(name="xkcd") + @commands.hybrid_command(name="xkcd") # type: ignore[arg-type] @app_commands.rename(comic_id="id") async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None): """Fetch comic `#id` from xkcd. diff --git a/didier/cogs/help.py b/didier/cogs/help.py index 459f802..57cad6d 100644 --- a/didier/cogs/help.py +++ b/didier/cogs/help.py @@ -159,6 +159,9 @@ class CustomHelpCommand(commands.MinimalHelpCommand): Code in codeblocks is ignored, as it is used to create examples. """ description = command.help + if description is None: + return "" + codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL) # Regex borrowed from https://stackoverflow.com/a/59843498/13568999 @@ -198,13 +201,10 @@ class CustomHelpCommand(commands.MinimalHelpCommand): return None - async def _filter_cogs(self, cogs: list[commands.Cog]) -> list[commands.Cog]: + async def _filter_cogs(self, cogs: list[Optional[commands.Cog]]) -> list[commands.Cog]: """Filter the list of cogs down to all those that the user can see""" - async def _predicate(cog: Optional[commands.Cog]) -> bool: - if cog is None: - return False - + async def _predicate(cog: commands.Cog) -> bool: # Remove cogs that we never want to see in the help page because they # don't contain commands, or shouldn't be visible at all if not cog.get_commands(): @@ -220,12 +220,12 @@ class CustomHelpCommand(commands.MinimalHelpCommand): return True # Filter list of cogs down - filtered_cogs = [cog for cog in cogs if await _predicate(cog)] + filtered_cogs = [cog for cog in cogs if cog is not None and await _predicate(cog)] return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name)) def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]: """Check if a command has flags""" - flag_param = command.params.get("flags", None) + flag_param = command.params.get("flags") if flag_param is None: return None diff --git a/didier/cogs/meta.py b/didier/cogs/meta.py index c330dbd..861bf58 100644 --- a/didier/cogs/meta.py +++ b/didier/cogs/meta.py @@ -1,6 +1,6 @@ import inspect import os -from typing import Optional +from typing import Any, Optional, Union from discord.ext import commands @@ -76,18 +76,24 @@ class Meta(commands.Cog): if command_name is None: return await ctx.reply(repo_home, mention_author=False) + command: Optional[Union[commands.HelpCommand, commands.Command]] + src: Any + if command_name == "help": command = self.client.help_command + if command is None: + return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) + src = type(self.client.help_command) filename = inspect.getsourcefile(src) else: command = self.client.get_command(command_name) + if command is None: + return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) + src = command.callback.__code__ filename = src.co_filename - if command is None: - return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) - lines, first_line = inspect.getsourcelines(src) if filename is None: diff --git a/didier/cogs/other.py b/didier/cogs/other.py index 02c0095..a48cb5e 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -22,7 +22,7 @@ class Other(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) + @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) # type: ignore[arg-type] async def covid(self, ctx: commands.Context, country: str = "Belgium"): """Show Covid-19 info for a specific country. @@ -43,7 +43,7 @@ class Other(commands.Cog): """Autocompletion for the 'country'-parameter""" return autocomplete_country(value)[:25] - @commands.hybrid_command( + @commands.hybrid_command( # type: ignore[arg-type] name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary" ) async def define(self, ctx: commands.Context, *, query: str): @@ -55,7 +55,7 @@ class Other(commands.Cog): mention_author=False, ) - @commands.hybrid_command(name="google", description="Google search") + @commands.hybrid_command(name="google", description="Google search") # type: ignore[arg-type] @app_commands.describe(query="Search query") async def google(self, ctx: commands.Context, *, query: str): """Show the Google search results for `query`. @@ -71,7 +71,7 @@ class Other(commands.Cog): embed = GoogleSearch(results).to_embed() await ctx.reply(embed=embed, mention_author=False) - @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") + @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") # type: ignore[arg-type] async def inspire(self, ctx: commands.Context): """Generate an [InspiroBot](https://inspirobot.me/) quote.""" async with ctx.typing(): @@ -82,7 +82,7 @@ class Other(commands.Cog): async with self.client.postgres_session as session: return await get_link_by_name(session, name.lower()) - @commands.command(name="Link", aliases=["Links"]) + @commands.command(name="Link", aliases=["Links"]) # type: ignore[arg-type] async def link_msg(self, ctx: commands.Context, name: str): """Get the link to the resource named `name`.""" link = await self._get_link(name) @@ -92,7 +92,7 @@ class Other(commands.Cog): target_message = await self.client.get_reply_target(ctx) await target_message.reply(link.url, mention_author=False) - @app_commands.command(name="link") + @app_commands.command(name="link") # type: ignore[arg-type] @app_commands.describe(name="The name of the resource") async def link_slash(self, interaction: discord.Interaction, name: str): """Get the link to something.""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 139f02c..1f72eff 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -42,7 +42,7 @@ class Owner(commands.Cog): def __init__(self, client: Didier): self.client = client - async def cog_check(self, ctx: commands.Context) -> bool: + async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override] """Global check for every command in this cog This means that we don't have to add is_owner() to every single command separately @@ -102,7 +102,7 @@ class Owner(commands.Cog): async def add_msg(self, ctx: commands.Context): """Command group for [add X] message commands""" - @add_msg.command(name="Alias") + @add_msg.command(name="Alias") # type: ignore[arg-type] async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" async with self.client.postgres_session as session: @@ -116,7 +116,7 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) - @add_msg.command(name="Custom") + @add_msg.command(name="Custom") # type: ignore[arg-type] async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): """Add a new custom command""" async with self.client.postgres_session as session: @@ -127,7 +127,7 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) - @add_msg.command(name="Link") + @add_msg.command(name="Link") # type: ignore[arg-type] async def add_link_msg(self, ctx: commands.Context, name: str, url: str): """Add a new link""" async with self.client.postgres_session as session: @@ -136,7 +136,7 @@ class Owner(commands.Cog): await self.client.confirm_message(ctx.message) - @add_slash.command(name="custom", description="Add a custom command") + @add_slash.command(name="custom", description="Add a custom command") # type: ignore[arg-type] async def add_custom_slash(self, interaction: discord.Interaction): """Slash command to add a custom command""" if not await self.client.is_owner(interaction.user): @@ -145,7 +145,7 @@ class Owner(commands.Cog): modal = CreateCustomCommand(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="dadjoke", description="Add a dad joke") + @add_slash.command(name="dadjoke", description="Add a dad joke") # type: ignore[arg-type] async def add_dad_joke_slash(self, interaction: discord.Interaction): """Slash command to add a dad joke""" if not await self.client.is_owner(interaction.user): @@ -154,7 +154,7 @@ class Owner(commands.Cog): modal = AddDadJoke(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="deadline", description="Add a deadline") + @add_slash.command(name="deadline", description="Add a deadline") # type: ignore[arg-type] @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)") async def add_deadline_slash(self, interaction: discord.Interaction, course: str): """Slash command to add a deadline""" @@ -174,7 +174,7 @@ class Owner(commands.Cog): """Autocompletion for the 'course'-parameter""" return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) - @add_slash.command(name="event", description="Add a new event") + @add_slash.command(name="event", description="Add a new event") # type: ignore[arg-type] async def add_event_slash(self, interaction: discord.Interaction): """Slash command to add new events""" if not await self.client.is_owner(interaction.user): @@ -183,7 +183,7 @@ class Owner(commands.Cog): modal = AddEvent(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="link", description="Add a new link") + @add_slash.command(name="link", description="Add a new link") # type: ignore[arg-type] async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" if not await self.client.is_owner(interaction.user): @@ -192,7 +192,7 @@ class Owner(commands.Cog): modal = AddLink(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="meme", description="Add a new meme") + @add_slash.command(name="meme", description="Add a new meme") # type: ignore[arg-type] async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): """Slash command to add new memes""" await interaction.response.defer(ephemeral=True) @@ -205,11 +205,11 @@ class Owner(commands.Cog): await interaction.followup.send(f"Added meme `{meme.meme_id}`.") await self.client.database_caches.memes.invalidate(session) - @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) # type: ignore[arg-type] async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" - @edit_msg.command(name="Custom") + @edit_msg.command(name="Custom") # type: ignore[arg-type] async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): """Edit an existing custom command""" async with self.client.postgres_session as session: @@ -220,7 +220,7 @@ class Owner(commands.Cog): await ctx.reply(f"No command found matching `{command}`.") return await self.client.reject_message(ctx.message) - @edit_slash.command(name="custom", description="Edit a custom command") + @edit_slash.command(name="custom", description="Edit a custom command") # type: ignore[arg-type] @app_commands.describe(command="The name of the command to edit") async def edit_custom_slash(self, interaction: discord.Interaction, command: str): """Slash command to edit a custom command""" diff --git a/didier/cogs/school.py b/didier/cogs/school.py index 7af8a81..cd9366a 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -27,7 +27,7 @@ class School(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="deadlines") + @commands.hybrid_command(name="deadlines") # type: ignore[arg-type] async def deadlines(self, ctx: commands.Context): """Show upcoming deadlines.""" async with ctx.typing(): @@ -40,7 +40,7 @@ class School(commands.Cog): embed = Deadlines(deadlines).to_embed() await ctx.reply(embed=embed, mention_author=False, ephemeral=False) - @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) + @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) # type: ignore[arg-type] @app_commands.rename(day_dt="date") async def les( self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None @@ -72,10 +72,7 @@ class School(commands.Cog): except NotInMainGuildException: return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False) - @commands.hybrid_command( - name="menu", - aliases=["eten", "food"], - ) + @commands.hybrid_command(name="menu", aliases=["eten", "food"]) # type: ignore[arg-type] @app_commands.rename(day_dt="date") async def menu( self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None @@ -96,7 +93,7 @@ class School(commands.Cog): embed = no_menu_found(day_dt) await ctx.reply(embed=embed, mention_author=False) - @commands.hybrid_command( + @commands.hybrid_command( # type: ignore[arg-type] name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"] ) @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)") @@ -124,7 +121,7 @@ class School(commands.Cog): mention_author=False, ) - @commands.hybrid_command(name="ufora") + @commands.hybrid_command(name="ufora") # type: ignore[arg-type] async def ufora(self, ctx: commands.Context, course: str): """Link the Ufora page for a course.""" async with self.client.postgres_session as session: diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 07d6508..f59d697 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,4 +1,6 @@ +import asyncio import datetime +import logging import random import discord @@ -20,9 +22,12 @@ from didier.data.embeds.schedules import ( from didier.data.rss_feeds.free_games import fetch_free_games from didier.data.rss_feeds.ufora import fetch_ufora_announcements from didier.decorators.tasks import timed_task +from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.checks import is_owner from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now +logger = logging.getLogger(__name__) + # datetime.time()-instances for when every task should run DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) @@ -56,7 +61,7 @@ class Tasks(commands.Cog): } @overrides - def cog_load(self) -> None: + async def cog_load(self) -> None: # Only check birthdays if there's a channel to send it to if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None: self.check_birthdays.start() @@ -72,9 +77,10 @@ class Tasks(commands.Cog): # Start other tasks self.reminders.start() + asyncio.create_task(self.get_error_channel()) @overrides - def cog_unload(self) -> None: + async def cog_unload(self) -> None: # Cancel all pending tasks for task in self._tasks.values(): if task.is_running(): @@ -96,7 +102,7 @@ class Tasks(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") + @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") # type: ignore[arg-type] async def force_task(self, ctx: commands.Context, name: str): """Command to force-run a task without waiting for the specified run time""" name = name.lower() @@ -107,23 +113,53 @@ class Tasks(commands.Cog): await task(forced=True) await self.client.confirm_message(ctx.message) + async def get_error_channel(self): + """Get the configured channel from the cache""" + await self.client.wait_until_ready() + + # Configure channel to send errors to + if settings.ERRORS_CHANNEL is not None: + channel = self.client.get_channel(settings.ERRORS_CHANNEL) + + if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): + logger.error(f"Configured error channel (id `{settings.ERRORS_CHANNEL}`) is not messageable.") + else: + self.client.error_channel = channel + elif self.client.owner_id is not None: + self.client.error_channel = self.client.get_user(self.client.owner_id) + @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) @timed_task(enums.TaskType.BIRTHDAYS) async def check_birthdays(self, **kwargs): """Check if it's currently anyone's birthday""" _ = kwargs + # Can't happen (task isn't started if this is None), but Mypy doesn't know + if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is None: + return + now = tz_aware_now().date() async with self.client.postgres_session as session: birthdays = await get_birthdays_on_day(session, now) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) if channel is None: - return await self.client.log_error("Unable to find channel for birthday announcements") + return await self.client.log_error("Unable to fetch channel for birthday announcements.") + + if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): + return await self.client.log_error( + f"Birthday announcement channel (id `{settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL}`) is not messageable." + ) for birthday in birthdays: user = self.client.get_user(birthday.user_id) + if user is None: + await self.client.log_error( + f"Unable to fetch user with id `{birthday.user_id}` for birthday announcement" + ) + continue + await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention)) @check_birthdays.before_loop @@ -143,6 +179,14 @@ class Tasks(commands.Cog): games = await fetch_free_games(self.client.http_session, session) channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL) + if channel is None: + return await self.client.log_error("Unable to fetch channel for free games announcements.") + + if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): + return await self.client.log_error( + f"Free games channel (id `{settings.FREE_GAMES_CHANNEL}`) is not messageable." + ) + for game in games: await channel.send(embed=game.to_embed()) @@ -204,6 +248,17 @@ class Tasks(commands.Cog): async with self.client.postgres_session as db_session: announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) + + if announcements_channel is None: + return await self.client.log_error( + f"Unable to fetch channel for ufora announcements (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`)." + ) + + if isinstance(announcements_channel, NON_MESSAGEABLE_CHANNEL_TYPES): + return await self.client.log_error( + f"Ufora announcements channel (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`) is not messageable." + ) + announcements = await fetch_ufora_announcements(self.client.http_session, db_session) for announcement in announcements: diff --git a/didier/data/apis/disease_sh.py b/didier/data/apis/disease_sh.py index 809cdcb..a32daea 100644 --- a/didier/data/apis/disease_sh.py +++ b/didier/data/apis/disease_sh.py @@ -19,7 +19,7 @@ async def get_country_info(http_session: ClientSession, country: str) -> CovidDa yesterday = response data = {"today": today, "yesterday": yesterday} - return CovidData.parse_obj(data) + return CovidData.model_validate(data) async def get_global_info(http_session: ClientSession) -> CovidData: @@ -35,4 +35,4 @@ async def get_global_info(http_session: ClientSession) -> CovidData: yesterday = response data = {"today": today, "yesterday": yesterday} - return CovidData.parse_obj(data) + return CovidData.model_validate(data) diff --git a/didier/data/apis/hydra.py b/didier/data/apis/hydra.py index 620e0df..8d7889b 100644 --- a/didier/data/apis/hydra.py +++ b/didier/data/apis/hydra.py @@ -12,4 +12,4 @@ async def fetch_menu(http_session: ClientSession, day_dt: date) -> Menu: """Fetch the menu for a given day""" endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json" async with ensure_get(http_session, endpoint, log_exceptions=False) as response: - return Menu.parse_obj(response) + return Menu.model_validate(response) diff --git a/didier/data/apis/urban_dictionary.py b/didier/data/apis/urban_dictionary.py index 6d81934..a6b5cd6 100644 --- a/didier/data/apis/urban_dictionary.py +++ b/didier/data/apis/urban_dictionary.py @@ -14,4 +14,4 @@ async def lookup(http_session: ClientSession, query: str) -> list[Definition]: url = "https://api.urbandictionary.com/v0/define" async with ensure_get(http_session, url, params={"term": query}) as response: - return list(map(Definition.parse_obj, response["list"])) + return list(map(Definition.model_validate, response["list"])) diff --git a/didier/data/apis/xkcd.py b/didier/data/apis/xkcd.py index c0ad766..bf8ff4d 100644 --- a/didier/data/apis/xkcd.py +++ b/didier/data/apis/xkcd.py @@ -13,4 +13,4 @@ async def fetch_xkcd_post(http_session: ClientSession, *, num: Optional[int] = N url = "https://xkcd.com" + (f"/{num}" if num is not None else "") + "/info.0.json" async with ensure_get(http_session, url) as response: - return XKCDPost.parse_obj(response) + return XKCDPost.model_validate(response) diff --git a/didier/data/embeds/disease_sh.py b/didier/data/embeds/disease_sh.py index 45e8895..a344828 100644 --- a/didier/data/embeds/disease_sh.py +++ b/didier/data/embeds/disease_sh.py @@ -1,6 +1,6 @@ import discord from overrides import overrides -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator from didier.data.embeds.base import EmbedPydantic @@ -24,7 +24,7 @@ class _CovidNumbers(BaseModel): active: int tests: int - @validator("updated") + @field_validator("updated") def updated_to_seconds(cls, value: int) -> int: """Turn the updated field into seconds instead of milliseconds""" return int(value) // 1000 diff --git a/didier/data/embeds/error_embed.py b/didier/data/embeds/error_embed.py index ea03bfe..696dd80 100644 --- a/didier/data/embeds/error_embed.py +++ b/didier/data/embeds/error_embed.py @@ -38,10 +38,10 @@ def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> embed = discord.Embed(title="Error", colour=discord.Colour.red()) if ctx is not None: - if ctx.guild is None: + if ctx.guild is None or isinstance(ctx.channel, discord.DMChannel): origin = "DM" else: - origin = f"{ctx.channel.mention} ({ctx.guild.name})" + origin = f"<#{ctx.channel.id}> ({ctx.guild.name})" invocation = f"{ctx.author.display_name} in {origin}" diff --git a/didier/data/embeds/free_games.py b/didier/data/embeds/free_games.py index d37e0b7..f159157 100644 --- a/didier/data/embeds/free_games.py +++ b/didier/data/embeds/free_games.py @@ -4,18 +4,17 @@ from typing import Optional import discord from aiohttp import ClientSession from overrides import overrides -from pydantic import validator +from pydantic import field_validator from didier.data.embeds.base import EmbedPydantic from didier.data.scrapers.common import GameStorePage from didier.data.scrapers.steam import get_steam_webpage_info from didier.utils.discord import colours - -__all__ = ["SEPARATOR", "FreeGameEmbed"] - from didier.utils.discord.constants import Limits from didier.utils.types.string import abbreviate +__all__ = ["SEPARATOR", "FreeGameEmbed"] + SEPARATOR = " • Free • " @@ -58,7 +57,7 @@ class FreeGameEmbed(EmbedPydantic): store_page: Optional[GameStorePage] = None - @validator("title") + @field_validator("title") def _clean_title(cls, value: str) -> str: return html.unescape(value) @@ -107,7 +106,6 @@ class FreeGameEmbed(EmbedPydantic): embed.add_field(name="Open in browser", value=f"[{self.link}]({self.link})") if self.store_page.xdg_open_url is not None: - embed.add_field( name="Open in app", value=f"[{self.store_page.xdg_open_url}]({self.store_page.xdg_open_url})" ) diff --git a/didier/data/embeds/logging_embed.py b/didier/data/embeds/logging_embed.py index 40556f2..3a803a1 100644 --- a/didier/data/embeds/logging_embed.py +++ b/didier/data/embeds/logging_embed.py @@ -11,7 +11,7 @@ __all__ = ["create_logging_embed"] def create_logging_embed(level: int, message: str) -> discord.Embed: """Create an embed to send to the logging channel""" colours = { - logging.DEBUG: discord.Colour.light_gray, + logging.DEBUG: discord.Colour.light_grey(), logging.ERROR: discord.Colour.red(), logging.INFO: discord.Colour.blue(), logging.WARNING: discord.Colour.yellow(), diff --git a/didier/data/embeds/urban_dictionary.py b/didier/data/embeds/urban_dictionary.py index 6dfeaac..ad3f90f 100644 --- a/didier/data/embeds/urban_dictionary.py +++ b/didier/data/embeds/urban_dictionary.py @@ -2,7 +2,7 @@ from datetime import datetime import discord from overrides import overrides -from pydantic import validator +from pydantic import field_validator from didier.data.embeds.base import EmbedPydantic from didier.utils.discord import colours @@ -39,8 +39,8 @@ class Definition(EmbedPydantic): total_votes = self.thumbs_up + self.thumbs_down return round(100 * self.thumbs_up / total_votes, 2) - @validator("definition", "example") - def modify_long_text(cls, field): + @field_validator("definition", "example") + def modify_long_text(cls, field: str): """Remove brackets from fields & cut them off if they are too long""" field = field.replace("[", "").replace("]", "") return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH) diff --git a/didier/data/rss_feeds/free_games.py b/didier/data/rss_feeds/free_games.py index fcc02c9..24a1ecf 100644 --- a/didier/data/rss_feeds/free_games.py +++ b/didier/data/rss_feeds/free_games.py @@ -32,7 +32,7 @@ async def fetch_free_games(http_session: ClientSession, database_session: AsyncS if SEPARATOR not in entry["title"]: continue - game = FreeGameEmbed.parse_obj(entry) + game = FreeGameEmbed.model_validate(entry) games.append(game) game_ids.append(game.dc_identifier) diff --git a/didier/data/scrapers/google.py b/didier/data/scrapers/google.py index 389e9ae..9c10716 100644 --- a/didier/data/scrapers/google.py +++ b/didier/data/scrapers/google.py @@ -72,12 +72,12 @@ def get_search_results(bs: BeautifulSoup) -> list[str]: return list(dict.fromkeys(results)) -async def google_search(http_client: ClientSession, query: str): +async def google_search(http_session: ClientSession, query: str): """Get the first 10 Google search results""" query = urlencode({"q": query}) # Request 20 results in case of duplicates, bad matches, ... - async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: + async with http_session.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: # Something went wrong if response.status != http.HTTPStatus.OK: return SearchData(query, response.status) diff --git a/didier/didier.py b/didier/didier.py index cf9ed1d..33e6e4b 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -17,7 +17,7 @@ from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.logging_embed import create_logging_embed from didier.data.embeds.schedules import Schedule, parse_schedule -from didier.exceptions import HTTPException, NoMatch +from didier.exceptions import GetNoneException, HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix from didier.utils.discord.snipe import should_snipe from didier.utils.easter_eggs import detect_easter_egg @@ -33,7 +33,7 @@ class Didier(commands.Bot): """DIDIER <3""" database_caches: CacheManager - error_channel: discord.abc.Messageable + error_channel: Optional[discord.abc.Messageable] = None initial_extensions: tuple[str, ...] = () http_session: ClientSession schedules: dict[settings.ScheduleType, Schedule] = {} @@ -56,12 +56,17 @@ class Didier(commands.Bot): command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status ) - self.tree.on_error = self.on_app_command_error + # I'm not creating a custom tree, this is the way to do it + self.tree.on_error = self.on_app_command_error # type: ignore[method-assign] @cached_property def main_guild(self) -> discord.Guild: """Obtain a reference to the main guild""" - return self.get_guild(settings.DISCORD_MAIN_GUILD) + guild = self.get_guild(settings.DISCORD_MAIN_GUILD) + if guild is None: + raise GetNoneException("Main guild could not be found in the bot's cache") + + return guild @property def postgres_session(self) -> AsyncSession: @@ -93,12 +98,6 @@ class Didier(commands.Bot): await self._load_initial_extensions() await self._load_directory_extensions("didier/cogs") - # Configure channel to send errors to - if settings.ERRORS_CHANNEL is not None: - self.error_channel = self.get_channel(settings.ERRORS_CHANNEL) - else: - self.error_channel = self.get_user(self.owner_id) - def _create_ignored_directories(self): """Create directories that store ignored data""" ignored = ["files/schedules"] @@ -152,18 +151,27 @@ class Didier(commands.Bot): original message instead """ if ctx.message.reference is not None: - return await self.resolve_message(ctx.message.reference) + return await self.resolve_message(ctx.message.reference) or ctx.message return ctx.message - async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: + async def resolve_message(self, reference: discord.MessageReference) -> Optional[discord.Message]: """Fetch a message from a reference""" # Message is in the cache, return it if reference.cached_message is not None: return reference.cached_message + if reference.message_id is None: + return None + # For older messages: fetch them from the API channel = self.get_channel(reference.channel_id) + if channel is None or isinstance( + channel, + (discord.CategoryChannel, discord.ForumChannel, discord.abc.PrivateChannel), + ): # Logically this can't happen, but we have to please Mypy + return None + return await channel.fetch_message(reference.message_id) async def confirm_message(self, message: discord.Message): @@ -184,7 +192,7 @@ class Didier(commands.Bot): } methods.get(level, logger.error)(message) - if log_to_discord: + if log_to_discord and self.error_channel is not None: embed = create_logging_embed(level, message) await self.error_channel.send(embed=embed) @@ -253,10 +261,9 @@ class Didier(commands.Bot): await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True) - if settings.ERRORS_CHANNEL is not None: + if self.error_channel is not None: embed = create_error_embed(await commands.Context.from_interaction(interaction), exception) - channel = self.get_channel(settings.ERRORS_CHANNEL) - await channel.send(embed=embed) + await self.error_channel.send(embed=embed) async def on_command_completion(self, ctx: commands.Context): """Event triggered when a message command completes successfully""" @@ -281,7 +288,7 @@ class Didier(commands.Bot): # Hybrid command errors are wrapped in an additional error, so wrap it back out if isinstance(exception, commands.HybridCommandError): - exception = exception.original + exception = exception.original # type: ignore[assignment] # Ignore exceptions that aren't important if isinstance( @@ -332,10 +339,9 @@ class Didier(commands.Bot): # Print everything that we care about to the logs/stderr await super().on_command_error(ctx, exception) - if settings.ERRORS_CHANNEL is not None: + if self.error_channel is not None: embed = create_error_embed(ctx, exception) - channel = self.get_channel(settings.ERRORS_CHANNEL) - await channel.send(embed=embed) + await self.error_channel.send(embed=embed) async def on_message(self, message: discord.Message, /) -> None: """Event triggered when a message is sent""" @@ -344,7 +350,7 @@ class Didier(commands.Bot): return # Boos react to people that say Dider - if "dider" in message.content.lower() and message.author.id != self.user.id: + if "dider" in message.content.lower() and self.user is not None and message.author.id != self.user.id: await message.add_reaction(settings.DISCORD_BOOS_REACT) # Potential custom command @@ -374,7 +380,7 @@ class Didier(commands.Bot): # If the edited message is currently present in the snipe cache, # don't update the , but instead change the - existing = self.sniped.get(before.channel.id, None) + existing = self.sniped.get(before.channel.id) if existing is not None and existing[0].id == before.id: before = existing[0] @@ -389,10 +395,9 @@ class Didier(commands.Bot): async def on_task_error(self, exception: Exception): """Event triggered when a task raises an exception""" - if settings.ERRORS_CHANNEL is not None: + if self.error_channel: embed = create_error_embed(None, exception) - channel = self.get_channel(settings.ERRORS_CHANNEL) - await channel.send(embed=embed) + await self.error_channel.send(embed=embed) async def on_thread_create(self, thread: discord.Thread): """Event triggered when a new thread is created""" diff --git a/didier/exceptions/__init__.py b/didier/exceptions/__init__.py index 1335dd4..fa5ad13 100644 --- a/didier/exceptions/__init__.py +++ b/didier/exceptions/__init__.py @@ -1,6 +1,14 @@ +from .get_none_exception import GetNoneException from .http_exception import HTTPException from .missing_env import MissingEnvironmentVariable from .no_match import NoMatch, expect from .not_in_main_guild_exception import NotInMainGuildException -__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect", "NotInMainGuildException"] +__all__ = [ + "GetNoneException", + "HTTPException", + "MissingEnvironmentVariable", + "NoMatch", + "expect", + "NotInMainGuildException", +] diff --git a/didier/exceptions/get_none_exception.py b/didier/exceptions/get_none_exception.py new file mode 100644 index 0000000..cbd2f77 --- /dev/null +++ b/didier/exceptions/get_none_exception.py @@ -0,0 +1,5 @@ +__all__ = ["GetNoneException"] + + +class GetNoneException(RuntimeError): + """Exception raised when a Bot.get()-method returned None""" diff --git a/didier/exceptions/not_in_main_guild_exception.py b/didier/exceptions/not_in_main_guild_exception.py index 5572c44..5279686 100644 --- a/didier/exceptions/not_in_main_guild_exception.py +++ b/didier/exceptions/not_in_main_guild_exception.py @@ -12,6 +12,6 @@ class NotInMainGuildException(ValueError): def __init__(self, user: Union[discord.User, discord.Member]): super().__init__( - f"User {user.display_name} (id {user.id}) " - f"is not a member of the configured main guild (id {settings.DISCORD_MAIN_GUILD})." + f"User {user.display_name} (id `{user.id}`) " + f"is not a member of the configured main guild (id `{settings.DISCORD_MAIN_GUILD}`)." ) diff --git a/didier/utils/discord/channels.py b/didier/utils/discord/channels.py new file mode 100644 index 0000000..26739f8 --- /dev/null +++ b/didier/utils/discord/channels.py @@ -0,0 +1,5 @@ +import discord + +__all__ = ["NON_MESSAGEABLE_CHANNEL_TYPES"] + +NON_MESSAGEABLE_CHANNEL_TYPES = (discord.ForumChannel, discord.CategoryChannel, discord.abc.PrivateChannel) diff --git a/didier/utils/discord/prefix.py b/didier/utils/discord/prefix.py index f3fa7c4..694a4b6 100644 --- a/didier/utils/discord/prefix.py +++ b/didier/utils/discord/prefix.py @@ -15,11 +15,14 @@ def match_prefix(client: commands.Bot, message: Message) -> Optional[str]: This is done dynamically through regexes to allow case-insensitivity and variable amounts of whitespace among other things. """ - mention = f"<@!?{client.user.id}>" + mention = f"<@!?{client.user.id}>" if client.user else None regex = r"^({})\s*" # Check which prefix was used for prefix in [*constants.PREFIXES, mention]: + if prefix is None: + continue + match = re.match(regex.format(prefix), message.content, flags=re.I) if match is not None: diff --git a/didier/views/modals/bookmarks.py b/didier/views/modals/bookmarks.py index f77b608..acd4c8f 100644 --- a/didier/views/modals/bookmarks.py +++ b/didier/views/modals/bookmarks.py @@ -25,24 +25,24 @@ class CreateBookmark(discord.ui.Modal, title="Create Bookmark"): @overrides async def on_submit(self, interaction: discord.Interaction): + await interaction.response.defer(ephemeral=True) + label = self.name.value.strip() try: async with self.client.postgres_session as session: bm = await create_bookmark(session, interaction.user.id, label, self.jump_url) - return await interaction.response.send_message( - f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True + return await interaction.followup.send( + f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`)." ) except DuplicateInsertException: # Label is already in use - return await interaction.response.send_message( - f"You already have a bookmark named `{label}`.", ephemeral=True - ) + return await interaction.followup.send(f"You already have a bookmark named `{label}`.") except ForbiddenNameException: # Label isn't allowed - return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True) + return await interaction.followup.send(f"Bookmarks cannot be named `{label}`.") @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.response.send_message("Something went wrong.", ephemeral=True) + await interaction.followup.send("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/views/modals/dad_jokes.py b/didier/views/modals/dad_jokes.py index c3b2f67..5ebfab7 100644 --- a/didier/views/modals/dad_jokes.py +++ b/didier/views/modals/dad_jokes.py @@ -26,12 +26,14 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"): @overrides async def on_submit(self, interaction: discord.Interaction): + await interaction.response.defer(ephemeral=True) + async with self.client.postgres_session as session: joke = await add_dad_joke(session, str(self.joke.value)) - await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) + await interaction.followup.send(f"Successfully added joke #{joke.dad_joke_id}") @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.response.send_message("Something went wrong.", ephemeral=True) + await interaction.followup.send("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/views/modals/events.py b/didier/views/modals/events.py index e7b92b4..71acea6 100644 --- a/didier/views/modals/events.py +++ b/didier/views/modals/events.py @@ -10,6 +10,8 @@ from didier import Didier __all__ = ["AddEvent"] +from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES + class AddEvent(discord.ui.Modal, title="Add Event"): """Modal to add a new event""" @@ -33,15 +35,20 @@ class AddEvent(discord.ui.Modal, title="Add Event"): @overrides async def on_submit(self, interaction: discord.Interaction) -> None: + await interaction.response.defer(ephemeral=True) + try: parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) except ParserError: - return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True) + return await interaction.followup.send("Unable to parse date argument.") - if self.client.get_channel(int(self.channel.value)) is None: - return await interaction.response.send_message( - f"Unable to find channel `{self.channel.value}`", ephemeral=True - ) + channel = self.client.get_channel(int(self.channel.value)) + + if channel is None: + return await interaction.followup.send(f"Unable to find channel with id `{self.channel.value}`") + + if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): + return await interaction.followup.send(f"Channel with id `{self.channel.value}` is not messageable.") async with self.client.postgres_session as session: event = await add_event( @@ -52,10 +59,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"): channel_id=int(self.channel.value), ) - await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) + await interaction.followup.send(f"Successfully added event `{event.event_id}`.") self.client.dispatch("event_create", event) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.response.send_message("Something went wrong.", ephemeral=True) + await interaction.followup.send("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/main.py b/main.py index dd9b030..a1e2655 100644 --- a/main.py +++ b/main.py @@ -36,6 +36,7 @@ def setup_logging(): # Configure discord handler discord_log = logging.getLogger("discord") + discord_handler: logging.StreamHandler # Make dev print to stderr instead, so you don't have to watch the file if settings.SANDBOX: diff --git a/pyproject.toml b/pyproject.toml index acd06c2..d28abe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ omit = [ profile = "black" [tool.mypy] +check_untyped_defs = true files = [ "database/**/*.py", "didier/**/*.py", @@ -35,7 +36,6 @@ files = [ ] plugins = [ "pydantic.mypy", - "sqlalchemy.ext.mypy.plugin" ] [[tool.mypy.overrides]] module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"] diff --git a/requirements-dev.txt b/requirements-dev.txt index a9f7109..8d2d1b3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,22 +1,21 @@ -black==22.3.0 -coverage[toml]==6.4.1 -freezegun==1.2.1 +black==23.3.0 +coverage[toml]==7.2.7 +freezegun==1.2.2 isort==5.12.0 -mypy==0.961 -pre-commit==2.20.0 -pytest==7.1.2 -pytest-asyncio==0.18.3 -pytest-env==0.6.2 -sqlalchemy2-stubs==0.0.2a23 -types-beautifulsoup4==4.11.3 -types-python-dateutil==2.8.19 +mypy==1.4.1 +pre-commit==3.3.3 +pytest==7.4.0 +pytest-asyncio==0.21.0 +pytest-env==0.8.2 +types-beautifulsoup4==4.12.0.5 +types-python-dateutil==2.8.19.13 # Flake8 + plugins -flake8==4.0.1 -flake8-bandit==3.0.0 -flake8-bugbear==22.7.1 -flake8-docstrings==1.6.0 -flake8-dunder-all==0.2.1 -flake8-eradicate==1.2.1 -flake8-isort==4.1.1 -flake8-simplify==0.19.2 +flake8==6.0.0 +flake8-bandit==4.1.1 +flake8-bugbear==23.6.5 +flake8-docstrings==1.7.0 +flake8-dunder-all==0.3.0 +flake8-eradicate==1.5.0 +flake8-isort==6.0.0 +flake8-simplify==0.20.0 diff --git a/requirements.txt b/requirements.txt index a29f1cc..4b0ffa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ -aiohttp==3.8.1 -alembic==1.8.0 -asyncpg==0.25.0 -beautifulsoup4==4.11.1 -discord.py==2.0.1 +aiohttp==3.8.4 +alembic==1.11.1 +asyncpg==0.28.0 +beautifulsoup4==4.12.2 +discord.py==2.3.1 environs==9.5.0 feedparser==6.0.10 ics==0.7.2 -markdownify==0.11.2 -overrides==6.1.0 -pydantic==1.9.1 +markdownify==0.11.6 +overrides==7.3.1 +pydantic==2.0.2 python-dateutil==2.8.2 -sqlalchemy[asyncio]==1.4.37 +sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18 diff --git a/settings.py b/settings.py index 32bd5e0..a862fde 100644 --- a/settings.py +++ b/settings.py @@ -111,7 +111,7 @@ class ScheduleInfo: role_id: Optional[int] schedule_url: Optional[str] - name: Optional[str] = None + name: ScheduleType SCHEDULE_DATA = [