mirror of https://github.com/stijndcl/didier
Compare commits
5 Commits
fd57b5a79b
...
add9399944
| Author | SHA1 | Date |
|---|---|---|
|
|
add9399944 | |
|
|
57e805e31c | |
|
|
fc195e40b3 | |
|
|
d8192cfa0a | |
|
|
efdc966611 |
|
|
@ -34,8 +34,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
|
|||
raise NoResultFoundException
|
||||
|
||||
# Check if the alias exists (either as an alias or as a name)
|
||||
alias_instance = await get_command(session, alias)
|
||||
if alias_instance is not None:
|
||||
if await get_command(session, alias) is not None:
|
||||
raise DuplicateInsertException
|
||||
|
||||
alias_instance = CustomCommandAlias(alias=alias, indexed_alias=clean_name(alias), command=command_instance)
|
||||
|
|
@ -47,7 +46,7 @@ async def create_alias(session: AsyncSession, command: str, alias: str) -> Custo
|
|||
|
||||
async def get_command(session: AsyncSession, message: str) -> Optional[CustomCommand]:
|
||||
"""Try to get a command out of a message"""
|
||||
# Search lowercase & without spaces, and strip the prefix
|
||||
# Search lowercase & without spaces
|
||||
message = clean_name(message)
|
||||
return (await get_command_by_name(session, message)) or (await get_command_by_alias(session, message))
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,14 @@
|
|||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from database.crud import custom_commands
|
||||
from database.exceptions.constraints import DuplicateInsertException
|
||||
from database.exceptions.not_found import NoResultFoundException
|
||||
from didier import Didier
|
||||
from didier.data.modals.custom_commands import CreateCustomCommand
|
||||
|
||||
|
||||
class Owner(commands.Cog):
|
||||
|
|
@ -11,16 +16,76 @@ class Owner(commands.Cog):
|
|||
|
||||
client: Didier
|
||||
|
||||
# Slash groups
|
||||
add_slash = app_commands.Group(name="add", description="Add something new to the database")
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
async def cog_check(self, ctx: commands.Context) -> bool:
|
||||
"""Global check for every command in this cog, so we don't have to add
|
||||
is_owner() to every single command separately
|
||||
"""
|
||||
# pylint: disable=W0236 # Pylint thinks this can't be async, but it can
|
||||
return await self.client.is_owner(ctx.author)
|
||||
|
||||
@commands.command(name="Sync")
|
||||
@commands.is_owner()
|
||||
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
|
||||
"""Sync all application-commands in Discord"""
|
||||
await self.client.tree.sync(guild=guild)
|
||||
if guild is not None:
|
||||
self.client.tree.copy_global_to(guild=guild)
|
||||
await self.client.tree.sync(guild=guild)
|
||||
else:
|
||||
self.client.tree.clear_commands(guild=None)
|
||||
await self.client.tree.sync()
|
||||
|
||||
await ctx.message.add_reaction("🔄")
|
||||
|
||||
@commands.group(name="Add", case_insensitive=True, invoke_without_command=False)
|
||||
async def add_msg(self, ctx: commands.Context):
|
||||
"""Command group for [add X] message commands"""
|
||||
|
||||
@add_msg.command(name="Custom")
|
||||
async def add_custom(self, ctx: commands.Context, name: str, *, response: str):
|
||||
"""Add a new custom command"""
|
||||
async with self.client.db_session as session:
|
||||
try:
|
||||
await custom_commands.create_command(session, name, response)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
except DuplicateInsertException:
|
||||
await ctx.reply("Er bestaat al een commando met deze naam.")
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@add_msg.command(name="Alias")
|
||||
async def add_alias(self, ctx: commands.Context, command: str, alias: str):
|
||||
"""Add a new alias for a custom command"""
|
||||
async with self.client.db_session as session:
|
||||
try:
|
||||
await custom_commands.create_alias(session, command, alias)
|
||||
await self.client.confirm_message(ctx.message)
|
||||
except NoResultFoundException:
|
||||
await ctx.reply(f'Geen commando gevonden voor "{command}".')
|
||||
await self.client.reject_message(ctx.message)
|
||||
except DuplicateInsertException:
|
||||
await ctx.reply("Er bestaat al een commando met deze naam.")
|
||||
await self.client.reject_message(ctx.message)
|
||||
|
||||
@add_slash.command(name="custom", description="Add a custom command")
|
||||
async def add_custom_slash(self, interaction: discord.Interaction):
|
||||
"""Slash command to add a custom command"""
|
||||
if not self.client.is_owner(interaction.user):
|
||||
return interaction.response.send_message(
|
||||
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
||||
)
|
||||
|
||||
await interaction.response.defer(ephemeral=True)
|
||||
modal = CreateCustomCommand()
|
||||
await interaction.response.send_message(modal)
|
||||
|
||||
@commands.group(name="Edit")
|
||||
async def edit(self, ctx: commands.Context):
|
||||
"""Command group for [edit X] commands"""
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
|
|
|
|||
|
|
@ -0,0 +1,20 @@
|
|||
import traceback
|
||||
|
||||
import discord
|
||||
|
||||
|
||||
class CreateCustomCommand(discord.ui.Modal, title="Custom Command"):
|
||||
"""Modal shown to visually create custom commands"""
|
||||
|
||||
name: discord.ui.TextInput = discord.ui.TextInput(label="Name", placeholder="Name of the command...")
|
||||
|
||||
response: discord.ui.TextInput = discord.ui.TextInput(
|
||||
label="Response", style=discord.TextStyle.long, placeholder="Response of the command...", max_length=2000
|
||||
)
|
||||
|
||||
async def on_submit(self, interaction: discord.Interaction) -> None:
|
||||
await interaction.response.send_message("Submitted", ephemeral=True)
|
||||
|
||||
async def on_error(self, interaction: discord.Interaction, error: Exception) -> None: # type: ignore
|
||||
await interaction.response.send_message("Errored", ephemeral=True)
|
||||
traceback.print_tb(error.__traceback__)
|
||||
|
|
@ -3,11 +3,11 @@ import sys
|
|||
import traceback
|
||||
|
||||
import discord
|
||||
from discord import Message
|
||||
from discord.ext import commands
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
import settings
|
||||
from database.crud import custom_commands
|
||||
from database.engine import DBSession
|
||||
from didier.utils.discord.prefix import get_prefix
|
||||
|
||||
|
|
@ -34,24 +34,17 @@ class Didier(commands.Bot):
|
|||
command_prefix=get_prefix, case_insensitive=True, intents=intents, activity=activity, status=status
|
||||
)
|
||||
|
||||
@property
|
||||
def db_session(self) -> AsyncSession:
|
||||
"""Obtain a database session"""
|
||||
return DBSession()
|
||||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Hook called once the bot is initialised"""
|
||||
# Load extensions
|
||||
await self._load_initial_extensions()
|
||||
await self._load_directory_extensions("didier/cogs")
|
||||
|
||||
# Sync application commands to the test guild
|
||||
for guild in settings.DISCORD_TEST_GUILDS:
|
||||
guild_object = discord.Object(id=guild)
|
||||
|
||||
self.tree.copy_global_to(guild=guild_object)
|
||||
await self.tree.sync(guild=guild_object)
|
||||
|
||||
@property
|
||||
def db_session(self) -> AsyncSession:
|
||||
"""Obtain a database session"""
|
||||
return DBSession()
|
||||
|
||||
async def _load_initial_extensions(self):
|
||||
"""Load all extensions that should be loaded before the others"""
|
||||
for extension in self.initial_extensions:
|
||||
|
|
@ -85,11 +78,19 @@ class Didier(commands.Bot):
|
|||
channel = self.get_channel(reference.channel_id)
|
||||
return await channel.fetch_message(reference.message_id)
|
||||
|
||||
async def confirm_message(self, message: discord.Message):
|
||||
"""Add a checkmark to a message"""
|
||||
await message.add_reaction("✅")
|
||||
|
||||
async def reject_message(self, message: discord.Message):
|
||||
"""Add an X to a message"""
|
||||
await message.add_reaction("❌")
|
||||
|
||||
async def on_ready(self):
|
||||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
||||
async def on_message(self, message: Message, /) -> None:
|
||||
async def on_message(self, message: discord.Message, /) -> None:
|
||||
"""Event triggered when a message is sent"""
|
||||
# Ignore messages by bots
|
||||
if message.author.bot:
|
||||
|
|
@ -100,18 +101,32 @@ class Didier(commands.Bot):
|
|||
await message.add_reaction(settings.DISCORD_BOOS_REACT)
|
||||
|
||||
# Potential custom command
|
||||
if self._try_invoke_custom_command(message):
|
||||
if await self._try_invoke_custom_command(message):
|
||||
return
|
||||
|
||||
await self.process_commands(message)
|
||||
|
||||
async def _try_invoke_custom_command(self, message: Message) -> bool:
|
||||
async def _try_invoke_custom_command(self, message: discord.Message) -> bool:
|
||||
"""Check if the message tries to invoke a custom command
|
||||
If it does, send the reply associated with it
|
||||
"""
|
||||
# Doesn't start with the custom command prefix
|
||||
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
|
||||
return False
|
||||
|
||||
async with self.db_session as session:
|
||||
# Remove the prefix
|
||||
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
|
||||
command = await custom_commands.get_command(session, content)
|
||||
|
||||
# Command found
|
||||
if command is not None:
|
||||
await message.reply(command.response, mention_author=False)
|
||||
return True
|
||||
|
||||
# Nothing found
|
||||
return False
|
||||
|
||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||
"""Event triggered when a regular command errors"""
|
||||
# If developing, print everything to stdout so you don't have to
|
||||
|
|
|
|||
|
|
@ -1,4 +1,5 @@
|
|||
from typing import AsyncGenerator
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, Generator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
|
@ -11,7 +12,14 @@ from didier import Didier
|
|||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tables():
|
||||
def event_loop() -> Generator:
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def tables(event_loop):
|
||||
"""Initialize a database before the tests, and then tear it down again
|
||||
Starts from an empty database and runs through all the migrations to check those as well
|
||||
while we're at it
|
||||
|
|
@ -23,7 +31,7 @@ def tables():
|
|||
|
||||
|
||||
@pytest.fixture
|
||||
async def database_session(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||
async def database_session(tables, event_loop) -> AsyncGenerator[AsyncSession, None]:
|
||||
"""Fixture to create a session for every test
|
||||
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue