mirror of https://github.com/stijndcl/didier
				
				
				
			Rework links
							parent
							
								
									94de47082b
								
							
						
					
					
						commit
						a614e9a9f1
					
				|  | @ -0,0 +1,35 @@ | |||
| """Add custom links | ||||
| 
 | ||||
| Revision ID: 3962636f3a3d | ||||
| Revises: 346b408c362a | ||||
| Create Date: 2022-08-10 00:54:05.668255 | ||||
| 
 | ||||
| """ | ||||
| import sqlalchemy as sa | ||||
| 
 | ||||
| from alembic import op | ||||
| 
 | ||||
| # revision identifiers, used by Alembic. | ||||
| revision = "3962636f3a3d" | ||||
| down_revision = "346b408c362a" | ||||
| branch_labels = None | ||||
| depends_on = None | ||||
| 
 | ||||
| 
 | ||||
| def upgrade() -> None: | ||||
|     # ### commands auto generated by Alembic - please adjust! ### | ||||
|     op.create_table( | ||||
|         "links", | ||||
|         sa.Column("link_id", sa.Integer(), nullable=False), | ||||
|         sa.Column("name", sa.Text(), nullable=False), | ||||
|         sa.Column("url", sa.Text(), nullable=False), | ||||
|         sa.PrimaryKeyConstraint("link_id"), | ||||
|         sa.UniqueConstraint("name"), | ||||
|     ) | ||||
|     # ### end Alembic commands ### | ||||
| 
 | ||||
| 
 | ||||
| def downgrade() -> None: | ||||
|     # ### commands auto generated by Alembic - please adjust! ### | ||||
|     op.drop_table("links") | ||||
|     # ### end Alembic commands ### | ||||
|  | @ -0,0 +1,45 @@ | |||
| from typing import Optional | ||||
| 
 | ||||
| from sqlalchemy import func, select | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.exceptions import NoResultFoundException | ||||
| from database.schemas.relational import Link | ||||
| 
 | ||||
| __all__ = ["add_link", "edit_link", "get_all_links", "get_link_by_name"] | ||||
| 
 | ||||
| 
 | ||||
| async def get_all_links(session: AsyncSession) -> list[Link]: | ||||
|     """Get a list of all links""" | ||||
|     statement = select(Link) | ||||
|     return (await session.execute(statement)).scalars().all() | ||||
| 
 | ||||
| 
 | ||||
| async def add_link(session: AsyncSession, name: str, url: str) -> Link: | ||||
|     """Add a new link into the database""" | ||||
|     if name.islower(): | ||||
|         name = name.capitalize() | ||||
| 
 | ||||
|     instance = Link(name=name, url=url) | ||||
|     session.add(instance) | ||||
|     await session.commit() | ||||
| 
 | ||||
|     return instance | ||||
| 
 | ||||
| 
 | ||||
| async def get_link_by_name(session: AsyncSession, name: str) -> Optional[Link]: | ||||
|     """Get a link by its name""" | ||||
|     statement = select(Link).where(func.lower(Link.name) == name.lower()) | ||||
|     return (await session.execute(statement)).scalar_one_or_none() | ||||
| 
 | ||||
| 
 | ||||
| async def edit_link(session: AsyncSession, name: str, new_url: str): | ||||
|     """Edit an existing link""" | ||||
|     link: Optional[Link] = await get_link_by_name(session, name) | ||||
| 
 | ||||
|     if link is None: | ||||
|         raise NoResultFoundException | ||||
| 
 | ||||
|     link.url = new_url | ||||
|     session.add(link) | ||||
|     await session.commit() | ||||
|  | @ -0,0 +1,5 @@ | |||
| from .constraints import DuplicateInsertException | ||||
| from .currency import DoubleNightly, NotEnoughDinks | ||||
| from .not_found import NoResultFoundException | ||||
| 
 | ||||
| __all__ = ["DuplicateInsertException", "DoubleNightly", "NotEnoughDinks", "NoResultFoundException"] | ||||
|  | @ -28,6 +28,7 @@ __all__ = [ | |||
|     "CustomCommand", | ||||
|     "CustomCommandAlias", | ||||
|     "DadJoke", | ||||
|     "Link", | ||||
|     "NightlyData", | ||||
|     "Task", | ||||
|     "UforaAnnouncement", | ||||
|  | @ -109,6 +110,16 @@ class DadJoke(Base): | |||
|     joke: str = Column(Text, nullable=False) | ||||
| 
 | ||||
| 
 | ||||
| class Link(Base): | ||||
|     """Useful links that go useful places""" | ||||
| 
 | ||||
|     __tablename__ = "links" | ||||
| 
 | ||||
|     link_id: int = Column(Integer, primary_key=True) | ||||
|     name: str = Column(Text, nullable=False, unique=True) | ||||
|     url: str = Column(Text, nullable=False) | ||||
| 
 | ||||
| 
 | ||||
| class NightlyData(Base): | ||||
|     """Data for a user's Nightly stats""" | ||||
| 
 | ||||
|  |  | |||
|  | @ -4,10 +4,10 @@ from typing import Generic, TypeVar | |||
| from overrides import overrides | ||||
| from sqlalchemy.ext.asyncio import AsyncSession | ||||
| 
 | ||||
| from database.crud import ufora_courses, wordle | ||||
| from database.crud import links, ufora_courses, wordle | ||||
| from database.mongo_types import MongoDatabase | ||||
| 
 | ||||
| __all__ = ["CacheManager", "UforaCourseCache"] | ||||
| __all__ = ["CacheManager", "LinkCache", "UforaCourseCache"] | ||||
| 
 | ||||
| T = TypeVar("T") | ||||
| 
 | ||||
|  | @ -35,12 +35,8 @@ class DatabaseCache(ABC, Generic[T]): | |||
|         self.data.clear() | ||||
| 
 | ||||
|     @abstractmethod | ||||
|     async def refresh(self, database_session: T): | ||||
|         """Refresh the data stored in this cache""" | ||||
| 
 | ||||
|     async def invalidate(self, database_session: T): | ||||
|         """Invalidate the data stored in this cache""" | ||||
|         await self.refresh(database_session) | ||||
| 
 | ||||
|     def get_autocomplete_suggestions(self, query: str): | ||||
|         """Filter the cache to find everything that matches the search query""" | ||||
|  | @ -49,6 +45,19 @@ class DatabaseCache(ABC, Generic[T]): | |||
|         return [self.data[index] for index, value in enumerate(self.data_transformed) if query in value] | ||||
| 
 | ||||
| 
 | ||||
| class LinkCache(DatabaseCache[AsyncSession]): | ||||
|     """Cache to store the names of links""" | ||||
| 
 | ||||
|     @overrides | ||||
|     async def invalidate(self, database_session: AsyncSession): | ||||
|         self.clear() | ||||
| 
 | ||||
|         all_links = await links.get_all_links(database_session) | ||||
|         self.data = list(map(lambda l: l.name, all_links)) | ||||
|         self.data.sort() | ||||
|         self.data_transformed = list(map(str.lower, self.data)) | ||||
| 
 | ||||
| 
 | ||||
| class UforaCourseCache(DatabaseCache[AsyncSession]): | ||||
|     """Cache to store the names of Ufora courses""" | ||||
| 
 | ||||
|  | @ -61,11 +70,10 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): | |||
|         super().clear() | ||||
| 
 | ||||
|     @overrides | ||||
|     async def refresh(self, database_session: AsyncSession): | ||||
|     async def invalidate(self, database_session: AsyncSession): | ||||
|         self.clear() | ||||
| 
 | ||||
|         courses = await ufora_courses.get_all_courses(database_session) | ||||
| 
 | ||||
|         self.data = list(map(lambda c: c.name, courses)) | ||||
| 
 | ||||
|         # Load the aliases | ||||
|  | @ -97,7 +105,7 @@ class UforaCourseCache(DatabaseCache[AsyncSession]): | |||
| class WordleCache(DatabaseCache[MongoDatabase]): | ||||
|     """Cache to store the current daily Wordle word""" | ||||
| 
 | ||||
|     async def refresh(self, database_session: MongoDatabase): | ||||
|     async def invalidate(self, database_session: MongoDatabase): | ||||
|         word = await wordle.get_daily_word(database_session) | ||||
|         if word is not None: | ||||
|             self.data = [word] | ||||
|  | @ -106,14 +114,17 @@ class WordleCache(DatabaseCache[MongoDatabase]): | |||
| class CacheManager: | ||||
|     """Class that keeps track of all caches""" | ||||
| 
 | ||||
|     links: LinkCache | ||||
|     ufora_courses: UforaCourseCache | ||||
|     wordle_word: WordleCache | ||||
| 
 | ||||
|     def __init__(self): | ||||
|         self.links = LinkCache() | ||||
|         self.ufora_courses = UforaCourseCache() | ||||
|         self.wordle_word = WordleCache() | ||||
| 
 | ||||
|     async def initialize_caches(self, postgres_session: AsyncSession, mongo_db: MongoDatabase): | ||||
|         """Initialize the contents of all caches""" | ||||
|         await self.ufora_courses.refresh(postgres_session) | ||||
|         await self.wordle_word.refresh(mongo_db) | ||||
|         await self.links.invalidate(postgres_session) | ||||
|         await self.ufora_courses.invalidate(postgres_session) | ||||
|         await self.wordle_word.invalidate(mongo_db) | ||||
|  |  | |||
|  | @ -1,6 +1,11 @@ | |||
| from typing import Optional | ||||
| 
 | ||||
| import discord | ||||
| from discord import app_commands | ||||
| from discord.ext import commands | ||||
| 
 | ||||
| from database.crud.links import get_link_by_name | ||||
| from database.schemas.relational import Link | ||||
| from didier import Didier | ||||
| from didier.data.apis import urban_dictionary | ||||
| from didier.data.embeds.google import GoogleSearch | ||||
|  | @ -34,6 +39,38 @@ class Other(commands.Cog): | |||
|             embed = GoogleSearch(results).to_embed() | ||||
|             await ctx.reply(embed=embed, mention_author=False) | ||||
| 
 | ||||
|     async def _get_link(self, name: str) -> Optional[Link]: | ||||
|         async with self.client.postgres_session as session: | ||||
|             return await get_link_by_name(session, name.lower()) | ||||
| 
 | ||||
|     @commands.command(name="Link", aliases=["Links"], usage="[Name]") | ||||
|     async def link_msg(self, ctx: commands.Context, name: str): | ||||
|         """Message command to get the link to something""" | ||||
|         link = await self._get_link(name) | ||||
|         if link is None: | ||||
|             return await ctx.reply(f"Found no links matching `{name}`.", mention_author=False) | ||||
| 
 | ||||
|         target_message = await self.client.get_reply_target(ctx) | ||||
|         await target_message.reply(link.url, mention_author=False) | ||||
| 
 | ||||
|     @app_commands.command(name="link", description="Get the link to something") | ||||
|     @app_commands.describe(name="The name of the link") | ||||
|     async def link_slash(self, interaction: discord.Interaction, name: str): | ||||
|         """Slash command to get the link to something""" | ||||
|         link = await self._get_link(name) | ||||
|         if link is None: | ||||
|             return await interaction.response.send_message(f"Found no links matching `{name}`.", ephemeral=True) | ||||
| 
 | ||||
|         return await interaction.response.send_message(link.url) | ||||
| 
 | ||||
|     @link_slash.autocomplete("name") | ||||
|     async def _link_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: | ||||
|         """Autocompletion for the 'name'-parameter""" | ||||
|         return [ | ||||
|             app_commands.Choice(name=name, value=name.lower()) | ||||
|             for name in self.client.database_caches.links.get_autocomplete_suggestions(current) | ||||
|         ] | ||||
| 
 | ||||
| 
 | ||||
| async def setup(client: Didier): | ||||
|     """Load the cog""" | ||||
|  |  | |||
|  | @ -5,12 +5,17 @@ from discord import app_commands | |||
| from discord.ext import commands | ||||
| 
 | ||||
| import settings | ||||
| from database.crud import custom_commands | ||||
| from database.crud import custom_commands, links | ||||
| from database.exceptions.constraints import DuplicateInsertException | ||||
| from database.exceptions.not_found import NoResultFoundException | ||||
| from didier import Didier | ||||
| from didier.utils.discord.flags.owner import EditCustomFlags, SyncOptionFlags | ||||
| from didier.views.modals import AddDadJoke, CreateCustomCommand, EditCustomCommand | ||||
| from didier.views.modals import ( | ||||
|     AddDadJoke, | ||||
|     AddLink, | ||||
|     CreateCustomCommand, | ||||
|     EditCustomCommand, | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| class Owner(commands.Cog): | ||||
|  | @ -80,17 +85,6 @@ class Owner(commands.Cog): | |||
|     async def add_msg(self, ctx: commands.Context): | ||||
|         """Command group for [add X] message commands""" | ||||
| 
 | ||||
|     @add_msg.command(name="Custom") | ||||
|     async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): | ||||
|         """Add a new custom command""" | ||||
|         async with self.client.postgres_session as session: | ||||
|             try: | ||||
|                 await custom_commands.create_command(session, name, response) | ||||
|                 await self.client.confirm_message(ctx.message) | ||||
|             except DuplicateInsertException: | ||||
|                 await ctx.reply("There is already a command with this name.") | ||||
|                 await self.client.reject_message(ctx.message) | ||||
| 
 | ||||
