Compare commits

...

4 Commits

Author SHA1 Message Date
stijndcl 97e815cbff Adding & removing github links 2022-09-20 17:55:59 +02:00
stijndcl 9e3527ae8a Fix relationship, add github links, improve error messages 2022-09-20 17:34:49 +02:00
stijndcl 41c8c9d0ab Put tracebacks in a codeblock for readability & to escape markdown 2022-09-20 16:45:02 +02:00
stijndcl 23edc51dbf Command stats 2022-09-20 14:47:26 +02:00
9 changed files with 344 additions and 39 deletions

View File

@ -0,0 +1,54 @@
"""Command stats & GitHub links
Revision ID: c1f9ee875616
Revises: b84bb10fb8de
Create Date: 2022-09-20 17:18:02.289593
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "c1f9ee875616"
down_revision = "b84bb10fb8de"
branch_labels = None
depends_on = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"command_stats",
sa.Column("command_stats_id", sa.Integer(), nullable=False),
sa.Column("command", sa.Text(), nullable=False),
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("slash", sa.Boolean(), nullable=False),
sa.Column("context_menu", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("command_stats_id"),
)
op.create_table(
"github_links",
sa.Column("github_link_id", sa.Integer(), nullable=False),
sa.Column("url", sa.Text(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("github_link_id"),
sa.UniqueConstraint("url"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("github_links")
op.drop_table("command_stats")
# ### end Alembic commands ###

View File

@ -0,0 +1,41 @@
from datetime import datetime
from typing import Optional, Union
from discord import app_commands
from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import CommandStats
__all__ = ["register_command_invocation"]
CommandT = Union[commands.Command, app_commands.Command, app_commands.ContextMenu]
async def register_command_invocation(
session: AsyncSession, ctx: commands.Context, command: Optional[CommandT], timestamp: datetime
):
"""Create an entry for a command invocation"""
if command is None:
return
await get_or_add_user(session, ctx.author.id)
# Check the type of invocation
context_menu = isinstance(command, app_commands.ContextMenu)
# (This is a bit uglier but it accounts for hybrid commands)
slash = isinstance(command, app_commands.Command) or (ctx.interaction is not None and not context_menu)
stats = CommandStats(
command=command.qualified_name.lower(),
timestamp=timestamp,
user_id=ctx.author.id,
slash=slash,
context_menu=context_menu,
)
session.add(stats)
await session.commit()

View File

@ -0,0 +1,51 @@
from typing import Optional
import sqlalchemy.exc
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions import (
DuplicateInsertException,
Forbidden,
NoResultFoundException,
)
from database.schemas import GitHubLink
__all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"]
async def add_github_link(session: AsyncSession, user_id: int, url: str) -> GitHubLink:
"""Add a new GitHub link into the database"""
try:
gh_link = GitHubLink(user_id=user_id, url=url)
session.add(gh_link)
await session.commit()
await session.refresh(gh_link)
except sqlalchemy.exc.IntegrityError:
raise DuplicateInsertException
return gh_link
async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: int):
"""Remove an existing link from the database
You can only remove links owned by you
"""
select_statement = select(GitHubLink).where(GitHubLink.github_link_id == link_id)
gh_link: Optional[GitHubLink] = (await session.execute(select_statement)).scalar_one_or_none()
if gh_link is None:
raise NoResultFoundException
if gh_link.user_id != user_id:
raise Forbidden
delete_statement = delete(GitHubLink).where(GitHubLink.github_link_id == gh_link.github_link_id)
await session.execute(delete_statement)
await session.commit()
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
"""Get a user's GitHub links"""
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
return (await session.execute(statement)).scalars().all()

View File

@ -27,11 +27,13 @@ __all__ = [
"Bank", "Bank",
"Birthday", "Birthday",
"Bookmark", "Bookmark",
"CommandStats",
"CustomCommand", "CustomCommand",
"CustomCommandAlias", "CustomCommandAlias",
"DadJoke", "DadJoke",
"Deadline", "Deadline",
"EasterEgg", "EasterEgg",
"GitHubLink",
"Link", "Link",
"MemeTemplate", "MemeTemplate",
"NightlyData", "NightlyData",
@ -95,6 +97,20 @@ class Bookmark(Base):
user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin") user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin")
class CommandStats(Base):
"""Metrics on how often commands are used"""
__tablename__ = "command_stats"
command_stats_id: int = Column(Integer, primary_key=True)
command: str = Column(Text, nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
slash: bool = Column(Boolean, nullable=False)
context_menu: bool = Column(Boolean, nullable=False)
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin")
class CustomCommand(Base): class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't""" """Custom commands to fill the hole Dyno couldn't"""
@ -157,6 +173,18 @@ class EasterEgg(Base):
startswith: bool = Column(Boolean, nullable=False, server_default="1") startswith: bool = Column(Boolean, nullable=False, server_default="1")
class GitHubLink(Base):
"""A user's GitHub link"""
__tablename__ = "github_links"
github_link_id: int = Column(Integer, primary_key=True)
url: str = Column(Text, nullable=False, unique=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin")
class Link(Base): class Link(Base):
"""Useful links that go useful places""" """Useful links that go useful places"""
@ -266,6 +294,12 @@ class User(Base):
bookmarks: list[Bookmark] = relationship( bookmarks: list[Bookmark] = relationship(
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" "Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )
command_stats: list[CommandStats] = relationship(
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
github_links: list[GitHubLink] = relationship(
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship( nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )

View File

@ -4,7 +4,7 @@ import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.crud import birthdays, bookmarks from database.crud import birthdays, bookmarks, github
from database.exceptions import ( from database.exceptions import (
DuplicateInsertException, DuplicateInsertException,
Forbidden, Forbidden,
@ -15,6 +15,7 @@ from didier import Didier
from didier.exceptions import expect from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource from didier.menus.bookmarks import BookmarkSource
from didier.menus.common import Menu from didier.menus.common import Menu
from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar from didier.utils.discord.assets import get_author_avatar
from didier.utils.types.datetime import str_to_date from didier.utils.types.datetime import str_to_date
from didier.utils.types.string import leading from didier.utils.types.string import leading
@ -193,6 +194,84 @@ class Discord(commands.Cog):
modal = CreateBookmark(self.client, message.jump_url) modal = CreateBookmark(self.client, message.jump_url)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None):
"""Show a user's GitHub links.
If no user is provided, this shows your links instead.
"""
# Default to author
user = user or ctx.author
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
embed.set_author(
name=user.display_name, icon_url=user.avatar.url if user.avatar is not None else user.default_avatar.url
)
embed.set_footer(text="Links can be added using didier github add <link>.")
async with self.client.postgres_session as session:
links = await github.get_github_links(session, user.id)
if not links:
embed.description = "This user has not set any GitHub links yet."
else:
regular_links = []
ugent_links = []
for link in links:
if "github.ugent.be" in link.url.lower():
ugent_links.append(f"`#{link.github_link_id}`: {link.url}")
else:
regular_links.append(f"`#{link.github_link_id}`: {link.url}")
regular_links.sort()
ugent_links.sort()
if ugent_links:
embed.add_field(name="Ghent University", value="\n".join(ugent_links), inline=False)
if regular_links:
embed.add_field(name="Other", value="\n".join(regular_links), inline=False)
return await ctx.reply(embed=embed, mention_author=False)
@github_group.command(name="add", aliases=["create", "insert"])
async def github_add(self, ctx: commands.Context, link: str):
"""Add a new link into the database."""
# Remove wrapping <brackets> which can be used to escape Discord embeds
link = link.removeprefix("<").removesuffix(">")
async with self.client.postgres_session as session:
try:
gh_link = await github.add_github_link(session, ctx.author.id, link)
except DuplicateInsertException:
return await ctx.reply("This link has already been registered by someone.", mention_author=False)
await self.client.confirm_message(ctx.message)
return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False)
@github_group.command(name="delete", aliases=["del", "remove", "rm"])
async def github_delete(self, ctx: commands.Context, link_id: str):
"""Delete the link with it `link_id` from the database.
You can only delete your own links.
"""
try:
link_id_int = int(link_id.removeprefix("#"))
except ValueError:
return await ctx.reply(f"`{link_id}` is not a valid link id.", mention_author=False)
async with self.client.postgres_session as session:
try:
await github.delete_github_link_by_id(session, ctx.author.id, link_id_int)
except NoResultFoundException:
return await ctx.reply(f"Found no GitHub link with id `#{link_id_int}`.", mention_author=False)
except Forbidden:
return await ctx.reply(f"You don't own GitHub link `#{link_id_int}`.", mention_author=False)
return await ctx.reply(f"Successfully deleted GitHub link `#{link_id_int}`.", mention_author=False)
@commands.command(name="join") @commands.command(name="join")
async def join(self, ctx: commands.Context, thread: discord.Thread): async def join(self, ctx: commands.Context, thread: discord.Thread):
"""Make Didier join `thread`. """Make Didier join `thread`.

View File

@ -188,7 +188,7 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
embed.add_field(name="Signature", value=signature, inline=False) embed.add_field(name="Signature", value=signature, inline=False)
if command.aliases: if command.aliases:
embed.add_field(name="Aliases", value=", ".join(command.aliases), inline=False) embed.add_field(name="Aliases", value=", ".join(sorted(command.aliases)), inline=False)
def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]: def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]:
"""Try to find a cog, case-insensitively""" """Try to find a cog, case-insensitively"""

View File

@ -23,12 +23,14 @@ def _get_traceback(exception: Exception) -> str:
if line.strip(): if line.strip():
error_string += "\n" error_string += "\n"
return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH) return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8)
def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed: def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed:
"""Create an embed for the traceback of an exception""" """Create an embed for the traceback of an exception"""
description = _get_traceback(exception) # Wrap the traceback in a codeblock for readability
description = _get_traceback(exception).strip()
description = f"```\n{description}\n```"
if ctx.guild is None: if ctx.guild is None:
origin = "DM" origin = "DM"
@ -40,7 +42,7 @@ def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.E
embed = discord.Embed(title="Error", colour=discord.Colour.red()) embed = discord.Embed(title="Error", colour=discord.Colour.red())
embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True) embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True)
embed.add_field(name="Context", value=invocation, inline=True) embed.add_field(name="Context", value=invocation, inline=True)
embed.add_field(name="Exception", value=abbreviate(str(exception), Limits.EMBED_FIELD_VALUE_LENGTH), inline=False) embed.add_field(name="Exception", value=str(exception), inline=False)
embed.add_field(name="Traceback", value=description, inline=False) embed.add_field(name="Traceback", value=description, inline=False)
return embed return embed

