Fix mongo connection

pull/123/head
stijndcl 2022-07-25 20:33:20 +02:00
parent 9abe5dd519
commit 52b452c85a
16 changed files with 53 additions and 38 deletions

View File

@ -4,7 +4,7 @@ from logging.config import fileConfig
from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncEngine
from alembic import context from alembic import context
from database.engine import engine from database.engine import postgres_engine
from database.models import Base from database.models import Base
# this is the Alembic Config object, which provides # this is the Alembic Config object, which provides
@ -40,7 +40,7 @@ def run_migrations_online() -> None:
and associate a connection with the context. and associate a connection with the context.
""" """
connectable = context.config.attributes.get("connection", None) or engine connectable = context.config.attributes.get("connection", None) or postgres_engine
if isinstance(connectable, AsyncEngine): if isinstance(connectable, AsyncEngine):
asyncio.run(run_async_migrations(connectable)) asyncio.run(run_async_migrations(connectable))

View File

@ -1,5 +1,6 @@
from urllib.parse import quote_plus from urllib.parse import quote_plus
import motor.motor_asyncio
from sqlalchemy.engine import URL from sqlalchemy.engine import URL
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
@ -8,7 +9,8 @@ import settings
encoded_password = quote_plus(settings.POSTGRES_PASS) encoded_password = quote_plus(settings.POSTGRES_PASS)
engine = create_async_engine( # PostgreSQL engine
postgres_engine = create_async_engine(
URL.create( URL.create(
drivername="postgresql+asyncpg", drivername="postgresql+asyncpg",
username=settings.POSTGRES_USER, username=settings.POSTGRES_USER,
@ -21,4 +23,9 @@ engine = create_async_engine(
future=True, future=True,
) )
DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession, expire_on_commit=False) DBSession = sessionmaker(
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
)
# MongoDB client
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(settings.MONGO_HOST, settings.MONGO_PORT)

View File

@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
from alembic import command, script from alembic import command, script
from alembic.config import Config from alembic.config import Config
from alembic.runtime import migration from alembic.runtime import migration
from database.engine import engine from database.engine import postgres_engine
__config_path__ = "alembic.ini" __config_path__ = "alembic.ini"
__migrations_path__ = "alembic/" __migrations_path__ = "alembic/"
@ -22,7 +22,7 @@ async def ensure_latest_migration():
"""Make sure we are currently on the latest revision, otherwise raise an exception""" """Make sure we are currently on the latest revision, otherwise raise an exception"""
alembic_script = script.ScriptDirectory.from_config(cfg) alembic_script = script.ScriptDirectory.from_config(cfg)
async with engine.begin() as connection: async with postgres_engine.begin() as connection:
current_revision = await connection.run_sync( current_revision = await connection.run_sync(
lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision() lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision()
) )
@ -49,5 +49,5 @@ def __execute_downgrade(connection: Session):
async def migrate(up: bool): async def migrate(up: bool):
"""Migrate the database upwards or downwards""" """Migrate the database upwards or downwards"""
async with engine.begin() as connection: async with postgres_engine.begin() as connection:
await connection.run_sync(__execute_upgrade if up else __execute_downgrade) await connection.run_sync(__execute_upgrade if up else __execute_downgrade)

View File

@ -31,7 +31,7 @@ class Currency(commands.Cog):
"""Award a user a given amount of Didier Dinks""" """Award a user a given amount of Didier Dinks"""
amount = typing.cast(int, amount) amount = typing.cast(int, amount)
async with self.client.db_session as session: async with self.client.postgres_session as session:
await crud.add_dinks(session, user.id, amount) await crud.add_dinks(session, user.id, amount)
plural = pluralize("Didier Dink", amount) plural = pluralize("Didier Dink", amount)
await ctx.reply( await ctx.reply(
@ -42,7 +42,7 @@ class Currency(commands.Cog):
@commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True) @commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True)
async def bank(self, ctx: commands.Context): async def bank(self, ctx: commands.Context):
"""Show your Didier Bank information""" """Show your Didier Bank information"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
embed = discord.Embed(colour=discord.Colour.blue()) embed = discord.Embed(colour=discord.Colour.blue())
@ -58,7 +58,7 @@ class Currency(commands.Cog):
@bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True) @bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True)
async def bank_upgrades(self, ctx: commands.Context): async def bank_upgrades(self, ctx: commands.Context):
"""List the upgrades you can buy & their prices""" """List the upgrades you can buy & their prices"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
embed = discord.Embed(colour=discord.Colour.blue()) embed = discord.Embed(colour=discord.Colour.blue())
@ -79,7 +79,7 @@ class Currency(commands.Cog):
@bank_upgrades.command(name="Capacity", aliases=["C"]) @bank_upgrades.command(name="Capacity", aliases=["C"])
async def bank_upgrade_capacity(self, ctx: commands.Context): async def bank_upgrade_capacity(self, ctx: commands.Context):
"""Upgrade the capacity level of your bank""" """Upgrade the capacity level of your bank"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.upgrade_capacity(session, ctx.author.id) await crud.upgrade_capacity(session, ctx.author.id)
await ctx.message.add_reaction("") await ctx.message.add_reaction("")
@ -90,7 +90,7 @@ class Currency(commands.Cog):
@bank_upgrades.command(name="Interest", aliases=["I"]) @bank_upgrades.command(name="Interest", aliases=["I"])
async def bank_upgrade_interest(self, ctx: commands.Context): async def bank_upgrade_interest(self, ctx: commands.Context):
"""Upgrade the interest level of your bank""" """Upgrade the interest level of your bank"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.upgrade_interest(session, ctx.author.id) await crud.upgrade_interest(session, ctx.author.id)
await ctx.message.add_reaction("") await ctx.message.add_reaction("")
@ -101,7 +101,7 @@ class Currency(commands.Cog):
@bank_upgrades.command(name="Rob", aliases=["R"]) @bank_upgrades.command(name="Rob", aliases=["R"])
async def bank_upgrade_rob(self, ctx: commands.Context): async def bank_upgrade_rob(self, ctx: commands.Context):
"""Upgrade the rob level of your bank""" """Upgrade the rob level of your bank"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.upgrade_rob(session, ctx.author.id) await crud.upgrade_rob(session, ctx.author.id)
await ctx.message.add_reaction("") await ctx.message.add_reaction("")
@ -112,7 +112,7 @@ class Currency(commands.Cog):
@commands.hybrid_command(name="dinks") @commands.hybrid_command(name="dinks")
async def dinks(self, ctx: commands.Context): async def dinks(self, ctx: commands.Context):
"""Check your Didier Dinks""" """Check your Didier Dinks"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
bank = await crud.get_bank(session, ctx.author.id) bank = await crud.get_bank(session, ctx.author.id)
plural = pluralize("Didier Dink", bank.dinks) plural = pluralize("Didier Dink", bank.dinks)
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False) await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
@ -122,7 +122,7 @@ class Currency(commands.Cog):
"""Invest a given amount of Didier Dinks""" """Invest a given amount of Didier Dinks"""
amount = typing.cast(typing.Union[str, int], amount) amount = typing.cast(typing.Union[str, int], amount)
async with self.client.db_session as session: async with self.client.postgres_session as session:
invested = await crud.invest(session, ctx.author.id, amount) invested = await crud.invest(session, ctx.author.id, amount)
plural = pluralize("Didier Dink", invested) plural = pluralize("Didier Dink", invested)
@ -136,7 +136,7 @@ class Currency(commands.Cog):
@commands.hybrid_command(name="nightly") @commands.hybrid_command(name="nightly")
async def nightly(self, ctx: commands.Context): async def nightly(self, ctx: commands.Context):
"""Claim nightly Didier Dinks""" """Claim nightly Didier Dinks"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await crud.claim_nightly(session, ctx.author.id) await crud.claim_nightly(session, ctx.author.id)
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.") await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")

View File

@ -19,7 +19,7 @@ class Discord(commands.Cog):
async def birthday(self, ctx: commands.Context, user: discord.User = None): async def birthday(self, ctx: commands.Context, user: discord.User = None):
"""Command to check the birthday of a user""" """Command to check the birthday of a user"""
user_id = (user and user.id) or ctx.author.id user_id = (user and user.id) or ctx.author.id
async with self.client.db_session as session: async with self.client.postgres_session as session:
birthday = await birthdays.get_birthday_for_user(session, user_id) birthday = await birthdays.get_birthday_for_user(session, user_id)
name = "Jouw" if user is None else f"{user.display_name}'s" name = "Jouw" if user is None else f"{user.display_name}'s"
@ -45,7 +45,7 @@ class Discord(commands.Cog):
except ValueError: except ValueError:
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False) return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
async with self.client.db_session as session: async with self.client.postgres_session as session:
await birthdays.add_birthday(session, ctx.author.id, date) await birthdays.add_birthday(session, ctx.author.id, date)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)

View File

@ -19,7 +19,7 @@ class Fun(commands.Cog):
) )
async def dad_joke(self, ctx: commands.Context): async def dad_joke(self, ctx: commands.Context):
"""Get a random dad joke""" """Get a random dad joke"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
joke = await get_random_dad_joke(session) joke = await get_random_dad_joke(session)
return await ctx.reply(joke.joke, mention_author=False) return await ctx.reply(joke.joke, mention_author=False)

View File

@ -83,7 +83,7 @@ class Owner(commands.Cog):
@add_msg.command(name="Custom") @add_msg.command(name="Custom")
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str): async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
"""Add a new custom command""" """Add a new custom command"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await custom_commands.create_command(session, name, response) await custom_commands.create_command(session, name, response)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
@ -94,7 +94,7 @@ class Owner(commands.Cog):
@add_msg.command(name="Alias") @add_msg.command(name="Alias")
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str): async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
"""Add a new alias for a custom command""" """Add a new alias for a custom command"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await custom_commands.create_alias(session, command, alias) await custom_commands.create_alias(session, command, alias)
await self.client.confirm_message(ctx.message) await self.client.confirm_message(ctx.message)
@ -130,7 +130,7 @@ class Owner(commands.Cog):
@edit_msg.command(name="Custom") @edit_msg.command(name="Custom")
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags): async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
"""Edit an existing custom command""" """Edit an existing custom command"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
try: try:
await custom_commands.edit_command(session, command, flags.name, flags.response) await custom_commands.edit_command(session, command, flags.name, flags.response)
return await self.client.confirm_message(ctx.message) return await self.client.confirm_message(ctx.message)
@ -147,7 +147,7 @@ class Owner(commands.Cog):
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True "Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
) )
async with self.client.db_session as session: async with self.client.postgres_session as session:
_command = await custom_commands.get_command(session, command) _command = await custom_commands.get_command(session, command)
if _command is None: if _command is None:
return await interaction.response.send_message( return await interaction.response.send_message(

View File

@ -68,7 +68,7 @@ class School(commands.Cog):
@app_commands.describe(course="vak") @app_commands.describe(course="vak")
async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags): async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
"""Create links to study guides""" """Create links to study guides"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
ufora_course = await ufora_courses.get_course_by_name(session, course) ufora_course = await ufora_courses.get_course_by_name(session, course)
if ufora_course is None: if ufora_course is None:

