diff --git a/frank/module/decorators/classes.py b/frank/module/decorators/classes.py index c35ce09..9d10a82 100644 --- a/frank/module/decorators/classes.py +++ b/frank/module/decorators/classes.py @@ -5,13 +5,14 @@ from __future__ import annotations # Built-in imports import re import asyncio +import shlex # Typing imports from typing import TYPE_CHECKING if TYPE_CHECKING: # Built-in imports - from typing import Union + from typing import Union, List, Tuple class Simple: @@ -53,6 +54,14 @@ class Simple: return self + @property + def client(self): + """ + Returns the Frank client instance. + """ + + return self._obj._client + class SimpleCommand(Simple): """ @@ -72,15 +81,48 @@ class SimpleCommand(Simple): self.cmd = cmd self.help_str = help_str - def match(self, message: str) -> bool: + def match(self, message: str) -> Tuple[bool, List[str]]: """ - Returns wether the command matches the given message. + Returns wether the command matches the given message. If the arguments + can't be parsed (e.g. unmatched quotes), it will return False as well. Args: message: message to check """ - return self.cmd == message.split(" ")[0] + return self._match_full(message.split(" ")) + + def _match_full(self, parts: List[str]) -> Tuple[bool, List[str]]: + """ + Returns wether the message matches the full command. + + Args: + parts: parts of the message + """ + + # Can't match without 3 or more parts + if len(parts) < 3: + return False, None + + # Return False if it doesn't match + if not all( + ( + parts[0] == self.client.PREFIX, + parts[1] in self._obj.PREFIX, + parts[2] == self.cmd, + ) + ): + return False, None + + # Parse the output, and return True with the parsed items if it works, + # otherwise return False + try: + parsed = shlex.split(" ".join(parts[3:])) + + return True, parsed + + except ValueError: + return False, None class RegularCommand(SimpleCommand): @@ -99,14 +141,43 @@ class RegularCommand(SimpleCommand): super().__init__(self, func, cmd, help_str) self.alias = alias + + # This only matters for aliases self.requires_prefix = requires_prefix - def match(self, message: str) -> bool: - # This just makes it a bit easier to use in the function - module = self._obj - client = module._client + # TODO: make this return the right value + def match(self, message: str) -> Tuple[bool, List[str]]: + """ + Returns wether the message matches the current command. + """ - words = [word for word in message.split(" ") if word] + parts = message.split(" ") + + # If the alias doesn't match, return the full match, otherwise return + # alias + matches, parts = self._match_alias(parts) + + if matches: + return matches, parts + + return self._match_full(parts) + + # TODO: make this return the right value + def _match_alias(self, parts: List[str]) -> Tuple[bool, List[str]]: + """ + Returns wether the message matches an alias. + """ + + # Return False if there's only one part but a prefix is required + if self.requires_prefix and len(parts) == 1: + return False + + # Match with prefix + if self.requires_prefix: + return parts[0] == self.client.PREFIX and parts[1] in self.alias + + # Match without prefix + return parts[0] in self.alias class RegexCommand(SimpleCommand): @@ -115,7 +186,7 @@ class RegexCommand(SimpleCommand): prefix. """ - def match(self, prefix: str) -> bool: + def match(self, message: str) -> Tuple[str, List[str]]: """ Returns wether the regex pattern matches the given prefix. @@ -124,7 +195,20 @@ class RegexCommand(SimpleCommand): prefix """ - return bool(re.fullmatch(self.cmd, prefix)) + parts = message.split(" ") + matches = bool(re.fullmatch(self.cmd, parts[0])) + + # If it doesn't match, just return False, don't parse the rest + if not matches: + return False, None + + try: + parsed = shlex.split(" ".join(parts)) + + return True, parsed + + except ValueError: + return False, None class Daemon(Simple): diff --git a/frank/module/meta.py b/frank/module/meta.py index 4446092..c43f056 100644 --- a/frank/module/meta.py +++ b/frank/module/meta.py @@ -49,5 +49,6 @@ class ModuleMeta: @cached_property def default(self) -> Default: return next( - iter(self._filter_attrs(lambda val: isinstance(val, Default))), None + iter(self._filter_attrs(lambda val: isinstance(val, Default))), + None, ) diff --git a/frank/module/module.py b/frank/module/module.py index 03eb053..f666c25 100644 --- a/frank/module/module.py +++ b/frank/module/module.py @@ -118,7 +118,9 @@ class Module(ModuleMeta): ) else: - await func(cmd=cmd[1:], author=author, channel=channel, mid=mid) + await func( + cmd=cmd[1:], author=author, channel=channel, mid=mid + ) elif self.default: await self.default(author=author, channel=channel, mid=mid) diff --git a/pyproject.toml b/pyproject.toml index f96e88f..a7e0353 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [tool.black] -line-length = 80 +line-length = 79 target-version = ['py38'] include = '\.pyi?$' exclude = '''