View File

@ -1,7 +1,9 @@
import logging import logging
import os import os
import pathlib import pathlib
import re
from functools import cached_property from functools import cached_property
from typing import Union
import discord import discord
from aiohttp import ClientSession from aiohttp import ClientSession
@ -10,7 +12,7 @@ from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import custom_commands from database.crud import command_stats, custom_commands
from database.engine import DBSession from database.engine import DBSession
from database.utils.caches import CacheManager from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.error_embed import create_error_embed
@ -18,6 +20,7 @@ from didier.data.embeds.schedules import Schedule, parse_schedule
from didier.exceptions import HTTPException, NoMatch from didier.exceptions import HTTPException, NoMatch
from didier.utils.discord.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
from didier.utils.easter_eggs import detect_easter_egg from didier.utils.easter_eggs import detect_easter_egg
from didier.utils.types.datetime import tz_aware_now
__all__ = ["Didier"] __all__ = ["Didier"]
@ -194,30 +197,6 @@ class Didier(commands.Bot):
"""Log a warning message""" """Log a warning message"""
await self._log(logging.WARNING, message, log_to_discord) await self._log(logging.WARNING, message, log_to_discord)
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def on_message(self, message: discord.Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if await self._try_invoke_custom_command(message):
return
await self.process_commands(message)
easter_egg = await detect_easter_egg(self, message, self.database_caches.easter_eggs)
if easter_egg is not None:
await message.reply(easter_egg, mention_author=False)
async def _try_invoke_custom_command(self, message: discord.Message) -> bool: async def _try_invoke_custom_command(self, message: discord.Message) -> bool:
"""Check if the message tries to invoke a custom command """Check if the message tries to invoke a custom command
@ -241,9 +220,16 @@ class Didier(commands.Bot):
# Nothing found # Nothing found
return False return False
async def on_thread_create(self, thread: discord.Thread): async def on_app_command_completion(
"""Event triggered when a new thread is created""" self,
await thread.join() interaction: discord.Interaction,
command: Union[discord.app_commands.Command, discord.app_commands.ContextMenu],
):
"""Event triggered when an app command completes successfully"""
ctx = await commands.Context.from_interaction(interaction)
async with self.postgres_session as session:
await command_stats.register_command_invocation(session, ctx, command, tz_aware_now())
async def on_app_command_error(self, interaction: discord.Interaction, exception: AppCommandError): async def on_app_command_error(self, interaction: discord.Interaction, exception: AppCommandError):
"""Event triggered when an application command errors""" """Event triggered when an application command errors"""
@ -257,8 +243,18 @@ class Didier(commands.Bot):
else: else:
return await interaction.followup.send(str(exception.original), ephemeral=True) return await interaction.followup.send(str(exception.original), ephemeral=True)
async def on_command_completion(self, ctx: commands.Context):
"""Event triggered when a message command completes successfully"""
# Hybrid command invocation triggers both this handler and on_app_command_completion
# We handle it in the correct place
if ctx.interaction is not None:
return
async with self.postgres_session as session:
await command_stats.register_command_invocation(session, ctx, ctx.command, tz_aware_now())
async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /): async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /):
"""Event triggered when a regular command errors""" """Event triggered when a message command errors"""
# If working locally, print everything to your console # If working locally, print everything to your console
if settings.SANDBOX: if settings.SANDBOX:
await super().on_command_error(ctx, exception) await super().on_command_error(ctx, exception)
@ -289,24 +285,61 @@ class Didier(commands.Bot):
): ):
return await ctx.reply(str(exception.original), mention_author=False) return await ctx.reply(str(exception.original), mention_author=False)
# Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception)
if isinstance(exception, commands.MessageNotFound): if isinstance(exception, commands.MessageNotFound):
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10) return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
if isinstance(exception, (commands.MissingRequiredArgument,)):
message = str(exception)
match = re.search(r"(.*) is a required argument that is missing\.", message)
if match is not None and match.groups():
message = f"Found no value for the `{match.groups()[0]}`-argument."
return await ctx.reply(message, ephemeral=True, delete_after=10)
if isinstance( if isinstance(
exception, exception,
( (
commands.BadArgument, commands.BadArgument,
commands.MissingRequiredArgument,
commands.UnexpectedQuoteError, commands.UnexpectedQuoteError,
commands.ExpectedClosingQuoteError, commands.ExpectedClosingQuoteError,
), ),
): ):
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10) return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
# Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception)
if settings.ERRORS_CHANNEL is not None: if settings.ERRORS_CHANNEL is not None:
embed = create_error_embed(ctx, exception) embed = create_error_embed(ctx, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL) channel = self.get_channel(settings.ERRORS_CHANNEL)
await channel.send(embed=embed) await channel.send(embed=embed)
async def on_message(self, message: discord.Message, /) -> None:
"""Event triggered when a message is sent"""
# Ignore messages by bots
if message.author.bot:
return
# Boos react to people that say Dider
if "dider" in message.content.lower() and message.author.id != self.user.id:
await message.add_reaction(settings.DISCORD_BOOS_REACT)
# Potential custom command
if await self._try_invoke_custom_command(message):
return
await self.process_commands(message)
easter_egg = await detect_easter_egg(self, message, self.database_caches.easter_eggs)
if easter_egg is not None:
await message.reply(easter_egg, mention_author=False)
async def on_ready(self):
"""Event triggered when the bot is ready"""
print(settings.DISCORD_READY_MESSAGE)
async def on_thread_create(self, thread: discord.Thread):
"""Event triggered when a new thread is created"""
# Join threads automatically
await thread.join()

View File

@ -1,12 +1,23 @@
import discord import discord
__all__ = ["error_red", "ghent_university_blue", "ghent_university_yellow", "google_blue", "urban_dictionary_green"] __all__ = [
"error_red",
"github_white",
"ghent_university_blue",
"ghent_university_yellow",
"google_blue",
"urban_dictionary_green",
]
def error_red() -> discord.Colour: def error_red() -> discord.Colour:
return discord.Colour.red() return discord.Colour.red()
def github_white() -> discord.Colour:
return discord.Colour.from_rgb(250, 250, 250)
def ghent_university_blue() -> discord.Colour: def ghent_university_blue() -> discord.Colour:
return discord.Colour.from_rgb(30, 100, 200) return discord.Colour.from_rgb(30, 100, 200)