mirror of https://github.com/stijndcl/didier
Compare commits
5 Commits
da0e60ac4f
...
b54aed24e8
| Author | SHA1 | Date |
|---|---|---|
|
|
b54aed24e8 | |
|
|
2e3b4823d0 | |
|
|
424399b88a | |
|
|
edc6343e12 | |
|
|
0834a4ccbc |
2
.flake8
2
.flake8
|
|
@ -23,6 +23,8 @@ extend-ignore =
|
||||||
D401,
|
D401,
|
||||||
# Whitespace before ":"
|
# Whitespace before ":"
|
||||||
E203,
|
E203,
|
||||||
|
# Standard pseudo-random generators are not suitable for security/cryptographic purposes.
|
||||||
|
S311,
|
||||||
# Don't require docstrings when overriding a method,
|
# Don't require docstrings when overriding a method,
|
||||||
# the base method should have a docstring but the rest not
|
# the base method should have a docstring but the rest not
|
||||||
ignore-decorators=overrides
|
ignore-decorators=overrides
|
||||||
|
|
|
||||||
|
|
@ -42,4 +42,4 @@ async def get_birthdays_on_day(session: AsyncSession, day: datetime.date) -> lis
|
||||||
months = extract("month", Birthday.birthday)
|
months = extract("month", Birthday.birthday)
|
||||||
|
|
||||||
statement = select(Birthday).where((days == day.day) & (months == day.month))
|
statement = select(Birthday).where((days == day.day) & (months == day.month))
|
||||||
return list((await session.execute(statement)).scalars())
|
return list((await session.execute(statement)).scalars().all())
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,12 @@ class Discord(commands.Cog):
|
||||||
await birthdays.add_birthday(session, ctx.author.id, date)
|
await birthdays.add_birthday(session, ctx.author.id, date)
|
||||||
await self.client.confirm_message(ctx.message)
|
await self.client.confirm_message(ctx.message)
|
||||||
|
|
||||||
|
@commands.command(name="Join", usage="[Thread]")
|
||||||
|
async def join(self, ctx: commands.Context, thread: discord.Thread):
|
||||||
|
"""Make Didier join a thread"""
|
||||||
|
if thread.me is not None:
|
||||||
|
return await ctx.reply()
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
"""Load the cog"""
|
"""Load the cog"""
|
||||||
|
|
|
||||||
|
|
@ -15,11 +15,14 @@ class Other(commands.Cog):
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
||||||
@commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Woord]")
|
@commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Term]")
|
||||||
async def define(self, ctx: commands.Context, *, query: str):
|
async def define(self, ctx: commands.Context, *, query: str):
|
||||||
"""Look up the definition of a word on the Urban Dictionary"""
|
"""Look up the definition of a word on the Urban Dictionary"""
|
||||||
async with ctx.typing():
|
async with ctx.typing():
|
||||||
definitions = await urban_dictionary.lookup(self.client.http_session, query)
|
status_code, definitions = await urban_dictionary.lookup(self.client.http_session, query)
|
||||||
|
if not definitions:
|
||||||
|
return await ctx.reply(f"Something went wrong (status {status_code})")
|
||||||
|
|
||||||
await ctx.reply(embed=definitions[0].to_embed(), mention_author=False)
|
await ctx.reply(embed=definitions[0].to_embed(), mention_author=False)
|
||||||
|
|
||||||
@commands.hybrid_command(name="google", description="Google search", usage="[Query]")
|
@commands.hybrid_command(name="google", description="Google search", usage="[Query]")
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,15 @@
|
||||||
from typing import Optional
|
from typing import Literal, Optional
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
|
import settings
|
||||||
from database.crud import custom_commands
|
from database.crud import custom_commands
|
||||||
from database.exceptions.constraints import DuplicateInsertException
|
from database.exceptions.constraints import DuplicateInsertException
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.utils.discord.flags.owner import EditCustomFlags
|
from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags
|
||||||
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand
|
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -18,8 +19,18 @@ class Owner(commands.Cog):
|
||||||
client: Didier
|
client: Didier
|
||||||
|
|
||||||
# Slash groups
|
# Slash groups
|
||||||
add_slash = app_commands.Group(name="add", description="Add something new to the database")
|
add_slash = app_commands.Group(
|
||||||
edit_slash = app_commands.Group(name="edit", description="Edit an existing database entry")
|
name="add",
|
||||||
|
description="Add something new to the database",
|
||||||
|
guild_ids=settings.DISCORD_OWNER_GUILDS,
|
||||||
|
guild_only=True,
|
||||||
|
)
|
||||||
|
edit_slash = app_commands.Group(
|
||||||
|
name="edit",
|
||||||
|
description="Edit an existing database entry",
|
||||||
|
guild_ids=settings.DISCORD_OWNER_GUILDS,
|
||||||
|
guild_only=True,
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self, client: Didier):
|
def __init__(self, client: Didier):
|
||||||
self.client = client
|
self.client = client
|
||||||
|
|
@ -31,16 +42,34 @@ class Owner(commands.Cog):
|
||||||
"""
|
"""
|
||||||
return await self.client.is_owner(ctx.author)
|
return await self.client.is_owner(ctx.author)
|
||||||
|
|
||||||
@commands.command(name="Error")
|
@commands.command(name="Error", aliases=["Raise"])
|
||||||
async def _error(self, ctx: commands.Context):
|
async def _error(self, ctx: commands.Context, *, message: str = "Debug"):
|
||||||
"""Raise an exception for debugging purposes"""
|
"""Raise an exception for debugging purposes"""
|
||||||
raise Exception("Debug")
|
raise Exception(message)
|
||||||
|
|
||||||
@commands.command(name="Sync")
|
@commands.command(name="Sync")
|
||||||
async def sync(self, ctx: commands.Context, guild: Optional[discord.Guild] = None):
|
async def sync(
|
||||||
|
self,
|
||||||
|
ctx: commands.Context,
|
||||||
|
guild: Optional[discord.Guild] = None,
|
||||||
|
symbol: Optional[Literal["."]] = None,
|
||||||
|
*,
|
||||||
|
flags: SyncOptionFlags,
|
||||||
|
):
|
||||||
"""Sync all application-commands in Discord"""
|
"""Sync all application-commands in Discord"""
|
||||||
|
# Allow using "." to specify the current guild
|
||||||
|
# When passing flags, and no guild was specified, default to the current guild as well
|
||||||
|
# because these don't work on global syncs
|
||||||
|
if guild is None and (symbol == "." or flags.clear or flags.copy_globals):
|
||||||
|
guild = ctx.guild
|
||||||
|
|
||||||
if guild is not None:
|
if guild is not None:
|
||||||
self.client.tree.copy_global_to(guild=guild)
|
if flags.clear:
|
||||||
|
self.client.tree.clear_commands(guild=guild)
|
||||||
|
|
||||||
|
if flags.copy_globals:
|
||||||
|
self.client.tree.copy_global_to(guild=guild)
|
||||||
|
|
||||||
await self.client.tree.sync(guild=guild)
|
await self.client.tree.sync(guild=guild)
|
||||||
else:
|
else:
|
||||||
await self.client.tree.sync()
|
await self.client.tree.sync()
|
||||||
|
|
@ -52,37 +81,35 @@ class Owner(commands.Cog):
|
||||||
"""Command group for [add X] message commands"""
|
"""Command group for [add X] message commands"""
|
||||||
|
|
||||||
@add_msg.command(name="Custom")
|
@add_msg.command(name="Custom")
|
||||||
async def add_custom(self, ctx: commands.Context, name: str, *, response: str):
|
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
|
||||||
"""Add a new custom command"""
|
"""Add a new custom command"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.db_session as session:
|
||||||
try:
|
try:
|
||||||
await custom_commands.create_command(session, name, response)
|
await custom_commands.create_command(session, name, response)
|
||||||
await self.client.confirm_message(ctx.message)
|
await self.client.confirm_message(ctx.message)
|
||||||
except DuplicateInsertException:
|
except DuplicateInsertException:
|
||||||
await ctx.reply("Er bestaat al een commando met deze naam.")
|
await ctx.reply("There is already a command with this name.")
|
||||||
await self.client.reject_message(ctx.message)
|
await self.client.reject_message(ctx.message)
|
||||||
|
|
||||||
@add_msg.command(name="Alias")
|
@add_msg.command(name="Alias")
|
||||||
async def add_alias(self, ctx: commands.Context, command: str, alias: str):
|
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
|
||||||
"""Add a new alias for a custom command"""
|
"""Add a new alias for a custom command"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.db_session as session:
|
||||||
try:
|
try:
|
||||||
await custom_commands.create_alias(session, command, alias)
|
await custom_commands.create_alias(session, command, alias)
|
||||||
await self.client.confirm_message(ctx.message)
|
await self.client.confirm_message(ctx.message)
|
||||||
except NoResultFoundException:
|
except NoResultFoundException:
|
||||||
await ctx.reply(f'Geen commando gevonden voor "{command}".')
|
await ctx.reply(f"No command found matching `{command}`.")
|
||||||
await self.client.reject_message(ctx.message)
|
await self.client.reject_message(ctx.message)
|
||||||
except DuplicateInsertException:
|
except DuplicateInsertException:
|
||||||
await ctx.reply("Er bestaat al een commando met deze naam.")
|
await ctx.reply("There is already a command with this name.")
|
||||||
await self.client.reject_message(ctx.message)
|
await self.client.reject_message(ctx.message)
|
||||||
|
|
||||||
@add_slash.command(name="custom", description="Add a custom command")
|
@add_slash.command(name="custom", description="Add a custom command")
|
||||||
async def add_custom_slash(self, interaction: discord.Interaction):
|
async def add_custom_slash(self, interaction: discord.Interaction):
|
||||||
"""Slash command to add a custom command"""
|
"""Slash command to add a custom command"""
|
||||||
if not await self.client.is_owner(interaction.user):
|
if not await self.client.is_owner(interaction.user):
|
||||||
return interaction.response.send_message(
|
return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True)
|
||||||
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
modal = CreateCustomCommand(self.client)
|
modal = CreateCustomCommand(self.client)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
@ -91,9 +118,7 @@ class Owner(commands.Cog):
|
||||||
async def add_dad_joke_slash(self, interaction: discord.Interaction):
|
async def add_dad_joke_slash(self, interaction: discord.Interaction):
|
||||||
"""Slash command to add a dad joke"""
|
"""Slash command to add a dad joke"""
|
||||||
if not await self.client.is_owner(interaction.user):
|
if not await self.client.is_owner(interaction.user):
|
||||||
return interaction.response.send_message(
|
return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True)
|
||||||
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
|
||||||
)
|
|
||||||
|
|
||||||
modal = AddDadJoke(self.client)
|
modal = AddDadJoke(self.client)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
@ -110,7 +135,7 @@ class Owner(commands.Cog):
|
||||||
await custom_commands.edit_command(session, command, flags.name, flags.response)
|
await custom_commands.edit_command(session, command, flags.name, flags.response)
|
||||||
return await self.client.confirm_message(ctx.message)
|
return await self.client.confirm_message(ctx.message)
|
||||||
except NoResultFoundException:
|
except NoResultFoundException:
|
||||||
await ctx.reply(f"Geen commando gevonden voor ``{command}``.")
|
await ctx.reply(f"No command found matching ``{command}``.")
|
||||||
return await self.client.reject_message(ctx.message)
|
return await self.client.reject_message(ctx.message)
|
||||||
|
|
||||||
@edit_slash.command(name="custom", description="Edit a custom command")
|
@edit_slash.command(name="custom", description="Edit a custom command")
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from discord.ext import commands
|
||||||
|
|
||||||
from database.crud import ufora_courses
|
from database.crud import ufora_courses
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data import constants
|
from didier.utils.discord.flags.school import StudyGuideFlags
|
||||||
|
|
||||||
|
|
||||||
class School(commands.Cog):
|
class School(commands.Cog):
|
||||||
|
|
@ -36,43 +36,46 @@ class School(commands.Cog):
|
||||||
|
|
||||||
# Didn't fix it, sad
|
# Didn't fix it, sad
|
||||||
if message is None:
|
if message is None:
|
||||||
return await ctx.reply("Er is geen bericht om te pinnen.", delete_after=10)
|
return await ctx.reply("Found no message to pin.", delete_after=10)
|
||||||
|
|
||||||
|
if message.pinned:
|
||||||
|
return await ctx.reply("This message is already pinned.", delete_after=10)
|
||||||
|
|
||||||
if message.is_system():
|
if message.is_system():
|
||||||
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
|
return await ctx.reply("Dus jij wil system messages pinnen?\nMag niet.")
|
||||||
|
|
||||||
await message.pin(reason=f"Didier Pin door {ctx.author.display_name}")
|
await message.pin(reason=f"Didier Pin by {ctx.author.display_name}")
|
||||||
await message.add_reaction("📌")
|
await message.add_reaction("📌")
|
||||||
|
|
||||||
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
|
async def pin_ctx(self, interaction: discord.Interaction, message: discord.Message):
|
||||||
"""Pin a message in the current channel"""
|
"""Pin a message in the current channel"""
|
||||||
# Is already pinned
|
# Is already pinned
|
||||||
if message.pinned:
|
if message.pinned:
|
||||||
return await interaction.response.send_message("Dit bericht staat al gepind.", ephemeral=True)
|
return await interaction.response.send_message("This message is already pinned.", ephemeral=True)
|
||||||
|
|
||||||
if message.is_system():
|
if message.is_system():
|
||||||
return await interaction.response.send_message(
|
return await interaction.response.send_message(
|
||||||
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
|
"Dus jij wil system messages pinnen?\nMag niet.", ephemeral=True
|
||||||
)
|
)
|
||||||
|
|
||||||
await message.pin(reason=f"Didier Pin door {interaction.user.display_name}")
|
await message.pin(reason=f"Didier Pin by {interaction.user.display_name}")
|
||||||
await message.add_reaction("📌")
|
await message.add_reaction("📌")
|
||||||
return await interaction.response.send_message("📌", ephemeral=True)
|
return await interaction.response.send_message("📌", ephemeral=True)
|
||||||
|
|
||||||
@commands.hybrid_command(
|
@commands.hybrid_command(
|
||||||
name="fiche", description="Stuurt de link naar de studiefiche voor [Vak]", aliases=["guide", "studiefiche"]
|
name="fiche", description="Sends the link to the study guide for [Course]", aliases=["guide", "studiefiche"]
|
||||||
)
|
)
|
||||||
@app_commands.describe(course="vak")
|
@app_commands.describe(course="vak")
|
||||||
async def study_guide(self, ctx: commands.Context, course: str):
|
async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
|
||||||
"""Create links to study guides"""
|
"""Create links to study guides"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.db_session as session:
|
||||||
ufora_course = await ufora_courses.get_course_by_name(session, course)
|
ufora_course = await ufora_courses.get_course_by_name(session, course)
|
||||||
|
|
||||||
if ufora_course is None:
|
if ufora_course is None:
|
||||||
return await ctx.reply(f"Geen vak gevonden voor ``{course}``", ephemeral=True)
|
return await ctx.reply(f"Found no course matching ``{course}``", ephemeral=True)
|
||||||
|
|
||||||
return await ctx.reply(
|
return await ctx.reply(
|
||||||
f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{constants.CURRENT_YEAR}",
|
f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{flags.year}",
|
||||||
mention_author=False,
|
mention_author=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
import datetime
|
import datetime
|
||||||
|
import random
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
|
from discord.ext import commands, tasks # type: ignore # Strange & incorrect Mypy error
|
||||||
|
|
@ -18,6 +19,10 @@ DAILY_RESET_TIME = datetime.time(hour=0, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||||
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
|
SOCIALLY_ACCEPTABLE_TIME = datetime.time(hour=7, minute=0, tzinfo=LOCAL_TIMEZONE)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO more messages?
|
||||||
|
BIRTHDAY_MESSAGES = ["Gelukkige verjaardag {mention}!", "Happy birthday {mention}!"]
|
||||||
|
|
||||||
|
|
||||||
class Tasks(commands.Cog):
|
class Tasks(commands.Cog):
|
||||||
"""Task loops that run periodically
|
"""Task loops that run periodically
|
||||||
|
|
||||||
|
|
@ -52,12 +57,12 @@ class Tasks(commands.Cog):
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
@tasks_group.command(name="Force", case_insensitive=True)
|
@tasks_group.command(name="Force", case_insensitive=True, usage="[Task]")
|
||||||
async def force_task(self, ctx: commands.Context, name: str):
|
async def force_task(self, ctx: commands.Context, name: str):
|
||||||
"""Command to force-run a task without waiting for the run time"""
|
"""Command to force-run a task without waiting for the specified run time"""
|
||||||
name = name.lower()
|
name = name.lower()
|
||||||
if name not in self._tasks:
|
if name not in self._tasks:
|
||||||
return await ctx.reply(f"Geen task gevonden voor `{name}`.", mention_author=False)
|
return await ctx.reply(f"Found no tasks matching `{name}`.", mention_author=False)
|
||||||
|
|
||||||
task = self._tasks[name]
|
task = self._tasks[name]
|
||||||
await task()
|
await task()
|
||||||
|
|
@ -76,8 +81,8 @@ class Tasks(commands.Cog):
|
||||||
|
|
||||||
for birthday in birthdays:
|
for birthday in birthdays:
|
||||||
user = self.client.get_user(birthday.user_id)
|
user = self.client.get_user(birthday.user_id)
|
||||||
# TODO more messages?
|
|
||||||
await channel.send(f"Gelukkig verjaardag {user.mention}!")
|
await channel.send(random.choice(BIRTHDAY_MESSAGES).format(mention=user.mention))
|
||||||
|
|
||||||
@check_birthdays.before_loop
|
@check_birthdays.before_loop
|
||||||
async def _before_check_birthdays(self):
|
async def _before_check_birthdays(self):
|
||||||
|
|
@ -114,6 +119,7 @@ class Tasks(commands.Cog):
|
||||||
async def _on_tasks_error(self, error: BaseException):
|
async def _on_tasks_error(self, error: BaseException):
|
||||||
"""Error handler for all tasks"""
|
"""Error handler for all tasks"""
|
||||||
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
|
print("".join(traceback.format_exception(type(error), error, error.__traceback__)))
|
||||||
|
self.client.dispatch("task_error")
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,5 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
from didier.data.embeds.urban_dictionary import Definition
|
from didier.data.embeds.urban_dictionary import Definition
|
||||||
|
|
@ -8,10 +10,13 @@ __all__ = ["lookup", "PER_PAGE"]
|
||||||
PER_PAGE = 10
|
PER_PAGE = 10
|
||||||
|
|
||||||
|
|
||||||
async def lookup(http_session: ClientSession, query: str) -> list[Definition]:
|
async def lookup(http_session: ClientSession, query: str) -> tuple[int, list[Definition]]:
|
||||||
"""Fetch the Urban Dictionary definitions for a given word"""
|
"""Fetch the Urban Dictionary definitions for a given word"""
|
||||||
url = "https://api.urbandictionary.com/v0/define"
|
url = "https://api.urbandictionary.com/v0/define"
|
||||||
|
|
||||||
async with http_session.get(url, params={"term": query}) as response:
|
async with http_session.get(url, params={"term": query}) as response:
|
||||||
|
if response.status != HTTPStatus.OK:
|
||||||
|
return response.status, []
|
||||||
|
|
||||||
response_json = await response.json()
|
response_json = await response.json()
|
||||||
return list(map(Definition.parse_obj, response_json["list"]))
|
return 200, list(map(Definition.parse_obj, response_json["list"]))
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,47 @@
|
||||||
|
import traceback
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
|
from didier.utils.discord.constants import Limits
|
||||||
|
from didier.utils.types.string import abbreviate
|
||||||
|
|
||||||
|
__all__ = ["create_error_embed"]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_traceback(exception: Exception) -> str:
|
||||||
|
"""Get a proper representation of the exception"""
|
||||||
|
tb = traceback.format_exception(type(exception), exception, exception.__traceback__)
|
||||||
|
error_string = ""
|
||||||
|
for line in tb:
|
||||||
|
# Don't add endless tracebacks
|
||||||
|
if line.strip().startswith("The above exception was the direct cause of"):
|
||||||
|
break
|
||||||
|
|
||||||
|
# Escape Discord markdown formatting
|
||||||
|
error_string += line.replace(r"*", r"\*").replace(r"_", r"\_")
|
||||||
|
if line.strip():
|
||||||
|
error_string += "\n"
|
||||||
|
|
||||||
|
return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH)
|
||||||
|
|
||||||
|
|
||||||
|
def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed:
|
||||||
|
"""Create an embed for the traceback of an exception"""
|
||||||
|
description = _get_traceback(exception)
|
||||||
|
|
||||||
|
if ctx.guild is None:
|
||||||
|
origin = "DM"
|
||||||
|
else:
|
||||||
|
origin = f"{ctx.channel.mention} ({ctx.guild.name})"
|
||||||
|
|
||||||
|
invocation = f"{ctx.author.display_name} in {origin}"
|
||||||
|
|
||||||
|
embed = discord.Embed(colour=discord.Colour.red())
|
||||||
|
embed.set_author(name="Error")
|
||||||
|
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
|
||||||
|
embed.add_field(name="Context", value=invocation, inline=True)
|
||||||
|
embed.add_field(name="Exception", value=abbreviate(str(exception), Limits.EMBED_FIELD_VALUE_LENGTH), inline=False)
|
||||||
|
embed.add_field(name="Traceback", value=description, inline=False)
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
@ -24,11 +24,11 @@ class GoogleSearch(EmbedBaseModel):
|
||||||
|
|
||||||
# Empty embed
|
# Empty embed
|
||||||
if not self.data.results:
|
if not self.data.results:
|
||||||
embed.description = "Geen resultaten gevonden"
|
embed.description = "Found no results"
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
# Error embed
|
# Error embed
|
||||||
embed.description = f"Status {self.data.status_code}"
|
embed.description = f"Something went wrong (status {self.data.status_code})"
|
||||||
|
|
||||||
return embed
|
return embed
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,10 +50,10 @@ class Definition(EmbedPydantic):
|
||||||
embed = discord.Embed(colour=colours.urban_dictionary_green())
|
embed = discord.Embed(colour=colours.urban_dictionary_green())
|
||||||
embed.set_author(name="Urban Dictionary")
|
embed.set_author(name="Urban Dictionary")
|
||||||
|
|
||||||
embed.add_field(name="Woord", value=self.word, inline=True)
|
embed.add_field(name="Term", value=self.word, inline=True)
|
||||||
embed.add_field(name="Auteur", value=self.author, inline=True)
|
embed.add_field(name="Author", value=self.author, inline=True)
|
||||||
embed.add_field(name="Definitie", value=self.definition, inline=False)
|
embed.add_field(name="Definition", value=self.definition, inline=False)
|
||||||
embed.add_field(name="Voorbeeld", value=self.example or "\u200B", inline=False)
|
embed.add_field(name="Example", value=self.example or "\u200B", inline=False)
|
||||||
embed.add_field(
|
embed.add_field(
|
||||||
name="Rating", value=f"{self.ratio}% ({self.thumbs_up}/{self.thumbs_up + self.thumbs_down})", inline=True
|
name="Rating", value=f"{self.ratio}% ({self.thumbs_up}/{self.thumbs_up + self.thumbs_down})", inline=True
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@ import settings
|
||||||
from database.crud import custom_commands
|
from database.crud import custom_commands
|
||||||
from database.engine import DBSession
|
from database.engine import DBSession
|
||||||
from database.utils.caches import CacheManager
|
from database.utils.caches import CacheManager
|
||||||
|
from didier.data.embeds.error_embed import create_error_embed
|
||||||
from didier.utils.discord.prefix import get_prefix
|
from didier.utils.discord.prefix import get_prefix
|
||||||
|
|
||||||
__all__ = ["Didier"]
|
__all__ = ["Didier"]
|
||||||
|
|
@ -139,6 +140,9 @@ class Didier(commands.Bot):
|
||||||
|
|
||||||
await self.process_commands(message)
|
await self.process_commands(message)
|
||||||
|
|
||||||
|
# TODO easter eggs
|
||||||
|
# TODO stats
|
||||||
|
|
||||||
async def _try_invoke_custom_command(self, message: discord.Message) -> bool:
|
async def _try_invoke_custom_command(self, message: discord.Message) -> bool:
|
||||||
"""Check if the message tries to invoke a custom command
|
"""Check if the message tries to invoke a custom command
|
||||||
|
|
||||||
|
|
@ -162,11 +166,50 @@ class Didier(commands.Bot):
|
||||||
# Nothing found
|
# Nothing found
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def on_command_error(self, context: commands.Context, exception: commands.CommandError, /) -> None:
|
async def on_thread_create(self, thread: discord.Thread):
|
||||||
"""Event triggered when a regular command errors"""
|
"""Event triggered when a new thread is created"""
|
||||||
# Print everything to the logs/stderr
|
await thread.join()
|
||||||
await super().on_command_error(context, exception)
|
|
||||||
|
|
||||||
# If developing, do nothing special
|
async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /):
|
||||||
|
"""Event triggered when a regular command errors"""
|
||||||
|
# If working locally, print everything to your console
|
||||||
if settings.SANDBOX:
|
if settings.SANDBOX:
|
||||||
|
await super().on_command_error(ctx, exception)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
# If commands have their own error handler, let it handle the error instead
|
||||||
|
if hasattr(ctx.command, "on_error"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Ignore exceptions that aren't important
|
||||||
|
if isinstance(
|
||||||
|
exception,
|
||||||
|
(
|
||||||
|
commands.CommandNotFound,
|
||||||
|
commands.CheckFailure,
|
||||||
|
commands.TooManyArguments,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return
|
||||||
|
|
||||||
|
# Print everything that we care about to the logs/stderr
|
||||||
|
await super().on_command_error(ctx, exception)
|
||||||
|
|
||||||
|
if isinstance(exception, commands.MessageNotFound):
|
||||||
|
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
|
||||||
|
|
||||||
|
if isinstance(
|
||||||
|
exception,
|
||||||
|
(
|
||||||
|
commands.BadArgument,
|
||||||
|
commands.MissingRequiredArgument,
|
||||||
|
commands.UnexpectedQuoteError,
|
||||||
|
commands.ExpectedClosingQuoteError,
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
|
||||||
|
|
||||||
|
if settings.ERRORS_CHANNEL is not None:
|
||||||
|
embed = create_error_embed(ctx, exception)
|
||||||
|
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||||
|
await channel.send(embed=embed)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,10 @@
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
from discord.ext import commands
|
||||||
|
|
||||||
from didier.utils.discord.flags import PosixFlags
|
from didier.utils.discord.flags import PosixFlags
|
||||||
|
|
||||||
__all__ = ["EditCustomFlags"]
|
__all__ = ["EditCustomFlags", "SyncOptionFlags"]
|
||||||
|
|
||||||
|
|
||||||
class EditCustomFlags(PosixFlags):
|
class EditCustomFlags(PosixFlags):
|
||||||
|
|
@ -10,3 +12,10 @@ class EditCustomFlags(PosixFlags):
|
||||||
|
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
response: Optional[str] = None
|
response: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
class SyncOptionFlags(PosixFlags):
|
||||||
|
"""Flags for the sync command"""
|
||||||
|
|
||||||
|
clear: bool = False
|
||||||
|
copy_globals: bool = commands.flag(aliases=["copy_global", "copy"], default=False)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from didier.data import constants
|
||||||
|
from didier.utils.discord.flags import PosixFlags
|
||||||
|
|
||||||
|
__all__ = ["StudyGuideFlags"]
|
||||||
|
|
||||||
|
|
||||||
|
class StudyGuideFlags(PosixFlags):
|
||||||
|
"""Flags for the study guide command"""
|
||||||
|
|
||||||
|
year: Optional[int] = constants.CURRENT_YEAR
|
||||||
|
|
@ -18,7 +18,8 @@ omit = [
|
||||||
"./didier/data/embeds/*",
|
"./didier/data/embeds/*",
|
||||||
"./didier/data/flags/*",
|
"./didier/data/flags/*",
|
||||||
"./didier/utils/discord/colours.py",
|
"./didier/utils/discord/colours.py",
|
||||||
"./didier/utils/discord/constants.py"
|
"./didier/utils/discord/constants.py",
|
||||||
|
"./didier/utils/discord/flags/*",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.isort]
|
[tool.isort]
|
||||||
|
|
|
||||||
|
|
@ -61,3 +61,5 @@ black
|
||||||
flake8
|
flake8
|
||||||
mypy
|
mypy
|
||||||
```
|
```
|
||||||
|
|
||||||
|
It's also convenient to have code-formatting happen automatically on-save. The [`Black documentation`](https://black.readthedocs.io/en/stable/integrations/editors.html) explains how to set this up for different types of editors.
|
||||||
|
|
|
||||||
|
|
@ -44,6 +44,7 @@ DISCORD_TOKEN: str = env.str("DISCORD_TOKEN")
|
||||||
DISCORD_READY_MESSAGE: str = env.str("DISCORD_READY_MESSAGE", "I'M READY I'M READY I'M READY")
|
DISCORD_READY_MESSAGE: str = env.str("DISCORD_READY_MESSAGE", "I'M READY I'M READY I'M READY")
|
||||||
DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didier Dinks.")
|
DISCORD_STATUS_MESSAGE: str = env.str("DISCORD_STATUS_MESSAGE", "with your Didier Dinks.")
|
||||||
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
|
DISCORD_TEST_GUILDS: list[int] = env.list("DISCORD_TEST_GUILDS", [], subcast=int)
|
||||||
|
DISCORD_OWNER_GUILDS: Optional[list[int]] = env.list("DISCORD_OWNER_GUILDS", [], subcast=int) or None
|
||||||
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
|
DISCORD_BOOS_REACT: str = env.str("DISCORD_BOOS_REACT", "<:boos:629603785840263179>")
|
||||||
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
|
DISCORD_CUSTOM_COMMAND_PREFIX: str = env.str("DISCORD_CUSTOM_COMMAND_PREFIX", "?")
|
||||||
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
|
BIRTHDAY_ANNOUNCEMENT_CHANNEL: Optional[int] = env.int("BIRTHDAY_ANNOUNCEMENT_CHANNEL", None)
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,63 @@
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from freezegun import freeze_time
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import tasks as crud
|
||||||
|
from database.enums import TaskType
|
||||||
|
from database.models import Task
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def task_type() -> TaskType:
|
||||||
|
"""Fixture to use the same TaskType in every test"""
|
||||||
|
return TaskType.BIRTHDAYS
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
async def task(database_session: AsyncSession, task_type: TaskType) -> Task:
|
||||||
|
"""Fixture to create a task"""
|
||||||
|
task = Task(task=task_type)
|
||||||
|
database_session.add(task)
|
||||||
|
await database_session.commit()
|
||||||
|
return task
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_task_by_enum_present(database_session: AsyncSession, task: Task, task_type: TaskType):
|
||||||
|
"""Test getting a task by its enum type when it exists"""
|
||||||
|
result = await crud.get_task_by_enum(database_session, task_type)
|
||||||
|
assert result is not None
|
||||||
|
assert result == task
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_task_by_enum_not_present(database_session: AsyncSession, task_type: TaskType):
|
||||||
|
"""Test getting a task by its enum type when it doesn't exist"""
|
||||||
|
result = await crud.get_task_by_enum(database_session, task_type)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2022/07/24")
|
||||||
|
async def test_set_execution_time_exists(database_session: AsyncSession, task: Task, task_type: TaskType):
|
||||||
|
"""Test setting the execution time of an existing task"""
|
||||||
|
await database_session.refresh(task)
|
||||||
|
assert task.previous_run is None
|
||||||
|
|
||||||
|
await crud.set_last_task_execution_time(database_session, task_type)
|
||||||
|
await database_session.refresh(task)
|
||||||
|
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)
|
||||||
|
|
||||||
|
|
||||||
|
@freeze_time("2022/07/24")
|
||||||
|
async def test_set_execution_time_doesnt_exist(database_session: AsyncSession, task_type: TaskType):
|
||||||
|
"""Test setting the execution time of a non-existing task"""
|
||||||
|
statement = select(Task).where(Task.task == task_type)
|
||||||
|
results = list((await database_session.execute(statement)).scalars().all())
|
||||||
|
assert len(results) == 0
|
||||||
|
|
||||||
|
await crud.set_last_task_execution_time(database_session, task_type)
|
||||||
|
results = list((await database_session.execute(statement)).scalars().all())
|
||||||
|
assert len(results) == 1
|
||||||
|
task = results[0]
|
||||||
|
assert task.previous_run == datetime.datetime(year=2022, month=7, day=24)
|
||||||
Loading…
Reference in New Issue