mirror of https://github.com/stijndcl/didier
				
				
				
			Rework links
							parent
							
								
									94de47082b
								
							
						
					
					
						commit
						a614e9a9f1
					
				| 
						 | 
				
			
			@ -0,0 +1,35 @@
 | 
			
		|||
"""Add custom links
 | 
			
		||||
 | 
			
		||||
Revision ID: 3962636f3a3d
 | 
			
		||||
Revises: 346b408c362a
 | 
			
		||||
Create Date: 2022-08-10 00:54:05.668255
 | 
			
		||||
 | 
			
		||||
"""
 | 
			
		||||
import sqlalchemy as sa
 | 
			
		||||
 | 
			
		||||
from alembic import op
 | 
			
		||||
 | 
			
		||||
# revision identifiers, used by Alembic.
 | 
			
		||||
revision = "3962636f3a3d"
 | 
			
		||||
down_revision = "346b408c362a"
 | 
			
		||||
branch_labels = None
 | 
			
		||||
depends_on = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def upgrade() -> None:
 | 
			
		||||
    # ### commands auto generated by Alembic - please adjust! ###
 | 
			
		||||
    op.create_table(
 | 
			
		||||
        "links",
 | 
			
		||||
        sa.Column("link_id", sa.Integer(), nullable=False),
 | 
			
		||||
        sa.Column("name", sa.Text(), nullable=False),
 | 
			
		||||
        sa.Column("url", sa.Text(), nullable=False),
 | 
			
		||||
        sa.PrimaryKeyConstraint("link_id"),
 | 
			
		||||
        sa.UniqueConstraint("name"),
 | 
			
		||||
    )
 | 
			
		||||
    # ### end Alembic commands ###
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def downgrade() -> None:
 | 
			
		||||
    # ### commands auto generated by Alembic - please adjust! ###
 | 
			
		||||
    op.drop_table("links")
 | 
			
		||||
    # ### end Alembic commands ###
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,45 @@
 | 
			
		|||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from sqlalchemy import func, select
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
 | 
			
		||||
from database.exceptions import NoResultFoundException
 | 
			
		||||
from database.schemas.relational import Link
 | 
			
		||||
 | 
			
		||||
__all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_all_links(session: AsyncSession) -> list[Link]:
 | 
			
		||||
    """Get a list of all links"""
 | 
			
		||||
    statement = select(Link)
 | 
			
		||||
    return (await session.execute(statement)).scalars().all()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def add_link(session: AsyncSession, name: str, url: str) -> Link:
 | 
			
		||||
    """Add a new link into the database"""
 | 
			
		||||
    if name.islower():
 | 
			
		||||
        name = name.capitalize()
 | 
			
		||||
 | 
			
		||||
    instance = Link(name=name, url=url)
 | 
			
		||||
    session.add(instance)
 | 
			
		||||
    await session.commit()
 | 
			
		||||
 | 
			
		||||
    return instance
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def get_link_by_name(session: AsyncSession, name: str) -> Optional[Link]:
 | 
			
		||||
    """Get a link by its name"""
 | 
			
		||||
    statement = select(Link).where(func.lower(Link.name) == name.lower())
 | 
			
		||||
    return (await session.execute(statement)).scalar_one_or_none()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def edit_link(session: AsyncSession, name: str, new_url: str):
 | 
			
		||||
    """Edit an existing link"""
 | 
			
		||||
    link: Optional[Link] = await get_link_by_name(session, name)
 | 
			
		||||
 | 
			
		||||
    if link is None:
 | 
			
		||||
        raise NoResultFoundException
 | 
			
		||||
 | 
			
		||||
    link.url = new_url
 | 
			
		||||
    session.add(link)
 | 
			
		||||
    await session.commit()
 | 
			
		||||
| 
						 | 
				
			
			@ -0,0 +1,5 @@
 | 
			
		|||
from .constraints import DuplicateInsertException
 | 
			
		||||
from .currency import DoubleNightly, NotEnoughDinks
 | 
			
		||||
