diff --git a/database/crud/dad_jokes.py b/database/crud/dad_jokes.py index 871c34d..30aa010 100644 --- a/database/crud/dad_jokes.py +++ b/database/crud/dad_jokes.py @@ -1,10 +1,12 @@ +from typing import Optional + from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from database.exceptions.not_found import NoResultFoundException from database.models import DadJoke -__all__ = ["add_dad_joke", "get_random_dad_joke"] +__all__ = ["add_dad_joke", "edit_dad_joke", "get_random_dad_joke"] async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke: @@ -16,6 +18,20 @@ async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke: return dad_joke +async def edit_dad_joke(session: AsyncSession, joke_id: int, new_joke: str) -> DadJoke: + """Edit an existing dad joke""" + statement = select(DadJoke).where(DadJoke.dad_joke_id == joke_id) + dad_joke: Optional[DadJoke] = (await session.execute(statement)).scalar_one_or_none() + if dad_joke is None: + raise NoResultFoundException + + dad_joke.joke = new_joke + session.add(dad_joke) + await session.commit() + + return dad_joke + + async def get_random_dad_joke(session: AsyncSession) -> DadJoke: """Return a random database entry""" statement = select(DadJoke).order_by(func.random()) diff --git a/didier/cogs/other.py b/didier/cogs/other.py index 04175ed..6a036e2 100644 --- a/didier/cogs/other.py +++ b/didier/cogs/other.py @@ -1,10 +1,7 @@ -from discord import app_commands from discord.ext import commands from didier import Didier from didier.data.apis import urban_dictionary -from didier.data.embeds.google import GoogleSearch -from didier.data.scrapers import google class Other(commands.Cog): @@ -18,18 +15,8 @@ class Other(commands.Cog): @commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Woord]") async def define(self, ctx: commands.Context, *, query: str): """Look up the definition of a word on the Urban Dictionary""" - async with ctx.typing(): - definitions = await urban_dictionary.lookup(self.client.http_session, query) - await ctx.reply(embed=definitions[0].to_embed(), mention_author=False) - - @commands.hybrid_command(name="google", description="Google search", usage="[Query]") - @app_commands.describe(query="Search query") - async def google(self, ctx: commands.Context, *, query: str): - """Google something""" - async with ctx.typing(): - results = await google.google_search(self.client.http_session, query) - embed = GoogleSearch(results).to_embed() - await ctx.reply(embed=embed, mention_author=False) + definitions = await urban_dictionary.lookup(self.client.http_session, query) + await ctx.reply(embed=definitions[0].to_embed(), mention_author=False) async def setup(client: Didier): diff --git a/didier/data/embeds/google/__init__.py b/didier/data/embeds/google/__init__.py deleted file mode 100644 index bc57985..0000000 --- a/didier/data/embeds/google/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .google_search import GoogleSearch - -__all__ = ["GoogleSearch"] diff --git a/didier/data/embeds/google/google_search.py b/didier/data/embeds/google/google_search.py deleted file mode 100644 index 9a8eefc..0000000 --- a/didier/data/embeds/google/google_search.py +++ /dev/null @@ -1,50 +0,0 @@ -from http import HTTPStatus - -import discord -from overrides import overrides - -from didier.data.embeds.base import EmbedBaseModel -from didier.data.scrapers.google import SearchData - -__all__ = ["GoogleSearch"] - - -class GoogleSearch(EmbedBaseModel): - """Embed to display Google search results""" - - data: SearchData - - def __init__(self, data: SearchData): - self.data = data - - def _error_embed(self) -> discord.Embed: - """Custom embed for unsuccessful requests""" - embed = discord.Embed(colour=discord.Colour.red()) - embed.set_author(name="Google Search") - - # Empty embed - if not self.data.results: - embed.description = "Geen resultaten gevonden" - return embed - - # Error embed - embed.description = f"Status {self.data.status_code}" - - return embed - - @overrides - def to_embed(self) -> discord.Embed: - if not self.data.results or self.data.status_code != HTTPStatus.OK: - return self._error_embed() - - embed = discord.Embed(colour=discord.Colour.blue()) - embed.set_author(name="Google Search") - embed.set_footer(text=self.data.result_stats or None) - - # Add all results into the description - results = [] - for index, url in enumerate(self.data.results): - results.append(f"{index + 1}: {url}") - - embed.description = "\n".join(results) - return embed diff --git a/didier/data/scrapers/__init__.py b/didier/data/scrapers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/didier/data/scrapers/google.py b/didier/data/scrapers/google.py deleted file mode 100644 index 9ebb003..0000000 --- a/didier/data/scrapers/google.py +++ /dev/null @@ -1,94 +0,0 @@ -import http -import typing -from dataclasses import dataclass, field -from typing import Optional -from urllib.parse import unquote_plus, urlencode - -from aiohttp import ClientSession -from bs4 import BeautifulSoup -from bs4.element import Tag - -__all__ = ["google_search", "SearchData"] - - -@dataclass -class SearchData: - """Dataclass to store some data about a search query""" - - query: str - status_code: int - results: list[str] = field(default_factory=list) - result_stats: str = "" - - def __post_init__(self): - self.query = unquote_plus(self.query) - - -def get_result_stats(bs: BeautifulSoup) -> Optional[str]: - """Parse the result stats - - Example result: "About 16.570.000 results (0,84 seconds)" - """ - stats = bs.find("div", id="result-stats") - if stats is None: - return None - - return stats.text.removesuffix("\xa0") - - -def parse_result(element: Tag) -> Optional[str]: - """Parse 1 wrapper into a link""" - a_tag = element.find("a", href=True) - if a_tag is None: - return None - - url = a_tag["href"] # type: ignore - title = typing.cast(Tag, a_tag.find("h3")) - - if ( - url is None - or not str(url).startswith( - ( - "http://", - "https://", - ) - ) - or title is None - ): - return None - - text = unquote_plus(title.text) - return f"[{text}]({url})" - - -def get_search_results(bs: BeautifulSoup) -> list[str]: - """Parse the search results""" - result_wrappers = bs.find_all("div", class_="g") - - results: list[str] = list(result for result in map(parse_result, result_wrappers) if result is not None) - - # Remove duplicates - # (sets don't preserve the order!) - return list(dict.fromkeys(results)) - - -async def google_search(http_client: ClientSession, query: str): - """Get the first 10 Google search results""" - headers = { - "User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/83.0.4103.97 Safari/537.36" - } - - query = urlencode({"q": query}) - - # Request 20 results in case of duplicates, bad matches, ... - async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en", headers=headers) as response: - # Something went wrong - if response.status != http.HTTPStatus.OK: - return SearchData(query, response.status) - - bs = BeautifulSoup(await response.text(), "html.parser") - result_stats = get_result_stats(bs) - results = get_search_results(bs) - - return SearchData(query, 200, results[:10], result_stats or "") diff --git a/requirements-dev.txt b/requirements-dev.txt index d82fde3..533740c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,7 +6,6 @@ pytest==7.1.2 pytest-asyncio==0.18.3 pytest-env==0.6.2 sqlalchemy2-stubs==0.0.2a23 -types-beautifulsoup4==4.11.3 types-pytz==2021.3.8 # Flake8 + plugins diff --git a/requirements.txt b/requirements.txt index 285b936..0737e1f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,6 @@ aiohttp==3.8.1 alembic==1.8.0 asyncpg==0.25.0 -beautifulsoup4==4.11.1 # Dev version of dpy git+https://github.com/Rapptz/discord.py environs==9.5.0 diff --git a/tests/test_database/test_crud/test_dad_jokes.py b/tests/test_database/test_crud/test_dad_jokes.py deleted file mode 100644 index 0c499c8..0000000 --- a/tests/test_database/test_crud/test_dad_jokes.py +++ /dev/null @@ -1,16 +0,0 @@ -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from database.crud import dad_jokes as crud -from database.models import DadJoke - - -async def test_add_dad_joke(database_session: AsyncSession): - """Test creating a new joke""" - statement = select(DadJoke) - result = (await database_session.execute(statement)).scalars().all() - assert len(result) == 0 - - await crud.add_dad_joke(database_session, "joke") - result = (await database_session.execute(statement)).scalars().all() - assert len(result) == 1 diff --git a/tests/test_database/test_utils/test_caches.py b/tests/test_database/test_utils/test_caches.py index 2e10664..09583d3 100644 --- a/tests/test_database/test_utils/test_caches.py +++ b/tests/test_database/test_utils/test_caches.py @@ -9,9 +9,8 @@ async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession, cache = UforaCourseCache() await cache.refresh(database_session) - assert len(cache.data) == 1 - assert cache.data == ["test"] - assert cache.aliases == {"alias": "test"} + assert len(cache.data) == 2 + assert cache.data == ["alias", "test"] async def test_ufora_course_cache_refresh_not_empty( @@ -24,6 +23,5 @@ async def test_ufora_course_cache_refresh_not_empty( await cache.refresh(database_session) - assert len(cache.data) == 1 - assert cache.data == ["test"] - assert cache.aliases == {"alias": "test"} + assert len(cache.data) == 2 + assert cache.data == ["alias", "test"]