didier/database/crud/custom_commands.py

106 lines
3.4 KiB
Python
Raw Normal View History

2022-06-21 23:58:21 +02:00
from typing import Optional
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions.constraints import DuplicateInsertException
from database.exceptions.not_found import NoResultFoundException
2022-08-29 20:24:42 +02:00
from database.schemas import CustomCommand, CustomCommandAlias
2022-06-21 23:58:21 +02:00
2022-07-11 22:23:38 +02:00
__all__ = [
"clean_name",
"create_alias",
"create_command",
"edit_command",
"get_all_commands",
2022-07-11 22:23:38 +02:00
"get_command",
"get_command_by_alias",
"get_command_by_name",
]
2022-06-21 23:58:21 +02:00
def clean_name(name: str) -> str:
"""Convert a name to lowercase & remove spaces to allow easier matching"""
return name.lower().replace(" ", "")
async def create_command(session: AsyncSession, name: str, response: str) -> CustomCommand:
"""Create a new custom command"""
# Check if command or alias already exists
command = await get_command(session, name)
if command is not None:
raise DuplicateInsertException
command = CustomCommand(name=name, indexed_name=clean_name(name), response=response)
session.add(command)
await session.commit()
2022-07-16 00:14:02 +02:00
2022-06-21 23:58:21 +02:00
return command
async def create_alias(session: AsyncSession, command: str, alias: str) -> CustomCommandAlias:
"""Create an alias for a command"""
# Check if the command exists
command_instance = await get_command(session, command)
if command_instance is None:
raise NoResultFoundException
# Check if the alias exists (either as an alias or as a name)
2022-06-22 00:22:26 +02:00
if await get_command(session, alias) is not None:
2022-06-21 23:58:21 +02:00
raise DuplicateInsertException
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
session.add(alias_instance)
await session.commit()
return alias_instance
async def get_all_commands(session: AsyncSession) -> list[CustomCommand]:
"""Get a list of all commands"""
statement = select(CustomCommand)
2023-07-08 01:23:47 +02:00
return list((await session.execute(statement)).scalars().all())
2022-06-21 23:58:21 +02:00
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command out of a message"""
2022-06-22 00:22:26 +02:00
# Search lowercase & without spaces
2022-06-21 23:58:21 +02:00
message = clean_name(message)
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
async def get_command_by_name(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its name"""
statement = select(CustomCommand).where(CustomCommand.indexed_name == message)
return (await session.execute(statement)).scalar_one_or_none()
async def get_command_by_alias(session: AsyncSession, message: str) -> Optional[CustomCommand]:
"""Try to get a command by its alias"""
statement = select(CustomCommandAlias).where(CustomCommandAlias.indexed_alias == message)
alias = (await session.execute(statement)).scalar_one_or_none()
if alias is None:
return None
return alias.command
async def edit_command(
session: AsyncSession, original_name: str, new_name: Optional[str] = None, new_response: Optional[str] = None
) -> CustomCommand:
"""Edit an existing command"""
# Check if the command exists
command = await get_command(session, original_name)
if command is None:
raise NoResultFoundException
if new_name is not None:
command.name = new_name
if new_response is not None:
command.response = new_response
session.add(command)
await session.commit()
return command