mirror of https://github.com/stijndcl/didier
Fix relationship, add github links, improve error messages
parent
41c8c9d0ab
commit
9e3527ae8a
|
@ -1,8 +1,8 @@
|
||||||
"""Command stats
|
"""Command stats & GitHub links
|
||||||
|
|
||||||
Revision ID: 3c94051821f8
|
Revision ID: c1f9ee875616
|
||||||
Revises: b84bb10fb8de
|
Revises: b84bb10fb8de
|
||||||
Create Date: 2022-09-20 14:38:41.737628
|
Create Date: 2022-09-20 17:18:02.289593
|
||||||
|
|
||||||
"""
|
"""
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
@ -10,7 +10,7 @@ import sqlalchemy as sa
|
||||||
from alembic import op
|
from alembic import op
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
# revision identifiers, used by Alembic.
|
||||||
revision = "3c94051821f8"
|
revision = "c1f9ee875616"
|
||||||
down_revision = "b84bb10fb8de"
|
down_revision = "b84bb10fb8de"
|
||||||
branch_labels = None
|
branch_labels = None
|
||||||
depends_on = None
|
depends_on = None
|
||||||
|
@ -23,15 +23,32 @@ def upgrade() -> None:
|
||||||
sa.Column("command_stats_id", sa.Integer(), nullable=False),
|
sa.Column("command_stats_id", sa.Integer(), nullable=False),
|
||||||
sa.Column("command", sa.Text(), nullable=False),
|
sa.Column("command", sa.Text(), nullable=False),
|
||||||
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False),
|
sa.Column("timestamp", sa.DateTime(timezone=True), nullable=False),
|
||||||
sa.Column("user_id", sa.BigInteger(), nullable=False),
|
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||||
sa.Column("slash", sa.Boolean(), nullable=False),
|
sa.Column("slash", sa.Boolean(), nullable=False),
|
||||||
sa.Column("context_menu", sa.Boolean(), nullable=False),
|
sa.Column("context_menu", sa.Boolean(), nullable=False),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.user_id"],
|
||||||
|
),
|
||||||
sa.PrimaryKeyConstraint("command_stats_id"),
|
sa.PrimaryKeyConstraint("command_stats_id"),
|
||||||
)
|
)
|
||||||
|
op.create_table(
|
||||||
|
"github_links",
|
||||||
|
sa.Column("github_link_id", sa.Integer(), nullable=False),
|
||||||
|
sa.Column("url", sa.Text(), nullable=False),
|
||||||
|
sa.Column("user_id", sa.BigInteger(), nullable=True),
|
||||||
|
sa.ForeignKeyConstraint(
|
||||||
|
["user_id"],
|
||||||
|
["users.user_id"],
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("github_link_id"),
|
||||||
|
sa.UniqueConstraint("url"),
|
||||||
|
)
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
# ### commands auto generated by Alembic - please adjust! ###
|
# ### commands auto generated by Alembic - please adjust! ###
|
||||||
|
op.drop_table("github_links")
|
||||||
op.drop_table("command_stats")
|
op.drop_table("command_stats")
|
||||||
# ### end Alembic commands ###
|
# ### end Alembic commands ###
|
|
@ -5,6 +5,7 @@ from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud.users import get_or_add_user
|
||||||
from database.schemas import CommandStats
|
from database.schemas import CommandStats
|
||||||
|
|
||||||
__all__ = ["register_command_invocation"]
|
__all__ = ["register_command_invocation"]
|
||||||
|
@ -20,6 +21,8 @@ async def register_command_invocation(
|
||||||
if command is None:
|
if command is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
await get_or_add_user(session, ctx.author.id)
|
||||||
|
|
||||||
# Check the type of invocation
|
# Check the type of invocation
|
||||||
context_menu = isinstance(command, app_commands.ContextMenu)
|
context_menu = isinstance(command, app_commands.ContextMenu)
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,43 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.exceptions import Forbidden, NoResultFoundException
|
||||||
|
from database.schemas import GitHubLink
|
||||||
|
|
||||||
|
__all__ = ["add_github_link", "delete_github_link_by_id", "get_github_links"]
|
||||||
|
|
||||||
|
|
||||||
|
async def add_github_link(session: AsyncSession, user_id: int, url: str) -> GitHubLink:
|
||||||
|
"""Add a new GitHub link into the database"""
|
||||||
|
gh_link = GitHubLink(user_id=user_id, url=url)
|
||||||
|
session.add(gh_link)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(gh_link)
|
||||||
|
|
||||||
|
return gh_link
|
||||||
|
|
||||||
|
|
||||||
|
async def delete_github_link_by_id(session: AsyncSession, user_id: int, link_id: int):
|
||||||
|
"""Remove an existing link from the database
|
||||||
|
|
||||||
|
You can only remove links owned by you
|
||||||
|
"""
|
||||||
|
select_statement = select(GitHubLink).where(GitHubLink.github_link_id == link_id)
|
||||||
|
gh_link: Optional[GitHubLink] = (await session.execute(select_statement)).scalar_one_or_none()
|
||||||
|
if gh_link is None:
|
||||||
|
raise NoResultFoundException
|
||||||
|
|
||||||
|
if gh_link.user_id != user_id:
|
||||||
|
raise Forbidden
|
||||||
|
|
||||||
|
delete_statement = delete(GitHubLink).where(GitHubLink.github_link_id == gh_link.github_link_id)
|
||||||
|
await session.execute(delete_statement)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_github_links(session: AsyncSession, user_id: int) -> list[GitHubLink]:
|
||||||
|
"""Get a user's GitHub links"""
|
||||||
|
statement = select(GitHubLink).where(GitHubLink.user_id == user_id)
|
||||||
|
return (await session.execute(statement)).scalars().all()
|
|
@ -33,6 +33,7 @@ __all__ = [
|
||||||
"DadJoke",
|
"DadJoke",
|
||||||
"Deadline",
|
"Deadline",
|
||||||
"EasterEgg",
|
"EasterEgg",
|
||||||
|
"GitHubLink",
|
||||||
"Link",
|
"Link",
|
||||||
"MemeTemplate",
|
"MemeTemplate",
|
||||||
"NightlyData",
|
"NightlyData",
|
||||||
|
@ -103,10 +104,12 @@ class CommandStats(Base):
|
||||||
command_stats_id: int = Column(Integer, primary_key=True)
|
command_stats_id: int = Column(Integer, primary_key=True)
|
||||||
command: str = Column(Text, nullable=False)
|
command: str = Column(Text, nullable=False)
|
||||||
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
|
timestamp: datetime = Column(DateTime(timezone=True), nullable=False)
|
||||||
user_id: int = Column(BigInteger, nullable=False)
|
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||||
slash: bool = Column(Boolean, nullable=False)
|
slash: bool = Column(Boolean, nullable=False)
|
||||||
context_menu: bool = Column(Boolean, nullable=False)
|
context_menu: bool = Column(Boolean, nullable=False)
|
||||||
|
|
||||||
|
user: User = relationship("User", back_populates="command_stats", uselist=False, lazy="selectin")
|
||||||
|
|
||||||
|
|
||||||
class CustomCommand(Base):
|
class CustomCommand(Base):
|
||||||
"""Custom commands to fill the hole Dyno couldn't"""
|
"""Custom commands to fill the hole Dyno couldn't"""
|
||||||
|
@ -170,6 +173,18 @@ class EasterEgg(Base):
|
||||||
startswith: bool = Column(Boolean, nullable=False, server_default="1")
|
startswith: bool = Column(Boolean, nullable=False, server_default="1")
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubLink(Base):
|
||||||
|
"""A user's GitHub link"""
|
||||||
|
|
||||||
|
__tablename__ = "github_links"
|
||||||
|
|
||||||
|
github_link_id: int = Column(Integer, primary_key=True)
|
||||||
|
url: str = Column(Text, nullable=False, unique=True)
|
||||||
|
user_id: int = Column(BigInteger, ForeignKey("users.user_id"))
|
||||||
|
|
||||||
|
user: User = relationship("User", back_populates="github_links", uselist=False, lazy="selectin")
|
||||||
|
|
||||||
|
|
||||||
class Link(Base):
|
class Link(Base):
|
||||||
"""Useful links that go useful places"""
|
"""Useful links that go useful places"""
|
||||||
|
|
||||||
|
@ -279,6 +294,12 @@ class User(Base):
|
||||||
bookmarks: list[Bookmark] = relationship(
|
bookmarks: list[Bookmark] = relationship(
|
||||||
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
"Bookmark", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
command_stats: list[CommandStats] = relationship(
|
||||||
|
"CommandStats", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
github_links: list[GitHubLink] = relationship(
|
||||||
|
"GitHubLink", back_populates="user", uselist=True, lazy="selectin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
nightly_data: NightlyData = relationship(
|
nightly_data: NightlyData = relationship(
|
||||||
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
"NightlyData", back_populates="user", uselist=False, lazy="selectin", cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@ import discord
|
||||||
from discord import app_commands
|
from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from database.crud import birthdays, bookmarks
|
from database.crud import birthdays, bookmarks, github
|
||||||
from database.exceptions import (
|
from database.exceptions import (
|
||||||
DuplicateInsertException,
|
DuplicateInsertException,
|
||||||
Forbidden,
|
Forbidden,
|
||||||
|
@ -15,6 +15,7 @@ from didier import Didier
|
||||||
from didier.exceptions import expect
|
from didier.exceptions import expect
|
||||||
from didier.menus.bookmarks import BookmarkSource
|
from didier.menus.bookmarks import BookmarkSource
|
||||||
from didier.menus.common import Menu
|
from didier.menus.common import Menu
|
||||||
|
from didier.utils.discord import colours
|
||||||
from didier.utils.discord.assets import get_author_avatar
|
from didier.utils.discord.assets import get_author_avatar
|
||||||
from didier.utils.types.datetime import str_to_date
|
from didier.utils.types.datetime import str_to_date
|
||||||
from didier.utils.types.string import leading
|
from didier.utils.types.string import leading
|
||||||
|
@ -193,6 +194,40 @@ class Discord(commands.Cog):
|
||||||
modal = CreateBookmark(self.client, message.jump_url)
|
modal = CreateBookmark(self.client, message.jump_url)
|
||||||
await interaction.response.send_modal(modal)
|
await interaction.response.send_modal(modal)
|
||||||
|
|
||||||
|
@commands.group(name="github", aliases=["gh", "git"], case_insensitive=True, invoke_without_command=True)
|
||||||
|
async def github(self, ctx: commands.Context, user: discord.User):
|
||||||
|
"""Show a user's GitHub links"""
|
||||||
|
embed = discord.Embed(colour=colours.github_white(), title="GitHub Links")
|
||||||
|
embed.set_author(name=user.display_name, icon_url=user.avatar.url or user.default_avatar.url)
|
||||||
|
|
||||||
|
embed.set_footer(text="Links can be added using `didier github add <link>`.")
|
||||||
|
|
||||||
|
async with self.client.postgres_session as session:
|
||||||
|
links = await github.get_github_links(session, user.id)
|
||||||
|
|
||||||
|
if not links:
|
||||||
|
embed.description = "This user has not set any GitHub links yet."
|
||||||
|
else:
|
||||||
|
regular_links = []
|
||||||
|
ugent_links = []
|
||||||
|
|
||||||
|
for link in links:
|
||||||
|
if "github.ugent.be" in link.url.lower():
|
||||||
|
ugent_links.append(link)
|
||||||
|
else:
|
||||||
|
regular_links.append(link)
|
||||||
|
|
||||||
|
regular_links.sort()
|
||||||
|
ugent_links.sort()
|
||||||
|
|
||||||
|
if ugent_links:
|
||||||
|
embed.add_field(name="Ghent University", value="\n".join(ugent_links), inline=False)
|
||||||
|
|
||||||
|
if regular_links:
|
||||||
|
embed.add_field(name="Other", value="\n".join(regular_links), inline=False)
|
||||||
|
|
||||||
|
return await ctx.reply(embed=embed, mention_author=False)
|
||||||
|
|
||||||
@commands.command(name="join")
|
@commands.command(name="join")
|
||||||
async def join(self, ctx: commands.Context, thread: discord.Thread):
|
async def join(self, ctx: commands.Context, thread: discord.Thread):
|
||||||
"""Make Didier join `thread`.
|
"""Make Didier join `thread`.
|
||||||
|
|
|
@ -188,7 +188,7 @@ class CustomHelpCommand(commands.MinimalHelpCommand):
|
||||||
embed.add_field(name="Signature", value=signature, inline=False)
|
embed.add_field(name="Signature", value=signature, inline=False)
|
||||||
|
|
||||||
if command.aliases:
|
if command.aliases:
|
||||||
embed.add_field(name="Aliases", value=", ".join(command.aliases), inline=False)
|
embed.add_field(name="Aliases", value=", ".join(sorted(command.aliases)), inline=False)
|
||||||
|
|
||||||
def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]:
|
def _get_cog(self, cogs: list[commands.Cog], name: str) -> Optional[commands.Cog]:
|
||||||
"""Try to find a cog, case-insensitively"""
|
"""Try to find a cog, case-insensitively"""
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import re
|
||||||
from functools import cached_property
|
from functools import cached_property
|
||||||
from typing import Union
|
from typing import Union
|
||||||
|
|
||||||
|
@ -284,23 +285,31 @@ class Didier(commands.Bot):
|
||||||
):
|
):
|
||||||
return await ctx.reply(str(exception.original), mention_author=False)
|
return await ctx.reply(str(exception.original), mention_author=False)
|
||||||
|
|
||||||
# Print everything that we care about to the logs/stderr
|
|
||||||
await super().on_command_error(ctx, exception)
|
|
||||||
|
|
||||||
if isinstance(exception, commands.MessageNotFound):
|
if isinstance(exception, commands.MessageNotFound):
|
||||||
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
|
return await ctx.reply("This message could not be found.", ephemeral=True, delete_after=10)
|
||||||
|
|
||||||
|
if isinstance(exception, (commands.MissingRequiredArgument,)):
|
||||||
|
message = str(exception)
|
||||||
|
|
||||||
|
match = re.search(r"(.*) is a required argument that is missing\.", message)
|
||||||
|
if match.groups():
|
||||||
|
message = f"Found no value for the `{match.groups()[0]}`-argument."
|
||||||
|
|
||||||
|
return await ctx.reply(message, ephemeral=True, delete_after=10)
|
||||||
|
|
||||||
if isinstance(
|
if isinstance(
|
||||||
exception,
|
exception,
|
||||||
(
|
(
|
||||||
commands.BadArgument,
|
commands.BadArgument,
|
||||||
commands.MissingRequiredArgument,
|
|
||||||
commands.UnexpectedQuoteError,
|
commands.UnexpectedQuoteError,
|
||||||
commands.ExpectedClosingQuoteError,
|
commands.ExpectedClosingQuoteError,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
|
return await ctx.reply("Invalid arguments.", ephemeral=True, delete_after=10)
|
||||||
|
|
||||||
|
# Print everything that we care about to the logs/stderr
|
||||||
|
await super().on_command_error(ctx, exception)
|
||||||
|
|
||||||
if settings.ERRORS_CHANNEL is not None:
|
if settings.ERRORS_CHANNEL is not None:
|
||||||
embed = create_error_embed(ctx, exception)
|
embed = create_error_embed(ctx, exception)
|
||||||
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
channel = self.get_channel(settings.ERRORS_CHANNEL)
|
||||||
|
|
|
@ -1,12 +1,23 @@
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
__all__ = ["error_red", "ghent_university_blue", "ghent_university_yellow", "google_blue", "urban_dictionary_green"]
|
__all__ = [
|
||||||
|
"error_red",
|
||||||
|
"github_white",
|
||||||
|
"ghent_university_blue",
|
||||||
|
"ghent_university_yellow",
|
||||||
|
"google_blue",
|
||||||
|
"urban_dictionary_green",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
def error_red() -> discord.Colour:
|
def error_red() -> discord.Colour:
|
||||||
return discord.Colour.red()
|
return discord.Colour.red()
|
||||||
|
|
||||||
|
|
||||||
|
def github_white() -> discord.Colour:
|
||||||
|
return discord.Colour.from_rgb(250, 250, 250)
|
||||||
|
|
||||||
|
|
||||||
def ghent_university_blue() -> discord.Colour:
|
def ghent_university_blue() -> discord.Colour:
|
||||||
return discord.Colour.from_rgb(30, 100, 200)
|
return discord.Colour.from_rgb(30, 100, 200)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue