Fix db typing

pull/176/head
Stijn De Clercq 2023-07-08 01:23:47 +02:00
parent d52c80aa8b
commit feccb88cfd
14 changed files with 20 additions and 20 deletions

View File

@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[
if query is not None:
statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%"))
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]:

View File

@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
"""Get a list of all commands"""
statement = select(CustomCommand)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:

View File

@ -38,4 +38,4 @@ async def get_deadlines(
statement = statement.where(Deadline.course_id == course.course_id)
statement = statement.options(selectinload(Deadline.course))
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())

View File

@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"]
async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]:
"""Return a list of all easter eggs"""
statement = select(EasterEgg)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())

View File

@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
"""Get a list of all upcoming events"""
statement = select(Event).where(Event.timestamp > now)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:

View File

@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]):
async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]:
"""Filter a list of game IDs down to the ones that aren't in the database yet"""
statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids))
matches: list[int] = (await session.execute(statement)).scalars().all()
matches: list[int] = list((await session.execute(statement)).scalars().all())
return list(set(game_ids).difference(matches))

View File

@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id:
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
"""Get a user's GitHub links"""
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())

View File

@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
async def get_all_links(session: AsyncSession) -> list[Link]:
"""Get a list of all links"""
statement = select(Link)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def add_link(session: AsyncSession, name: str, url: str) -> Link:

View File

@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
"""Get a list of all memes"""
statement = select(MemeTemplate)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:

View File

@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"]
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
"""Get a list of all Reminders for a given category"""
statement = select(Reminder).where(Reminder.category == category)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:

View File

@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_
async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]:
"""Get all courses where announcements are enabled"""
statement = select(UforaCourse).where(UforaCourse.log_announcements)
return (await session.execute(statement)).scalars().all()
return list((await session.execute(statement)).scalars().all())
async def create_new_announcement(

View File

@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor
# Search case-insensitively
query = query.lower()
statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
result = (await session.execute(statement)).scalars().first()
if result:
return result
course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
course_result = (await session.execute(course_statement)).scalars().first()
if course_result:
return course_result
statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
result = (await session.execute(statement)).scalars().first()
return result.course if result else None
alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
alias_result = (await session.execute(alias_statement)).scalars().first()
return alias_result.course if alias_result else None

View File

@ -28,6 +28,7 @@ omit = [
profile = "black"
[tool.mypy]
check_untyped_defs = true
files = [
"database/**/*.py",
"didier/**/*.py",
@ -35,7 +36,6 @@ files = [
]
plugins = [
"pydantic.mypy",
"sqlalchemy.ext.mypy.plugin"
]
[[tool.mypy.overrides]]
module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"]

View File

@ -10,4 +10,4 @@ markdownify==0.11.6
overrides==7.3.1
pydantic==2.0.2
python-dateutil==2.8.2
sqlalchemy[asyncio]==2.0.18
sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18