From 23edc51dbff8bbe93e276e81f240d3240ad64525 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 20 Sep 2022 14:47:26 +0200 Subject: [PATCH 1/4] Command stats --- .../versions/3c94051821f8_command_stats.py | 37 +++++++++ database/crud/command_stats.py | 38 +++++++++ database/schemas.py | 13 +++ didier/didier.py | 82 ++++++++++++------- 4 files changed, 141 insertions(+), 29 deletions(-) create mode 100644 alembic/versions/3c94051821f8_command_stats.py create mode 100644 database/crud/command_stats.py diff --git a/alembic/versions/3c94051821f8_command_stats.py b/alembic/versions/3c94051821f8_command_stats.py new file mode 100644 index 0000000..3dddc94 --- /dev/null +++ b/alembic/versions/3c94051821f8_command_stats.py @@ -0,0 +1,37 @@ +"""Command stats + +Revision ID: 3c94051821f8 +Revises: b84bb10fb8de +Create Date: 2022-09-20 14:38:41.737628 + +""" +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision = "3c94051821f8" +down_revision = "b84bb10fb8de" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "command_stats", + sa.Column("command_stats_id", sa.Integer(), nullable=False), + sa.Column("command", sa.Text(), nullable=False), + sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("slash", sa.Boolean(), nullable=False), + sa.Column("context_menu", sa.Boolean(), nullable=False), + sa.PrimaryKeyConstraint("command_stats_id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("command_stats") + # ### end Alembic commands ### diff --git a/database/crud/command_stats.py b/database/crud/command_stats.py new file mode 100644 index 0000000..91c9636 --- /dev/null +++ b/database/crud/command_stats.py @@ -0,0 +1,38 @@ +from datetime import datetime +from typing import Optional, Union + +from discord import app_commands +from discord.ext import commands +from sqlalchemy.ext.asyncio import AsyncSession + +from database.schemas import CommandStats + +__all__ = ["register_command_invocation"] + + +CommandT = Union[commands.Command, app_commands.Command, app_commands.ContextMenu] + + +async def register_command_invocation( + session: AsyncSession, ctx: commands.Context, command: Optional[CommandT], timestamp: datetime +): + """Create an entry for a command invocation""" + if command is None: + return + + # Check the type of invocation + context_menu = isinstance(command, app_commands.ContextMenu) + + # (This is a bit uglier but it accounts for hybrid commands) + slash = isinstance(command, app_commands.Command) or (ctx.interaction is not None and not context_menu) + + stats = CommandStats( + command=command.qualified_name.lower(), + timestamp=timestamp, + user_id=ctx.author.id, + slash=slash, + context_menu=context_menu, + ) + + session.add(stats) + await session.commit() diff --git a/database/schemas.py b/database/schemas.py index 8b952fd..03637f8 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -27,6 +27,7 @@ __all__ = [ "Bank", "Birthday", "Bookmark", + "CommandStats", "CustomCommand", "CustomCommandAlias", "DadJoke", @@ -95,6 +96,18 @@ class Bookmark(Base): user: User = relationship("User", back_populates="bookmarks", uselist=False, lazy="selectin") +class CommandStats(Base): + """Metrics on how often commands are used""" + + __tablename__ = "command_stats" + command_stats_id: int = Column(Integer, primary_key=True) + command: str = Column(Text, nullable=False) + timestamp: datetime = Column(DateTime(timezone=True), nullable=False) + user_id: int = Column(BigInteger, nullable=False) + slash: bool = Column(Boolean, nullable=False) + context_menu: bool = Column(Boolean, nullable=False) + + class CustomCommand(Base): """Custom commands to fill the hole Dyno couldn't""" diff --git a/didier/didier.py b/didier/didier.py index a72196e..455284c 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -2,6 +2,7 @@ import logging import os import pathlib from functools import cached_property +from typing import Union import discord from aiohttp import ClientSession @@ -10,7 +11,7 @@ from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession import settings -from database.crud import custom_commands +from database.crud import command_stats, custom_commands from database.engine import DBSession from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed @@ -18,6 +19,7 @@ from didier.data.embeds.schedules import Schedule, parse_schedule from didier.exceptions import HTTPException, NoMatch from didier.utils.discord.prefix import get_prefix from didier.utils.easter_eggs import detect_easter_egg +from didier.utils.types.datetime import tz_aware_now __all__ = ["Didier"] @@ -194,30 +196,6 @@ class Didier(commands.Bot): """Log a warning message""" await self._log(logging.WARNING, message, log_to_discord) - async def on_ready(self): - """Event triggered when the bot is ready""" - print(settings.DISCORD_READY_MESSAGE) - - async def on_message(self, message: discord.Message, /) -> None: - """Event triggered when a message is sent""" - # Ignore messages by bots - if message.author.bot: - return - - # Boos react to people that say Dider - if "dider" in message.content.lower() and message.author.id != self.user.id: - await message.add_reaction(settings.DISCORD_BOOS_REACT) - - # Potential custom command - if await self._try_invoke_custom_command(message): - return - - await self.process_commands(message) - - easter_egg = await detect_easter_egg(self, message, self.database_caches.easter_eggs) - if easter_egg is not None: - await message.reply(easter_egg, mention_author=False) - async def _try_invoke_custom_command(self, message: discord.Message) -> bool: """Check if the message tries to invoke a custom command @@ -241,9 +219,16 @@ class Didier(commands.Bot): # Nothing found return False - async def on_thread_create(self, thread: discord.Thread): - """Event triggered when a new thread is created""" - await thread.join() + async def on_app_command_completion( + self, + interaction: discord.Interaction, + command: Union[discord.app_commands.Command, discord.app_commands.ContextMenu], + ): + """Event triggered when an app command completes successfully""" + ctx = await commands.Context.from_interaction(interaction) + + async with self.postgres_session as session: + await command_stats.register_command_invocation(session, ctx, command, tz_aware_now()) async def on_app_command_error(self, interaction: discord.Interaction, exception: AppCommandError): """Event triggered when an application command errors""" @@ -257,8 +242,18 @@ class Didier(commands.Bot): else: return await interaction.followup.send(str(exception.original), ephemeral=True) + async def on_command_completion(self, ctx: commands.Context): + """Event triggered when a message command completes successfully""" + # Hybrid command invocation triggers both this handler and on_app_command_completion + # We handle it in the correct place + if ctx.interaction is not None: + return + + async with self.postgres_session as session: + await command_stats.register_command_invocation(session, ctx, ctx.command, tz_aware_now()) + async def on_command_error(self, ctx: commands.Context, exception: commands.CommandError, /): - """Event triggered when a regular command errors""" + """Event triggered when a message command errors""" # If working locally, print everything to your console if settings.SANDBOX: await super().on_command_error(ctx, exception) @@ -310,3 +305,32 @@ class Didier(commands.Bot): embed = create_error_embed(ctx, exception) channel = self.get_channel(settings.ERRORS_CHANNEL) await channel.send(embed=embed) + + async def on_message(self, message: discord.Message, /) -> None: + """Event triggered when a message is sent""" + # Ignore messages by bots + if message.author.bot: + return + + # Boos react to people that say Dider + if "dider" in message.content.lower() and message.author.id != self.user.id: + await message.add_reaction(settings.DISCORD_BOOS_REACT) + + # Potential custom command + if await self._try_invoke_custom_command(message): + return + + await self.process_commands(message) + + easter_egg = await detect_easter_egg(self, message, self.database_caches.easter_eggs) + if easter_egg is not None: + await message.reply(easter_egg, mention_author=False) + + async def on_ready(self): + """Event triggered when the bot is ready""" + print(settings.DISCORD_READY_MESSAGE) + + async def on_thread_create(self, thread: discord.Thread): + """Event triggered when a new thread is created""" + # Join threads automatically + await thread.join() From 41c8c9d0ab1f9b007e6c7a3f341b6a3bb89482eb Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 20 Sep 2022 16:45:02 +0200 Subject: [PATCH 2/4] Put tracebacks in a codeblock for readability & to escape markdown --- didier/data/embeds/error_embed.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/didier/data/embeds/error_embed.py b/didier/data/embeds/error_embed.py index 9118c6d..03edd25 100644 --- a/didier/data/embeds/error_embed.py +++ b/didier/data/embeds/error_embed.py @@ -23,12 +23,14 @@ def _get_traceback(exception: Exception) -> str: if line.strip(): error_string += "\n" - return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH) + return abbreviate(error_string, Limits.EMBED_FIELD_VALUE_LENGTH - 8) def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.Embed: """Create an embed for the traceback of an exception""" - description = _get_traceback(exception) + # Wrap the traceback in a codeblock for readability + description = _get_traceback(exception).strip() + description = f"```\n{description}\n```" if ctx.guild is None: origin = "DM" @@ -40,7 +42,7 @@ def create_error_embed(ctx: commands.Context, exception: Exception) -> discord.E embed = discord.Embed(title="Error", colour=discord.Colour.red()) embed.add_field(name="Command", value=f"{ctx.message.content}", inline=True) embed.add_field(name="Context", value=invocation, inline=True) - embed.add_field(name="Exception", value=abbreviate(str(exception), Limits.EMBED_FIELD_VALUE_LENGTH), inline=False) + embed.add_field(name="Exception", value=str(exception), inline=False) embed.add_field(name="Traceback", value=description, inline=False) return embed From 9e3527ae8a2b0c9528ffa1f9ae1649f721e82d60 Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 20 Sep 2022 17:34:49 +0200 Subject: [PATCH 3/4] Fix relationship, add github links, improve error messages --- ...1f9ee875616_command_stats_github_links.py} | 27 +++++++++--- database/crud/command_stats.py | 3 ++ database/crud/github.py | 43 +++++++++++++++++++ database/schemas.py | 23 +++++++++- didier/cogs/discord.py | 37 +++++++++++++++- didier/cogs/help.py | 2 +- didier/didier.py | 17 ++++++-- didier/utils/discord/colours.py | 13 +++++- 8 files changed, 152 insertions(+), 13 deletions(-) rename alembic/versions/{3c94051821f8_command_stats.py => c1f9ee875616_command_stats_github_links.py} (55%) create mode 100644 database/crud/github.py diff --git a/alembic/versions/3c94051821f8_command_stats.py b/alembic/versions/c1f9ee875616_command_stats_github_links.py similarity index 55% rename from alembic/versions/3c94051821f8_command_stats.py rename to alembic/versions/c1f9ee875616_command_stats_github_links.py index 3dddc94..3ef4b9f 100644 --- a/alembic/versions/3c94051821f8_command_stats.py +++ b/alembic/versions/c1f9ee875616_command_stats_github_links.py @@ -1,8 +1,8 @@ -"""Command stats +"""Command stats & GitHub links -Revision ID: 3c94051821f8 +Revision ID: c1f9ee875616 Revises: b84bb10fb8de -Create Date: 2022-09-20 14:38:41.737628 +Create Date: 2022-09-20 17:18:02.289593 """ import sqlalchemy as sa @@ -10,7 +10,7 @@ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. -revision = "3c94051821f8" +revision = "c1f9ee875616" down_revision = "b84bb10fb8de" branch_labels = None depends_on = None @@ -23,15 +23,32 @@ def upgrade() -> None: sa.Column("command_stats_id", sa.Integer(), nullable=False), sa.Column("command", sa.Text(), nullable=False), sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), - sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), sa.Column("slash", sa.Boolean(), nullable=False), sa.Column("context_menu", sa.Boolean(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), sa.PrimaryKeyConstraint("command_stats_id"), ) + op.create_table( + "github_links", + sa.Column("github_link_id", sa.Integer(), nullable=False), + sa.Column("url", sa.Text(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=True), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.user_id"], + ), + sa.PrimaryKeyConstraint("github_link_id"), + sa.UniqueConstraint("url"), + ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("github_links") op.drop_table("command_stats") # ### end Alembic commands ### diff --git a/database/crud/command_stats.py b/database/crud/command_stats.py index 91c9636..000eebe 100644 --- a/database/crud/command_stats.py +++ b/database/crud/command_stats.py @@ -5,6 +5,7 @@ from discord import app_commands from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession +from database.crud.users import get_or_add_user from database.schemas import CommandStats __all__ = ["register_command_invocation"] @@ -20,6 +21,8 @@ async def register_command_invocation( if command is None: return + await get_or_add_user(session, ctx.author.id) + # Check the type of invocation context_menu = isinstance(command, app_commands.ContextMenu) diff --git a/database/crud/github.py b/database/crud/github.py new file mode 100644 index 0000000..4dda51e --- /dev/null +++ b/database/crud/github.py @@ -0,0 +1,43 @@ +from typing import Optional + +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from database.exceptions import Forbidden, NoResultFoundException +from database.schemas import GitHubLink + +__all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"] + + +async def add_github_link(session: AsyncSession, user_id: int, url: str) -> GitHubLink: + """Add a new GitHub link into the database""" + gh_link = GitHubLink(user_id=user_id, url=url) + session.add(gh_link) + await session.commit() + await session.refresh(gh_link) + + return gh_link + + +async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: int): + """Remove an existing link from the database + + You can only remove links owned by you + """ + select_statement = select(GitHubLink).where(GitHubLink.github_link_id == link_id) + gh_link: Optional[GitHubLink] = (await session.execute(select_statement)).scalar_one_or_none() + if gh_link is None: + raise NoResultFoundException + + if gh_link.user_id != user_id: + raise Forbidden + + delete_statement = delete(GitHubLink).where(GitHubLink.github_link_id == gh_link.github_link_id) + await session.execute(delete_statement) + await session.commit() + + +async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]: + """Get a user's GitHub links""" + statement = select(GitHubLink).where(GitHubLink.user_id == user_id) + return (await session.execute(statement)).scalars().all() diff --git a/database/schemas.py b/database/schemas.py index 03637f8..f497b2d 100644 --- a/database/schemas.py +++ b/database/schemas.py @@ -33,6 +33,7 @@ __all__ = [ "DadJoke", "Deadline", "EasterEgg", + "GitHubLink", "Link", "MemeTemplate", "NightlyData", @@ -103,10 +104,12 @@ class CommandStats(Base): command_stats_id: int = Column(Integer, primary_key=True) command: str = Column(Text, nullable=False) timestamp: datetime = Column(DateTime(timezone=True), nullable=False) - user_id: int = Column(BigInteger, nullable=False) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) slash: bool = Column(Boolean, nullable=False) context_menu: bool = Column(Boolean, nullable=False) + user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin") + class CustomCommand(Base): """Custom commands to fill the hole Dyno couldn't""" @@ -170,6 +173,18 @@ class EasterEgg(Base): startswith: bool = Column(Boolean, nullable=False, server_default="1") +class GitHubLink(Base): + """A user's GitHub link""" + + __tablename__ = "github_links" + + github_link_id: int = Column(Integer, primary_key=True) + url: str = Column(Text, nullable=False, unique=True) + user_id: int = Column(BigInteger, ForeignKey("users.user_id")) + + user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin") + + class Link(Base): """Useful links that go useful places""" @@ -279,6 +294,12 @@ class User(Base): bookmarks: list[Bookmark] = relationship( "Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" ) + command_stats: list[CommandStats] = relationship( + "CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + ) + github_links: list[GitHubLink] = relationship( + "GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" + ) nightly_data: NightlyData = relationship( "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" ) diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index c13c782..e73b0f6 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -4,7 +4,7 @@ import discord from discord import app_commands from discord.ext import commands -from database.crud import birthdays, bookmarks +from database.crud import birthdays, bookmarks, github from database.exceptions import ( DuplicateInsertException, Forbidden, @@ -15,6 +15,7 @@ from didier import Didier from didier.exceptions import expect from didier.menus.bookmarks import BookmarkSource from didier.menus.common import Menu +from didier.utils.discord import colours from didier.utils.discord.assets import get_author_avatar from didier.utils.types.datetime import str_to_date from didier.utils.types.string import leading @@ -193,6 +194,40 @@ class Discord(commands.Cog): modal = CreateBookmark(self.client, message.jump_url) await interaction.response.send_modal(modal) + @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) + async def github(self, ctx: commands.Context, user: discord.User): + """Show a user's GitHub links""" + embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") + embed.set_author(name=user.display_name, icon_url=user.avatar.url or user.default_avatar.url) + + embed.set_footer(text="Links can be added using `didier github add `.") + + async with self.client.postgres_session as session: + links = await github.get_github_links(session, user.id) + + if not links: + embed.description = "This user has not set any GitHub links yet." + else: + regular_links = [] + ugent_links = [] + + for link in links: + if "github.ugent.be" in link.url.lower(): + ugent_links.append(link) + else: + regular_links.append(link) + + regular_links.sort() + ugent_links.sort() + + if ugent_links: + embed.add_field(name="Ghent University", value="\n".join(ugent_links), inline=False) + + if regular_links: + embed.add_field(name="Other", value="\n".join(regular_links), inline=False) + + return await ctx.reply(embed=embed, mention_author=False) + @commands.command(name="join") async def join(self, ctx: commands.Context, thread: discord.Thread): """Make Didier join `thread`. diff --git a/didier/cogs/help.py b/didier/cogs/help.py index c6d88d6..459f802 100644 --- a/didier/cogs/help.py +++ b/didier/cogs/help.py @@ -188,7 +188,7 @@ class CustomHelpCommand(commands.MinimalHelpCommand): embed.add_field(name="Signature", value=signature, inline=False) if command.aliases: - embed.add_field(name="Aliases", value=", ".join(command.aliases), inline=False) + embed.add_field(name="Aliases", value=", ".join(sorted(command.aliases)), inline=False) def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]: """Try to find a cog, case-insensitively""" diff --git a/didier/didier.py b/didier/didier.py index 455284c..ad47e38 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -1,6 +1,7 @@ import logging import os import pathlib +import re from functools import cached_property from typing import Union @@ -284,23 +285,31 @@ class Didier(commands.Bot): ): return await ctx.reply(str(exception.original), mention_author=False) - # Print everything that we care about to the logs/stderr - await super().on_command_error(ctx, exception) - if isinstance(exception, commands.MessageNotFound): return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10) + if isinstance(exception, (commands.MissingRequiredArgument,)): + message = str(exception) + + match = re.search(r"(.*) is a required argument that is missing\.", message) + if match.groups(): + message = f"Found no value for the `{match.groups()[0]}`-argument." + + return await ctx.reply(message, ephemeral=True, delete_after=10) + if isinstance( exception, ( commands.BadArgument, - commands.MissingRequiredArgument, commands.UnexpectedQuoteError, commands.ExpectedClosingQuoteError, ), ): return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10) + # Print everything that we care about to the logs/stderr + await super().on_command_error(ctx, exception) + if settings.ERRORS_CHANNEL is not None: embed = create_error_embed(ctx, exception) channel = self.get_channel(settings.ERRORS_CHANNEL) diff --git a/didier/utils/discord/colours.py b/didier/utils/discord/colours.py index c8a55e6..e0ebb5c 100644 --- a/didier/utils/discord/colours.py +++ b/didier/utils/discord/colours.py @@ -1,12 +1,23 @@ import discord -__all__ = ["error_red", "ghent_university_blue", "ghent_university_yellow", "google_blue", "urban_dictionary_green"] +__all__ = [ + "error_red", + "github_white", + "ghent_university_blue", + "ghent_university_yellow", + "google_blue", + "urban_dictionary_green", +] def error_red() -> discord.Colour: return discord.Colour.red() +def github_white() -> discord.Colour: + return discord.Colour.from_rgb(250, 250, 250) + + def ghent_university_blue() -> discord.Colour: return discord.Colour.from_rgb(30, 100, 200) From 97e815cbff6f06d3ce95afd83c1a0e3e1b2b02fd Mon Sep 17 00:00:00 2001 From: stijndcl Date: Tue, 20 Sep 2022 17:55:59 +0200 Subject: [PATCH 4/4] Adding & removing github links --- database/crud/github.py | 18 +++++++++---- didier/cogs/discord.py | 58 ++++++++++++++++++++++++++++++++++++----- didier/didier.py | 2 +- 3 files changed, 65 insertions(+), 13 deletions(-) diff --git a/database/crud/github.py b/database/crud/github.py index 4dda51e..0d32377 100644 --- a/database/crud/github.py +++ b/database/crud/github.py @@ -1,9 +1,14 @@ from typing import Optional +import sqlalchemy.exc from sqlalchemy import delete, select from sqlalchemy.ext.asyncio import AsyncSession -from database.exceptions import Forbidden, NoResultFoundException +from database.exceptions import ( + DuplicateInsertException, + Forbidden, + NoResultFoundException, +) from database.schemas import GitHubLink __all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"] @@ -11,10 +16,13 @@ __all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"] async def add_github_link(session: AsyncSession, user_id: int, url: str) -> GitHubLink: """Add a new GitHub link into the database""" - gh_link = GitHubLink(user_id=user_id, url=url) - session.add(gh_link) - await session.commit() - await session.refresh(gh_link) + try: + gh_link = GitHubLink(user_id=user_id, url=url) + session.add(gh_link) + await session.commit() + await session.refresh(gh_link) + except sqlalchemy.exc.IntegrityError: + raise DuplicateInsertException return gh_link diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index e73b0f6..70b4414 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -195,12 +195,20 @@ class Discord(commands.Cog): await interaction.response.send_modal(modal) @commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True) - async def github(self, ctx: commands.Context, user: discord.User): - """Show a user's GitHub links""" - embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") - embed.set_author(name=user.display_name, icon_url=user.avatar.url or user.default_avatar.url) + async def github_group(self, ctx: commands.Context, user: Optional[discord.User] = None): + """Show a user's GitHub links. - embed.set_footer(text="Links can be added using `didier github add `.") + If no user is provided, this shows your links instead. + """ + # Default to author + user = user or ctx.author + + embed = discord.Embed(colour=colours.github_white(), title="GitHub Links") + embed.set_author( + name=user.display_name, icon_url=user.avatar.url if user.avatar is not None else user.default_avatar.url + ) + + embed.set_footer(text="Links can be added using didier github add .") async with self.client.postgres_session as session: links = await github.get_github_links(session, user.id) @@ -213,9 +221,9 @@ class Discord(commands.Cog): for link in links: if "github.ugent.be" in link.url.lower(): - ugent_links.append(link) + ugent_links.append(f"`#{link.github_link_id}`: {link.url}") else: - regular_links.append(link) + regular_links.append(f"`#{link.github_link_id}`: {link.url}") regular_links.sort() ugent_links.sort() @@ -228,6 +236,42 @@ class Discord(commands.Cog): return await ctx.reply(embed=embed, mention_author=False) + @github_group.command(name="add", aliases=["create", "insert"]) + async def github_add(self, ctx: commands.Context, link: str): + """Add a new link into the database.""" + # Remove wrapping which can be used to escape Discord embeds + link = link.removeprefix("<").removesuffix(">") + + async with self.client.postgres_session as session: + try: + gh_link = await github.add_github_link(session, ctx.author.id, link) + except DuplicateInsertException: + return await ctx.reply("This link has already been registered by someone.", mention_author=False) + + await self.client.confirm_message(ctx.message) + return await ctx.reply(f"Successfully inserted link `#{gh_link.github_link_id}`.", mention_author=False) + + @github_group.command(name="delete", aliases=["del", "remove", "rm"]) + async def github_delete(self, ctx: commands.Context, link_id: str): + """Delete the link with it `link_id` from the database. + + You can only delete your own links. + """ + try: + link_id_int = int(link_id.removeprefix("#")) + except ValueError: + return await ctx.reply(f"`{link_id}` is not a valid link id.", mention_author=False) + + async with self.client.postgres_session as session: + try: + await github.delete_github_link_by_id(session, ctx.author.id, link_id_int) + except NoResultFoundException: + return await ctx.reply(f"Found no GitHub link with id `#{link_id_int}`.", mention_author=False) + except Forbidden: + return await ctx.reply(f"You don't own GitHub link `#{link_id_int}`.", mention_author=False) + + return await ctx.reply(f"Successfully deleted GitHub link `#{link_id_int}`.", mention_author=False) + @commands.command(name="join") async def join(self, ctx: commands.Context, thread: discord.Thread): """Make Didier join `thread`. diff --git a/didier/didier.py b/didier/didier.py index ad47e38..9bec3f6 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -292,7 +292,7 @@ class Didier(commands.Bot): message = str(exception) match = re.search(r"(.*) is a required argument that is missing\.", message) - if match.groups(): + if match is not None and match.groups(): message = f"Found no value for the `{match.groups()[0]}`-argument." return await ctx.reply(message, ephemeral=True, delete_after=10)