diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index acdc6a2..ba38355 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -3,12 +3,12 @@ default_language_version: repos: - repo: https://github.com/ambv/black - rev: 23.3.0 + rev: 22.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 + rev: v4.3.0 hooks: - id: check-json - id: end-of-file-fixer @@ -21,7 +21,7 @@ repos: - id: isort - repo: https://github.com/PyCQA/autoflake - rev: v2.2.0 + rev: v1.4 hooks: - id: autoflake name: autoflake (python) @@ -31,7 +31,7 @@ repos: - "--ignore-init-module-imports" - repo: https://github.com/PyCQA/flake8 - rev: 6.0.0 + rev: 4.0.1 hooks: - id: flake8 exclude: ^(alembic|.github) diff --git a/alembic/versions/09128b6e34dd_migrate_to_2_x.py b/alembic/versions/09128b6e34dd_migrate_to_2_x.py deleted file mode 100644 index ef4ee2c..0000000 --- a/alembic/versions/09128b6e34dd_migrate_to_2_x.py +++ /dev/null @@ -1,94 +0,0 @@ -"""Migrate to 2.x - -Revision ID: 09128b6e34dd -Revises: 1e3e7f4192c4 -Create Date: 2023-07-07 16:23:15.990231 - -""" -import sqlalchemy as sa - -from alembic import op - -# revision identifiers, used by Alembic. -revision = "09128b6e34dd" -down_revision = "1e3e7f4192c4" -branch_labels = None -depends_on = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("bank", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("birthdays", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("bookmarks", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("command_stats", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op: - batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=False) - - with op.batch_alter_table("deadlines", schema=None) as batch_op: - batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False) - - with op.batch_alter_table("github_links", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("nightly_data", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("reminders", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=False) - - with op.batch_alter_table("ufora_announcements", schema=None) as batch_op: - batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False) - batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=False) - - with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op: - batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=False) - - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - with op.batch_alter_table("ufora_course_aliases", schema=None) as batch_op: - batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True) - - with op.batch_alter_table("ufora_announcements", schema=None) as batch_op: - batch_op.alter_column("publication_date", existing_type=sa.DATE(), nullable=True) - batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True) - - with op.batch_alter_table("reminders", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - with op.batch_alter_table("nightly_data", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - with op.batch_alter_table("github_links", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - with op.batch_alter_table("deadlines", schema=None) as batch_op: - batch_op.alter_column("course_id", existing_type=sa.INTEGER(), nullable=True) - - with op.batch_alter_table("custom_command_aliases", schema=None) as batch_op: - batch_op.alter_column("command_id", existing_type=sa.INTEGER(), nullable=True) - - with op.batch_alter_table("command_stats", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - with op.batch_alter_table("bookmarks", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - with op.batch_alter_table("birthdays", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - with op.batch_alter_table("bank", schema=None) as batch_op: - batch_op.alter_column("user_id", existing_type=sa.BIGINT(), nullable=True) - - # ### end Alembic commands ### diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index 925e645..d696e50 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[ if query is not None: statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%")) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]: diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index eac9c7e..efbc689 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo async def get_all_commands(session: AsyncSession) -> list[CustomCommand]: """Get a list of all commands""" statement = select(CustomCommand) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index a539518..78d623f 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -38,4 +38,4 @@ async def get_deadlines( statement = statement.where(Deadline.course_id == course.course_id) statement = statement.options(selectinload(Deadline.course)) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() diff --git a/database/crud/easter_eggs.py b/database/crud/easter_eggs.py index f4a55ed..d4c25d9 100644 --- a/database/crud/easter_eggs.py +++ b/database/crud/easter_eggs.py @@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"] async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: """Return a list of all easter eggs""" statement = select(EasterEgg) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() diff --git a/database/crud/events.py b/database/crud/events.py index e8dfac5..19887e6 100644 --- a/database/crud/events.py +++ b/database/crud/events.py @@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: """Get a list of all upcoming events""" statement = select(Event).where(Event.timestamp > now) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: diff --git a/database/crud/free_games.py b/database/crud/free_games.py index e0dab5b..b2d835d 100644 --- a/database/crud/free_games.py +++ b/database/crud/free_games.py @@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]): async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]: """Filter a list of game IDs down to the ones that aren't in the database yet""" statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids)) - matches: list[int] = list((await session.execute(statement)).scalars().all()) + matches: list[int] = (await session.execute(statement)).scalars().all() return list(set(game_ids).difference(matches)) diff --git a/database/crud/github.py b/database/crud/github.py index a352bae..0d32377 100644 --- a/database/crud/github.py +++ b/database/crud/github.py @@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]: """Get a user's GitHub links""" statement = select(GitHubLink).where(GitHubLink.user_id == user_id) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() diff --git a/database/crud/links.py b/database/crud/links.py index 20bcab3..495e0f3 100644 --- a/database/crud/links.py +++ b/database/crud/links.py @@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] async def get_all_links(session: AsyncSession) -> list[Link]: """Get a list of all links""" statement = select(Link) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def add_link(session: AsyncSession, name: str, url: str) -> Link: diff --git a/database/crud/memes.py b/database/crud/memes.py index b1ed1e0..ab288aa 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: """Get a list of all memes""" statement = select(MemeTemplate) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: diff --git a/database/crud/reminders.py b/database/crud/reminders.py index 78350e6..007a779 100644 --- a/database/crud/reminders.py +++ b/database/crud/reminders.py @@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"] async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]: """Get a list of all Reminders for a given category""" statement = select(Reminder).where(Reminder.category == category) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool: diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 06c2b58..688bcc7 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]: """Get all courses where announcements are enabled""" statement = select(UforaCourse).where(UforaCourse.log_announcements) - return list((await session.execute(statement)).scalars().all()) + return (await session.execute(statement)).scalars().all() async def create_new_announcement( diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index d4cf728..5374c07 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor # Search case-insensitively query = query.lower() - course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) - course_result = (await session.execute(course_statement)).scalars().first() - if course_result: - return course_result + statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) + result = (await session.execute(statement)).scalars().first() + if result: + return result - alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) - alias_result = (await session.execute(alias_statement)).scalars().first() - return alias_result.course if alias_result else None + statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) + result = (await session.execute(statement)).scalars().first() + return result.course if result else None diff --git a/database/engine.py b/database/engine.py index ec81bfb..23e5b89 100644 --- a/database/engine.py +++ b/database/engine.py @@ -1,7 +1,8 @@ from urllib.parse import quote_plus from sqlalchemy.engine import URL -from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine +from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine +from sqlalchemy.orm import sessionmaker import settings @@ -21,4 +22,6 @@ postgres_engine = create_async_engine( future=True, ) -DBSession = async_sessionmaker(autocommit=False, autoflush=False, bind=postgres_engine, expire_on_commit=False) +DBSession = sessionmaker( + autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False +) diff --git a/database/schemas.py b/database/schemas.py index b92b2a2..34efcbc 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -1,14 +1,27 @@ from __future__ import annotations from datetime import date, datetime -from typing import List, Optional +from typing import Optional -from sqlalchemy import BigInteger, ForeignKey, UniqueConstraint -from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship -from sqlalchemy.types import DateTime +from sqlalchemy import ( + BigInteger, + Boolean, + Column, + Date, + DateTime, + Enum, + ForeignKey, + Integer, + Text, + UniqueConstraint, +) +from sqlalchemy.orm import declarative_base, relationship from database import enums +Base = declarative_base() + + __all__ = [ "Base", "Bank", @@ -35,34 +48,27 @@ __all__ = [ ] -class Base(DeclarativeBase): - """Required base class for all tables""" - - # Make all DateTimes timezone-aware - type_annotation_map = {datetime: DateTime(timezone=True)} - - class Bank(Base): """A user's currency information""" __tablename__ = "bank" - bank_id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + bank_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - dinks: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False) - invested: Mapped[int] = mapped_column(BigInteger, server_default="0", nullable=False) + dinks: int = Column(BigInteger, server_default="0", nullable=False) + invested: int = Column(BigInteger, server_default="0", nullable=False) # Interest rate - interest_level: Mapped[int] = mapped_column(server_default="1", nullable=False) + interest_level: int = Column(Integer, server_default="1", nullable=False) # Maximum amount that can be stored in the bank - capacity_level: Mapped[int] = mapped_column(server_default="1", nullable=False) + capacity_level: int = Column(Integer, server_default="1", nullable=False) # Maximum amount that can be robbed - rob_level: Mapped[int] = mapped_column(server_default="1", nullable=False) + rob_level: int = Column(Integer, server_default="1", nullable=False) - user: Mapped[User] = relationship(uselist=False, back_populates="bank", lazy="selectin") + user: User = relationship("User", uselist=False, back_populates="bank", lazy="selectin") class Birthday(Base): @@ -70,11 +76,11 @@ class Birthday(Base): __tablename__ = "birthdays" - birthday_id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - birthday: Mapped[date] = mapped_column(nullable=False) + birthday_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + birthday: date = Column(Date, nullable=False) - user: Mapped[User] = relationship(uselist=False, back_populates="birthday", lazy="selectin") + user: User = relationship("User", uselist=False, back_populates="birthday", lazy="selectin") class Bookmark(Base): @@ -83,26 +89,26 @@ class Bookmark(Base): __tablename__ = "bookmarks" __table_args__ = (UniqueConstraint("user_id", "label"),) - bookmark_id: Mapped[int] = mapped_column(primary_key=True) - label: Mapped[str] = mapped_column(nullable=False) - jump_url: Mapped[str] = mapped_column(nullable=False) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + bookmark_id: int = Column(Integer, primary_key=True) + label: str = Column(Text, nullable=False) + jump_url: str = Column(Text, nullable=False) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - user: Mapped[User] = relationship(back_populates="bookmarks", uselist=False, lazy="selectin") + user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin") class CommandStats(Base): """Metrics on how often commands are used""" __tablename__ = "command_stats" - command_stats_id: Mapped[int] = mapped_column(primary_key=True) - command: Mapped[str] = mapped_column(nullable=False) - timestamp: Mapped[datetime] = mapped_column(nullable=False) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - slash: Mapped[bool] = mapped_column(nullable=False) - context_menu: Mapped[bool] = mapped_column(nullable=False) + command_stats_id: int = Column(Integer, primary_key=True) + command: str = Column(Text, nullable=False) + timestamp: datetime = Column(DateTime(timezone=True), nullable=False) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + slash: bool = Column(Boolean, nullable=False) + context_menu: bool = Column(Boolean, nullable=False) - user: Mapped[User] = relationship(back_populates="command_stats", uselist=False, lazy="selectin") + user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin") class CustomCommand(Base): @@ -110,13 +116,13 @@ class CustomCommand(Base): __tablename__ = "custom_commands" - command_id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(nullable=False, unique=True) - indexed_name: Mapped[str] = mapped_column(nullable=False, index=True) - response: Mapped[str] = mapped_column(nullable=False) + command_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + indexed_name: str = Column(Text, nullable=False, index=True) + response: str = Column(Text, nullable=False) - aliases: Mapped[List[CustomCommandAlias]] = relationship( - back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" + aliases: list[CustomCommandAlias] = relationship( + "CustomCommandAlias", back_populates="command", uselist=True, cascade="all, delete-orphan", lazy="selectin" ) @@ -125,12 +131,12 @@ class CustomCommandAlias(Base): __tablename__ = "custom_command_aliases" - alias_id: Mapped[int] = mapped_column(primary_key=True) - alias: Mapped[str] = mapped_column(nullable=False, unique=True) - indexed_alias: Mapped[str] = mapped_column(nullable=False, index=True) - command_id: Mapped[int] = mapped_column(ForeignKey("custom_commands.command_id")) + alias_id: int = Column(Integer, primary_key=True) + alias: str = Column(Text, nullable=False, unique=True) + indexed_alias: str = Column(Text, nullable=False, index=True) + command_id: int = Column(Integer, ForeignKey("custom_commands.command_id")) - command: Mapped[CustomCommand] = relationship(back_populates="aliases", uselist=False, lazy="selectin") + command: CustomCommand = relationship("CustomCommand", back_populates="aliases", uselist=False, lazy="selectin") class DadJoke(Base): @@ -138,8 +144,8 @@ class DadJoke(Base): __tablename__ = "dad_jokes" - dad_joke_id: Mapped[int] = mapped_column(primary_key=True) - joke: Mapped[str] = mapped_column(nullable=False) + dad_joke_id: int = Column(Integer, primary_key=True) + joke: str = Column(Text, nullable=False) class Deadline(Base): @@ -147,12 +153,12 @@ class Deadline(Base): __tablename__ = "deadlines" - deadline_id: Mapped[int] = mapped_column(primary_key=True) - course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id")) - name: Mapped[str] = mapped_column(nullable=False) - deadline: Mapped[datetime] = mapped_column(nullable=False) + deadline_id: int = Column(Integer, primary_key=True) + course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) + name: str = Column(Text, nullable=False) + deadline: datetime = Column(DateTime(timezone=True), nullable=False) - course: Mapped[UforaCourse] = relationship(back_populates="deadlines", uselist=False, lazy="selectin") + course: UforaCourse = relationship("UforaCourse", back_populates="deadlines", uselist=False, lazy="selectin") class EasterEgg(Base): @@ -160,11 +166,11 @@ class EasterEgg(Base): __tablename__ = "easter_eggs" - easter_egg_id: Mapped[int] = mapped_column(primary_key=True) - match: Mapped[str] = mapped_column(nullable=False) - response: Mapped[str] = mapped_column(nullable=False) - exact: Mapped[bool] = mapped_column(nullable=False, server_default="1") - startswith: Mapped[bool] = mapped_column(nullable=False, server_default="1") + easter_egg_id: int = Column(Integer, primary_key=True) + match: str = Column(Text, nullable=False) + response: str = Column(Text, nullable=False) + exact: bool = Column(Boolean, nullable=False, server_default="1") + startswith: bool = Column(Boolean, nullable=False, server_default="1") class Event(Base): @@ -172,11 +178,11 @@ class Event(Base): __tablename__ = "events" - event_id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(nullable=False) - description: Mapped[Optional[str]] = mapped_column(nullable=True) - notification_channel: Mapped[int] = mapped_column(BigInteger, nullable=False) - timestamp: Mapped[datetime] = mapped_column(nullable=False) + event_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False) + description: Optional[str] = Column(Text, nullable=True) + notification_channel: int = Column(BigInteger, nullable=False) + timestamp: datetime = Column(DateTime(timezone=True), nullable=False) class FreeGame(Base): @@ -184,7 +190,7 @@ class FreeGame(Base): __tablename__ = "free_games" - free_game_id: Mapped[int] = mapped_column(primary_key=True) + free_game_id: int = Column(Integer, primary_key=True) class GitHubLink(Base): @@ -192,11 +198,11 @@ class GitHubLink(Base): __tablename__ = "github_links" - github_link_id: Mapped[int] = mapped_column(primary_key=True) - url: Mapped[str] = mapped_column(nullable=False, unique=True) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) + github_link_id: int = Column(Integer, primary_key=True) + url: str = Column(Text, nullable=False, unique=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) - user: Mapped[User] = relationship(back_populates="github_links", uselist=False, lazy="selectin") + user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin") class Link(Base): @@ -204,9 +210,9 @@ class Link(Base): __tablename__ = "links" - link_id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(nullable=False, unique=True) - url: Mapped[str] = mapped_column(nullable=False) + link_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + url: str = Column(Text, nullable=False) class MemeTemplate(Base): @@ -214,10 +220,10 @@ class MemeTemplate(Base): __tablename__ = "meme" - meme_id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(nullable=False, unique=True) - template_id: Mapped[int] = mapped_column(nullable=False, unique=True) - field_count: Mapped[int] = mapped_column(nullable=False) + meme_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + template_id: int = Column(Integer, nullable=False, unique=True) + field_count: int = Column(Integer, nullable=False) class NightlyData(Base): @@ -225,12 +231,12 @@ class NightlyData(Base): __tablename__ = "nightly_data" - nightly_id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - last_nightly: Mapped[Optional[date]] = mapped_column(nullable=True) - count: Mapped[int] = mapped_column(server_default="0", nullable=False) + nightly_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + last_nightly: Optional[date] = Column(Date, nullable=True) + count: int = Column(Integer, server_default="0", nullable=False) - user: Mapped[User] = relationship(back_populates="nightly_data", uselist=False, lazy="selectin") + user: User = relationship("User", back_populates="nightly_data", uselist=False, lazy="selectin") class Reminder(Base): @@ -238,11 +244,11 @@ class Reminder(Base): __tablename__ = "reminders" - reminder_id: Mapped[int] = mapped_column(primary_key=True) - user_id: Mapped[int] = mapped_column(BigInteger, ForeignKey("users.user_id")) - category: Mapped[enums.ReminderCategory] = mapped_column(nullable=False) + reminder_id: int = Column(Integer, primary_key=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + category: enums.ReminderCategory = Column(Enum(enums.ReminderCategory), nullable=False) - user: Mapped[User] = relationship(back_populates="reminders", uselist=False, lazy="selectin") + user: User = relationship("User", back_populates="reminders", uselist=False, lazy="selectin") class Task(Base): @@ -250,9 +256,9 @@ class Task(Base): __tablename__ = "tasks" - task_id: Mapped[int] = mapped_column(primary_key=True) - task: Mapped[enums.TaskType] = mapped_column(nullable=False, unique=True) - previous_run: Mapped[datetime] = mapped_column(nullable=True) + task_id: int = Column(Integer, primary_key=True) + task: enums.TaskType = Column(Enum(enums.TaskType), nullable=False, unique=True) + previous_run: datetime = Column(DateTime(timezone=True), nullable=True) class UforaCourse(Base): @@ -260,25 +266,25 @@ class UforaCourse(Base): __tablename__ = "ufora_courses" - course_id: Mapped[int] = mapped_column(primary_key=True) - name: Mapped[str] = mapped_column(nullable=False, unique=True) - code: Mapped[str] = mapped_column(nullable=False, unique=True) - year: Mapped[int] = mapped_column(nullable=False) - compulsory: Mapped[bool] = mapped_column(server_default="1", nullable=False) - role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False) - overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False) + course_id: int = Column(Integer, primary_key=True) + name: str = Column(Text, nullable=False, unique=True) + code: str = Column(Text, nullable=False, unique=True) + year: int = Column(Integer, nullable=False) + compulsory: bool = Column(Boolean, server_default="1", nullable=False) + role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) + overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) # This is not the greatest fix, but there can only ever be two, so it will do the job - alternative_overarching_role_id: Mapped[Optional[int]] = mapped_column(BigInteger, nullable=True, unique=False) - log_announcements: Mapped[bool] = mapped_column(server_default="0", nullable=False) + alternative_overarching_role_id: Optional[int] = Column(BigInteger, nullable=True, unique=False) + log_announcements: bool = Column(Boolean, server_default="0", nullable=False) - announcements: Mapped[List[UforaAnnouncement]] = relationship( - back_populates="course", cascade="all, delete-orphan", lazy="selectin" + announcements: list[UforaAnnouncement] = relationship( + "UforaAnnouncement", back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) - aliases: Mapped[List[UforaCourseAlias]] = relationship( - back_populates="course", cascade="all, delete-orphan", lazy="selectin" + aliases: list[UforaCourseAlias] = relationship( + "UforaCourseAlias", back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) - deadlines: Mapped[List[Deadline]] = relationship( - back_populates="course", cascade="all, delete-orphan", lazy="selectin" + deadlines: list[Deadline] = relationship( + "Deadline", back_populates="course", cascade="all, delete-orphan", lazy="selectin" ) @@ -287,11 +293,11 @@ class UforaCourseAlias(Base): __tablename__ = "ufora_course_aliases" - alias_id: Mapped[int] = mapped_column(primary_key=True) - alias: Mapped[str] = mapped_column(nullable=False, unique=True) - course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id")) + alias_id: int = Column(Integer, primary_key=True) + alias: str = Column(Text, nullable=False, unique=True) + course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) - course: Mapped[UforaCourse] = relationship(back_populates="aliases", uselist=False, lazy="selectin") + course: UforaCourse = relationship("UforaCourse", back_populates="aliases", uselist=False, lazy="selectin") class UforaAnnouncement(Base): @@ -299,11 +305,11 @@ class UforaAnnouncement(Base): __tablename__ = "ufora_announcements" - announcement_id: Mapped[int] = mapped_column(primary_key=True) - course_id: Mapped[int] = mapped_column(ForeignKey("ufora_courses.course_id")) - publication_date: Mapped[date] = mapped_column() + announcement_id: int = Column(Integer, primary_key=True) + course_id: int = Column(Integer, ForeignKey("ufora_courses.course_id")) + publication_date: date = Column(Date) - course: Mapped[UforaCourse] = relationship(back_populates="announcements", uselist=False, lazy="selectin") + course: UforaCourse = relationship("UforaCourse", back_populates="announcements", uselist=False, lazy="selectin") class User(Base): @@ -311,26 +317,26 @@ class User(Base): __tablename__ = "users" - user_id: Mapped[int] = mapped_column(BigInteger, primary_key=True) + user_id: int = Column(BigInteger, primary_key=True) - bank: Mapped[Bank] = relationship( - back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + bank: Bank = relationship( + "Bank", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - birthday: Mapped[Optional[Birthday]] = relationship( - back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + birthday: Optional[Birthday] = relationship( + "Birthday", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - bookmarks: Mapped[List[Bookmark]] = relationship( - back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + bookmarks: list[Bookmark] = relationship( + "Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) - command_stats: Mapped[List[CommandStats]] = relationship( - back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + command_stats: list[CommandStats] = relationship( + "CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) - github_links: Mapped[List[GitHubLink]] = relationship( - back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + github_links: list[GitHubLink] = relationship( + "GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) - nightly_data: Mapped[NightlyData] = relationship( - back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" + nightly_data: NightlyData = relationship( + "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) - reminders: Mapped[List[Reminder]] = relationship( - back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + reminders: list[Reminder] = relationship( + "Reminder", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/database/utils/caches.py b/database/utils/caches.py index 2df9dac..248eb5f 100644 --- a/database/utils/caches.py +++ b/database/utils/caches.py @@ -69,7 +69,7 @@ class LinkCache(DatabaseCache): self.clear() all_links = await links.get_all_links(database_session) - self.data = list(map(lambda link: link.name, all_links)) + self.data = list(map(lambda l: l.name, all_links)) self.data.sort() self.data_transformed = list(map(str.lower, self.data)) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 4049654..709a461 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -25,7 +25,7 @@ class Currency(commands.Cog): super().__init__() self.client = client - @commands.command(name="award") # type: ignore[arg-type] + @commands.command(name="award") @commands.check(is_owner) async def award( self, @@ -49,9 +49,7 @@ class Currency(commands.Cog): bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(title=f"{ctx.author.display_name}'s Bank", colour=discord.Colour.blue()) - - if ctx.author.avatar is not None: - embed.set_thumbnail(url=ctx.author.avatar.url) + embed.set_thumbnail(url=ctx.author.avatar.url) embed.add_field(name="Interest level", value=bank.interest_level) embed.add_field(name="Capacity level", value=bank.capacity_level) @@ -59,9 +57,7 @@ class Currency(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @bank.group( # type: ignore[arg-type] - name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True - ) + @bank.group(name="upgrade", aliases=["u", "upgrades"], case_insensitive=True, invoke_without_command=True) async def bank_upgrades(self, ctx: commands.Context): """List the upgrades you can buy & their prices.""" async with self.client.postgres_session as session: @@ -81,7 +77,7 @@ class Currency(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @bank_upgrades.command(name="capacity", aliases=["c"]) # type: ignore[arg-type] + @bank_upgrades.command(name="capacity", aliases=["c"]) async def bank_upgrade_capacity(self, ctx: commands.Context): """Upgrade the capacity level of your bank.""" async with self.client.postgres_session as session: @@ -92,7 +88,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @bank_upgrades.command(name="interest", aliases=["i"]) # type: ignore[arg-type] + @bank_upgrades.command(name="interest", aliases=["i"]) async def bank_upgrade_interest(self, ctx: commands.Context): """Upgrade the interest level of your bank.""" async with self.client.postgres_session as session: @@ -103,7 +99,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @bank_upgrades.command(name="rob", aliases=["r"]) # type: ignore[arg-type] + @bank_upgrades.command(name="rob", aliases=["r"]) async def bank_upgrade_rob(self, ctx: commands.Context): """Upgrade the rob level of your bank.""" async with self.client.postgres_session as session: @@ -114,7 +110,7 @@ class Currency(commands.Cog): await ctx.reply("You don't have enough Didier Dinks to do this.", mention_author=False) await self.client.reject_message(ctx.message) - @commands.hybrid_command(name="dinks") # type: ignore[arg-type] + @commands.hybrid_command(name="dinks") async def dinks(self, ctx: commands.Context): """Check your Didier Dinks.""" async with self.client.postgres_session as session: @@ -122,7 +118,7 @@ class Currency(commands.Cog): plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** has **{bank.dinks}** {plural}.", mention_author=False) - @commands.command(name="invest", aliases=["deposit", "dep"]) # type: ignore[arg-type] + @commands.command(name="invest", aliases=["deposit", "dep"]) async def invest(self, ctx: commands.Context, amount: typing.Annotated[typing.Union[str, int], abbreviated_number]): """Invest `amount` Didier Dinks into your bank. @@ -148,7 +144,7 @@ class Currency(commands.Cog): f"**{ctx.author.display_name}** has invested **{invested}** {plural}.", mention_author=False ) - @commands.hybrid_command(name="nightly") # type: ignore[arg-type] + @commands.hybrid_command(name="nightly") async def nightly(self, ctx: commands.Context): """Claim nightly Didier Dinks.""" async with self.client.postgres_session as session: diff --git a/didier/cogs/debug_cog.py b/didier/cogs/debug_cog.py index a0e4747..2d03b9f 100644 --- a/didier/cogs/debug_cog.py +++ b/didier/cogs/debug_cog.py @@ -13,7 +13,7 @@ class DebugCog(commands.Cog): self.client = client @overrides - async def cog_check(self, ctx: commands.Context) -> bool: # type:ignore[override] + async def cog_check(self, ctx: commands.Context) -> bool: return await self.client.is_owner(ctx.author) @commands.command(aliases=["Dev"]) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index fdfa05d..4d9b423 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, cast +from typing import Optional import discord from discord import app_commands @@ -17,7 +17,6 @@ from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar, get_user_avatar -from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.constants import Limits from didier.utils.timer import Timer from didier.utils.types.datetime import localize, str_to_date, tz_aware_now @@ -61,19 +60,9 @@ class Discord(commands.Cog): event = await events.get_event_by_id(session, event_id) if event is None: - return await self.client.log_error(f"Unable to find event with id {event_id}.", log_to_discord=True) + return await self.client.log_error(f"Unable to find event with id {event_id}", log_to_discord=True) channel = self.client.get_channel(event.notification_channel) - if channel is None: - return await self.client.log_error( - f"Unable to fetch channel for event `#{event_id}` (id `{event.notification_channel}`)." - ) - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Channel for event `#{event_id}` (id `{event.notification_channel}`) is not messageable." - ) - human_readable_time = localize(event.timestamp).strftime("%A, %B %d %Y - %H:%M") embed = discord.Embed(title=event.name, colour=discord.Colour.blue()) @@ -92,7 +81,7 @@ class Discord(commands.Cog): self.client.loop.create_task(self.timer.update()) @commands.group(name="birthday", aliases=["bd", "birthdays"], case_insensitive=True, invoke_without_command=True) - async def birthday(self, ctx: commands.Context, user: Optional[discord.User] = None): + async def birthday(self, ctx: commands.Context, user: discord.User = None): """Command to check the birthday of `user`. Not passing an argument for `user` will show yours instead. @@ -109,10 +98,8 @@ class Discord(commands.Cog): day, month = leading("0", str(birthday.birthday.day)), leading("0", str(birthday.birthday.month)) return await ctx.reply(f"{name or 'Your'} birthday is set to **{day}/{month}**.", mention_author=False) - @birthday.command(name="set", aliases=["config"]) # type: ignore[arg-type] - async def birthday_set( - self, ctx: commands.Context, day: str, user: Optional[Union[discord.User, discord.Member]] = None - ): + @birthday.command(name="set", aliases=["config"]) + async def birthday_set(self, ctx: commands.Context, day: str, user: Optional[discord.User] = None): """Set your birthday to `day`. Parsing of the `day`-argument happens in the following order: `DD/MM/YYYY`, `DD/MM/YY`, `DD/MM`. @@ -126,9 +113,6 @@ class Discord(commands.Cog): if user is None: user = ctx.author - # Please Mypy - user = cast(Union[discord.User, discord.Member], user) - try: default_year = 2001 date = str_to_date(day, formats=["%d/%m/%Y", "%d/%m/%y", "%d/%m"]) @@ -157,7 +141,7 @@ class Discord(commands.Cog): """ # No label: shortcut to display bookmarks if label is None: - return await self.bookmark_search(ctx, query=None) # type: ignore[arg-type] + return await self.bookmark_search(ctx, query=None) async with self.client.postgres_session as session: result = expect( @@ -167,7 +151,7 @@ class Discord(commands.Cog): ) await ctx.reply(result.jump_url, mention_author=False) - @bookmark.command(name="create", aliases=["new"]) # type: ignore[arg-type] + @bookmark.command(name="create", aliases=["new"]) async def bookmark_create(self, ctx: commands.Context, label: str, message: Optional[discord.Message]): """Create a new bookmark for message `message` with label `label`. @@ -198,7 +182,7 @@ class Discord(commands.Cog): # Label isn't allowed return await ctx.reply(f"Bookmarks cannot be named `{label}`.", mention_author=False) - @bookmark.command(name="delete", aliases=["rm"]) # type: ignore[arg-type] + @bookmark.command(name="delete", aliases=["rm"]) async def bookmark_delete(self, ctx: commands.Context, bookmark_id: str): """Delete the bookmark with id `bookmark_id`. @@ -223,7 +207,7 @@ class Discord(commands.Cog): return await ctx.reply(f"Successfully deleted bookmark `#{bookmark_id_int}`.", mention_author=False) - @bookmark.command(name="search", aliases=["list", "ls"]) # type: ignore[arg-type] + @bookmark.command(name="search", aliases=["list", "ls"]) async def bookmark_search(self, ctx: commands.Context, *, query: Optional[str] = None): """Search through the list of bookmarks. @@ -252,7 +236,7 @@ class Discord(commands.Cog): modal = CreateBookmark(self.client, message.jump_url) await interaction.response.send_modal(modal) - @commands.hybrid_command(name="events") # type: ignore[arg-type] + @commands.hybrid_command(name="events") @app_commands.rename(event_id="id") @app_commands.describe(event_id="The id of the event to fetch. If not passed, all events are fetched instead.") async def events(self, ctx: commands.Context, event_id: Optional[int] = None): @@ -292,16 +276,16 @@ class Discord(commands.Cog): embed.add_field( name="Timer", value=discord.utils.format_dt(result_event.timestamp, style="R"), inline=True ) - - channel = self.client.get_channel(result_event.notification_channel) - if channel is not None and not isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - embed.add_field(name="Channel", value=channel.mention, inline=False) - + embed.add_field( + name="Channel", + value=self.client.get_channel(result_event.notification_channel).mention, + inline=False, + ) embed.description = result_event.description return await ctx.reply(embed=embed, mention_author=False) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) - async def github_group(self, ctx: commands.Context, user: Optional[Union[discord.User, discord.Member]] = None): + async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): """Show a user's GitHub links. If no user is provided, this shows your links instead. @@ -309,9 +293,6 @@ class Discord(commands.Cog): # Default to author user = user or ctx.author - # Please Mypy - user = cast(Union[discord.User, discord.Member], user) - embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") embed.set_author(name=user.display_name, icon_url=get_user_avatar(user)) @@ -343,7 +324,7 @@ class Discord(commands.Cog): return await ctx.reply(embed=embed, mention_author=False) - @github_group.command(name="add", aliases=["create", "insert"]) # type: ignore[arg-type] + @github_group.command(name="add", aliases=["create", "insert"]) async def github_add(self, ctx: commands.Context, link: str): """Add a new link into the database.""" # Remove wrapping which can be used to escape Discord embeds @@ -358,7 +339,7 @@ class Discord(commands.Cog): await self.client.confirm_message(ctx.message) return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False) - @github_group.command(name="delete", aliases=["del", "remove", "rm"]) # type: ignore[arg-type] + @github_group.command(name="delete", aliases=["del", "remove", "rm"]) async def github_delete(self, ctx: commands.Context, link_id: str): """Delete the link with it `link_id` from the database. @@ -430,7 +411,7 @@ class Discord(commands.Cog): await message.add_reaction("📌") return await interaction.response.send_message("📌", ephemeral=True) - @commands.hybrid_command(name="snipe") # type: ignore[arg-type] + @commands.hybrid_command(name="snipe") async def snipe(self, ctx: commands.Context): """Publicly shame people when they edit or delete one of their messages. @@ -439,7 +420,7 @@ class Discord(commands.Cog): if ctx.guild is None: return await ctx.reply("Snipe only works in servers.", mention_author=False, ephemeral=True) - sniped_data = self.client.sniped.get(ctx.channel.id) + sniped_data = self.client.sniped.get(ctx.channel.id, None) if sniped_data is None: return await ctx.reply( "There's no one to make fun of in this channel.", mention_author=False, ephemeral=True diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index 4ccfb2a..e824ab2 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -28,7 +28,7 @@ class Fun(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="clap") # type: ignore[arg-type] + @commands.hybrid_command(name="clap") async def clap(self, ctx: commands.Context, *, text: str): """Clap a message with emojis for extra dramatic effect""" chars = list(filter(lambda c: c in constants.EMOJI_MAP, text)) @@ -50,7 +50,10 @@ class Fun(commands.Cog): meme = await generate_meme(self.client.http_session, result, fields) return meme - @commands.hybrid_command(name="dadjoke", aliases=["dad", "dj"]) # type: ignore[arg-type] + @commands.hybrid_command( + name="dadjoke", + aliases=["dad", "dj"], + ) async def dad_joke(self, ctx: commands.Context): """Why does Yoda's code always crash? Because there is no try.""" async with self.client.postgres_session as session: @@ -80,13 +83,13 @@ class Fun(commands.Cog): return await self.memegen_ls_msg(ctx) if fields is None: - return await self.memegen_preview_msg(ctx, template) # type: ignore[arg-type] + return await self.memegen_preview_msg(ctx, template) async with ctx.typing(): meme = await self._do_generate_meme(template, shlex.split(fields)) return await ctx.reply(meme, mention_author=False) - @memegen_msg.command(name="list", aliases=["ls"]) # type: ignore[arg-type] + @memegen_msg.command(name="list", aliases=["ls"]) async def memegen_ls_msg(self, ctx: commands.Context): """Get a list of all available meme templates. @@ -97,14 +100,14 @@ class Fun(commands.Cog): await MemeSource(ctx, results).start() - @memegen_msg.command(name="preview", aliases=["p"]) # type: ignore[arg-type] + @memegen_msg.command(name="preview", aliases=["p"]) async def memegen_preview_msg(self, ctx: commands.Context, template: str): """Generate a preview for the meme template `template`, to see how the fields are structured.""" async with ctx.typing(): meme = await self._do_generate_meme(template, []) return await ctx.reply(meme, mention_author=False) - @memes_slash.command(name="generate") # type: ignore[arg-type] + @memes_slash.command(name="generate") async def memegen_slash(self, interaction: discord.Interaction, template: str): """Generate a meme.""" async with self.client.postgres_session as session: @@ -113,7 +116,7 @@ class Fun(commands.Cog): modal = GenerateMeme(self.client, result) await interaction.response.send_modal(modal) - @memes_slash.command(name="preview") # type: ignore[arg-type] + @memes_slash.command(name="preview") @app_commands.describe(template="The meme template to use in the preview.") async def memegen_preview_slash(self, interaction: discord.Interaction, template: str): """Generate a preview for a meme, to see how the fields are structured.""" @@ -131,7 +134,7 @@ class Fun(commands.Cog): """Autocompletion for the 'template'-parameter""" return self.client.database_caches.memes.get_autocomplete_suggestions(current) - @app_commands.command() # type: ignore[arg-type] + @app_commands.command() @app_commands.describe(message="The text to convert.") async def mock(self, interaction: discord.Interaction, message: str): """Mock a message. @@ -155,7 +158,7 @@ class Fun(commands.Cog): return await interaction.followup.send(mock(message)) - @commands.hybrid_command(name="xkcd") # type: ignore[arg-type] + @commands.hybrid_command(name="xkcd") @app_commands.rename(comic_id="id") async def xkcd(self, ctx: commands.Context, comic_id: Optional[int] = None): """Fetch comic `#id` from xkcd. diff --git a/didier/cogs/help.py b/didier/cogs/help.py index 57cad6d..459f802 100644 --- a/didier/cogs/help.py +++ b/didier/cogs/help.py @@ -159,9 +159,6 @@ class CustomHelpCommand(commands.MinimalHelpCommand): Code in codeblocks is ignored, as it is used to create examples. """ description = command.help - if description is None: - return "" - codeblocks = re_find_all(r"\n?```.*?```", description, flags=re.DOTALL) # Regex borrowed from https://stackoverflow.com/a/59843498/13568999 @@ -201,10 +198,13 @@ class CustomHelpCommand(commands.MinimalHelpCommand): return None - async def _filter_cogs(self, cogs: list[Optional[commands.Cog]]) -> list[commands.Cog]: + async def _filter_cogs(self, cogs: list[commands.Cog]) -> list[commands.Cog]: """Filter the list of cogs down to all those that the user can see""" - async def _predicate(cog: commands.Cog) -> bool: + async def _predicate(cog: Optional[commands.Cog]) -> bool: + if cog is None: + return False + # Remove cogs that we never want to see in the help page because they # don't contain commands, or shouldn't be visible at all if not cog.get_commands(): @@ -220,12 +220,12 @@ class CustomHelpCommand(commands.MinimalHelpCommand): return True # Filter list of cogs down - filtered_cogs = [cog for cog in cogs if cog is not None and await _predicate(cog)] + filtered_cogs = [cog for cog in cogs if await _predicate(cog)] return list(sorted(filtered_cogs, key=lambda cog: cog.qualified_name)) def _get_flags_class(self, command: commands.Command) -> Optional[Type[PosixFlags]]: """Check if a command has flags""" - flag_param = command.params.get("flags") + flag_param = command.params.get("flags", None) if flag_param is None: return None diff --git a/didier/cogs/meta.py b/didier/cogs/meta.py index 861bf58..c330dbd 100644 --- a/didier/cogs/meta.py +++ b/didier/cogs/meta.py @@ -1,6 +1,6 @@ import inspect import os -from typing import Any, Optional, Union +from typing import Optional from discord.ext import commands @@ -76,24 +76,18 @@ class Meta(commands.Cog): if command_name is None: return await ctx.reply(repo_home, mention_author=False) - command: Optional[Union[commands.HelpCommand, commands.Command]] - src: Any - if command_name == "help": command = self.client.help_command - if command is None: - return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) - src = type(self.client.help_command) filename = inspect.getsourcefile(src) else: command = self.client.get_command(command_name) - if command is None: - return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) - src = command.callback.__code__ filename = src.co_filename + if command is None: + return await ctx.reply(f"Found no command named `{command_name}`.", mention_author=False) + lines, first_line = inspect.getsourcelines(src) if filename is None: diff --git a/didier/cogs/other.py b/didier/cogs/other.py index a48cb5e..02c0095 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -22,7 +22,7 @@ class Other(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) # type: ignore[arg-type] + @commands.hybrid_command(name="corona", aliases=["covid", "rona"]) async def covid(self, ctx: commands.Context, country: str = "Belgium"): """Show Covid-19 info for a specific country. @@ -43,7 +43,7 @@ class Other(commands.Cog): """Autocompletion for the 'country'-parameter""" return autocomplete_country(value)[:25] - @commands.hybrid_command( # type: ignore[arg-type] + @commands.hybrid_command( name="define", aliases=["ud", "urban"], description="Look up the definition of a word on the Urban Dictionary" ) async def define(self, ctx: commands.Context, *, query: str): @@ -55,7 +55,7 @@ class Other(commands.Cog): mention_author=False, ) - @commands.hybrid_command(name="google", description="Google search") # type: ignore[arg-type] + @commands.hybrid_command(name="google", description="Google search") @app_commands.describe(query="Search query") async def google(self, ctx: commands.Context, *, query: str): """Show the Google search results for `query`. @@ -71,7 +71,7 @@ class Other(commands.Cog): embed = GoogleSearch(results).to_embed() await ctx.reply(embed=embed, mention_author=False) - @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") # type: ignore[arg-type] + @commands.hybrid_command(name="inspire", description="Generate an InspiroBot quote.") async def inspire(self, ctx: commands.Context): """Generate an [InspiroBot](https://inspirobot.me/) quote.""" async with ctx.typing(): @@ -82,7 +82,7 @@ class Other(commands.Cog): async with self.client.postgres_session as session: return await get_link_by_name(session, name.lower()) - @commands.command(name="Link", aliases=["Links"]) # type: ignore[arg-type] + @commands.command(name="Link", aliases=["Links"]) async def link_msg(self, ctx: commands.Context, name: str): """Get the link to the resource named `name`.""" link = await self._get_link(name) @@ -92,7 +92,7 @@ class Other(commands.Cog): target_message = await self.client.get_reply_target(ctx) await target_message.reply(link.url, mention_author=False) - @app_commands.command(name="link") # type: ignore[arg-type] + @app_commands.command(name="link") @app_commands.describe(name="The name of the resource") async def link_slash(self, interaction: discord.Interaction, name: str): """Get the link to something.""" diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 1f72eff..139f02c 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -42,7 +42,7 @@ class Owner(commands.Cog): def __init__(self, client: Didier): self.client = client - async def cog_check(self, ctx: commands.Context) -> bool: # type: ignore[override] + async def cog_check(self, ctx: commands.Context) -> bool: """Global check for every command in this cog This means that we don't have to add is_owner() to every single command separately @@ -102,7 +102,7 @@ class Owner(commands.Cog): async def add_msg(self, ctx: commands.Context): """Command group for [add X] message commands""" - @add_msg.command(name="Alias") # type: ignore[arg-type] + @add_msg.command(name="Alias") async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" async with self.client.postgres_session as session: @@ -116,7 +116,7 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) - @add_msg.command(name="Custom") # type: ignore[arg-type] + @add_msg.command(name="Custom") async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): """Add a new custom command""" async with self.client.postgres_session as session: @@ -127,7 +127,7 @@ class Owner(commands.Cog): await ctx.reply("There is already a command with this name.") await self.client.reject_message(ctx.message) - @add_msg.command(name="Link") # type: ignore[arg-type] + @add_msg.command(name="Link") async def add_link_msg(self, ctx: commands.Context, name: str, url: str): """Add a new link""" async with self.client.postgres_session as session: @@ -136,7 +136,7 @@ class Owner(commands.Cog): await self.client.confirm_message(ctx.message) - @add_slash.command(name="custom", description="Add a custom command") # type: ignore[arg-type] + @add_slash.command(name="custom", description="Add a custom command") async def add_custom_slash(self, interaction: discord.Interaction): """Slash command to add a custom command""" if not await self.client.is_owner(interaction.user): @@ -145,7 +145,7 @@ class Owner(commands.Cog): modal = CreateCustomCommand(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="dadjoke", description="Add a dad joke") # type: ignore[arg-type] + @add_slash.command(name="dadjoke", description="Add a dad joke") async def add_dad_joke_slash(self, interaction: discord.Interaction): """Slash command to add a dad joke""" if not await self.client.is_owner(interaction.user): @@ -154,7 +154,7 @@ class Owner(commands.Cog): modal = AddDadJoke(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="deadline", description="Add a deadline") # type: ignore[arg-type] + @add_slash.command(name="deadline", description="Add a deadline") @app_commands.describe(course="The name of the course to add a deadline for (aliases work too)") async def add_deadline_slash(self, interaction: discord.Interaction, course: str): """Slash command to add a deadline""" @@ -174,7 +174,7 @@ class Owner(commands.Cog): """Autocompletion for the 'course'-parameter""" return self.client.database_caches.ufora_courses.get_autocomplete_suggestions(current) - @add_slash.command(name="event", description="Add a new event") # type: ignore[arg-type] + @add_slash.command(name="event", description="Add a new event") async def add_event_slash(self, interaction: discord.Interaction): """Slash command to add new events""" if not await self.client.is_owner(interaction.user): @@ -183,7 +183,7 @@ class Owner(commands.Cog): modal = AddEvent(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="link", description="Add a new link") # type: ignore[arg-type] + @add_slash.command(name="link", description="Add a new link") async def add_link_slash(self, interaction: discord.Interaction): """Slash command to add new links""" if not await self.client.is_owner(interaction.user): @@ -192,7 +192,7 @@ class Owner(commands.Cog): modal = AddLink(self.client) await interaction.response.send_modal(modal) - @add_slash.command(name="meme", description="Add a new meme") # type: ignore[arg-type] + @add_slash.command(name="meme", description="Add a new meme") async def add_meme_slash(self, interaction: discord.Interaction, name: str, imgflip_id: int, field_count: int): """Slash command to add new memes""" await interaction.response.defer(ephemeral=True) @@ -205,11 +205,11 @@ class Owner(commands.Cog): await interaction.followup.send(f"Added meme `{meme.meme_id}`.") await self.client.database_caches.memes.invalidate(session) - @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) # type: ignore[arg-type] + @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) async def edit_msg(self, ctx: commands.Context): """Command group for [edit X] commands""" - @edit_msg.command(name="Custom") # type: ignore[arg-type] + @edit_msg.command(name="Custom") async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): """Edit an existing custom command""" async with self.client.postgres_session as session: @@ -220,7 +220,7 @@ class Owner(commands.Cog): await ctx.reply(f"No command found matching `{command}`.") return await self.client.reject_message(ctx.message) - @edit_slash.command(name="custom", description="Edit a custom command") # type: ignore[arg-type] + @edit_slash.command(name="custom", description="Edit a custom command") @app_commands.describe(command="The name of the command to edit") async def edit_custom_slash(self, interaction: discord.Interaction, command: str): """Slash command to edit a custom command""" diff --git a/didier/cogs/school.py b/didier/cogs/school.py index cd9366a..7af8a81 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -27,7 +27,7 @@ class School(commands.Cog): def __init__(self, client: Didier): self.client = client - @commands.hybrid_command(name="deadlines") # type: ignore[arg-type] + @commands.hybrid_command(name="deadlines") async def deadlines(self, ctx: commands.Context): """Show upcoming deadlines.""" async with ctx.typing(): @@ -40,7 +40,7 @@ class School(commands.Cog): embed = Deadlines(deadlines).to_embed() await ctx.reply(embed=embed, mention_author=False, ephemeral=False) - @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) # type: ignore[arg-type] + @commands.hybrid_command(name="les", aliases=["sched", "schedule"]) @app_commands.rename(day_dt="date") async def les( self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None @@ -72,7 +72,10 @@ class School(commands.Cog): except NotInMainGuildException: return await ctx.reply(f"You are not a member of {self.client.main_guild.name}.", mention_author=False) - @commands.hybrid_command(name="menu", aliases=["eten", "food"]) # type: ignore[arg-type] + @commands.hybrid_command( + name="menu", + aliases=["eten", "food"], + ) @app_commands.rename(day_dt="date") async def menu( self, ctx: commands.Context, *, day_dt: Optional[app_commands.Transform[date, DateTransformer]] = None @@ -93,7 +96,7 @@ class School(commands.Cog): embed = no_menu_found(day_dt) await ctx.reply(embed=embed, mention_author=False) - @commands.hybrid_command( # type: ignore[arg-type] + @commands.hybrid_command( name="fiche", description="Sends the link to study guides", aliases=["guide", "studiefiche"] ) @app_commands.describe(course="The name of the course to fetch the study guide for (aliases work too)") @@ -121,7 +124,7 @@ class School(commands.Cog): mention_author=False, ) - @commands.hybrid_command(name="ufora") # type: ignore[arg-type] + @commands.hybrid_command(name="ufora") async def ufora(self, ctx: commands.Context, course: str): """Link the Ufora page for a course.""" async with self.client.postgres_session as session: diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index f59d697..07d6508 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -1,6 +1,4 @@ -import asyncio import datetime -import logging import random import discord @@ -22,12 +20,9 @@ from didier.data.embeds.schedules import ( from didier.data.rss_feeds.free_games import fetch_free_games from didier.data.rss_feeds.ufora import fetch_ufora_announcements from didier.decorators.tasks import timed_task -from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES from didier.utils.discord.checks import is_owner from didier.utils.types.datetime import LOCAL_TIMEZONE, tz_aware_now -logger = logging.getLogger(__name__) - # datetime.time()-instances for when every task should run DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE) SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE) @@ -61,7 +56,7 @@ class Tasks(commands.Cog): } @overrides - async def cog_load(self) -> None: + def cog_load(self) -> None: # Only check birthdays if there's a channel to send it to if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is not None: self.check_birthdays.start() @@ -77,10 +72,9 @@ class Tasks(commands.Cog): # Start other tasks self.reminders.start() - asyncio.create_task(self.get_error_channel()) @overrides - async def cog_unload(self) -> None: + def cog_unload(self) -> None: # Cancel all pending tasks for task in self._tasks.values(): if task.is_running(): @@ -102,7 +96,7 @@ class Tasks(commands.Cog): await ctx.reply(embed=embed, mention_author=False) - @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") # type: ignore[arg-type] + @tasks_group.command(name="Force", case_insensitive=True, usage="[Task]") async def force_task(self, ctx: commands.Context, name: str): """Command to force-run a task without waiting for the specified run time""" name = name.lower() @@ -113,53 +107,23 @@ class Tasks(commands.Cog): await task(forced=True) await self.client.confirm_message(ctx.message) - async def get_error_channel(self): - """Get the configured channel from the cache""" - await self.client.wait_until_ready() - - # Configure channel to send errors to - if settings.ERRORS_CHANNEL is not None: - channel = self.client.get_channel(settings.ERRORS_CHANNEL) - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - logger.error(f"Configured error channel (id `{settings.ERRORS_CHANNEL}`) is not messageable.") - else: - self.client.error_channel = channel - elif self.client.owner_id is not None: - self.client.error_channel = self.client.get_user(self.client.owner_id) - @tasks.loop(time=SOCIALLY_ACCEPTABLE_TIME) @timed_task(enums.TaskType.BIRTHDAYS) async def check_birthdays(self, **kwargs): """Check if it's currently anyone's birthday""" _ = kwargs - # Can't happen (task isn't started if this is None), but Mypy doesn't know - if settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL is None: - return - now = tz_aware_now().date() async with self.client.postgres_session as session: birthdays = await get_birthdays_on_day(session, now) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) if channel is None: - return await self.client.log_error("Unable to fetch channel for birthday announcements.") - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Birthday announcement channel (id `{settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL}`) is not messageable." - ) + return await self.client.log_error("Unable to find channel for birthday announcements") for birthday in birthdays: user = self.client.get_user(birthday.user_id) - if user is None: - await self.client.log_error( - f"Unable to fetch user with id `{birthday.user_id}` for birthday announcement" - ) - continue - await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention)) @check_birthdays.before_loop @@ -179,14 +143,6 @@ class Tasks(commands.Cog): games = await fetch_free_games(self.client.http_session, session) channel = self.client.get_channel(settings.FREE_GAMES_CHANNEL) - if channel is None: - return await self.client.log_error("Unable to fetch channel for free games announcements.") - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Free games channel (id `{settings.FREE_GAMES_CHANNEL}`) is not messageable." - ) - for game in games: await channel.send(embed=game.to_embed()) @@ -248,17 +204,6 @@ class Tasks(commands.Cog): async with self.client.postgres_session as db_session: announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) - - if announcements_channel is None: - return await self.client.log_error( - f"Unable to fetch channel for ufora announcements (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`)." - ) - - if isinstance(announcements_channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await self.client.log_error( - f"Ufora announcements channel (id `{settings.UFORA_ANNOUNCEMENTS_CHANNEL}`) is not messageable." - ) - announcements = await fetch_ufora_announcements(self.client.http_session, db_session) for announcement in announcements: diff --git a/didier/data/apis/disease_sh.py b/didier/data/apis/disease_sh.py index a32daea..809cdcb 100644 --- a/didier/data/apis/disease_sh.py +++ b/didier/data/apis/disease_sh.py @@ -19,7 +19,7 @@ async def get_country_info(http_session: ClientSession, country: str) -> CovidDa yesterday = response data = {"today": today, "yesterday": yesterday} - return CovidData.model_validate(data) + return CovidData.parse_obj(data) async def get_global_info(http_session: ClientSession) -> CovidData: @@ -35,4 +35,4 @@ async def get_global_info(http_session: ClientSession) -> CovidData: yesterday = response data = {"today": today, "yesterday": yesterday} - return CovidData.model_validate(data) + return CovidData.parse_obj(data) diff --git a/didier/data/apis/hydra.py b/didier/data/apis/hydra.py index 8d7889b..620e0df 100644 --- a/didier/data/apis/hydra.py +++ b/didier/data/apis/hydra.py @@ -12,4 +12,4 @@ async def fetch_menu(http_session: ClientSession, day_dt: date) -> Menu: """Fetch the menu for a given day""" endpoint = f"https://hydra.ugent.be/api/2.0/resto/menu/nl/{day_dt.year}/{day_dt.month}/{day_dt.day}.json" async with ensure_get(http_session, endpoint, log_exceptions=False) as response: - return Menu.model_validate(response) + return Menu.parse_obj(response) diff --git a/didier/data/apis/urban_dictionary.py b/didier/data/apis/urban_dictionary.py index a6b5cd6..6d81934 100644 --- a/didier/data/apis/urban_dictionary.py +++ b/didier/data/apis/urban_dictionary.py @@ -14,4 +14,4 @@ async def lookup(http_session: ClientSession, query: str) -> list[Definition]: url = "https://api.urbandictionary.com/v0/define" async with ensure_get(http_session, url, params={"term": query}) as response: - return list(map(Definition.model_validate, response["list"])) + return list(map(Definition.parse_obj, response["list"])) diff --git a/didier/data/apis/xkcd.py b/didier/data/apis/xkcd.py index bf8ff4d..c0ad766 100644 --- a/didier/data/apis/xkcd.py +++ b/didier/data/apis/xkcd.py @@ -13,4 +13,4 @@ async def fetch_xkcd_post(http_session: ClientSession, *, num: Optional[int] = N url = "https://xkcd.com" + (f"/{num}" if num is not None else "") + "/info.0.json" async with ensure_get(http_session, url) as response: - return XKCDPost.model_validate(response) + return XKCDPost.parse_obj(response) diff --git a/didier/data/embeds/disease_sh.py b/didier/data/embeds/disease_sh.py index a344828..45e8895 100644 --- a/didier/data/embeds/disease_sh.py +++ b/didier/data/embeds/disease_sh.py @@ -1,6 +1,6 @@ import discord from overrides import overrides -from pydantic import BaseModel, Field, field_validator +from pydantic import BaseModel, Field, validator from didier.data.embeds.base import EmbedPydantic @@ -24,7 +24,7 @@ class _CovidNumbers(BaseModel): active: int tests: int - @field_validator("updated") + @validator("updated") def updated_to_seconds(cls, value: int) -> int: """Turn the updated field into seconds instead of milliseconds""" return int(value) // 1000 diff --git a/didier/data/embeds/error_embed.py b/didier/data/embeds/error_embed.py index 696dd80..ea03bfe 100644 --- a/didier/data/embeds/error_embed.py +++ b/didier/data/embeds/error_embed.py @@ -38,10 +38,10 @@ def create_error_embed(ctx: Optional[commands.Context], exception: Exception) -> embed = discord.Embed(title="Error", colour=discord.Colour.red()) if ctx is not None: - if ctx.guild is None or isinstance(ctx.channel, discord.DMChannel): + if ctx.guild is None: origin = "DM" else: - origin = f"<#{ctx.channel.id}> ({ctx.guild.name})" + origin = f"{ctx.channel.mention} ({ctx.guild.name})" invocation = f"{ctx.author.display_name} in {origin}" diff --git a/didier/data/embeds/free_games.py b/didier/data/embeds/free_games.py index f159157..d37e0b7 100644 --- a/didier/data/embeds/free_games.py +++ b/didier/data/embeds/free_games.py @@ -4,17 +4,18 @@ from typing import Optional import discord from aiohttp import ClientSession from overrides import overrides -from pydantic import field_validator +from pydantic import validator from didier.data.embeds.base import EmbedPydantic from didier.data.scrapers.common import GameStorePage from didier.data.scrapers.steam import get_steam_webpage_info from didier.utils.discord import colours -from didier.utils.discord.constants import Limits -from didier.utils.types.string import abbreviate __all__ = ["SEPARATOR", "FreeGameEmbed"] +from didier.utils.discord.constants import Limits +from didier.utils.types.string import abbreviate + SEPARATOR = " • Free • " @@ -57,7 +58,7 @@ class FreeGameEmbed(EmbedPydantic): store_page: Optional[GameStorePage] = None - @field_validator("title") + @validator("title") def _clean_title(cls, value: str) -> str: return html.unescape(value) @@ -106,6 +107,7 @@ class FreeGameEmbed(EmbedPydantic): embed.add_field(name="Open in browser", value=f"[{self.link}]({self.link})") if self.store_page.xdg_open_url is not None: + embed.add_field( name="Open in app", value=f"[{self.store_page.xdg_open_url}]({self.store_page.xdg_open_url})" ) diff --git a/didier/data/embeds/logging_embed.py b/didier/data/embeds/logging_embed.py index 3a803a1..40556f2 100644 --- a/didier/data/embeds/logging_embed.py +++ b/didier/data/embeds/logging_embed.py @@ -11,7 +11,7 @@ __all__ = ["create_logging_embed"] def create_logging_embed(level: int, message: str) -> discord.Embed: """Create an embed to send to the logging channel""" colours = { - logging.DEBUG: discord.Colour.light_grey(), + logging.DEBUG: discord.Colour.light_gray, logging.ERROR: discord.Colour.red(), logging.INFO: discord.Colour.blue(), logging.WARNING: discord.Colour.yellow(), diff --git a/didier/data/embeds/urban_dictionary.py b/didier/data/embeds/urban_dictionary.py index ad3f90f..6dfeaac 100644 --- a/didier/data/embeds/urban_dictionary.py +++ b/didier/data/embeds/urban_dictionary.py @@ -2,7 +2,7 @@ from datetime import datetime import discord from overrides import overrides -from pydantic import field_validator +from pydantic import validator from didier.data.embeds.base import EmbedPydantic from didier.utils.discord import colours @@ -39,8 +39,8 @@ class Definition(EmbedPydantic): total_votes = self.thumbs_up + self.thumbs_down return round(100 * self.thumbs_up / total_votes, 2) - @field_validator("definition", "example") - def modify_long_text(cls, field: str): + @validator("definition", "example") + def modify_long_text(cls, field): """Remove brackets from fields & cut them off if they are too long""" field = field.replace("[", "").replace("]", "") return string_utils.abbreviate(field, max_length=Limits.EMBED_FIELD_VALUE_LENGTH) diff --git a/didier/data/rss_feeds/free_games.py b/didier/data/rss_feeds/free_games.py index 24a1ecf..fcc02c9 100644 --- a/didier/data/rss_feeds/free_games.py +++ b/didier/data/rss_feeds/free_games.py @@ -32,7 +32,7 @@ async def fetch_free_games(http_session: ClientSession, database_session: AsyncS if SEPARATOR not in entry["title"]: continue - game = FreeGameEmbed.model_validate(entry) + game = FreeGameEmbed.parse_obj(entry) games.append(game) game_ids.append(game.dc_identifier) diff --git a/didier/data/scrapers/google.py b/didier/data/scrapers/google.py index 9c10716..389e9ae 100644 --- a/didier/data/scrapers/google.py +++ b/didier/data/scrapers/google.py @@ -72,12 +72,12 @@ def get_search_results(bs: BeautifulSoup) -> list[str]: return list(dict.fromkeys(results)) -async def google_search(http_session: ClientSession, query: str): +async def google_search(http_client: ClientSession, query: str): """Get the first 10 Google search results""" query = urlencode({"q": query}) # Request 20 results in case of duplicates, bad matches, ... - async with http_session.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: + async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en") as response: # Something went wrong if response.status != http.HTTPStatus.OK: return SearchData(query, response.status) diff --git a/didier/didier.py b/didier/didier.py index 33e6e4b..cf9ed1d 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -17,7 +17,7 @@ from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.logging_embed import create_logging_embed from didier.data.embeds.schedules import Schedule, parse_schedule -from didier.exceptions import GetNoneException, HTTPException, NoMatch +from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix from didier.utils.discord.snipe import should_snipe from didier.utils.easter_eggs import detect_easter_egg @@ -33,7 +33,7 @@ class Didier(commands.Bot): """DIDIER <3""" database_caches: CacheManager - error_channel: Optional[discord.abc.Messageable] = None + error_channel: discord.abc.Messageable initial_extensions: tuple[str, ...] = () http_session: ClientSession schedules: dict[settings.ScheduleType, Schedule] = {} @@ -56,17 +56,12 @@ class Didier(commands.Bot): command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status ) - # I'm not creating a custom tree, this is the way to do it - self.tree.on_error = self.on_app_command_error # type: ignore[method-assign] + self.tree.on_error = self.on_app_command_error @cached_property def main_guild(self) -> discord.Guild: """Obtain a reference to the main guild""" - guild = self.get_guild(settings.DISCORD_MAIN_GUILD) - if guild is None: - raise GetNoneException("Main guild could not be found in the bot's cache") - - return guild + return self.get_guild(settings.DISCORD_MAIN_GUILD) @property def postgres_session(self) -> AsyncSession: @@ -98,6 +93,12 @@ class Didier(commands.Bot): await self._load_initial_extensions() await self._load_directory_extensions("didier/cogs") + # Configure channel to send errors to + if settings.ERRORS_CHANNEL is not None: + self.error_channel = self.get_channel(settings.ERRORS_CHANNEL) + else: + self.error_channel = self.get_user(self.owner_id) + def _create_ignored_directories(self): """Create directories that store ignored data""" ignored = ["files/schedules"] @@ -151,27 +152,18 @@ class Didier(commands.Bot): original message instead """ if ctx.message.reference is not None: - return await self.resolve_message(ctx.message.reference) or ctx.message + return await self.resolve_message(ctx.message.reference) return ctx.message - async def resolve_message(self, reference: discord.MessageReference) -> Optional[discord.Message]: + async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: """Fetch a message from a reference""" # Message is in the cache, return it if reference.cached_message is not None: return reference.cached_message - if reference.message_id is None: - return None - # For older messages: fetch them from the API channel = self.get_channel(reference.channel_id) - if channel is None or isinstance( - channel, - (discord.CategoryChannel, discord.ForumChannel, discord.abc.PrivateChannel), - ): # Logically this can't happen, but we have to please Mypy - return None - return await channel.fetch_message(reference.message_id) async def confirm_message(self, message: discord.Message): @@ -192,7 +184,7 @@ class Didier(commands.Bot): } methods.get(level, logger.error)(message) - if log_to_discord and self.error_channel is not None: + if log_to_discord: embed = create_logging_embed(level, message) await self.error_channel.send(embed=embed) @@ -261,9 +253,10 @@ class Didier(commands.Bot): await interaction.response.send_message("Something went wrong processing this command.", ephemeral=True) - if self.error_channel is not None: + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(await commands.Context.from_interaction(interaction), exception) - await self.error_channel.send(embed=embed) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) async def on_command_completion(self, ctx: commands.Context): """Event triggered when a message command completes successfully""" @@ -288,7 +281,7 @@ class Didier(commands.Bot): # Hybrid command errors are wrapped in an additional error, so wrap it back out if isinstance(exception, commands.HybridCommandError): - exception = exception.original # type: ignore[assignment] + exception = exception.original # Ignore exceptions that aren't important if isinstance( @@ -339,9 +332,10 @@ class Didier(commands.Bot): # Print everything that we care about to the logs/stderr await super().on_command_error(ctx, exception) - if self.error_channel is not None: + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(ctx, exception) - await self.error_channel.send(embed=embed) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) async def on_message(self, message: discord.Message, /) -> None: """Event triggered when a message is sent""" @@ -350,7 +344,7 @@ class Didier(commands.Bot): return # Boos react to people that say Dider - if "dider" in message.content.lower() and self.user is not None and message.author.id != self.user.id: + if "dider" in message.content.lower() and message.author.id != self.user.id: await message.add_reaction(settings.DISCORD_BOOS_REACT) # Potential custom command @@ -380,7 +374,7 @@ class Didier(commands.Bot): # If the edited message is currently present in the snipe cache, # don't update the , but instead change the - existing = self.sniped.get(before.channel.id) + existing = self.sniped.get(before.channel.id, None) if existing is not None and existing[0].id == before.id: before = existing[0] @@ -395,9 +389,10 @@ class Didier(commands.Bot): async def on_task_error(self, exception: Exception): """Event triggered when a task raises an exception""" - if self.error_channel: + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(None, exception) - await self.error_channel.send(embed=embed) + channel = self.get_channel(settings.ERRORS_CHANNEL) + await channel.send(embed=embed) async def on_thread_create(self, thread: discord.Thread): """Event triggered when a new thread is created""" diff --git a/didier/exceptions/__init__.py b/didier/exceptions/__init__.py index fa5ad13..1335dd4 100644 --- a/didier/exceptions/__init__.py +++ b/didier/exceptions/__init__.py @@ -1,14 +1,6 @@ -from .get_none_exception import GetNoneException from .http_exception import HTTPException from .missing_env import MissingEnvironmentVariable from .no_match import NoMatch, expect from .not_in_main_guild_exception import NotInMainGuildException -__all__ = [ - "GetNoneException", - "HTTPException", - "MissingEnvironmentVariable", - "NoMatch", - "expect", - "NotInMainGuildException", -] +__all__ = ["HTTPException", "MissingEnvironmentVariable", "NoMatch", "expect", "NotInMainGuildException"] diff --git a/didier/exceptions/get_none_exception.py b/didier/exceptions/get_none_exception.py deleted file mode 100644 index cbd2f77..0000000 --- a/didier/exceptions/get_none_exception.py +++ /dev/null @@ -1,5 +0,0 @@ -__all__ = ["GetNoneException"] - - -class GetNoneException(RuntimeError): - """Exception raised when a Bot.get()-method returned None""" diff --git a/didier/exceptions/not_in_main_guild_exception.py b/didier/exceptions/not_in_main_guild_exception.py index 5279686..5572c44 100644 --- a/didier/exceptions/not_in_main_guild_exception.py +++ b/didier/exceptions/not_in_main_guild_exception.py @@ -12,6 +12,6 @@ class NotInMainGuildException(ValueError): def __init__(self, user: Union[discord.User, discord.Member]): super().__init__( - f"User {user.display_name} (id `{user.id}`) " - f"is not a member of the configured main guild (id `{settings.DISCORD_MAIN_GUILD}`)." + f"User {user.display_name} (id {user.id}) " + f"is not a member of the configured main guild (id {settings.DISCORD_MAIN_GUILD})." ) diff --git a/didier/utils/discord/channels.py b/didier/utils/discord/channels.py deleted file mode 100644 index 26739f8..0000000 --- a/didier/utils/discord/channels.py +++ /dev/null @@ -1,5 +0,0 @@ -import discord - -__all__ = ["NON_MESSAGEABLE_CHANNEL_TYPES"] - -NON_MESSAGEABLE_CHANNEL_TYPES = (discord.ForumChannel, discord.CategoryChannel, discord.abc.PrivateChannel) diff --git a/didier/utils/discord/prefix.py b/didier/utils/discord/prefix.py index 694a4b6..f3fa7c4 100644 --- a/didier/utils/discord/prefix.py +++ b/didier/utils/discord/prefix.py @@ -15,14 +15,11 @@ def match_prefix(client: commands.Bot, message: Message) -> Optional[str]: This is done dynamically through regexes to allow case-insensitivity and variable amounts of whitespace among other things. """ - mention = f"<@!?{client.user.id}>" if client.user else None + mention = f"<@!?{client.user.id}>" regex = r"^({})\s*" # Check which prefix was used for prefix in [*constants.PREFIXES, mention]: - if prefix is None: - continue - match = re.match(regex.format(prefix), message.content, flags=re.I) if match is not None: diff --git a/didier/views/modals/bookmarks.py b/didier/views/modals/bookmarks.py index acd4c8f..f77b608 100644 --- a/didier/views/modals/bookmarks.py +++ b/didier/views/modals/bookmarks.py @@ -25,24 +25,24 @@ class CreateBookmark(discord.ui.Modal, title="Create Bookmark"): @overrides async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) - label = self.name.value.strip() try: async with self.client.postgres_session as session: bm = await create_bookmark(session, interaction.user.id, label, self.jump_url) - return await interaction.followup.send( - f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`)." + return await interaction.response.send_message( + f"Bookmark `{label}` successfully created (`#{bm.bookmark_id}`).", ephemeral=True ) except DuplicateInsertException: # Label is already in use - return await interaction.followup.send(f"You already have a bookmark named `{label}`.") + return await interaction.response.send_message( + f"You already have a bookmark named `{label}`.", ephemeral=True + ) except ForbiddenNameException: # Label isn't allowed - return await interaction.followup.send(f"Bookmarks cannot be named `{label}`.") + return await interaction.response.send_message(f"Bookmarks cannot be named `{label}`.", ephemeral=True) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.followup.send("Something went wrong.", ephemeral=True) + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/views/modals/dad_jokes.py b/didier/views/modals/dad_jokes.py index 5ebfab7..c3b2f67 100644 --- a/didier/views/modals/dad_jokes.py +++ b/didier/views/modals/dad_jokes.py @@ -26,14 +26,12 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"): @overrides async def on_submit(self, interaction: discord.Interaction): - await interaction.response.defer(ephemeral=True) - async with self.client.postgres_session as session: joke = await add_dad_joke(session, str(self.joke.value)) - await interaction.followup.send(f"Successfully added joke #{joke.dad_joke_id}") + await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.followup.send("Something went wrong.", ephemeral=True) + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/didier/views/modals/events.py b/didier/views/modals/events.py index 71acea6..e7b92b4 100644 --- a/didier/views/modals/events.py +++ b/didier/views/modals/events.py @@ -10,8 +10,6 @@ from didier import Didier __all__ = ["AddEvent"] -from didier.utils.discord.channels import NON_MESSAGEABLE_CHANNEL_TYPES - class AddEvent(discord.ui.Modal, title="Add Event"): """Modal to add a new event""" @@ -35,20 +33,15 @@ class AddEvent(discord.ui.Modal, title="Add Event"): @overrides async def on_submit(self, interaction: discord.Interaction) -> None: - await interaction.response.defer(ephemeral=True) - try: parse(self.timestamp.value, dayfirst=True).replace(tzinfo=ZoneInfo("Europe/Brussels")) except ParserError: - return await interaction.followup.send("Unable to parse date argument.") + return await interaction.response.send_message("Unable to parse date argument.", ephemeral=True) - channel = self.client.get_channel(int(self.channel.value)) - - if channel is None: - return await interaction.followup.send(f"Unable to find channel with id `{self.channel.value}`") - - if isinstance(channel, NON_MESSAGEABLE_CHANNEL_TYPES): - return await interaction.followup.send(f"Channel with id `{self.channel.value}` is not messageable.") + if self.client.get_channel(int(self.channel.value)) is None: + return await interaction.response.send_message( + f"Unable to find channel `{self.channel.value}`", ephemeral=True + ) async with self.client.postgres_session as session: event = await add_event( @@ -59,10 +52,10 @@ class AddEvent(discord.ui.Modal, title="Add Event"): channel_id=int(self.channel.value), ) - await interaction.followup.send(f"Successfully added event `{event.event_id}`.") + await interaction.response.send_message(f"Successfully added event `{event.event_id}`.", ephemeral=True) self.client.dispatch("event_create", event) @overrides async def on_error(self, interaction: discord.Interaction, error: Exception): # type: ignore - await interaction.followup.send("Something went wrong.", ephemeral=True) + await interaction.response.send_message("Something went wrong.", ephemeral=True) traceback.print_tb(error.__traceback__) diff --git a/main.py b/main.py index a1e2655..dd9b030 100644 --- a/main.py +++ b/main.py @@ -36,7 +36,6 @@ def setup_logging(): # Configure discord handler discord_log = logging.getLogger("discord") - discord_handler: logging.StreamHandler # Make dev print to stderr instead, so you don't have to watch the file if settings.SANDBOX: diff --git a/pyproject.toml b/pyproject.toml index d28abe3..acd06c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,6 @@ omit = [ profile = "black" [tool.mypy] -check_untyped_defs = true files = [ "database/**/*.py", "didier/**/*.py", @@ -36,6 +35,7 @@ files = [ ] plugins = [ "pydantic.mypy", + "sqlalchemy.ext.mypy.plugin" ] [[tool.mypy.overrides]] module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"] diff --git a/requirements-dev.txt b/requirements-dev.txt index 8d2d1b3..a9f7109 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,21 +1,22 @@ -black==23.3.0 -coverage[toml]==7.2.7 -freezegun==1.2.2 +black==22.3.0 +coverage[toml]==6.4.1 +freezegun==1.2.1 isort==5.12.0 -mypy==1.4.1 -pre-commit==3.3.3 -pytest==7.4.0 -pytest-asyncio==0.21.0 -pytest-env==0.8.2 -types-beautifulsoup4==4.12.0.5 -types-python-dateutil==2.8.19.13 +mypy==0.961 +pre-commit==2.20.0 +pytest==7.1.2 +pytest-asyncio==0.18.3 +pytest-env==0.6.2 +sqlalchemy2-stubs==0.0.2a23 +types-beautifulsoup4==4.11.3 +types-python-dateutil==2.8.19 # Flake8 + plugins -flake8==6.0.0 -flake8-bandit==4.1.1 -flake8-bugbear==23.6.5 -flake8-docstrings==1.7.0 -flake8-dunder-all==0.3.0 -flake8-eradicate==1.5.0 -flake8-isort==6.0.0 -flake8-simplify==0.20.0 +flake8==4.0.1 +flake8-bandit==3.0.0 +flake8-bugbear==22.7.1 +flake8-docstrings==1.6.0 +flake8-dunder-all==0.2.1 +flake8-eradicate==1.2.1 +flake8-isort==4.1.1 +flake8-simplify==0.19.2 diff --git a/requirements.txt b/requirements.txt index 4b0ffa3..a29f1cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,13 +1,13 @@ -aiohttp==3.8.4 -alembic==1.11.1 -asyncpg==0.28.0 -beautifulsoup4==4.12.2 -discord.py==2.3.1 +aiohttp==3.8.1 +alembic==1.8.0 +asyncpg==0.25.0 +beautifulsoup4==4.11.1 +discord.py==2.0.1 environs==9.5.0 feedparser==6.0.10 ics==0.7.2 -markdownify==0.11.6 -overrides==7.3.1 -pydantic==2.0.2 +markdownify==0.11.2 +overrides==6.1.0 +pydantic==1.9.1 python-dateutil==2.8.2 -sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18 +sqlalchemy[asyncio]==1.4.37 diff --git a/settings.py b/settings.py index a862fde..32bd5e0 100644 --- a/settings.py +++ b/settings.py @@ -111,7 +111,7 @@ class ScheduleInfo: role_id: Optional[int] schedule_url: Optional[str] - name: ScheduleType + name: Optional[str] = None SCHEDULE_DATA = [