diff --git a/database/crud/bookmarks.py b/database/crud/bookmarks.py index d696e50..925e645 100644 --- a/database/crud/bookmarks.py +++ b/database/crud/bookmarks.py @@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[ if query is not None: statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%")) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]: diff --git a/database/crud/custom_commands.py b/database/crud/custom_commands.py index efbc689..eac9c7e 100644 --- a/database/crud/custom_commands.py +++ b/database/crud/custom_commands.py @@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo async def get_all_commands(session: AsyncSession) -> list[CustomCommand]: """Get a list of all commands""" statement = select(CustomCommand) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]: diff --git a/database/crud/deadlines.py b/database/crud/deadlines.py index 78d623f..a539518 100644 --- a/database/crud/deadlines.py +++ b/database/crud/deadlines.py @@ -38,4 +38,4 @@ async def get_deadlines( statement = statement.where(Deadline.course_id == course.course_id) statement = statement.options(selectinload(Deadline.course)) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) diff --git a/database/crud/easter_eggs.py b/database/crud/easter_eggs.py index d4c25d9..f4a55ed 100644 --- a/database/crud/easter_eggs.py +++ b/database/crud/easter_eggs.py @@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"] async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]: """Return a list of all easter eggs""" statement = select(EasterEgg) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) diff --git a/database/crud/events.py b/database/crud/events.py index 19887e6..e8dfac5 100644 --- a/database/crud/events.py +++ b/database/crud/events.py @@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]: """Get a list of all upcoming events""" statement = select(Event).where(Event.timestamp > now) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]: diff --git a/database/crud/free_games.py b/database/crud/free_games.py index b2d835d..e0dab5b 100644 --- a/database/crud/free_games.py +++ b/database/crud/free_games.py @@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]): async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]: """Filter a list of game IDs down to the ones that aren't in the database yet""" statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids)) - matches: list[int] = (await session.execute(statement)).scalars().all() + matches: list[int] = list((await session.execute(statement)).scalars().all()) return list(set(game_ids).difference(matches)) diff --git a/database/crud/github.py b/database/crud/github.py index 0d32377..a352bae 100644 --- a/database/crud/github.py +++ b/database/crud/github.py @@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]: """Get a user's GitHub links""" statement = select(GitHubLink).where(GitHubLink.user_id == user_id) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) diff --git a/database/crud/links.py b/database/crud/links.py index 495e0f3..20bcab3 100644 --- a/database/crud/links.py +++ b/database/crud/links.py @@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] async def get_all_links(session: AsyncSession) -> list[Link]: """Get a list of all links""" statement = select(Link) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def add_link(session: AsyncSession, name: str, url: str) -> Link: diff --git a/database/crud/memes.py b/database/crud/memes.py index ab288aa..b1ed1e0 100644 --- a/database/crud/memes.py +++ b/database/crud/memes.py @@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]: """Get a list of all memes""" statement = select(MemeTemplate) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]: diff --git a/database/crud/reminders.py b/database/crud/reminders.py index 007a779..78350e6 100644 --- a/database/crud/reminders.py +++ b/database/crud/reminders.py @@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"] async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]: """Get a list of all Reminders for a given category""" statement = select(Reminder).where(Reminder.category == category) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool: diff --git a/database/crud/ufora_announcements.py b/database/crud/ufora_announcements.py index 688bcc7..06c2b58 100644 --- a/database/crud/ufora_announcements.py +++ b/database/crud/ufora_announcements.py @@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_ async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]: """Get all courses where announcements are enabled""" statement = select(UforaCourse).where(UforaCourse.log_announcements) - return (await session.execute(statement)).scalars().all() + return list((await session.execute(statement)).scalars().all()) async def create_new_announcement( diff --git a/database/crud/ufora_courses.py b/database/crud/ufora_courses.py index 5374c07..d4cf728 100644 --- a/database/crud/ufora_courses.py +++ b/database/crud/ufora_courses.py @@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor # Search case-insensitively query = query.lower() - statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) - result = (await session.execute(statement)).scalars().first() - if result: - return result + course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%")) + course_result = (await session.execute(course_statement)).scalars().first() + if course_result: + return course_result - statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) - result = (await session.execute(statement)).scalars().first() - return result.course if result else None + alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%")) + alias_result = (await session.execute(alias_statement)).scalars().first() + return alias_result.course if alias_result else None diff --git a/pyproject.toml b/pyproject.toml index acd06c2..d28abe3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ omit = [ profile = "black" [tool.mypy] +check_untyped_defs = true files = [ "database/**/*.py", "didier/**/*.py", @@ -35,7 +36,6 @@ files = [ ] plugins = [ "pydantic.mypy", - "sqlalchemy.ext.mypy.plugin" ] [[tool.mypy.overrides]] module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"] diff --git a/requirements.txt b/requirements.txt index a7b6db2..4b0ffa3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,4 +10,4 @@ markdownify==0.11.6 overrides==7.3.1 pydantic==2.0.2 python-dateutil==2.8.2 -sqlalchemy[asyncio]==2.0.18 +sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18