from .not_found import NoResultFoundException
 | 
			
		||||
 | 
			
		||||
__all__ = ["DuplicateInsertException", "DoubleNightly", "NotEnoughDinks", "NoResultFoundException"]
 | 
			
		||||
| 
						 | 
				
			
			@ -28,6 +28,7 @@ __all__ = [
 | 
			
		|||
    "CustomCommand",
 | 
			
		||||
    "CustomCommandAlias",
 | 
			
		||||
    "DadJoke",
 | 
			
		||||
    "Link",
 | 
			
		||||
    "NightlyData",
 | 
			
		||||
    "Task",
 | 
			
		||||
    "UforaAnnouncement",
 | 
			
		||||
| 
						 | 
				
			
			@ -109,6 +110,16 @@ class DadJoke(Base):
 | 
			
		|||
    joke: str = Column(Text, nullable=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Link(Base):
 | 
			
		||||
    """Useful links that go useful places"""
 | 
			
		||||
 | 
			
		||||
    __tablename__ = "links"
 | 
			
		||||
 | 
			
		||||
    link_id: int = Column(Integer, primary_key=True)
 | 
			
		||||
    name: str = Column(Text, nullable=False, unique=True)
 | 
			
		||||
    url: str = Column(Text, nullable=False)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class NightlyData(Base):
 | 
			
		||||
    """Data for a user's Nightly stats"""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -4,10 +4,10 @@ from typing import Generic, TypeVar
 | 
			
		|||
from overrides import overrides
 | 
			
		||||
from sqlalchemy.ext.asyncio import AsyncSession
 | 
			
		||||
 | 
			
		||||
from database.crud import ufora_courses, wordle
 | 
			
		||||
from database.crud import links, ufora_courses, wordle
 | 
			
		||||
from database.mongo_types import MongoDatabase
 | 
			
		||||
 | 
			
		||||
__all__ = ["CacheManager", "UforaCourseCache"]
 | 
			
		||||
__all__ = ["CacheManager", "LinkCache", "UforaCourseCache"]
 | 
			
		||||
 | 
			
		||||
T = TypeVar("T")
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -35,12 +35,8 @@ class DatabaseCache(ABC, Generic[T]):
 | 
			
		|||
        self.data.clear()
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def refresh(self, database_session: T):
 | 
			
		||||
        """Refresh the data stored in this cache"""
 | 
			
		||||
 | 
			
		||||
    async def invalidate(self, database_session: T):
 | 
			
		||||
        """Invalidate the data stored in this cache"""
 | 
			
		||||
        await self.refresh(database_session)
 | 
			
		||||
 | 
			
		||||
    def get_autocomplete_suggestions(self, query: str):
 | 
			
		||||
        """Filter the cache to find everything that matches the search query"""
 | 
			
		||||
| 
						 | 
				
			
			@ -49,6 +45,19 @@ class DatabaseCache(ABC, Generic[T]):
 | 
			
		|||
        return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LinkCache(DatabaseCache[AsyncSession]):
 | 
			
		||||
    """Cache to store the names of links"""
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    async def invalidate(self, database_session: AsyncSession):
 | 
			
		||||
        self.clear()
 | 
			
		||||
 | 
			
		||||
        all_links = await links.get_all_links(database_session)
 | 
			
		||||
        self.data = list(map(lambda l: l.name, all_links))
 | 
			
		||||
        self.data.sort()
 | 
			
		||||
        self.data_transformed = list(map(str.lower, self.data))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class UforaCourseCache(DatabaseCache[AsyncSession]):
 | 
			
		||||
    """Cache to store the names of Ufora courses"""
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -61,11 +70,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
 | 
			
		|||
        super().clear()
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    async def refresh(self, database_session: AsyncSession):
 | 
			
		||||
    async def invalidate(self, database_session: AsyncSession):
 | 
			
		||||
        self.clear()
 | 
			
		||||
 | 
			
		||||
        courses = await ufora_courses.get_all_courses(database_session)
 | 
			
		||||
 | 
			
		||||
        self.data = list(map(lambda c: c.name, courses))
 | 
			
		||||
 | 
			
		||||
        # Load the aliases
 | 
			
		||||
| 
						 | 
				
			
			@ -97,7 +105,7 @@ class UforaCourseCache(DatabaseCache[AsyncSession]):
 | 
			
		|||
class WordleCache(DatabaseCache[MongoDatabase]):
 | 
			
		||||
    """Cache to store the current daily Wordle word"""
 | 
			
		||||
 | 
			
		||||
    async def refresh(self, database_session: MongoDatabase):
 | 
			
		||||
    async def invalidate(self, database_session: MongoDatabase):
 | 
			
		||||
        word = await wordle.get_daily_word(database_session)
 | 
			
		||||
        if word is not None:
 | 
			
		||||
            self.data = [word]
 | 
			
		||||
| 
						 | 
				
			
			@ -106,14 +114,17 @@ class WordleCache(DatabaseCache[MongoDatabase]):
 | 
			
		|||
class CacheManager:
 | 
			
		||||
    """Class that keeps track of all caches"""
 | 
			
		||||
 | 
			
		||||
    links: LinkCache
 | 
			
		||||
    ufora_courses: UforaCourseCache
 | 
			
		||||
    wordle_word: WordleCache
 | 
			
		||||
 | 
			
		||||
    def __init__(self):
 | 
			
		||||
        self.links = LinkCache()
 | 
			
		||||
        self.ufora_courses = UforaCourseCache()
 | 
			
		||||
        self.wordle_word = WordleCache()
 | 
			
		||||
 | 
			
		||||
    async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase):
 | 
			
		||||
        """Initialize the contents of all caches"""
 | 
			
		||||
        await self.ufora_courses.refresh(postgres_session)
 | 
			
		||||
        await self.wordle_word.refresh(mongo_db)
 | 
			
		||||
        await self.links.invalidate(postgres_session)
 | 
			
		||||
        await self.ufora_courses.invalidate(postgres_session)
 | 
			
		||||
        await self.wordle_word.invalidate(mongo_db)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,11 @@
 | 
			
		|||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
import discord
 | 
			
		||||
from discord import app_commands
 | 
			
		||||
from discord.ext import commands
 | 
			
		||||
 | 
			
		||||
from database.crud.links import get_link_by_name
 | 
			
		||||
from database.schemas.relational import Link
 | 
			
		||||
from didier import Didier
 | 
			
		||||
from didier.data.apis import urban_dictionary
 | 
			
		||||
from didier.data.embeds.google import GoogleSearch
 | 
			
		||||
| 
						 | 
				
			
			@ -34,6 +39,38 @@ class Other(commands.Cog):
 | 
			
		|||
            embed = GoogleSearch(results).to_embed()
 | 
			
		||||
            await ctx.reply(embed=embed, mention_author=False)
 | 
			
		||||
 | 
			
		||||
    async def _get_link(self, name: str) -> Optional[Link]:
 | 
			
		||||
        async with self.client.postgres_session as session:
 | 
			
		||||
            return await get_link_by_name(session, name.lower())
 | 
			
		||||
 | 
			
		||||
    @commands.command(name="Link", aliases=["Links"], usage="[Name]")
 | 
			
		||||
    async def link_msg(self, ctx: commands.Context, name: str):
 | 
			
		||||
        """Message command to get the link to something"""
 | 
			
		||||
        link = await self._get_link(name)
 | 
			
		||||
        if link is None:
 | 
			
		||||
            return await ctx.reply(f"Found no links matching `{name}`.", mention_author=False)
 | 
			
		||||
 | 
			
		||||
        target_message = await self.client.get_reply_target(ctx)
 | 
			
		||||
        await target_message.reply(link.url, mention_author=False)
 | 
			
		||||
 | 
			
		||||
    @app_commands.command(name="link", description="Get the link to something")
 | 
			
		||||
    @app_commands.describe(name="The name of the link")
 | 
			
		||||
    async def link_slash(self, interaction: discord.Interaction, name: str):
 | 
			
		||||
        """Slash command to get the link to something"""
 | 
			
		||||
        link = await self._get_link(name)
 | 
			
		||||
        if link is None:
 | 
			
		||||
            return await interaction.response.send_message(f"Found no links matching `{name}`.", ephemeral=True)
 | 
			
		||||
 | 
			
		||||
        return await interaction.response.send_message(link.url)
 | 
			
		||||
 | 
			
		||||
    @link_slash.autocomplete("name")
 | 
			
		||||
    async def _link_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
 | 
			
		||||
        """Autocompletion for the 'name'-parameter"""
 | 
			
		||||
        return [
 | 
			
		||||
            app_commands.Choice(name=name, value=name.lower())
 | 
			
		||||
            for name in self.client.database_caches.links.get_autocomplete_suggestions(current)
 | 
			
		||||
        ]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def setup(client: Didier):
 | 
			
		||||
    """Load the cog"""
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -5,12 +5,17 @@ from discord import app_commands
 | 
			
		|||
from discord.ext import commands
 | 
			
		||||
 | 
			
		||||
import settings
 | 
			
		||||
from database.crud import custom_commands
 | 
			
		||||
from database.crud import custom_commands, links
 | 
			
		||||
from database.exceptions.constraints import DuplicateInsertException
 | 
			
		||||
from database.exceptions.not_found import NoResultFoundException
 | 
			
		||||
from didier import Didier
 | 
			
		||||
from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags
 | 
			
		||||
from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand
 | 
			
		||||
from didier.views.modals import (
 | 
			
		||||
    AddDadJoke,
 | 
			
		||||
    AddLink,
 | 
			
		||||
    CreateCustomCommand,
 | 
			
		||||
    EditCustomCommand,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Owner(commands.Cog):
 | 
			
		||||
| 
						 | 
				
			
			@ -80,17 +85,6 @@ class Owner(commands.Cog):
 | 
			
		|||
    async def add_msg(self, ctx: commands.Context):
 | 
			
		||||
        """Command group for [add X] message commands"""
 | 
			
		||||
 | 
			
		||||
    @add_msg.command(name="Custom")
 | 
			
		||||
    async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
 | 
			
		||||
        """Add a new custom command"""
 | 
			
		||||
        async with self.client.postgres_session as session:
 | 
			
		||||
            try:
 | 
			
		||||
                await custom_commands.create_command(session, name, response)
 | 
			
		||||
                await self.client.confirm_message(ctx.message)
 | 
			
		||||
            except DuplicateInsertException:
 | 
			
		||||
                await ctx.reply("There is already a command with this name.")
 | 
			
		||||
                await self.client.reject_message(ctx.message)
 | 
			
		||||
 | 
			
		||||
    @add_msg.command(name="Alias")
 | 
			
		||||
    async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
 | 
			
		||||
        """Add a new alias for a custom command"""
 | 
			
		||||
| 
						 | 
				
			
			@ -105,6 +99,26 @@ class Owner(commands.Cog):
 | 
			
		|||
                await ctx.reply("There is already a command with this name.")
 | 
			
		||||
                await self.client.reject_message(ctx.message)
 | 
			
		||||
 | 
			
		||||
    @add_msg.command(name="Custom")
 | 
			
		||||
    async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
 | 
			
		||||
        """Add a new custom command"""
 | 
			
		||||
        async with self.client.postgres_session as session:
 | 
			
		||||
            try:
 | 
			
		||||
                await custom_commands.create_command(session, name, response)
 | 
			
		||||
                await self.client.confirm_message(ctx.message)
 | 
			
		||||
            except DuplicateInsertException:
 | 
			
		||||
                await ctx.reply("There is already a command with this name.")
 | 
			
		||||
                await self.client.reject_message(ctx.message)
 | 
			
		||||
 | 
			
		||||
    @add_msg.command(name="Link")
 | 
			
		||||
    async def add_link_msg(self, ctx: commands.Context, name: str, url: str):
 | 
			
		||||
        """Add a new link"""
 | 
			
		||||
        async with self.client.postgres_session as session:
 | 
			
		||||
            await links.add_link(session, name, url)
 | 
			
		||||
            await self.client.database_caches.links.invalidate(session)
 | 
			
		||||
 | 
			
		||||
        await self.client.confirm_message(ctx.message)
 | 
			
		||||
 | 
			
		||||
    @add_slash.command(name="custom", description="Add a custom command")
 | 
			
		||||
    async def add_custom_slash(self, interaction: discord.Interaction):
 | 
			
		||||
        """Slash command to add a custom command"""
 | 
			
		||||
| 
						 | 
				
			
			@ -123,6 +137,15 @@ class Owner(commands.Cog):
 | 
			
		|||
        modal = AddDadJoke(self.client)
 | 
			
		||||
        await interaction.response.send_modal(modal)
 | 
			
		||||
 | 
			
		||||
    @add_slash.command(name="link", description="Add a new link")
 | 
			
		||||
    async def add_link_slash(self, interaction: discord.Interaction):
 | 
			
		||||
        """Slash command to add new links"""
 | 
			
		||||
        if not await self.client.is_owner(interaction.user):
 | 
			
		||||
            return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True)
 | 
			
		||||
 | 
			
		||||
        modal = AddLink(self.client)
 | 
			
		||||
        await interaction.response.send_modal(modal)
 | 
			
		||||
 | 
			
		||||
    @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False)
 | 
			
		||||
    async def edit_msg(self, ctx: commands.Context):
 | 
			
		||||
        """Command group for [edit X] commands"""
 | 
			
		||||
