Fix bug in prefix, pin command, pin context menu

pull/115/head
stijndcl 2022-06-16 00:29:38 +02:00
parent 304ad850b7
commit de0d543bf8
4 changed files with 132 additions and 31 deletions

View File

@ -0,0 +1,63 @@
from typing import Optional
import discord
from discord import app_commands
from discord.ext import commands
from didier import Didier
class School(commands.Cog):
"""School-related commands"""
client: Didier
# Context-menu references
_pin_ctx_menu: app_commands.ContextMenu
def __init__(self, client: Didier):
self.client = client
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx)
self.client.tree.add_command(self._pin_ctx_menu)
async def cog_unload(self) -> None:
"""Remove the commands when the cog is unloaded"""
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
@commands.command(name="Pin", usage="[Message]")
async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None):
"""Pin a message in the current channel"""
# If no message was passed, allow replying to the message that should be pinned
if message is None and ctx.message.reference is not None:
message = await self.client.resolve_message(ctx.message.reference)
# Didn't fix it, sad
if message is None:
return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10)
if message.is_system():
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
await message.pin(reason=f"Didier Pin door {ctx.author.display_name}")
await message.add_reaction("📌")
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
"""Pin a message in the current channel"""
# Is already pinned
if message.pinned:
return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True)
if message.is_system():
return await interaction.response.send_message(
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
)
await message.pin(reason=f"Didier Pin door {interaction.user.display_name}")
await message.add_reaction("📌")
return await interaction.response.send_message("📌", ephemeral=True)
async def setup(client: Didier):
"""Load the cog"""
await client.add_cog(School(client))

View File

@ -1,14 +0,0 @@
from discord.ext import commands
from didier import Didier
class TestCog(commands.Cog):
client: Didier
def __init__(self, client: Didier):
self.client = client
async def setup(client: Didier):
await client.add_cog(TestCog(client))

View File

@ -1,4 +1,7 @@
import os
import sys
import traceback
from typing import Union, Optional
import discord
from discord.ext import commands
@ -33,29 +36,78 @@ class Didier(commands.Bot):
async def setup_hook(self) -> None:
"""Hook called once the bot is initialised"""
await self._load_initial_cogs()
await self._load_directory_cogs("didier/cogs")
# Load extensions
await self._load_initial_extensions()
await self._load_directory_extensions("didier/cogs")
# Sync application commands to the test guild
for guild in settings.DISCORD_TEST_GUILDS:
guild_object = discord.Object(id=guild)
self.tree.copy_global_to(guild=guild_object)
await self.tree.sync(guild=guild_object)
@property
def db_session(self) -> AsyncSession:
"""Obtain a database session"""
return DBSession()
async def _load_initial_extensions(self):
"""Load all extensions that should be loaded before the others"""
for extension in self.initial_extensions:
await self.load_extension(f"didier.{extension}")
async def _load_directory_extensions(self, path: str):
"""Load all extensions in a given directory"""
load_path = path.removeprefix("./").replace("/", ".")
parent_path = load_path.removeprefix("didier.")
# Check every file in the directory
for file in os.listdir(path):
# Construct a path that includes all parent packages in order to
# Allow checking against initial extensions more easily
full_name = parent_path + file
# Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__)
# Also ignore the files that we have already loaded previously
if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions):
await self.load_extension(f"{load_path}.{file[:-3]}")
elif os.path.isdir(new_path := f"{path}/{file}"):
await self._load_directory_extensions(new_path)
async def respond(
self,
context: Union[commands.Context, discord.Interaction],
message: str,
mention_author: bool = False,
ephemeral: bool = True,
embeds: Optional[list[discord.Embed]] = None,
):
"""Function to respond to both a normal message and an interaction"""
if isinstance(context, commands.Context):
return await context.reply(message, mention_author=mention_author, embeds=embeds)
if isinstance(context, discord.Interaction):
return await context.response.send_message(message, ephemeral=ephemeral, embeds=embeds)
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
"""Fetch a message from a reference"""
# Message is in the cache, return it
if reference.cached_message is not None:
return reference.cached_message
# For older messages: fetch them from the API
channel = self.get_channel(reference.channel_id)
return await channel.fetch_message(reference.message_id)
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def _load_initial_cogs(self):
"""Load all cogs"""
for extension in self.initial_extensions:
await self.load_extension(f"didier.cogs.{extension}")
async def _load_directory_cogs(self, path: str):
"""Load all cogs in a given directory"""
load_path = path.removeprefix("./").replace("/", ".")
for file in os.listdir(path):
if file.endswith(".py") and not file.startswith("_") and not file.startswith(self.initial_extensions):
await self.load_extension(f"{load_path}.{file[:-3]}")
elif os.path.isdir(new_path := f"{path}/{file}"):
await self._load_directory_cogs(new_path)
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
"""Event triggered when a regular command errors"""
# If developing, print everything to stdout so you don't have to
# check the logs all the time
if settings.SANDBOX:
print(traceback.format_exc(), file=sys.stderr)
return

View File

@ -11,7 +11,7 @@ def get_prefix(client: commands.Bot, message: Message) -> str:
This is done dynamically to allow variable amounts of whitespace,
and through regexes to allow case-insensitivity among other things.
"""
mention = f"<@!{client.user.id}>"
mention = f"<@!?{client.user.id}>"
regex = r"^({})\s*"
# Check which prefix was used