View File

@ -72,7 +72,7 @@ class Tasks(commands.Cog):
async def check_birthdays(self): async def check_birthdays(self):
"""Check if it's currently anyone's birthday""" """Check if it's currently anyone's birthday"""
now = tz_aware_now().date() now = tz_aware_now().date()
async with self.client.db_session as session: async with self.client.postgres_session as session:
birthdays = await get_birthdays_on_day(session, now) birthdays = await get_birthdays_on_day(session, now)
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL) channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
@ -96,7 +96,7 @@ class Tasks(commands.Cog):
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None: if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
return return
async with self.client.db_session as db_session: async with self.client.postgres_session as db_session:
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL) announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
announcements = await fetch_ufora_announcements(self.client.http_session, db_session) announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
@ -110,7 +110,7 @@ class Tasks(commands.Cog):
@tasks.loop(hours=24) @tasks.loop(hours=24)
async def remove_old_ufora_announcements(self): async def remove_old_ufora_announcements(self):
"""Remove all announcements that are over 1 week old, once per day""" """Remove all announcements that are over 1 week old, once per day"""
async with self.client.db_session as session: async with self.client.postgres_session as session:
await remove_old_announcements(session) await remove_old_announcements(session)
@check_birthdays.error @check_birthdays.error

View File