| 
						 | 
				
			
			@ -135,7 +158,7 @@ class Owner(commands.Cog):
 | 
			
		|||
                await custom_commands.edit_command(session, command, flags.name, flags.response)
 | 
			
		||||
                return await self.client.confirm_message(ctx.message)
 | 
			
		||||
            except NoResultFoundException:
 | 
			
		||||
                await ctx.reply(f"No command found matching ``{command}``.")
 | 
			
		||||
                await ctx.reply(f"No command found matching `{command}`.")
 | 
			
		||||
                return await self.client.reject_message(ctx.message)
 | 
			
		||||
 | 
			
		||||
    @edit_slash.command(name="custom", description="Edit a custom command")
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -72,7 +72,7 @@ class School(commands.Cog):
 | 
			
		|||
            ufora_course = await ufora_courses.get_course_by_name(session, course)
 | 
			
		||||
 | 
			
		||||
        if ufora_course is None:
 | 
			
		||||
            return await ctx.reply(f"Found no course matching ``{course}``", ephemeral=True)
 | 
			
		||||
            return await ctx.reply(f"Found no course matching `{course}`", ephemeral=True)
 | 
			
		||||
 | 
			
		||||
        return await ctx.reply(
 | 
			
		||||
            f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{flags.year}",
 | 
			
		||||
| 
						 | 
				
			
			@ -80,7 +80,7 @@ class School(commands.Cog):
 | 
			
		|||
        )
 | 
			
		||||
 | 
			
		||||
    @study_guide.autocomplete("course")
 | 
			
		||||
    async def study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
 | 
			
		||||
    async def _study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]:
 | 
			
		||||
        """Autocompletion for the 'course'-parameter"""
 | 
			
		||||
        return [
 | 
			
		||||
            app_commands.Choice(name=course, value=course)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -111,6 +111,17 @@ class Didier(commands.Bot):
 | 
			
		|||
            for line in fp:
 | 
			
		||||
                self.wordle_words.add(line.strip())
 | 
			
		||||
 | 
			
		||||
    async def get_reply_target(self, ctx: commands.Context) -> discord.Message:
 | 
			
		||||
        """Get the target message that should be replied to
 | 
			
		||||
 | 
			
		||||
        In case the invoking message is a reply to something, reply to the
 | 
			
		||||
        original message instead
 | 
			
		||||
        """
 | 
			
		||||
        if ctx.message.reference is not None:
 | 
			
		||||
            return await self.resolve_message(ctx.message.reference)
 | 
			
		||||
 | 
			
		||||
        return ctx.message
 | 
			
		||||
 | 
			
		||||
    async def resolve_message(self, reference: discord.MessageReference) -> discord.Message:
 | 
			
		||||
        """Fetch a message from a reference"""
 | 
			
		||||
        # Message is in the cache, return it
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,4 +1,5 @@
 | 
			
		|||
from .custom_commands import CreateCustomCommand, EditCustomCommand
 | 
			
		||||
from .dad_jokes import AddDadJoke
 | 
			
		||||
from .links import AddLink
 | 
			
		||||
 | 
			
		||||
__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand"]
 | 
			
		||||
__all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand", "AddLink"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -0,0 +1,37 @@
 | 
			
		|||
import traceback
 | 
			
		||||
 | 
			
		||||
import discord.ui
 | 
			
		||||
from overrides import overrides
 | 
			
		||||
 | 
			
		||||
from database.crud.links import add_link
 | 
			
		||||
from didier import Didier
 | 
			
		||||
 | 
			
		||||
__all__ = ["AddLink"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AddLink(discord.ui.Modal, title="Add Link"):
 | 
			
		||||
    """Modal to add a new link"""
 | 
			
		||||
 | 
			
		||||
    name = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, placeholder="Source")
 | 
			
		||||
    url = discord.ui.TextInput(
 | 
			
		||||
        label="URL", style=discord.TextStyle.short, placeholder="https://github.com/stijndcl/didier"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    client: Didier
 | 
			
		||||
 | 
			
		||||
    def __init__(self, client: Didier, *args, **kwargs):
 | 
			
		||||
        super().__init__(*args, **kwargs)
 | 
			
		||||
        self.client = client
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    async def on_submit(self, interaction: discord.Interaction):
 | 
			
		||||
        async with self.client.postgres_session as session:
 | 
			
		||||
            await add_link(session, self.name.value, self.url.value)
 | 
			
		||||
            await self.client.database_caches.links.invalidate(session)
 | 
			
		||||
 | 
			
		||||
        await interaction.response.send_message(f"Successfully added `{self.name.value.capitalize()}`.", ephemeral=True)
 | 
			
		||||
 | 
			
		||||
    @overrides
 | 
			
		||||
    async def on_error(self, interaction: discord.Interaction, error: Exception):  # type: ignore
 | 
			
		||||
        await interaction.response.send_message("Something went wrong.", ephemeral=True)
 | 
			
		||||
        traceback.print_tb(error.__traceback__)
 | 
			
		||||
| 
						 | 
				
			
			@ -7,7 +7,7 @@ from database.utils.caches import UforaCourseCache
 | 
			
		|||
async def test_ufora_course_cache_refresh_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse):
 | 
			
		||||
    """Test loading the data for the Ufora Course cache when it's empty"""
 | 
			
		||||
    cache = UforaCourseCache()
 | 
			
		||||
    await cache.refresh(postgres)
 | 
			
		||||
    await cache.invalidate(postgres)
 | 
			
		||||
 | 
			
		||||
    assert len(cache.data) == 1
 | 
			
		||||
    assert cache.data == ["test"]
 | 
			
		||||
| 
						 | 
				
			
			@ -20,7 +20,7 @@ async def test_ufora_course_cache_refresh_not_empty(postgres: AsyncSession, ufor
 | 
			
		|||
    cache.data = ["Something"]
 | 
			
		||||
    cache.data_transformed = ["something"]
 | 
			
		||||
 | 
			
		||||
    await cache.refresh(postgres)
 | 
			
		||||
    await cache.invalidate(postgres)
 | 
			
		||||
 | 
			
		||||
    assert len(cache.data) == 1
 | 
			
		||||
    assert cache.data == ["test"]
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue