From 52b452c85ae420e71843914675e309431ae8423a Mon Sep 17 00:00:00 2001 From: stijndcl Date: Mon, 25 Jul 2022 20:33:20 +0200 Subject: [PATCH] Fix mongo connection --- alembic/env.py | 4 ++-- database/engine.py | 11 +++++++++-- database/migrations.py | 6 +++--- didier/cogs/currency.py | 18 +++++++++--------- didier/cogs/discord.py | 4 ++-- didier/cogs/fun.py | 2 +- didier/cogs/owner.py | 8 ++++---- didier/cogs/school.py | 2 +- didier/cogs/tasks.py | 6 +++--- didier/decorators/tasks.py | 2 +- didier/didier.py | 16 +++++++++++----- didier/views/modals/custom_commands.py | 4 ++-- didier/views/modals/dad_jokes.py | 2 +- docker-compose.yml | 1 + requirements.txt | 1 + tests/conftest.py | 4 ++-- 16 files changed, 53 insertions(+), 38 deletions(-) diff --git a/alembic/env.py b/alembic/env.py index 3cca2cf..72d5170 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -4,7 +4,7 @@ from logging.config import fileConfig from sqlalchemy.ext.asyncio import AsyncEngine from alembic import context -from database.engine import engine +from database.engine import postgres_engine from database.models import Base # this is the Alembic Config object, which provides @@ -40,7 +40,7 @@ def run_migrations_online() -> None: and associate a connection with the context. """ - connectable = context.config.attributes.get("connection", None) or engine + connectable = context.config.attributes.get("connection", None) or postgres_engine if isinstance(connectable, AsyncEngine): asyncio.run(run_async_migrations(connectable)) diff --git a/database/engine.py b/database/engine.py index 06bdb93..e73bc27 100644 --- a/database/engine.py +++ b/database/engine.py @@ -1,5 +1,6 @@ from urllib.parse import quote_plus +import motor.motor_asyncio from sqlalchemy.engine import URL from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.orm import sessionmaker @@ -8,7 +9,8 @@ import settings encoded_password = quote_plus(settings.POSTGRES_PASS) -engine = create_async_engine( +# PostgreSQL engine +postgres_engine = create_async_engine( URL.create( drivername="postgresql+asyncpg", username=settings.POSTGRES_USER, @@ -21,4 +23,9 @@ engine = create_async_engine( future=True, ) -DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession, expire_on_commit=False) +DBSession = sessionmaker( + autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False +) + +# MongoDB client +mongo_client = motor.motor_asyncio.AsyncIOMotorClient(settings.MONGO_HOST, settings.MONGO_PORT) diff --git a/database/migrations.py b/database/migrations.py index e842410..20c43c8 100644 --- a/database/migrations.py +++ b/database/migrations.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import Session from alembic import command, script from alembic.config import Config from alembic.runtime import migration -from database.engine import engine +from database.engine import postgres_engine __config_path__ = "alembic.ini" __migrations_path__ = "alembic/" @@ -22,7 +22,7 @@ async def ensure_latest_migration(): """Make sure we are currently on the latest revision, otherwise raise an exception""" alembic_script = script.ScriptDirectory.from_config(cfg) - async with engine.begin() as connection: + async with postgres_engine.begin() as connection: current_revision = await connection.run_sync( lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision() ) @@ -49,5 +49,5 @@ def __execute_downgrade(connection: Session): async def migrate(up: bool): """Migrate the database upwards or downwards""" - async with engine.begin() as connection: + async with postgres_engine.begin() as connection: await connection.run_sync(__execute_upgrade if up else __execute_downgrade) diff --git a/didier/cogs/currency.py b/didier/cogs/currency.py index 4c76e58..1b3dea5 100644 --- a/didier/cogs/currency.py +++ b/didier/cogs/currency.py @@ -31,7 +31,7 @@ class Currency(commands.Cog): """Award a user a given amount of Didier Dinks""" amount = typing.cast(int, amount) - async with self.client.db_session as session: + async with self.client.postgres_session as session: await crud.add_dinks(session, user.id, amount) plural = pluralize("Didier Dink", amount) await ctx.reply( @@ -42,7 +42,7 @@ class Currency(commands.Cog): @commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True) async def bank(self, ctx: commands.Context): """Show your Didier Bank information""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(colour=discord.Colour.blue()) @@ -58,7 +58,7 @@ class Currency(commands.Cog): @bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True) async def bank_upgrades(self, ctx: commands.Context): """List the upgrades you can buy & their prices""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) embed = discord.Embed(colour=discord.Colour.blue()) @@ -79,7 +79,7 @@ class Currency(commands.Cog): @bank_upgrades.command(name="Capacity", aliases=["C"]) async def bank_upgrade_capacity(self, ctx: commands.Context): """Upgrade the capacity level of your bank""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.upgrade_capacity(session, ctx.author.id) await ctx.message.add_reaction("⏫") @@ -90,7 +90,7 @@ class Currency(commands.Cog): @bank_upgrades.command(name="Interest", aliases=["I"]) async def bank_upgrade_interest(self, ctx: commands.Context): """Upgrade the interest level of your bank""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.upgrade_interest(session, ctx.author.id) await ctx.message.add_reaction("⏫") @@ -101,7 +101,7 @@ class Currency(commands.Cog): @bank_upgrades.command(name="Rob", aliases=["R"]) async def bank_upgrade_rob(self, ctx: commands.Context): """Upgrade the rob level of your bank""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.upgrade_rob(session, ctx.author.id) await ctx.message.add_reaction("⏫") @@ -112,7 +112,7 @@ class Currency(commands.Cog): @commands.hybrid_command(name="dinks") async def dinks(self, ctx: commands.Context): """Check your Didier Dinks""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: bank = await crud.get_bank(session, ctx.author.id) plural = pluralize("Didier Dink", bank.dinks) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) @@ -122,7 +122,7 @@ class Currency(commands.Cog): """Invest a given amount of Didier Dinks""" amount = typing.cast(typing.Union[str, int], amount) - async with self.client.db_session as session: + async with self.client.postgres_session as session: invested = await crud.invest(session, ctx.author.id, amount) plural = pluralize("Didier Dink", invested) @@ -136,7 +136,7 @@ class Currency(commands.Cog): @commands.hybrid_command(name="nightly") async def nightly(self, ctx: commands.Context): """Claim nightly Didier Dinks""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await crud.claim_nightly(session, ctx.author.id) await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") diff --git a/didier/cogs/discord.py b/didier/cogs/discord.py index db9ae7d..9dbcc9d 100644 --- a/didier/cogs/discord.py +++ b/didier/cogs/discord.py @@ -19,7 +19,7 @@ class Discord(commands.Cog): async def birthday(self, ctx: commands.Context, user: discord.User = None): """Command to check the birthday of a user""" user_id = (user and user.id) or ctx.author.id - async with self.client.db_session as session: + async with self.client.postgres_session as session: birthday = await birthdays.get_birthday_for_user(session, user_id) name = "Jouw" if user is None else f"{user.display_name}'s" @@ -45,7 +45,7 @@ class Discord(commands.Cog): except ValueError: return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) - async with self.client.db_session as session: + async with self.client.postgres_session as session: await birthdays.add_birthday(session, ctx.author.id, date) await self.client.confirm_message(ctx.message) diff --git a/didier/cogs/fun.py b/didier/cogs/fun.py index ddc119b..0aade01 100644 --- a/didier/cogs/fun.py +++ b/didier/cogs/fun.py @@ -19,7 +19,7 @@ class Fun(commands.Cog): ) async def dad_joke(self, ctx: commands.Context): """Get a random dad joke""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: joke = await get_random_dad_joke(session) return await ctx.reply(joke.joke, mention_author=False) diff --git a/didier/cogs/owner.py b/didier/cogs/owner.py index 30090df..7d623b0 100644 --- a/didier/cogs/owner.py +++ b/didier/cogs/owner.py @@ -83,7 +83,7 @@ class Owner(commands.Cog): @add_msg.command(name="Custom") async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): """Add a new custom command""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await custom_commands.create_command(session, name, response) await self.client.confirm_message(ctx.message) @@ -94,7 +94,7 @@ class Owner(commands.Cog): @add_msg.command(name="Alias") async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): """Add a new alias for a custom command""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await custom_commands.create_alias(session, command, alias) await self.client.confirm_message(ctx.message) @@ -130,7 +130,7 @@ class Owner(commands.Cog): @edit_msg.command(name="Custom") async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): """Edit an existing custom command""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: try: await custom_commands.edit_command(session, command, flags.name, flags.response) return await self.client.confirm_message(ctx.message) @@ -147,7 +147,7 @@ class Owner(commands.Cog): "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True ) - async with self.client.db_session as session: + async with self.client.postgres_session as session: _command = await custom_commands.get_command(session, command) if _command is None: return await interaction.response.send_message( diff --git a/didier/cogs/school.py b/didier/cogs/school.py index ee13a0c..32eb8b9 100644 --- a/didier/cogs/school.py +++ b/didier/cogs/school.py @@ -68,7 +68,7 @@ class School(commands.Cog): @app_commands.describe(course="vak") async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): """Create links to study guides""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: ufora_course = await ufora_courses.get_course_by_name(session, course) if ufora_course is None: diff --git a/didier/cogs/tasks.py b/didier/cogs/tasks.py index 3764331..6a1e5c6 100644 --- a/didier/cogs/tasks.py +++ b/didier/cogs/tasks.py @@ -72,7 +72,7 @@ class Tasks(commands.Cog): async def check_birthdays(self): """Check if it's currently anyone's birthday""" now = tz_aware_now().date() - async with self.client.db_session as session: + async with self.client.postgres_session as session: birthdays = await get_birthdays_on_day(session, now) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) @@ -96,7 +96,7 @@ class Tasks(commands.Cog): if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: return - async with self.client.db_session as db_session: + async with self.client.postgres_session as db_session: announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) announcements = await fetch_ufora_announcements(self.client.http_session, db_session) @@ -110,7 +110,7 @@ class Tasks(commands.Cog): @tasks.loop(hours=24) async def remove_old_ufora_announcements(self): """Remove all announcements that are over 1 week old, once per day""" - async with self.client.db_session as session: + async with self.client.postgres_session as session: await remove_old_announcements(session) @check_birthdays.error diff --git a/didier/decorators/tasks.py b/didier/decorators/tasks.py index 36b4f89..67e3e19 100644 --- a/didier/decorators/tasks.py +++ b/didier/decorators/tasks.py @@ -20,7 +20,7 @@ def timed_task(task: enums.TaskType): async def _wrapper(tasks_cog: Tasks, *args, **kwargs): await func(tasks_cog, *args, **kwargs) - async with tasks_cog.client.db_session as session: + async with tasks_cog.client.postgres_session as session: await set_last_task_execution_time(session, task) return _wrapper diff --git a/didier/didier.py b/didier/didier.py index d5745c2..ce81307 100644 --- a/didier/didier.py +++ b/didier/didier.py @@ -2,13 +2,14 @@ import logging import os import discord +import motor.motor_asyncio from aiohttp import ClientSession from discord.ext import commands from sqlalchemy.ext.asyncio import AsyncSession import settings from database.crud import custom_commands -from database.engine import DBSession +from database.engine import DBSession, mongo_client from database.utils.caches import CacheManager from didier.data.embeds.error_embed import create_error_embed from didier.utils.discord.prefix import get_prefix @@ -45,10 +46,15 @@ class Didier(commands.Bot): ) @property - def db_session(self) -> AsyncSession: - """Obtain a database session""" + def postgres_session(self) -> AsyncSession: + """Obtain a session for the PostgreSQL database""" return DBSession() + @property + def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase: + """Obtain a reference to the MongoDB database""" + return mongo_client[settings.MONGO_DB] + async def setup_hook(self) -> None: """Do some initial setup @@ -60,7 +66,7 @@ class Didier(commands.Bot): # Initialize caches self.database_caches = CacheManager() - async with self.db_session as session: + async with self.postgres_session as session: await self.database_caches.initialize_caches(session) # Create aiohttp session @@ -153,7 +159,7 @@ class Didier(commands.Bot): if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): return False - async with self.db_session as session: + async with self.postgres_session as session: # Remove the prefix content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :] command = await custom_commands.get_command(session, content) diff --git a/didier/views/modals/custom_commands.py b/didier/views/modals/custom_commands.py index 35ac158..2116bd9 100644 --- a/didier/views/modals/custom_commands.py +++ b/didier/views/modals/custom_commands.py @@ -27,7 +27,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"): @overrides async def on_submit(self, interaction: discord.Interaction): - async with self.client.db_session as session: + async with self.client.postgres_session as session: command = await create_command(session, str(self.name.value), str(self.response.value)) await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) @@ -68,7 +68,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"): name_field = typing.cast(discord.ui.TextInput, self.children[0]) response_field = typing.cast(discord.ui.TextInput, self.children[1]) - async with self.client.db_session as session: + async with self.client.postgres_session as session: await edit_command(session, self.original_name, name_field.value, response_field.value) await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) diff --git a/didier/views/modals/dad_jokes.py b/didier/views/modals/dad_jokes.py index 9632197..f52b051 100644 --- a/didier/views/modals/dad_jokes.py +++ b/didier/views/modals/dad_jokes.py @@ -26,7 +26,7 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"): @overrides async def on_submit(self, interaction: discord.Interaction): - async with self.client.db_session as session: + async with self.client.postgres_session as session: joke = await add_dad_joke(session, str(self.name.value)) await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) diff --git a/docker-compose.yml b/docker-compose.yml index 26db0ab..4af0a0e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -19,6 +19,7 @@ services: - MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root} - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root} - MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev} + command: [--auth] ports: - "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}" volumes: diff --git a/requirements.txt b/requirements.txt index 285b936..950679b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -7,6 +7,7 @@ git+https://github.com/Rapptz/discord.py environs==9.5.0 feedparser==6.0.10 markdownify==0.11.2 +motor==3.0.0 overrides==6.1.0 pydantic==1.9.1 python-dateutil==2.8.2 diff --git a/tests/conftest.py b/tests/conftest.py index 219568c..c218524 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,7 +5,7 @@ from unittest.mock import MagicMock import pytest from sqlalchemy.ext.asyncio import AsyncSession -from database.engine import engine +from database.engine import postgres_engine from database.migrations import ensure_latest_migration, migrate from didier import Didier @@ -40,7 +40,7 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]: Rollbacks the transaction afterwards so that the future tests start with a clean database """ - connection = await engine.connect() + connection = await postgres_engine.connect() transaction = await connection.begin() session = AsyncSession(bind=connection, expire_on_commit=False)