@ -20,7 +20,7 @@ def timed_task(task: enums.TaskType):
async def _wrapper(tasks_cog: Tasks, *args, **kwargs): async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
await func(tasks_cog, *args, **kwargs) await func(tasks_cog, *args, **kwargs)
async with tasks_cog.client.db_session as session: async with tasks_cog.client.postgres_session as session:
await set_last_task_execution_time(session, task) await set_last_task_execution_time(session, task)
return _wrapper return _wrapper

View File

@ -2,13 +2,14 @@ import logging
import os import os
import discord import discord
import motor.motor_asyncio
from aiohttp import ClientSession from aiohttp import ClientSession
from discord.ext import commands from discord.ext import commands
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
import settings import settings
from database.crud import custom_commands from database.crud import custom_commands
from database.engine import DBSession from database.engine import DBSession, mongo_client
from database.utils.caches import CacheManager from database.utils.caches import CacheManager
from didier.data.embeds.error_embed import create_error_embed from didier.data.embeds.error_embed import create_error_embed
from didier.utils.discord.prefix import get_prefix from didier.utils.discord.prefix import get_prefix
@ -45,10 +46,15 @@ class Didier(commands.Bot):
) )
@property @property
def db_session(self) -> AsyncSession: def postgres_session(self) -> AsyncSession:
"""Obtain a database session""" """Obtain a session for the PostgreSQL database"""
return DBSession() return DBSession()
@property
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
"""Obtain a reference to the MongoDB database"""
return mongo_client[settings.MONGO_DB]
async def setup_hook(self) -> None: async def setup_hook(self) -> None:
"""Do some initial setup """Do some initial setup
@ -60,7 +66,7 @@ class Didier(commands.Bot):
# Initialize caches # Initialize caches
self.database_caches = CacheManager() self.database_caches = CacheManager()
async with self.db_session as session: async with self.postgres_session as session:
await self.database_caches.initialize_caches(session) await self.database_caches.initialize_caches(session)
# Create aiohttp session # Create aiohttp session
@ -153,7 +159,7 @@ class Didier(commands.Bot):
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX): if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
return False return False
async with self.db_session as session: async with self.postgres_session as session:
# Remove the prefix # Remove the prefix
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :] content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
command = await custom_commands.get_command(session, content) command = await custom_commands.get_command(session, content)

