mirror of https://github.com/stijndcl/didier
Fix bug in prefix, pin command, pin context menu
parent
304ad850b7
commit
de0d543bf8
|
@ -0,0 +1,63 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord import app_commands
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from didier import Didier
|
||||||
|
|
||||||
|
|
||||||
|
class School(commands.Cog):
|
||||||
|
"""School-related commands"""
|
||||||
|
|
||||||
|
client: Didier
|
||||||
|
|
||||||
|
# Context-menu references
|
||||||
|
_pin_ctx_menu: app_commands.ContextMenu
|
||||||
|
|
||||||
|
def __init__(self, client: Didier):
|
||||||
|
self.client = client
|
||||||
|
|
||||||
|
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx)
|
||||||
|
self.client.tree.add_command(self._pin_ctx_menu)
|
||||||
|
|
||||||
|
async def cog_unload(self) -> None:
|
||||||
|
"""Remove the commands when the cog is unloaded"""
|
||||||
|
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
||||||
|
|
||||||
|
@commands.command(name="Pin", usage="[Message]")
|
||||||
|
async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None):
|
||||||
|
"""Pin a message in the current channel"""
|
||||||
|
# If no message was passed, allow replying to the message that should be pinned
|
||||||
|
if message is None and ctx.message.reference is not None:
|
||||||
|
message = await self.client.resolve_message(ctx.message.reference)
|
||||||
|
|
||||||
|
# Didn't fix it, sad
|
||||||
|
if message is None:
|
||||||
|
return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10)
|
||||||
|
|
||||||
|
if message.is_system():
|
||||||
|
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
|
||||||
|
|
||||||
|
await message.pin(reason=f"Didier Pin door {ctx.author.display_name}")
|
||||||
|
await message.add_reaction("📌")
|
||||||
|
|
||||||
|
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
|
||||||
|
"""Pin a message in the current channel"""
|
||||||
|
# Is already pinned
|
||||||
|
if message.pinned:
|
||||||
|
return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True)
|
||||||
|
|
||||||
|
if message.is_system():
|
||||||
|
return await interaction.response.send_message(
|
||||||
|
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
|
||||||
|
)
|
||||||
|
|
||||||
|
await message.pin(reason=f"Didier Pin door {interaction.user.display_name}")
|
||||||
|
await message.add_reaction("📌")
|
||||||
|
return await interaction.response.send_message("📌", ephemeral=True)
|
||||||
|
|
||||||
|
|
||||||
|
async def setup(client: Didier):
|
||||||
|
"""Load the cog"""
|
||||||
|
await client.add_cog(School(client))
|
|
@ -1,14 +0,0 @@
|
||||||
from discord.ext import commands
|
|
||||||
|
|
||||||
from didier import Didier
|
|
||||||
|
|
||||||
|
|
||||||
class TestCog(commands.Cog):
|
|
||||||
client: Didier
|
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
|
||||||
self.client = client
|
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
|
||||||
await client.add_cog(TestCog(client))
|
|
|
@ -1,4 +1,7 @@
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from typing import Union, Optional
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
@ -33,29 +36,78 @@ class Didier(commands.Bot):
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Hook called once the bot is initialised"""
|
"""Hook called once the bot is initialised"""
|
||||||
await self._load_initial_cogs()
|
# Load extensions
|
||||||
await self._load_directory_cogs("didier/cogs")
|
await self._load_initial_extensions()
|
||||||
|
await self._load_directory_extensions("didier/cogs")
|
||||||
|
|
||||||
|
# Sync application commands to the test guild
|
||||||
|
for guild in settings.DISCORD_TEST_GUILDS:
|
||||||
|
guild_object = discord.Object(id=guild)
|
||||||
|
|
||||||
|
self.tree.copy_global_to(guild=guild_object)
|
||||||
|
await self.tree.sync(guild=guild_object)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_session(self) -> AsyncSession:
|
def db_session(self) -> AsyncSession:
|
||||||
"""Obtain a database session"""
|
"""Obtain a database session"""
|
||||||
return DBSession()
|
return DBSession()
|
||||||
|
|
||||||
|
async def _load_initial_extensions(self):
|
||||||
|
"""Load all extensions that should be loaded before the others"""
|
||||||
|
for extension in self.initial_extensions:
|
||||||
|
await self.load_extension(f"didier.{extension}")
|
||||||
|
|
||||||
|
async def _load_directory_extensions(self, path: str):
|
||||||
|
"""Load all extensions in a given directory"""
|
||||||
|
load_path = path.removeprefix("./").replace("/", ".")
|
||||||
|
parent_path = load_path.removeprefix("didier.")
|
||||||
|
|
||||||
|
# Check every file in the directory
|
||||||
|
for file in os.listdir(path):
|
||||||
|
# Construct a path that includes all parent packages in order to
|
||||||
|
# Allow checking against initial extensions more easily
|
||||||
|
full_name = parent_path + file
|
||||||
|
|
||||||
|
# Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__)
|
||||||
|
# Also ignore the files that we have already loaded previously
|
||||||
|
if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions):
|
||||||
|
await self.load_extension(f"{load_path}.{file[:-3]}")
|
||||||
|
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||||
|
await self._load_directory_extensions(new_path)
|
||||||
|
|
||||||
|
async def respond(
|
||||||
|
self,
|
||||||
|
context: Union[commands.Context, discord.Interaction],
|
||||||
|
message: str,
|
||||||
|
mention_author: bool = False,
|
||||||
|
ephemeral: bool = True,
|
||||||
|
embeds: Optional[list[discord.Embed]] = None,
|
||||||
|
):
|
||||||
|
"""Function to respond to both a normal message and an interaction"""
|
||||||
|
if isinstance(context, commands.Context):
|
||||||
|
return await context.reply(message, mention_author=mention_author, embeds=embeds)
|
||||||
|
|
||||||
|
if isinstance(context, discord.Interaction):
|
||||||
|
return await context.response.send_message(message, ephemeral=ephemeral, embeds=embeds)
|
||||||
|
|
||||||
|
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
||||||
|
"""Fetch a message from a reference"""
|
||||||
|
# Message is in the cache, return it
|
||||||
|
if reference.cached_message is not None:
|
||||||
|
return reference.cached_message
|
||||||
|
|
||||||
|
# For older messages: fetch them from the API
|
||||||
|
channel = self.get_channel(reference.channel_id)
|
||||||
|
return await channel.fetch_message(reference.message_id)
|
||||||
|
|
||||||
async def on_ready(self):
|
async def on_ready(self):
|
||||||
"""Event triggered when the bot is ready"""
|
"""Event triggered when the bot is ready"""
|
||||||
print(settings.DISCORD_READY_MESSAGE)
|
print(settings.DISCORD_READY_MESSAGE)
|
||||||
|
|
||||||
async def _load_initial_cogs(self):
|
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||||
"""Load all cogs"""
|
"""Event triggered when a regular command errors"""
|
||||||
for extension in self.initial_extensions:
|
# If developing, print everything to stdout so you don't have to
|
||||||
await self.load_extension(f"didier.cogs.{extension}")
|
# check the logs all the time
|
||||||
|
if settings.SANDBOX:
|
||||||
async def _load_directory_cogs(self, path: str):
|
print(traceback.format_exc(), file=sys.stderr)
|
||||||
"""Load all cogs in a given directory"""
|
return
|
||||||
load_path = path.removeprefix("./").replace("/", ".")
|
|
||||||
|
|
||||||
for file in os.listdir(path):
|
|
||||||
if file.endswith(".py") and not file.startswith("_") and not file.startswith(self.initial_extensions):
|
|
||||||
await self.load_extension(f"{load_path}.{file[:-3]}")
|
|
||||||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
|
||||||
await self._load_directory_cogs(new_path)
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ def get_prefix(client: commands.Bot, message: Message) -> str:
|
||||||
This is done dynamically to allow variable amounts of whitespace,
|
This is done dynamically to allow variable amounts of whitespace,
|
||||||
and through regexes to allow case-insensitivity among other things.
|
and through regexes to allow case-insensitivity among other things.
|
||||||
"""
|
"""
|
||||||
mention = f"<@!{client.user.id}>"
|
mention = f"<@!?{client.user.id}>"
|
||||||
regex = r"^({})\s*"
|
regex = r"^({})\s*"
|
||||||
|
|
||||||
# Check which prefix was used
|
# Check which prefix was used
|
||||||
|
|
Loading…
Reference in New Issue