Fix relationship, add github links, improve error messages

pull/133/head
stijndcl 2022-09-20 17:34:49 +02:00
parent 41c8c9d0ab
commit 9e3527ae8a
8 changed files with 152 additions and 13 deletions

View File

@ -1,8 +1,8 @@
"""Command stats """Command stats & GitHub links
Revision ID: 3c94051821f8 Revision ID: c1f9ee875616
Revises: b84bb10fb8de Revises: b84bb10fb8de
Create Date: 2022-09-20 14:38:41.737628 Create Date: 2022-09-20 17:18:02.289593
""" """
import sqlalchemy as sa import sqlalchemy as sa
@ -10,7 +10,7 @@ import sqlalchemy as sa
from alembic import op from alembic import op
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
revision = "3c94051821f8" revision = "c1f9ee875616"
down_revision = "b84bb10fb8de" down_revision = "b84bb10fb8de"
branch_labels = None branch_labels = None
depends_on = None depends_on = None
@ -23,15 +23,32 @@ def upgrade() -> None:
sa.Column("command_stats_id", sa.Integer(), nullable=False), sa.Column("command_stats_id", sa.Integer(), nullable=False),
sa.Column("command", sa.Text(), nullable=False), sa.Column("command", sa.Text(), nullable=False),
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False), sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=False), sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.Column("slash", sa.Boolean(), nullable=False), sa.Column("slash", sa.Boolean(), nullable=False),
sa.Column("context_menu", sa.Boolean(), nullable=False), sa.Column("context_menu", sa.Boolean(), nullable=False),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("command_stats_id"), sa.PrimaryKeyConstraint("command_stats_id"),
) )
op.create_table(
"github_links",
sa.Column("github_link_id", sa.Integer(), nullable=False),
sa.Column("url", sa.Text(), nullable=False),
sa.Column("user_id", sa.BigInteger(), nullable=True),
sa.ForeignKeyConstraint(
["user_id"],
["users.user_id"],
),
sa.PrimaryKeyConstraint("github_link_id"),
sa.UniqueConstraint("url"),
)
# ### end Alembic commands ### # ### end Alembic commands ###
def downgrade() -> None: def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ### # ### commands auto generated by Alembic - please adjust! ###
op.drop_table("github_links")
op.drop_table("command_stats") op.drop_table("command_stats")
# ### end Alembic commands ### # ### end Alembic commands ###

View File

@ -5,6 +5,7 @@ from discord import app_commands
from discord.ext import commands from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.crud.users import get_or_add_user
from database.schemas import CommandStats from database.schemas import CommandStats
__all__ = ["register_command_invocation"] __all__ = ["register_command_invocation"]
@ -20,6 +21,8 @@ async def register_command_invocation(
if command is None: if command is None:
return return
await get_or_add_user(session, ctx.author.id)
# Check the type of invocation # Check the type of invocation
context_menu = isinstance(command, app_commands.ContextMenu) context_menu = isinstance(command, app_commands.ContextMenu)

View File

@ -0,0 +1,43 @@
from typing import Optional
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from database.exceptions import Forbidden, NoResultFoundException
from database.schemas import GitHubLink
__all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"]
async def add_github_link(session: AsyncSession, user_id: int, url: str) -> GitHubLink:
"""Add a new GitHub link into the database"""
gh_link = GitHubLink(user_id=user_id, url=url)
session.add(gh_link)
await session.commit()
await session.refresh(gh_link)
return gh_link
async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: int):
"""Remove an existing link from the database
You can only remove links owned by you
"""
select_statement = select(GitHubLink).where(GitHubLink.github_link_id == link_id)
gh_link: Optional[GitHubLink] = (await session.execute(select_statement)).scalar_one_or_none()
if gh_link is None:
raise NoResultFoundException
if gh_link.user_id != user_id:
raise Forbidden
delete_statement = delete(GitHubLink).where(GitHubLink.github_link_id == gh_link.github_link_id)
await session.execute(delete_statement)
await session.commit()
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
"""Get a user's GitHub links"""
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
return (await session.execute(statement)).scalars().all()

View File

@ -33,6 +33,7 @@ __all__ = [
"DadJoke", "DadJoke",
"Deadline", "Deadline",
"EasterEgg", "EasterEgg",
"GitHubLink",
"Link", "Link",
"MemeTemplate", "MemeTemplate",
"NightlyData", "NightlyData",
@ -103,10 +104,12 @@ class CommandStats(Base):
command_stats_id: int = Column(Integer, primary_key=True) command_stats_id: int = Column(Integer, primary_key=True)
command: str = Column(Text, nullable=False) command: str = Column(Text, nullable=False)
timestamp: datetime = Column(DateTime(timezone=True), nullable=False) timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
user_id: int = Column(BigInteger, nullable=False) user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
slash: bool = Column(Boolean, nullable=False) slash: bool = Column(Boolean, nullable=False)
context_menu: bool = Column(Boolean, nullable=False) context_menu: bool = Column(Boolean, nullable=False)
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin")
class CustomCommand(Base): class CustomCommand(Base):
"""Custom commands to fill the hole Dyno couldn't""" """Custom commands to fill the hole Dyno couldn't"""
@ -170,6 +173,18 @@ class EasterEgg(Base):
startswith: bool = Column(Boolean, nullable=False, server_default="1") startswith: bool = Column(Boolean, nullable=False, server_default="1")
class GitHubLink(Base):
"""A user's GitHub link"""
__tablename__ = "github_links"
github_link_id: int = Column(Integer, primary_key=True)
url: str = Column(Text, nullable=False, unique=True)
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin")
class Link(Base): class Link(Base):
"""Useful links that go useful places""" """Useful links that go useful places"""
@ -279,6 +294,12 @@ class User(Base):
bookmarks: list[Bookmark] = relationship( bookmarks: list[Bookmark] = relationship(
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan" "Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
) )
command_stats: list[CommandStats] = relationship(
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
github_links: list[GitHubLink] = relationship(
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
)
nightly_data: NightlyData = relationship( nightly_data: NightlyData = relationship(
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan" "NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
) )

View File

@ -4,7 +4,7 @@ import discord
from discord import app_commands from discord import app_commands
from discord.ext import commands from discord.ext import commands
from database.crud import birthdays, bookmarks from database.crud import birthdays, bookmarks, github
from database.exceptions import ( from database.exceptions import (
DuplicateInsertException, DuplicateInsertException,
Forbidden, Forbidden,
@ -15,6 +15,7 @@ from didier import Didier
from didier.exceptions import expect from didier.exceptions import expect
from didier.menus.bookmarks import BookmarkSource from didier.menus.bookmarks import BookmarkSource
from didier.menus.common import Menu from didier.menus.common import Menu
from didier.utils.discord import colours
from didier.utils.discord.assets import get_author_avatar from didier.utils.discord.assets import get_author_avatar
from didier.utils.types.datetime import str_to_date from didier.utils.types.datetime import str_to_date
from didier.utils.types.string import leading from didier.utils.types.string import leading
@ -193,6 +194,40 @@ class Discord(commands.Cog):
modal = CreateBookmark(self.client, message.jump_url) modal = CreateBookmark(self.client, message.jump_url)
await interaction.response.send_modal(modal) await interaction.response.send_modal(modal)
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
async def github(self, ctx: commands.Context, user: discord.User):
"""Show a user's GitHub links"""
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
embed.set_author(name=user.display_name, icon_url=user.avatar.url or user.default_avatar.url)
embed.set_footer(text="Links can be added using `didier github add <link>`.")
async with self.client.postgres_session as session:
links = await github.get_github_links(session, user.id)
if not links:
embed.description = "This user has not set any GitHub links yet."
else:
regular_links = []
ugent_links = []
for link in links:
if "github.ugent.be" in link.url.lower():
ugent_links.append(link)
else:
regular_links.append(link)
regular_links.sort()
ugent_links.sort()
if ugent_links:
embed.add_field(name="Ghent University", value="\n".join(ugent_links), inline=False)
if regular_links:
embed.add_field(name="Other", value="\n".join(regular_links), inline=False)
return await ctx.reply(embed=embed, mention_author=False)
@commands.command(name="join") @commands.command(name="join")
async def join(self, ctx: commands.Context, thread: discord.Thread): async def join(self, ctx: commands.Context, thread: discord.Thread):
"""Make Didier join `thread`. """Make Didier join `thread`.

View File

@ -188,7 +188,7 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
embed.add_field(name="Signature", value=signature, inline=False) embed.add_field(name="Signature", value=signature, inline=False)
if command.aliases: if command.aliases:
embed.add_field(name="Aliases", value=", ".join(command.aliases), inline=False) embed.add_field(name="Aliases", value=", ".join(sorted(command.aliases)), inline=False)
def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]: def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]:
"""Try to find a cog, case-insensitively""" """Try to find a cog, case-insensitively"""

View File

@ -1,6 +1,7 @@
import logging import logging
import os import os
import pathlib import pathlib
import re
from functools import cached_property from functools import cached_property
from typing import Union from typing import Union
@ -284,23 +285,31 @@ class Didier(commands.Bot):
): ):
return await ctx.reply(str(exception.original), mention_author=False) return await ctx.reply(str(exception.original), mention_author=False)
# Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception)
if isinstance(exception, commands.MessageNotFound): if isinstance(exception, commands.MessageNotFound):
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10) return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
if isinstance(exception, (commands.MissingRequiredArgument,)):
message = str(exception)
match = re.search(r"(.*) is a required argument that is missing\.", message)
if match.groups():
message = f"Found no value for the `{match.groups()[0]}`-argument."
return await ctx.reply(message, ephemeral=True, delete_after=10)
if isinstance( if isinstance(
exception, exception,
( (
commands.BadArgument, commands.BadArgument,
commands.MissingRequiredArgument,
commands.UnexpectedQuoteError, commands.UnexpectedQuoteError,
commands.ExpectedClosingQuoteError, commands.ExpectedClosingQuoteError,
), ),
): ):
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10) return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
# Print everything that we care about to the logs/stderr
await super().on_command_error(ctx, exception)
if settings.ERRORS_CHANNEL is not None: if settings.ERRORS_CHANNEL is not None:
embed = create_error_embed(ctx, exception) embed = create_error_embed(ctx, exception)
channel = self.get_channel(settings.ERRORS_CHANNEL) channel = self.get_channel(settings.ERRORS_CHANNEL)

View File

@ -1,12 +1,23 @@
import discord import discord
__all__ = ["error_red", "ghent_university_blue", "ghent_university_yellow", "google_blue", "urban_dictionary_green"] __all__ = [
"error_red",
"github_white",
"ghent_university_blue",
"ghent_university_yellow",
"google_blue",
"urban_dictionary_green",
]
def error_red() -> discord.Colour: def error_red() -> discord.Colour:
return discord.Colour.red() return discord.Colour.red()
def github_white() -> discord.Colour:
return discord.Colour.from_rgb(250, 250, 250)
def ghent_university_blue() -> discord.Colour: def ghent_university_blue() -> discord.Colour:
return discord.Colour.from_rgb(30, 100, 200) return discord.Colour.from_rgb(30, 100, 200)