View File

@ -27,7 +27,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction): async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session: async with self.client.postgres_session as session:
command = await create_command(session, str(self.name.value), str(self.response.value)) command = await create_command(session, str(self.name.value), str(self.response.value))
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True) await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
@ -68,7 +68,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
name_field = typing.cast(discord.ui.TextInput, self.children[0]) name_field = typing.cast(discord.ui.TextInput, self.children[0])
response_field = typing.cast(discord.ui.TextInput, self.children[1]) response_field = typing.cast(discord.ui.TextInput, self.children[1])
async with self.client.db_session as session: async with self.client.postgres_session as session:
await edit_command(session, self.original_name, name_field.value, response_field.value) await edit_command(session, self.original_name, name_field.value, response_field.value)
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True) await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)

View File

@ -26,7 +26,7 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
@overrides @overrides
async def on_submit(self, interaction: discord.Interaction): async def on_submit(self, interaction: discord.Interaction):
async with self.client.db_session as session: async with self.client.postgres_session as session:
joke = await add_dad_joke(session, str(self.name.value)) joke = await add_dad_joke(session, str(self.name.value))
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True) await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)

View File

@ -19,6 +19,7 @@ services:
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root} - MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root} - MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev} - MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
command: [--auth]
ports: ports:
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}" - "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
volumes: volumes:

View File

@ -7,6 +7,7 @@ git+https://github.com/Rapptz/discord.py
environs==9.5.0 environs==9.5.0
feedparser==6.0.10 feedparser==6.0.10
markdownify==0.11.2 markdownify==0.11.2
motor==3.0.0
overrides==6.1.0 overrides==6.1.0
pydantic==1.9.1 pydantic==1.9.1
python-dateutil==2.8.2 python-dateutil==2.8.2

View File

@ -5,7 +5,7 @@ from unittest.mock import MagicMock
import pytest import pytest
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from database.engine import engine from database.engine import postgres_engine
from database.migrations import ensure_latest_migration, migrate from database.migrations import ensure_latest_migration, migrate
from didier import Didier from didier import Didier
@ -40,7 +40,7 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
Rollbacks the transaction afterwards so that the future tests start with a clean database Rollbacks the transaction afterwards so that the future tests start with a clean database
""" """
connection = await engine.connect() connection = await postgres_engine.connect()
transaction = await connection.begin() transaction = await connection.begin()
session = AsyncSession(bind=connection, expire_on_commit=False) session = AsyncSession(bind=connection, expire_on_commit=False)