mirror of https://github.com/stijndcl/didier
Compare commits
3 Commits
304ad850b7
...
a1449a4c9c
| Author | SHA1 | Date |
|---|---|---|
|
|
a1449a4c9c | |
|
|
3d1aabf77c | |
|
|
de0d543bf8 |
|
|
@ -42,7 +42,7 @@ jobs:
|
|||
- name: Run Pytest
|
||||
run: pytest tests
|
||||
linting:
|
||||
needs: [tests]
|
||||
needs: [dependencies]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
|
@ -62,7 +62,7 @@ jobs:
|
|||
- name: Linting
|
||||
run: pylint didier database
|
||||
typing:
|
||||
needs: [tests]
|
||||
needs: [dependencies]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
|
@ -82,7 +82,7 @@ jobs:
|
|||
- name: Typing
|
||||
run: mypy didier database
|
||||
formatting:
|
||||
needs: [tests]
|
||||
needs: [dependencies]
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- uses: actions/checkout@v2
|
||||
|
|
|
|||
|
|
@ -0,0 +1,63 @@
|
|||
from typing import Optional
|
||||
|
||||
import discord
|
||||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from didier import Didier
|
||||
|
||||
|
||||
class School(commands.Cog):
|
||||
"""School-related commands"""
|
||||
|
||||
client: Didier
|
||||
|
||||
# Context-menu references
|
||||
_pin_ctx_menu: app_commands.ContextMenu
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
self._pin_ctx_menu = app_commands.ContextMenu(name="Pin", callback=self.pin_ctx)
|
||||
self.client.tree.add_command(self._pin_ctx_menu)
|
||||
|
||||
async def cog_unload(self) -> None:
|
||||
"""Remove the commands when the cog is unloaded"""
|
||||
self.client.tree.remove_command(self._pin_ctx_menu.name, type=self._pin_ctx_menu.type)
|
||||
|
||||
@commands.command(name="Pin", usage="[Message]")
|
||||
async def pin(self, ctx: commands.Context, message: Optional[discord.Message] = None):
|
||||
"""Pin a message in the current channel"""
|
||||
# If no message was passed, allow replying to the message that should be pinned
|
||||
if message is None and ctx.message.reference is not None:
|
||||
message = await self.client.resolve_message(ctx.message.reference)
|
||||
|
||||
# Didn't fix it, sad
|
||||
if message is None:
|
||||
return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10)
|
||||
|
||||
if message.is_system():
|
||||
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
|
||||
|
||||
await message.pin(reason=f"Didier Pin door {ctx.author.display_name}")
|
||||
await message.add_reaction("📌")
|
||||
|
||||
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
|
||||
"""Pin a message in the current channel"""
|
||||
# Is already pinned
|
||||
if message.pinned:
|
||||
return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True)
|
||||
|
||||
if message.is_system():
|
||||
return await interaction.response.send_message(
|
||||
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
|
||||
)
|
||||
|
||||
await message.pin(reason=f"Didier Pin door {interaction.user.display_name}")
|
||||
await message.add_reaction("📌")
|
||||
return await interaction.response.send_message("📌", ephemeral=True)
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
"""Load the cog"""
|
||||
await client.add_cog(School(client))
|
||||
|
|
@ -1,14 +0,0 @@
|
|||
from discord.ext import commands
|
||||
|
||||
from didier import Didier
|
||||
|
||||
|
||||
class TestCog(commands.Cog):
|
||||
client: Didier
|
||||
|
||||
def __init__(self, client: Didier):
|
||||
self.client = client
|
||||
|
||||
|
||||
async def setup(client: Didier):
|
||||
await client.add_cog(TestCog(client))
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
|
@ -12,7 +14,7 @@ from didier.utils.prefix import get_prefix
|
|||
class Didier(commands.Bot):
|
||||
"""DIDIER <3"""
|
||||
|
||||
initial_extensions: tuple[str] = ()
|
||||
initial_extensions: tuple[str, ...] = ()
|
||||
|
||||
def __init__(self):
|
||||
activity = discord.Activity(type=discord.ActivityType.playing, name=settings.DISCORD_STATUS_MESSAGE)
|
||||
|
|
@ -33,29 +35,63 @@ class Didier(commands.Bot):
|
|||
|
||||
async def setup_hook(self) -> None:
|
||||
"""Hook called once the bot is initialised"""
|
||||
await self._load_initial_cogs()
|
||||
await self._load_directory_cogs("didier/cogs")
|
||||
# Load extensions
|
||||
await self._load_initial_extensions()
|
||||
await self._load_directory_extensions("didier/cogs")
|
||||
|
||||
# Sync application commands to the test guild
|
||||
for guild in settings.DISCORD_TEST_GUILDS:
|
||||
guild_object = discord.Object(id=guild)
|
||||
|
||||
self.tree.copy_global_to(guild=guild_object)
|
||||
await self.tree.sync(guild=guild_object)
|
||||
|
||||
@property
|
||||
def db_session(self) -> AsyncSession:
|
||||
"""Obtain a database session"""
|
||||
return DBSession()
|
||||
|
||||
async def _load_initial_extensions(self):
|
||||
"""Load all extensions that should be loaded before the others"""
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(f"didier.{extension}")
|
||||
|
||||
async def _load_directory_extensions(self, path: str):
|
||||
"""Load all extensions in a given directory"""
|
||||
load_path = path.removeprefix("./").replace("/", ".")
|
||||
parent_path = load_path.removeprefix("didier.")
|
||||
|
||||
# Check every file in the directory
|
||||
for file in os.listdir(path):
|
||||
# Construct a path that includes all parent packages in order to
|
||||
# Allow checking against initial extensions more easily
|
||||
full_name = parent_path + file
|
||||
|
||||
# Only take Python files, and ignore the ones starting with an underscore (like __init__ and __pycache__)
|
||||
# Also ignore the files that we have already loaded previously
|
||||
if file.endswith(".py") and not file.startswith("_") and not full_name.startswith(self.initial_extensions):
|
||||
await self.load_extension(f"{load_path}.{file[:-3]}")
|
||||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||
await self._load_directory_extensions(new_path)
|
||||
|
||||
async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
|
||||
"""Fetch a message from a reference"""
|
||||
# Message is in the cache, return it
|
||||
if reference.cached_message is not None:
|
||||
return reference.cached_message
|
||||
|
||||
# For older messages: fetch them from the API
|
||||
channel = self.get_channel(reference.channel_id)
|
||||
return await channel.fetch_message(reference.message_id)
|
||||
|
||||
async def on_ready(self):
|
||||
"""Event triggered when the bot is ready"""
|
||||
print(settings.DISCORD_READY_MESSAGE)
|
||||
|
||||
async def _load_initial_cogs(self):
|
||||
"""Load all cogs"""
|
||||
for extension in self.initial_extensions:
|
||||
await self.load_extension(f"didier.cogs.{extension}")
|
||||
|
||||
async def _load_directory_cogs(self, path: str):
|
||||
"""Load all cogs in a given directory"""
|
||||
load_path = path.removeprefix("./").replace("/", ".")
|
||||
|
||||
for file in os.listdir(path):
|
||||
if file.endswith(".py") and not file.startswith("_") and not file.startswith(self.initial_extensions):
|
||||
await self.load_extension(f"{load_path}.{file[:-3]}")
|
||||
elif os.path.isdir(new_path := f"{path}/{file}"):
|
||||
await self._load_directory_cogs(new_path)
|
||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
||||
"""Event triggered when a regular command errors"""
|
||||
# If developing, print everything to stdout so you don't have to
|
||||
# check the logs all the time
|
||||
if settings.SANDBOX:
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
return
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ def get_prefix(client: commands.Bot, message: Message) -> str:
|
|||
This is done dynamically to allow variable amounts of whitespace,
|
||||
and through regexes to allow case-insensitivity among other things.
|
||||
"""
|
||||
mention = f"<@!{client.user.id}>"
|
||||
mention = f"<@!?{client.user.id}>"
|
||||
regex = r"^({})\s*"
|
||||
|
||||
# Check which prefix was used
|
||||
|
|
|
|||
|
|
@ -9,7 +9,8 @@ plugins = [
|
|||
[tool.pylint.master]
|
||||
disable = [
|
||||
"missing-module-docstring",
|
||||
"too-few-public-methods"
|
||||
"too-few-public-methods",
|
||||
"too-many-arguments"
|
||||
]
|
||||
|
||||
[tool.pylint.format]
|
||||
|
|
|
|||
Loading…
Reference in New Issue