mirror of https://github.com/stijndcl/didier
Fix bug in prefix, pin command, pin context menu
parent
304ad850b7
commit
de0d543bf8
|
@ -0,0 +1,63 @@
|
|||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from didier import Didier
|
||||
|
||||
|
||||
class School(commands.Cog):
|
||||
"""School-related commands"""
|
||||
|
||||
client: Didier
|
||||
|
||||
# Context-menu references
|
||||
_pin_ctx_menu: app_commands.ContextMenu
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx)
|
||||
self.client.tree.add_command(self._pin_ctx_menu)
|
||||
|
||||
async def cog_unload(self) -> None:
|
||||
"""Remove the commands when the cog is unloaded"""
|
||||
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
||||
|
||||
@commands.command(name="Pin", usage="[Message]")
|
||||
async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None):
|
||||
"""Pin a message in the current channel"""
|
||||
# If no message was passed, allow replying to the message that should be pinned
|
||||
if message is None and ctx.message.reference is not None:
|
||||
message = await self.client.resolve_message(ctx.message.reference)
|
||||
|
||||
# Didn't fix it, sad
|
||||
if message is None:
|
||||
return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10)
|
||||
|
||||
if message.is_system():
|
||||
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
|
||||
|
||||
await message.pin(reason=f"Didier Pin door {ctx.author.display_name}")
|
||||
await message.add_reaction("📌")
|
||||
|
||||
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
|
||||
"""Pin a message in the current channel"""
|
||||
# Is already pinned
|
||||
if message.pinned:
|
||||
return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True)
|
||||
|
||||
if message.is_system():
|
||||
return await interaction.response.send_message(
|
||||
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
|
||||
)
|
||||
|
||||
await message.pin(reason=f"Didier Pin door {interaction.user.display_name}")
|
||||
await message.add_reaction("📌")
|
||||
return await interaction.response.send_message("📌", ephemeral=True)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(School(client))
|
|
@ -1,14 +0,0 @@
|
|||
from discord.ext import commands
|
||||
|
||||
from didier import Didier
|
||||
|
||||
|
||||
class TestCog(commands.Cog):
|
||||
client: Didier
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
await client.add_cog(TestCog(client))
|
|
@ -1,4 +1,7 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from typing import Union, Optional
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
@ -33,29 +36,78 @@ class Didier(commands.Bot):
|
|||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Hook called once the bot is initialised"""
|
||||
await self._load_initial_cogs()
|
||||
await self._load_directory_cogs("didier/cogs")
|
||||
# Load extensions
|
||||
await self._load_initial_extensions()
|
||||
await self._load_directory_extensions("didier/cogs")
|
||||
|
||||
# Sync application commands to the test guild
|
||||
for guild in settings.DISCORD_TEST_GUILDS:
|
||||
guild_object = discord.Object(id=guild)
|
||||
|
||||
self.tree.copy_global_to(guild=guild_object)
|
||||
await self.tree.sync(guild=guild_object)
|
||||
|
||||
@property
|
||||
def db_session(self) -> AsyncSession:
|
||||
"""Obtain a database session"""
|
||||
return DBSession()
|
||||
|
||||
async def _load_initial_extensions(self):
|
||||
"""Load all extensions that should be loaded before the others"""
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(f"didier.{extension}")
|
||||
|
||||
async def _load_directory_extensions(self, path: str):
|
||||
"""Load all extensions in a given directory"""
|
||||
load_path = path.removeprefix("./").replace("/", ".")
|
||||
parent_path = load_path.removeprefix("didier.")
|
||||
|
||||
# Check every file in the directory
|
||||
for file in os.listdir(path):
|
||||
# Construct a path that includes all parent packages in order to
|
||||
# Allow checking against initial extensions more easily
|
||||
full_name = parent_path + file
|
||||
|
||||
# Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__)
|
||||
# Also ignore the files that we have already loaded previously
|
||||
if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions):
|
||||
await self.load_extension(f"{load_path}.{file[:-3]}")
|
||||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||
await self._load_directory_extensions(new_path)
|
||||
|
||||
async def respond(
|
||||
self,
|
||||
context: Union[commands.Context, discord.Interaction],
|
||||
message: str,
|
||||
mention_author: bool = False,
|
||||
ephemeral: bool = True,
|
||||
embeds: Optional[list[discord.Embed]] = None,
|
||||
):
|
||||
"""Function to respond to both a normal message and an interaction"""
|
||||
if isinstance(context, commands.Context):
|
||||
return await context.reply(message, mention_author=mention_author, embeds=embeds)
|
||||
|
||||
if isinstance(context, discord.Interaction):
|
||||
return await context.response.send_message(message, ephemeral=ephemeral, embeds=embeds)
|
||||
|
||||
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
||||
"""Fetch a message from a reference"""
|
||||
# Message is in the cache, return it
|
||||
if reference.cached_message is not None:
|
||||
return reference.cached_message
|
||||
|
||||
# For older messages: fetch them from the API
|
||||
channel = self.get_channel(reference.channel_id)
|
||||
return await channel.fetch_message(reference.message_id)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
||||
async def _load_initial_cogs(self):
|
||||
"""Load all cogs"""
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(f"didier.cogs.{extension}")
|
||||
|
||||
async def _load_directory_cogs(self, path: str):
|
||||
"""Load all cogs in a given directory"""
|
||||
load_path = path.removeprefix("./").replace("/", ".")
|
||||
|
||||
for file in os.listdir(path):
|
||||
if file.endswith(".py") and not file.startswith("_") and not file.startswith(self.initial_extensions):
|
||||
await self.load_extension(f"{load_path}.{file[:-3]}")
|
||||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||
await self._load_directory_cogs(new_path)
|
||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||
"""Event triggered when a regular command errors"""
|
||||
# If developing, print everything to stdout so you don't have to
|
||||
# check the logs all the time
|
||||
if settings.SANDBOX:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return
|
||||
|
|
|
@ -11,7 +11,7 @@ def get_prefix(client: commands.Bot, message: Message) -> str:
|
|||
This is done dynamically to allow variable amounts of whitespace,
|
||||
and through regexes to allow case-insensitivity among other things.
|
||||
"""
|
||||
mention = f"<@!{client.user.id}>"
|
||||
mention = f"<@!?{client.user.id}>"
|
||||
regex = r"^({})\s*"
|
||||
|
||||
# Check which prefix was used
|
||||
|
|
Loading…
Reference in New Issue