mirror of https://github.com/stijndcl/didier
Compare commits
4 Commits
8227190a8d
...
f4056d8af6
| Author | SHA1 | Date |
|---|---|---|
|
|
f4056d8af6 | |
|
|
c9dd275860 | |
|
|
0c810d84e9 | |
|
|
1aeaa71ef8 |
|
|
@ -1,12 +1,10 @@
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.exceptions.not_found import NoResultFoundException
|
from database.exceptions.not_found import NoResultFoundException
|
||||||
from database.models import DadJoke
|
from database.models import DadJoke
|
||||||
|
|
||||||
__all__ = ["add_dad_joke", "edit_dad_joke", "get_random_dad_joke"]
|
__all__ = ["add_dad_joke", "get_random_dad_joke"]
|
||||||
|
|
||||||
|
|
||||||
async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
|
async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
|
||||||
|
|
@ -18,20 +16,6 @@ async def add_dad_joke(session: AsyncSession, joke: str) -> DadJoke:
|
||||||
return dad_joke
|
return dad_joke
|
||||||
|
|
||||||
|
|
||||||
async def edit_dad_joke(session: AsyncSession, joke_id: int, new_joke: str) -> DadJoke:
|
|
||||||
"""Edit an existing dad joke"""
|
|
||||||
statement = select(DadJoke).where(DadJoke.dad_joke_id == joke_id)
|
|
||||||
dad_joke: Optional[DadJoke] = (await session.execute(statement)).scalar_one_or_none()
|
|
||||||
if dad_joke is None:
|
|
||||||
raise NoResultFoundException
|
|
||||||
|
|
||||||
dad_joke.joke = new_joke
|
|
||||||
session.add(dad_joke)
|
|
||||||
await session.commit()
|
|
||||||
|
|
||||||
return dad_joke
|
|
||||||
|
|
||||||
|
|
||||||
async def get_random_dad_joke(session: AsyncSession) -> DadJoke:
|
async def get_random_dad_joke(session: AsyncSession) -> DadJoke:
|
||||||
"""Return a random database entry"""
|
"""Return a random database entry"""
|
||||||
statement = select(DadJoke).order_by(func.random())
|
statement = select(DadJoke).order_by(func.random())
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,10 @@
|
||||||
|
from discord import app_commands
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
from didier.data.apis import urban_dictionary
|
from didier.data.apis import urban_dictionary
|
||||||
|
from didier.data.embeds.google import GoogleSearch
|
||||||
|
from didier.data.scrapers import google
|
||||||
|
|
||||||
|
|
||||||
class Other(commands.Cog):
|
class Other(commands.Cog):
|
||||||
|
|
@ -15,8 +18,18 @@ class Other(commands.Cog):
|
||||||
@commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Woord]")
|
@commands.hybrid_command(name="define", description="Urban Dictionary", aliases=["Ud", "Urban"], usage="[Woord]")
|
||||||
async def define(self, ctx: commands.Context, *, query: str):
|
async def define(self, ctx: commands.Context, *, query: str):
|
||||||
"""Look up the definition of a word on the Urban Dictionary"""
|
"""Look up the definition of a word on the Urban Dictionary"""
|
||||||
definitions = await urban_dictionary.lookup(self.client.http_session, query)
|
async with ctx.typing():
|
||||||
await ctx.reply(embed=definitions[0].to_embed(), mention_author=False)
|
definitions = await urban_dictionary.lookup(self.client.http_session, query)
|
||||||
|
await ctx.reply(embed=definitions[0].to_embed(), mention_author=False)
|
||||||
|
|
||||||
|
@commands.hybrid_command(name="google", description="Google search", usage="[Query]")
|
||||||
|
@app_commands.describe(query="Search query")
|
||||||
|
async def google(self, ctx: commands.Context, *, query: str):
|
||||||
|
"""Google something"""
|
||||||
|
async with ctx.typing():
|
||||||
|
results = await google.google_search(self.client.http_session, query)
|
||||||
|
embed = GoogleSearch(results).to_embed()
|
||||||
|
await ctx.reply(embed=embed, mention_author=False)
|
||||||
|
|
||||||
|
|
||||||
async def setup(client: Didier):
|
async def setup(client: Didier):
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .google_search import GoogleSearch
|
||||||
|
|
||||||
|
__all__ = ["GoogleSearch"]
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
from http import HTTPStatus
|
||||||
|
|
||||||
|
import discord
|
||||||
|
from overrides import overrides
|
||||||
|
|
||||||
|
from didier.data.embeds.base import EmbedBaseModel
|
||||||
|
from didier.data.scrapers.google import SearchData
|
||||||
|
|
||||||
|
__all__ = ["GoogleSearch"]
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleSearch(EmbedBaseModel):
|
||||||
|
"""Embed to display Google search results"""
|
||||||
|
|
||||||
|
data: SearchData
|
||||||
|
|
||||||
|
def __init__(self, data: SearchData):
|
||||||
|
self.data = data
|
||||||
|
|
||||||
|
def _error_embed(self) -> discord.Embed:
|
||||||
|
"""Custom embed for unsuccessful requests"""
|
||||||
|
embed = discord.Embed(colour=discord.Colour.red())
|
||||||
|
embed.set_author(name="Google Search")
|
||||||
|
|
||||||
|
# Empty embed
|
||||||
|
if not self.data.results:
|
||||||
|
embed.description = "Geen resultaten gevonden"
|
||||||
|
return embed
|
||||||
|
|
||||||
|
# Error embed
|
||||||
|
embed.description = f"Status {self.data.status_code}"
|
||||||
|
|
||||||
|
return embed
|
||||||
|
|
||||||
|
@overrides
|
||||||
|
def to_embed(self) -> discord.Embed:
|
||||||
|
if not self.data.results or self.data.status_code != HTTPStatus.OK:
|
||||||
|
return self._error_embed()
|
||||||
|
|
||||||
|
embed = discord.Embed(colour=discord.Colour.blue())
|
||||||
|
embed.set_author(name="Google Search")
|
||||||
|
embed.set_footer(text=self.data.result_stats or None)
|
||||||
|
|
||||||
|
# Add all results into the description
|
||||||
|
results = []
|
||||||
|
for index, url in enumerate(self.data.results):
|
||||||
|
results.append(f"{index + 1}: {url}")
|
||||||
|
|
||||||
|
embed.description = "\n".join(results)
|
||||||
|
return embed
|
||||||
|
|
@ -0,0 +1,94 @@
|
||||||
|
import http
|
||||||
|
import typing
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Optional
|
||||||
|
from urllib.parse import unquote_plus, urlencode
|
||||||
|
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
from bs4.element import Tag
|
||||||
|
|
||||||
|
__all__ = ["google_search", "SearchData"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SearchData:
|
||||||
|
"""Dataclass to store some data about a search query"""
|
||||||
|
|
||||||
|
query: str
|
||||||
|
status_code: int
|
||||||
|
results: list[str] = field(default_factory=list)
|
||||||
|
result_stats: str = ""
|
||||||
|
|
||||||
|
def __post_init__(self):
|
||||||
|
self.query = unquote_plus(self.query)
|
||||||
|
|
||||||
|
|
||||||
|
def get_result_stats(bs: BeautifulSoup) -> Optional[str]:
|
||||||
|
"""Parse the result stats
|
||||||
|
|
||||||
|
Example result: "About 16.570.000 results (0,84 seconds)"
|
||||||
|
"""
|
||||||
|
stats = bs.find("div", id="result-stats")
|
||||||
|
if stats is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return stats.text.removesuffix("\xa0")
|
||||||
|
|
||||||
|
|
||||||
|
def parse_result(element: Tag) -> Optional[str]:
|
||||||
|
"""Parse 1 wrapper into a link"""
|
||||||
|
a_tag = element.find("a", href=True)
|
||||||
|
if a_tag is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
url = a_tag["href"] # type: ignore
|
||||||
|
title = typing.cast(Tag, a_tag.find("h3"))
|
||||||
|
|
||||||
|
if (
|
||||||
|
url is None
|
||||||
|
or not str(url).startswith(
|
||||||
|
(
|
||||||
|
"http://",
|
||||||
|
"https://",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
or title is None
|
||||||
|
):
|
||||||
|
return None
|
||||||
|
|
||||||
|
text = unquote_plus(title.text)
|
||||||
|
return f"[{text}]({url})"
|
||||||
|
|
||||||
|
|
||||||
|
def get_search_results(bs: BeautifulSoup) -> list[str]:
|
||||||
|
"""Parse the search results"""
|
||||||
|
result_wrappers = bs.find_all("div", class_="g")
|
||||||
|
|
||||||
|
results: list[str] = list(result for result in map(parse_result, result_wrappers) if result is not None)
|
||||||
|
|
||||||
|
# Remove duplicates
|
||||||
|
# (sets don't preserve the order!)
|
||||||
|
return list(dict.fromkeys(results))
|
||||||
|
|
||||||
|
|
||||||
|
async def google_search(http_client: ClientSession, query: str):
|
||||||
|
"""Get the first 10 Google search results"""
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_4) AppleWebKit/537.36 (KHTML, like Gecko) "
|
||||||
|
"Chrome/83.0.4103.97 Safari/537.36"
|
||||||
|
}
|
||||||
|
|
||||||
|
query = urlencode({"q": query})
|
||||||
|
|
||||||
|
# Request 20 results in case of duplicates, bad matches, ...
|
||||||
|
async with http_client.get(f"https://www.google.com/search?{query}&num=20&hl=en", headers=headers) as response:
|
||||||
|
# Something went wrong
|
||||||
|
if response.status != http.HTTPStatus.OK:
|
||||||
|
return SearchData(query, response.status)
|
||||||
|
|
||||||
|
bs = BeautifulSoup(await response.text(), "html.parser")
|
||||||
|
result_stats = get_result_stats(bs)
|
||||||
|
results = get_search_results(bs)
|
||||||
|
|
||||||
|
return SearchData(query, 200, results[:10], result_stats or "")
|
||||||
|
|
@ -6,6 +6,7 @@ pytest==7.1.2
|
||||||
pytest-asyncio==0.18.3
|
pytest-asyncio==0.18.3
|
||||||
pytest-env==0.6.2
|
pytest-env==0.6.2
|
||||||
sqlalchemy2-stubs==0.0.2a23
|
sqlalchemy2-stubs==0.0.2a23
|
||||||
|
types-beautifulsoup4==4.11.3
|
||||||
types-pytz==2021.3.8
|
types-pytz==2021.3.8
|
||||||
|
|
||||||
# Flake8 + plugins
|
# Flake8 + plugins
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,7 @@
|
||||||
aiohttp==3.8.1
|
aiohttp==3.8.1
|
||||||
alembic==1.8.0
|
alembic==1.8.0
|
||||||
asyncpg==0.25.0
|
asyncpg==0.25.0
|
||||||
|
beautifulsoup4==4.11.1
|
||||||
# Dev version of dpy
|
# Dev version of dpy
|
||||||
git+https://github.com/Rapptz/discord.py
|
git+https://github.com/Rapptz/discord.py
|
||||||
environs==9.5.0
|
environs==9.5.0
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,16 @@
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from database.crud import dad_jokes as crud
|
||||||
|
from database.models import DadJoke
|
||||||
|
|
||||||
|
|
||||||
|
async def test_add_dad_joke(database_session: AsyncSession):
|
||||||
|
"""Test creating a new joke"""
|
||||||
|
statement = select(DadJoke)
|
||||||
|
result = (await database_session.execute(statement)).scalars().all()
|
||||||
|
assert len(result) == 0
|
||||||
|
|
||||||
|
await crud.add_dad_joke(database_session, "joke")
|
||||||
|
result = (await database_session.execute(statement)).scalars().all()
|
||||||
|
assert len(result) == 1
|
||||||
|
|
@ -9,8 +9,9 @@ async def test_ufora_course_cache_refresh_empty(database_session: AsyncSession,
|
||||||
cache = UforaCourseCache()
|
cache = UforaCourseCache()
|
||||||
await cache.refresh(database_session)
|
await cache.refresh(database_session)
|
||||||
|
|
||||||
assert len(cache.data) == 2
|
assert len(cache.data) == 1
|
||||||
assert cache.data == ["alias", "test"]
|
assert cache.data == ["test"]
|
||||||
|
assert cache.aliases == {"alias": "test"}
|
||||||
|
|
||||||
|
|
||||||
async def test_ufora_course_cache_refresh_not_empty(
|
async def test_ufora_course_cache_refresh_not_empty(
|
||||||
|
|
@ -23,5 +24,6 @@ async def test_ufora_course_cache_refresh_not_empty(
|
||||||
|
|
||||||
await cache.refresh(database_session)
|
await cache.refresh(database_session)
|
||||||
|
|
||||||
assert len(cache.data) == 2
|
assert len(cache.data) == 1
|
||||||
assert cache.data == ["alias", "test"]
|
assert cache.data == ["test"]
|
||||||
|
assert cache.aliases == {"alias": "test"}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue