mirror of https://github.com/stijndcl/didier
Fix relationship, add github links, improve error messages
parent
41c8c9d0ab
commit
9e3527ae8a
|
@ -1,8 +1,8 @@
|
|||
"""Command stats
|
||||
"""Command stats & GitHub links
|
||||
|
||||
Revision ID: 3c94051821f8
|
||||
Revision ID: c1f9ee875616
|
||||
Revises: b84bb10fb8de
|
||||
Create Date: 2022-09-20 14:38:41.737628
|
||||
Create Date: 2022-09-20 17:18:02.289593
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
|||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3c94051821f8"
|
||||
revision = "c1f9ee875616"
|
||||
down_revision = "b84bb10fb8de"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
@ -23,15 +23,32 @@ def upgrade() -> None:
|
|||
sa.Column("command_stats_id", sa.Integer(), nullable=False),
|
||||
sa.Column("command", sa.Text(), nullable=False),
|
||||
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.Column("slash", sa.Boolean(), nullable=False),
|
||||
sa.Column("context_menu", sa.Boolean(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("command_stats_id"),
|
||||
)
|
||||
op.create_table(
|
||||
"github_links",
|
||||
sa.Column("github_link_id", sa.Integer(), nullable=False),
|
||||
sa.Column("url", sa.Text(), nullable=False),
|
||||
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.user_id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("github_link_id"),
|
||||
sa.UniqueConstraint("url"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("github_links")
|
||||
op.drop_table("command_stats")
|
||||
# ### end Alembic commands ###
|
|
@ -5,6 +5,7 @@ from discord import app_commands
|
|||
from discord.ext import commands
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.crud.users import get_or_add_user
|
||||
from database.schemas import CommandStats
|
||||
|
||||
__all__ = ["register_command_invocation"]
|
||||
|
@ -20,6 +21,8 @@ async def register_command_invocation(
|
|||
if command is None:
|
||||
return
|
||||
|
||||
await get_or_add_user(session, ctx.author.id)
|
||||
|
||||
# Check the type of invocation
|
||||
context_menu = isinstance(command, app_commands.ContextMenu)
|
||||
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
from typing import Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from database.exceptions import Forbidden, NoResultFoundException
|
||||
from database.schemas import GitHubLink
|
||||
|
||||
__all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"]
|
||||
|
||||
|
||||
async def add_github_link(session: AsyncSession, user_id: int, url: str) -> GitHubLink:
|
||||
"""Add a new GitHub link into the database"""
|
||||
gh_link = GitHubLink(user_id=user_id, url=url)
|
||||
session.add(gh_link)
|
||||
await session.commit()
|
||||
await session.refresh(gh_link)
|
||||
|
||||
return gh_link
|
||||
|
||||
|
||||
async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: int):
|
||||
"""Remove an existing link from the database
|
||||
|
||||
You can only remove links owned by you
|
||||
"""
|
||||
select_statement = select(GitHubLink).where(GitHubLink.github_link_id == link_id)
|
||||
gh_link: Optional[GitHubLink] = (await session.execute(select_statement)).scalar_one_or_none()
|
||||
if gh_link is None:
|
||||
raise NoResultFoundException
|
||||
|
||||
if gh_link.user_id != user_id:
|
||||
raise Forbidden
|
||||
|
||||
delete_statement = delete(GitHubLink).where(GitHubLink.github_link_id == gh_link.github_link_id)
|
||||
await session.execute(delete_statement)
|
||||
await session.commit()
|
||||
|
||||
|
||||
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
|
||||
"""Get a user's GitHub links"""
|
||||
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
|
||||
return (await session.execute(statement)).scalars().all()
|
|
@ -33,6 +33,7 @@ __all__ = [
|
|||
"DadJoke",
|
||||
"Deadline",
|
||||
"EasterEgg",
|
||||
"GitHubLink",
|
||||
"Link",
|
||||
"MemeTemplate",
|
||||
"NightlyData",
|
||||
|
@ -103,10 +104,12 @@ class CommandStats(Base):
|
|||
command_stats_id: int = Column(Integer, primary_key=True)
|
||||
command: str = Column(Text, nullable=False)
|
||||
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
|
||||
user_id: int = Column(BigInteger, nullable=False)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
slash: bool = Column(Boolean, nullable=False)
|
||||
context_menu: bool = Column(Boolean, nullable=False)
|
||||
|
||||
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class CustomCommand(Base):
|
||||
"""Custom commands to fill the hole Dyno couldn't"""
|
||||
|
@ -170,6 +173,18 @@ class EasterEgg(Base):
|
|||
startswith: bool = Column(Boolean, nullable=False, server_default="1")
|
||||
|
||||
|
||||
class GitHubLink(Base):
|
||||
"""A user's GitHub link"""
|
||||
|
||||
__tablename__ = "github_links"
|
||||
|
||||
github_link_id: int = Column(Integer, primary_key=True)
|
||||
url: str = Column(Text, nullable=False, unique=True)
|
||||
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||
|
||||
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin")
|
||||
|
||||
|
||||
class Link(Base):
|
||||
"""Useful links that go useful places"""
|
||||
|
||||
|
@ -279,6 +294,12 @@ class User(Base):
|
|||
bookmarks: list[Bookmark] = relationship(
|
||||
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
command_stats: list[CommandStats] = relationship(
|
||||
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
github_links: list[GitHubLink] = relationship(
|
||||
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
nightly_data: NightlyData = relationship(
|
||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@ import discord
|
|||
from discord import app_commands
|
||||
from discord.ext import commands
|
||||
|
||||
from database.crud import birthdays, bookmarks
|
||||
from database.crud import birthdays, bookmarks, github
|
||||
from database.exceptions import (
|
||||
DuplicateInsertException,
|
||||
Forbidden,
|
||||
|
@ -15,6 +15,7 @@ from didier import Didier
|
|||
from didier.exceptions import expect
|
||||
from didier.menus.bookmarks import BookmarkSource
|
||||
from didier.menus.common import Menu
|
||||
from didier.utils.discord import colours
|
||||
from didier.utils.discord.assets import get_author_avatar
|
||||
from didier.utils.types.datetime import str_to_date
|
||||
from didier.utils.types.string import leading
|
||||
|
@ -193,6 +194,40 @@ class Discord(commands.Cog):
|
|||
modal = CreateBookmark(self.client, message.jump_url)
|
||||
await interaction.response.send_modal(modal)
|
||||
|
||||
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
|
||||
async def github(self, ctx: commands.Context, user: discord.User):
|
||||
"""Show a user's GitHub links"""
|
||||
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
|
||||
embed.set_author(name=user.display_name, icon_url=user.avatar.url or user.default_avatar.url)
|
||||
|
||||
embed.set_footer(text="Links can be added using `didier github add <link>`.")
|
||||
|
||||
async with self.client.postgres_session as session:
|
||||
links = await github.get_github_links(session, user.id)
|
||||
|
||||
if not links:
|
||||
embed.description = "This user has not set any GitHub links yet."
|
||||
else:
|
||||
regular_links = []
|
||||
ugent_links = []
|
||||
|
||||
for link in links:
|
||||
if "github.ugent.be" in link.url.lower():
|
||||
ugent_links.append(link)
|
||||
else:
|
||||
regular_links.append(link)
|
||||
|
||||
regular_links.sort()
|
||||
ugent_links.sort()
|
||||
|
||||
if ugent_links:
|
||||
embed.add_field(name="Ghent University", value="\n".join(ugent_links), inline=False)
|
||||
|
||||
if regular_links:
|
||||
embed.add_field(name="Other", value="\n".join(regular_links), inline=False)
|
||||
|
||||
return await ctx.reply(embed=embed, mention_author=False)
|
||||
|
||||
@commands.command(name="join")
|
||||
async def join(self, ctx: commands.Context, thread: discord.Thread):
|
||||
"""Make Didier join `thread`.
|
||||
|
|
|
@ -188,7 +188,7 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
|
|||
embed.add_field(name="Signature", value=signature, inline=False)
|
||||
|
||||
if command.aliases:
|
||||
embed.add_field(name="Aliases", value=", ".join(command.aliases), inline=False)
|
||||
embed.add_field(name="Aliases", value=", ".join(sorted(command.aliases)), inline=False)
|
||||
|
||||
def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]:
|
||||
"""Try to find a cog, case-insensitively"""
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import re
|
||||
from functools import cached_property
|
||||
from typing import Union
|
||||
|
||||
|
@ -284,23 +285,31 @@ class Didier(commands.Bot):
|
|||
):
|
||||
return await ctx.reply(str(exception.original), mention_author=False)
|
||||
|
||||
# Print everything that we care about to the logs/stderr
|
||||
await super().on_command_error(ctx, exception)
|
||||
|
||||
if isinstance(exception, commands.MessageNotFound):
|
||||
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
|
||||
|
||||
if isinstance(exception, (commands.MissingRequiredArgument,)):
|
||||
message = str(exception)
|
||||
|
||||
match = re.search(r"(.*) is a required argument that is missing\.", message)
|
||||
if match.groups():
|
||||
message = f"Found no value for the `{match.groups()[0]}`-argument."
|
||||
|
||||
return await ctx.reply(message, ephemeral=True, delete_after=10)
|
||||
|
||||
if isinstance(
|
||||
exception,
|
||||
(
|
||||
commands.BadArgument,
|
||||
commands.MissingRequiredArgument,
|
||||
commands.UnexpectedQuoteError,
|
||||
commands.ExpectedClosingQuoteError,
|
||||
),
|
||||
):
|
||||
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
|
||||
|
||||
# Print everything that we care about to the logs/stderr
|
||||
await super().on_command_error(ctx, exception)
|
||||
|
||||
if settings.ERRORS_CHANNEL is not None:
|
||||
embed = create_error_embed(ctx, exception)
|
||||
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||
|
|
|
@ -1,12 +1,23 @@
|
|||
import discord
|
||||
|
||||
__all__ = ["error_red", "ghent_university_blue", "ghent_university_yellow", "google_blue", "urban_dictionary_green"]
|
||||
__all__ = [
|
||||
"error_red",
|
||||
"github_white",
|
||||
"ghent_university_blue",
|
||||
"ghent_university_yellow",
|
||||
"google_blue",
|
||||
"urban_dictionary_green",
|
||||
]
|
||||
|
||||
|
||||
def error_red() -> discord.Colour:
|
||||
return discord.Colour.red()
|
||||
|
||||
|
||||
def github_white() -> discord.Colour:
|
||||
return discord.Colour.from_rgb(250, 250, 250)
|
||||
|
||||
|
||||
def ghent_university_blue() -> discord.Colour:
|
||||
return discord.Colour.from_rgb(30, 100, 200)
|
||||
|
||||
|
|
Loading…
Reference in New Issue