|     @add_msg.command(name="Alias") | ||||
|     async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): | ||||
|         """Add a new alias for a custom command""" | ||||
|  | @ -105,6 +99,26 @@ class Owner(commands.Cog): | |||
|                 await ctx.reply("There is already a command with this name.") | ||||
|                 await self.client.reject_message(ctx.message) | ||||
| 
 | ||||
|     @add_msg.command(name="Custom") | ||||
|     async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): | ||||
|         """Add a new custom command""" | ||||
|         async with self.client.postgres_session as session: | ||||
|             try: | ||||
|                 await custom_commands.create_command(session, name, response) | ||||
|                 await self.client.confirm_message(ctx.message) | ||||
|             except DuplicateInsertException: | ||||
|                 await ctx.reply("There is already a command with this name.") | ||||
|                 await self.client.reject_message(ctx.message) | ||||
| 
 | ||||
|     @add_msg.command(name="Link") | ||||
|     async def add_link_msg(self, ctx: commands.Context, name: str, url: str): | ||||
|         """Add a new link""" | ||||
|         async with self.client.postgres_session as session: | ||||
|             await links.add_link(session, name, url) | ||||
|             await self.client.database_caches.links.invalidate(session) | ||||
| 
 | ||||
|         await self.client.confirm_message(ctx.message) | ||||
| 
 | ||||
|     @add_slash.command(name="custom", description="Add a custom command") | ||||
|     async def add_custom_slash(self, interaction: discord.Interaction): | ||||
|         """Slash command to add a custom command""" | ||||
|  | @ -123,6 +137,15 @@ class Owner(commands.Cog): | |||
|         modal = AddDadJoke(self.client) | ||||
|         await interaction.response.send_modal(modal) | ||||
| 
 | ||||
|     @add_slash.command(name="link", description="Add a new link") | ||||
|     async def add_link_slash(self, interaction: discord.Interaction): | ||||
|         """Slash command to add new links""" | ||||
|         if not await self.client.is_owner(interaction.user): | ||||
|             return interaction.response.send_message("You don't have permission to run this command.", ephemeral=True) | ||||
| 
 | ||||
|         modal = AddLink(self.client) | ||||
|         await interaction.response.send_modal(modal) | ||||
| 
 | ||||
|     @commands.group(name="Edit", case_insensitive=True, invoke_without_command=False) | ||||
|     async def edit_msg(self, ctx: commands.Context): | ||||
|         """Command group for [edit X] commands""" | ||||
|  | @ -135,7 +158,7 @@ class Owner(commands.Cog): | |||
|                 await custom_commands.edit_command(session, command, flags.name, flags.response) | ||||
|                 return await self.client.confirm_message(ctx.message) | ||||
|             except NoResultFoundException: | ||||
|                 await ctx.reply(f"No command found matching ``{command}``.") | ||||
|                 await ctx.reply(f"No command found matching `{command}`.") | ||||
|                 return await self.client.reject_message(ctx.message) | ||||
| 
 | ||||
|     @edit_slash.command(name="custom", description="Edit a custom command") | ||||
|  |  | |||
|  | @ -72,7 +72,7 @@ class School(commands.Cog): | |||
|             ufora_course = await ufora_courses.get_course_by_name(session, course) | ||||
| 
 | ||||
|         if ufora_course is None: | ||||
|             return await ctx.reply(f"Found no course matching ``{course}``", ephemeral=True) | ||||
|             return await ctx.reply(f"Found no course matching `{course}`", ephemeral=True) | ||||
| 
 | ||||
|         return await ctx.reply( | ||||
|             f"https://studiekiezer.ugent.be/studiefiche/nl/{ufora_course.code}/{flags.year}", | ||||
|  | @ -80,7 +80,7 @@ class School(commands.Cog): | |||
|         ) | ||||
| 
 | ||||
|     @study_guide.autocomplete("course") | ||||
|     async def study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: | ||||
|     async def _study_guide_autocomplete(self, _: discord.Interaction, current: str) -> list[app_commands.Choice[str]]: | ||||
|         """Autocompletion for the 'course'-parameter""" | ||||
|         return [ | ||||
|             app_commands.Choice(name=course, value=course) | ||||
|  |  | |||
|  | @ -111,6 +111,17 @@ class Didier(commands.Bot): | |||
|             for line in fp: | ||||
|                 self.wordle_words.add(line.strip()) | ||||
| 
 | ||||
|     async def get_reply_target(self, ctx: commands.Context) -> discord.Message: | ||||
|         """Get the target message that should be replied to | ||||
| 
 | ||||
|         In case the invoking message is a reply to something, reply to the | ||||
|         original message instead | ||||
|         """ | ||||
|         if ctx.message.reference is not None: | ||||
|             return await self.resolve_message(ctx.message.reference) | ||||
| 
 | ||||
|         return ctx.message | ||||
| 
 | ||||
|     async def resolve_message(self, reference: discord.MessageReference) -> discord.Message: | ||||
|         """Fetch a message from a reference""" | ||||
|         # Message is in the cache, return it | ||||
|  |  | |||
|  | @ -1,4 +1,5 @@ | |||
| from .custom_commands import CreateCustomCommand, EditCustomCommand | ||||
| from .dad_jokes import AddDadJoke | ||||
| from .links import AddLink | ||||
| 
 | ||||
| __all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand"] | ||||
| __all__ = ["AddDadJoke", "CreateCustomCommand", "EditCustomCommand", "AddLink"] | ||||
|  |  | |||
|  | @ -0,0 +1,37 @@ | |||
| import traceback | ||||
| 
 | ||||
| import discord.ui | ||||
| from overrides import overrides | ||||
| 
 | ||||
| from database.crud.links import add_link | ||||
| from didier import Didier | ||||
| 
 | ||||
| __all__ = ["AddLink"] | ||||
| 
 | ||||
| 
 | ||||
| class AddLink(discord.ui.Modal, title="Add Link"): | ||||
|     """Modal to add a new link""" | ||||
| 
 | ||||
|     name = discord.ui.TextInput(label="Name", style=discord.TextStyle.short, placeholder="Source") | ||||
|     url = discord.ui.TextInput( | ||||
|         label="URL", style=discord.TextStyle.short, placeholder="https://github.com/stijndcl/didier" | ||||
|     ) | ||||
| 
 | ||||
|     client: Didier | ||||
| 
 | ||||
|     def __init__(self, client: Didier, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.client = client | ||||
| 
 | ||||
|     @overrides | ||||
|     async def on_submit(self, interaction: discord.Interaction): | ||||
|         async with self.client.postgres_session as session: | ||||
|             await add_link(session, self.name.value, self.url.value) | ||||
|             await self.client.database_caches.links.invalidate(session) | ||||
| 
 | ||||
|         await interaction.response.send_message(f"Successfully added `{self.name.value.capitalize()}`.", ephemeral=True) | ||||
| 
 | ||||
|     @overrides | ||||
|     async def on_error(self, interaction: discord.Interaction, error: Exception):  # type: ignore | ||||
|         await interaction.response.send_message("Something went wrong.", ephemeral=True) | ||||
|         traceback.print_tb(error.__traceback__) | ||||
|  | @ -7,7 +7,7 @@ from database.utils.caches import UforaCourseCache | |||
| async def test_ufora_course_cache_refresh_empty(postgres: AsyncSession, ufora_course_with_alias: UforaCourse): | ||||
|     """Test loading the data for the Ufora Course cache when it's empty""" | ||||
|     cache = UforaCourseCache() | ||||
|     await cache.refresh(postgres) | ||||
|     await cache.invalidate(postgres) | ||||
| 
 | ||||
|     assert len(cache.data) == 1 | ||||
|     assert cache.data == ["test"] | ||||
|  | @ -20,7 +20,7 @@ async def test_ufora_course_cache_refresh_not_empty(postgres: AsyncSession, ufor | |||
|     cache.data = ["Something"] | ||||
|     cache.data_transformed = ["something"] | ||||
| 
 | ||||
|     await cache.refresh(postgres) | ||||
|     await cache.invalidate(postgres) | ||||
| 
 | ||||
|     assert len(cache.data) == 1 | ||||
|     assert cache.data == ["test"] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue