mirror of https://github.com/stijndcl/didier
Fix db typing
parent
d52c80aa8b
commit
feccb88cfd
|
@ -64,7 +64,7 @@ async def get_bookmarks(session: AsyncSession, user_id: int, *, query: Optional[
|
|||
if query is not None:
|
||||
statement = statement.where(Bookmark.label.ilike(f"%{query.lower()}%"))
|
||||
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_bookmark_by_name(session: AsyncSession, user_id: int, query: str) -> Optional[Bookmark]:
|
||||
|
|
|
@ -59,7 +59,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
|
|||
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
|
||||
"""Get a list of all commands"""
|
||||
statement = select(CustomCommand)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
|
|
|
@ -38,4 +38,4 @@ async def get_deadlines(
|
|||
statement = statement.where(Deadline.course_id == course.course_id)
|
||||
|
||||
statement = statement.options(selectinload(Deadline.course))
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
|
|
@ -9,4 +9,4 @@ __all__ = ["get_all_easter_eggs"]
|
|||
async def get_all_easter_eggs(session: AsyncSession) -> list[EasterEgg]:
|
||||
"""Return a list of all easter eggs"""
|
||||
statement = select(EasterEgg)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
|
|
@ -41,7 +41,7 @@ async def get_event_by_id(session: AsyncSession, event_id: int) -> Optional[Even
|
|||
async def get_events(session: AsyncSession, *, now: datetime.datetime) -> list[Event]:
|
||||
"""Get a list of all upcoming events"""
|
||||
statement = select(Event).where(Event.timestamp > now)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_next_event(session: AsyncSession, *, now: datetime.datetime) -> Optional[Event]:
|
||||
|
|
|
@ -16,5 +16,5 @@ async def add_free_games(session: AsyncSession, game_ids: list[int]):
|
|||
async def filter_present_games(session: AsyncSession, game_ids: list[int]) -> list[int]:
|
||||
"""Filter a list of game IDs down to the ones that aren't in the database yet"""
|
||||
statement = select(FreeGame.free_game_id).where(FreeGame.free_game_id.in_(game_ids))
|
||||
matches: list[int] = (await session.execute(statement)).scalars().all()
|
||||
matches: list[int] = list((await session.execute(statement)).scalars().all())
|
||||
return list(set(game_ids).difference(matches))
|
||||
|
|
|
@ -48,4 +48,4 @@ async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id:
|
|||
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
|
||||
"""Get a user's GitHub links"""
|
||||
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
|
|
@ -12,7 +12,7 @@ __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
|
|||
async def get_all_links(session: AsyncSession) -> list[Link]:
|
||||
"""Get a list of all links"""
|
||||
statement = select(Link)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def add_link(session: AsyncSession, name: str, url: str) -> Link:
|
||||
|
|
|
@ -23,7 +23,7 @@ async def add_meme(session: AsyncSession, name: str, template_id: int, field_cou
|
|||
async def get_all_memes(session: AsyncSession) -> list[MemeTemplate]:
|
||||
"""Get a list of all memes"""
|
||||
statement = select(MemeTemplate)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def get_meme_by_name(session: AsyncSession, query: str) -> Optional[MemeTemplate]:
|
||||
|
|
|
@ -13,7 +13,7 @@ __all__ = ["get_all_reminders_for_category", "toggle_reminder"]
|
|||
async def get_all_reminders_for_category(session: AsyncSession, category: ReminderCategory) -> list[Reminder]:
|
||||
"""Get a list of all Reminders for a given category"""
|
||||
statement = select(Reminder).where(Reminder.category == category)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def toggle_reminder(session: AsyncSession, user_id: int, category: ReminderCategory) -> bool:
|
||||
|
|
|
@ -11,7 +11,7 @@ __all__ = ["create_new_announcement", "get_courses_with_announcements", "remove_
|
|||
async def get_courses_with_announcements(session: AsyncSession) -> list[UforaCourse]:
|
||||
"""Get all courses where announcements are enabled"""
|
||||
statement = select(UforaCourse).where(UforaCourse.log_announcements)
|
||||
return (await session.execute(statement)).scalars().all()
|
||||
return list((await session.execute(statement)).scalars().all())
|
||||
|
||||
|
||||
async def create_new_announcement(
|
||||
|
|
|
@ -28,11 +28,11 @@ async def get_course_by_name(session: AsyncSession, query: str) -> Optional[Ufor
|
|||
# Search case-insensitively
|
||||
query = query.lower()
|
||||
|
||||
statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
|
||||
result = (await session.execute(statement)).scalars().first()
|
||||
if result:
|
||||
return result
|
||||
course_statement = select(UforaCourse).where(UforaCourse.name.ilike(f"%{query}%"))
|
||||
course_result = (await session.execute(course_statement)).scalars().first()
|
||||
if course_result:
|
||||
return course_result
|
||||
|
||||
statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
|
||||
result = (await session.execute(statement)).scalars().first()
|
||||
return result.course if result else None
|
||||
alias_statement = select(UforaCourseAlias).where(UforaCourseAlias.alias.ilike(f"%{query}%"))
|
||||
alias_result = (await session.execute(alias_statement)).scalars().first()
|
||||
return alias_result.course if alias_result else None
|
||||
|
|
|
@ -28,6 +28,7 @@ omit = [
|
|||
profile = "black"
|
||||
|
||||
[tool.mypy]
|
||||
check_untyped_defs = true
|
||||
files = [
|
||||
"database/**/*.py",
|
||||
"didier/**/*.py",
|
||||
|
@ -35,7 +36,6 @@ files = [
|
|||
]
|
||||
plugins = [
|
||||
"pydantic.mypy",
|
||||
"sqlalchemy.ext.mypy.plugin"
|
||||
]
|
||||
[[tool.mypy.overrides]]
|
||||
module = ["discord.*", "feedparser.*", "ics.*", "markdownify.*"]
|
||||
|
|
|
@ -10,4 +10,4 @@ markdownify==0.11.6
|
|||
overrides==7.3.1
|
||||
pydantic==2.0.2
|
||||
python-dateutil==2.8.2
|
||||
sqlalchemy[asyncio]==2.0.18
|
||||
sqlalchemy[asyncio,postgresql_asyncpg]==2.0.18
|
||||
|
|
Loading…
Reference in New Issue