mirror of https://github.com/stijndcl/didier
Fix mongo connection
parent
9abe5dd519
commit
52b452c85a
|
@ -4,7 +4,7 @@ from logging.config import fileConfig
|
||||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from database.engine import engine
|
from database.engine import postgres_engine
|
||||||
from database.models import Base
|
from database.models import Base
|
||||||
|
|
||||||
# this is the Alembic Config object, which provides
|
# this is the Alembic Config object, which provides
|
||||||
|
@ -40,7 +40,7 @@ def run_migrations_online() -> None:
|
||||||
and associate a connection with the context.
|
and associate a connection with the context.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
connectable = context.config.attributes.get("connection", None) or engine
|
connectable = context.config.attributes.get("connection", None) or postgres_engine
|
||||||
|
|
||||||
if isinstance(connectable, AsyncEngine):
|
if isinstance(connectable, AsyncEngine):
|
||||||
asyncio.run(run_async_migrations(connectable))
|
asyncio.run(run_async_migrations(connectable))
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
import motor.motor_asyncio
|
||||||
from sqlalchemy.engine import URL
|
from sqlalchemy.engine import URL
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
|
||||||
from sqlalchemy.orm import sessionmaker
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
@ -8,7 +9,8 @@ import settings
|
||||||
|
|
||||||
encoded_password = quote_plus(settings.POSTGRES_PASS)
|
encoded_password = quote_plus(settings.POSTGRES_PASS)
|
||||||
|
|
||||||
engine = create_async_engine(
|
# PostgreSQL engine
|
||||||
|
postgres_engine = create_async_engine(
|
||||||
URL.create(
|
URL.create(
|
||||||
drivername="postgresql+asyncpg",
|
drivername="postgresql+asyncpg",
|
||||||
username=settings.POSTGRES_USER,
|
username=settings.POSTGRES_USER,
|
||||||
|
@ -21,4 +23,9 @@ engine = create_async_engine(
|
||||||
future=True,
|
future=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
DBSession = sessionmaker(autocommit=False, autoflush=False, bind=engine, class_=AsyncSession, expire_on_commit=False)
|
DBSession = sessionmaker(
|
||||||
|
autocommit=False, autoflush=False, bind=postgres_engine, class_=AsyncSession, expire_on_commit=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# MongoDB client
|
||||||
|
mongo_client = motor.motor_asyncio.AsyncIOMotorClient(settings.MONGO_HOST, settings.MONGO_PORT)
|
||||||
|
|
|
@ -5,7 +5,7 @@ from sqlalchemy.orm import Session
|
||||||
from alembic import command, script
|
from alembic import command, script
|
||||||
from alembic.config import Config
|
from alembic.config import Config
|
||||||
from alembic.runtime import migration
|
from alembic.runtime import migration
|
||||||
from database.engine import engine
|
from database.engine import postgres_engine
|
||||||
|
|
||||||
__config_path__ = "alembic.ini"
|
__config_path__ = "alembic.ini"
|
||||||
__migrations_path__ = "alembic/"
|
__migrations_path__ = "alembic/"
|
||||||
|
@ -22,7 +22,7 @@ async def ensure_latest_migration():
|
||||||
"""Make sure we are currently on the latest revision, otherwise raise an exception"""
|
"""Make sure we are currently on the latest revision, otherwise raise an exception"""
|
||||||
alembic_script = script.ScriptDirectory.from_config(cfg)
|
alembic_script = script.ScriptDirectory.from_config(cfg)
|
||||||
|
|
||||||
async with engine.begin() as connection:
|
async with postgres_engine.begin() as connection:
|
||||||
current_revision = await connection.run_sync(
|
current_revision = await connection.run_sync(
|
||||||
lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision()
|
lambda sync_connection: migration.MigrationContext.configure(sync_connection).get_current_revision()
|
||||||
)
|
)
|
||||||
|
@ -49,5 +49,5 @@ def __execute_downgrade(connection: Session):
|
||||||
|
|
||||||
async def migrate(up: bool):
|
async def migrate(up: bool):
|
||||||
"""Migrate the database upwards or downwards"""
|
"""Migrate the database upwards or downwards"""
|
||||||
async with engine.begin() as connection:
|
async with postgres_engine.begin() as connection:
|
||||||
await connection.run_sync(__execute_upgrade if up else __execute_downgrade)
|
await connection.run_sync(__execute_upgrade if up else __execute_downgrade)
|
||||||
|
|
|
@ -31,7 +31,7 @@ class Currency(commands.Cog):
|
||||||
"""Award a user a given amount of Didier Dinks"""
|
"""Award a user a given amount of Didier Dinks"""
|
||||||
amount = typing.cast(int, amount)
|
amount = typing.cast(int, amount)
|
||||||
|
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
await crud.add_dinks(session, user.id, amount)
|
await crud.add_dinks(session, user.id, amount)
|
||||||
plural = pluralize("Didier Dink", amount)
|
plural = pluralize("Didier Dink", amount)
|
||||||
await ctx.reply(
|
await ctx.reply(
|
||||||
|
@ -42,7 +42,7 @@ class Currency(commands.Cog):
|
||||||
@commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True)
|
@commands.group(name="bank", aliases=["B"], case_insensitive=True, invoke_without_command=True)
|
||||||
async def bank(self, ctx: commands.Context):
|
async def bank(self, ctx: commands.Context):
|
||||||
"""Show your Didier Bank information"""
|
"""Show your Didier Bank information"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
bank = await crud.get_bank(session, ctx.author.id)
|
bank = await crud.get_bank(session, ctx.author.id)
|
||||||
|
|
||||||
embed = discord.Embed(colour=discord.Colour.blue())
|
embed = discord.Embed(colour=discord.Colour.blue())
|
||||||
|
@ -58,7 +58,7 @@ class Currency(commands.Cog):
|
||||||
@bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True)
|
@bank.group(name="Upgrade", aliases=["U", "Upgrades"], case_insensitive=True, invoke_without_command=True)
|
||||||
async def bank_upgrades(self, ctx: commands.Context):
|
async def bank_upgrades(self, ctx: commands.Context):
|
||||||
"""List the upgrades you can buy & their prices"""
|
"""List the upgrades you can buy & their prices"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
bank = await crud.get_bank(session, ctx.author.id)
|
bank = await crud.get_bank(session, ctx.author.id)
|
||||||
|
|
||||||
embed = discord.Embed(colour=discord.Colour.blue())
|
embed = discord.Embed(colour=discord.Colour.blue())
|
||||||
|
@ -79,7 +79,7 @@ class Currency(commands.Cog):
|
||||||
@bank_upgrades.command(name="Capacity", aliases=["C"])
|
@bank_upgrades.command(name="Capacity", aliases=["C"])
|
||||||
async def bank_upgrade_capacity(self, ctx: commands.Context):
|
async def bank_upgrade_capacity(self, ctx: commands.Context):
|
||||||
"""Upgrade the capacity level of your bank"""
|
"""Upgrade the capacity level of your bank"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await crud.upgrade_capacity(session, ctx.author.id)
|
await crud.upgrade_capacity(session, ctx.author.id)
|
||||||
await ctx.message.add_reaction("⏫")
|
await ctx.message.add_reaction("⏫")
|
||||||
|
@ -90,7 +90,7 @@ class Currency(commands.Cog):
|
||||||
@bank_upgrades.command(name="Interest", aliases=["I"])
|
@bank_upgrades.command(name="Interest", aliases=["I"])
|
||||||
async def bank_upgrade_interest(self, ctx: commands.Context):
|
async def bank_upgrade_interest(self, ctx: commands.Context):
|
||||||
"""Upgrade the interest level of your bank"""
|
"""Upgrade the interest level of your bank"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await crud.upgrade_interest(session, ctx.author.id)
|
await crud.upgrade_interest(session, ctx.author.id)
|
||||||
await ctx.message.add_reaction("⏫")
|
await ctx.message.add_reaction("⏫")
|
||||||
|
@ -101,7 +101,7 @@ class Currency(commands.Cog):
|
||||||
@bank_upgrades.command(name="Rob", aliases=["R"])
|
@bank_upgrades.command(name="Rob", aliases=["R"])
|
||||||
async def bank_upgrade_rob(self, ctx: commands.Context):
|
async def bank_upgrade_rob(self, ctx: commands.Context):
|
||||||
"""Upgrade the rob level of your bank"""
|
"""Upgrade the rob level of your bank"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await crud.upgrade_rob(session, ctx.author.id)
|
await crud.upgrade_rob(session, ctx.author.id)
|
||||||
await ctx.message.add_reaction("⏫")
|
await ctx.message.add_reaction("⏫")
|
||||||
|
@ -112,7 +112,7 @@ class Currency(commands.Cog):
|
||||||
@commands.hybrid_command(name="dinks")
|
@commands.hybrid_command(name="dinks")
|
||||||
async def dinks(self, ctx: commands.Context):
|
async def dinks(self, ctx: commands.Context):
|
||||||
"""Check your Didier Dinks"""
|
"""Check your Didier Dinks"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
bank = await crud.get_bank(session, ctx.author.id)
|
bank = await crud.get_bank(session, ctx.author.id)
|
||||||
plural = pluralize("Didier Dink", bank.dinks)
|
plural = pluralize("Didier Dink", bank.dinks)
|
||||||
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
|
await ctx.reply(f"**{ctx.author.display_name}** heeft **{bank.dinks}** {plural}.", mention_author=False)
|
||||||
|
@ -122,7 +122,7 @@ class Currency(commands.Cog):
|
||||||
"""Invest a given amount of Didier Dinks"""
|
"""Invest a given amount of Didier Dinks"""
|
||||||
amount = typing.cast(typing.Union[str, int], amount)
|
amount = typing.cast(typing.Union[str, int], amount)
|
||||||
|
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
invested = await crud.invest(session, ctx.author.id, amount)
|
invested = await crud.invest(session, ctx.author.id, amount)
|
||||||
plural = pluralize("Didier Dink", invested)
|
plural = pluralize("Didier Dink", invested)
|
||||||
|
|
||||||
|
@ -136,7 +136,7 @@ class Currency(commands.Cog):
|
||||||
@commands.hybrid_command(name="nightly")
|
@commands.hybrid_command(name="nightly")
|
||||||
async def nightly(self, ctx: commands.Context):
|
async def nightly(self, ctx: commands.Context):
|
||||||
"""Claim nightly Didier Dinks"""
|
"""Claim nightly Didier Dinks"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await crud.claim_nightly(session, ctx.author.id)
|
await crud.claim_nightly(session, ctx.author.id)
|
||||||
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")
|
await ctx.reply(f"Je hebt je dagelijkse **{crud.NIGHTLY_AMOUNT}** Didier Dinks geclaimd.")
|
||||||
|
|
|
@ -19,7 +19,7 @@ class Discord(commands.Cog):
|
||||||
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
async def birthday(self, ctx: commands.Context, user: discord.User = None):
|
||||||
"""Command to check the birthday of a user"""
|
"""Command to check the birthday of a user"""
|
||||||
user_id = (user and user.id) or ctx.author.id
|
user_id = (user and user.id) or ctx.author.id
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
birthday = await birthdays.get_birthday_for_user(session, user_id)
|
birthday = await birthdays.get_birthday_for_user(session, user_id)
|
||||||
|
|
||||||
name = "Jouw" if user is None else f"{user.display_name}'s"
|
name = "Jouw" if user is None else f"{user.display_name}'s"
|
||||||
|
@ -45,7 +45,7 @@ class Discord(commands.Cog):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
return await ctx.reply(f"`{date_str}` is geen geldige datum.", mention_author=False)
|
||||||
|
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
await birthdays.add_birthday(session, ctx.author.id, date)
|
await birthdays.add_birthday(session, ctx.author.id, date)
|
||||||
await self.client.confirm_message(ctx.message)
|
await self.client.confirm_message(ctx.message)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ class Fun(commands.Cog):
|
||||||
)
|
)
|
||||||
async def dad_joke(self, ctx: commands.Context):
|
async def dad_joke(self, ctx: commands.Context):
|
||||||
"""Get a random dad joke"""
|
"""Get a random dad joke"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
joke = await get_random_dad_joke(session)
|
joke = await get_random_dad_joke(session)
|
||||||
return await ctx.reply(joke.joke, mention_author=False)
|
return await ctx.reply(joke.joke, mention_author=False)
|
||||||
|
|
||||||
|
|
|
@ -83,7 +83,7 @@ class Owner(commands.Cog):
|
||||||
@add_msg.command(name="Custom")
|
@add_msg.command(name="Custom")
|
||||||
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
|
async def add_custom_msg(self, ctx: commands.Context, name: str, *, response: str):
|
||||||
"""Add a new custom command"""
|
"""Add a new custom command"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await custom_commands.create_command(session, name, response)
|
await custom_commands.create_command(session, name, response)
|
||||||
await self.client.confirm_message(ctx.message)
|
await self.client.confirm_message(ctx.message)
|
||||||
|
@ -94,7 +94,7 @@ class Owner(commands.Cog):
|
||||||
@add_msg.command(name="Alias")
|
@add_msg.command(name="Alias")
|
||||||
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
|
async def add_alias_msg(self, ctx: commands.Context, command: str, alias: str):
|
||||||
"""Add a new alias for a custom command"""
|
"""Add a new alias for a custom command"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await custom_commands.create_alias(session, command, alias)
|
await custom_commands.create_alias(session, command, alias)
|
||||||
await self.client.confirm_message(ctx.message)
|
await self.client.confirm_message(ctx.message)
|
||||||
|
@ -130,7 +130,7 @@ class Owner(commands.Cog):
|
||||||
@edit_msg.command(name="Custom")
|
@edit_msg.command(name="Custom")
|
||||||
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
|
async def edit_custom_msg(self, ctx: commands.Context, command: str, *, flags: EditCustomFlags):
|
||||||
"""Edit an existing custom command"""
|
"""Edit an existing custom command"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
try:
|
try:
|
||||||
await custom_commands.edit_command(session, command, flags.name, flags.response)
|
await custom_commands.edit_command(session, command, flags.name, flags.response)
|
||||||
return await self.client.confirm_message(ctx.message)
|
return await self.client.confirm_message(ctx.message)
|
||||||
|
@ -147,7 +147,7 @@ class Owner(commands.Cog):
|
||||||
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
"Je hebt geen toestemming om dit commando uit te voeren.", ephemeral=True
|
||||||
)
|
)
|
||||||
|
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
_command = await custom_commands.get_command(session, command)
|
_command = await custom_commands.get_command(session, command)
|
||||||
if _command is None:
|
if _command is None:
|
||||||
return await interaction.response.send_message(
|
return await interaction.response.send_message(
|
||||||
|
|
|
@ -68,7 +68,7 @@ class School(commands.Cog):
|
||||||
@app_commands.describe(course="vak")
|
@app_commands.describe(course="vak")
|
||||||
async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
|
async def study_guide(self, ctx: commands.Context, course: str, *, flags: StudyGuideFlags):
|
||||||
"""Create links to study guides"""
|
"""Create links to study guides"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
ufora_course = await ufora_courses.get_course_by_name(session, course)
|
ufora_course = await ufora_courses.get_course_by_name(session, course)
|
||||||
|
|
||||||
if ufora_course is None:
|
if ufora_course is None:
|
||||||
|
|
|
@ -72,7 +72,7 @@ class Tasks(commands.Cog):
|
||||||
async def check_birthdays(self):
|
async def check_birthdays(self):
|
||||||
"""Check if it's currently anyone's birthday"""
|
"""Check if it's currently anyone's birthday"""
|
||||||
now = tz_aware_now().date()
|
now = tz_aware_now().date()
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
birthdays = await get_birthdays_on_day(session, now)
|
birthdays = await get_birthdays_on_day(session, now)
|
||||||
|
|
||||||
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
channel = self.client.get_channel(settings.BIRTHDAY_ANNOUNCEMENT_CHANNEL)
|
||||||
|
@ -96,7 +96,7 @@ class Tasks(commands.Cog):
|
||||||
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
if settings.UFORA_RSS_TOKEN is None or settings.UFORA_ANNOUNCEMENTS_CHANNEL is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
async with self.client.db_session as db_session:
|
async with self.client.postgres_session as db_session:
|
||||||
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
|
announcements_channel = self.client.get_channel(settings.UFORA_ANNOUNCEMENTS_CHANNEL)
|
||||||
announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
|
announcements = await fetch_ufora_announcements(self.client.http_session, db_session)
|
||||||
|
|
||||||
|
@ -110,7 +110,7 @@ class Tasks(commands.Cog):
|
||||||
@tasks.loop(hours=24)
|
@tasks.loop(hours=24)
|
||||||
async def remove_old_ufora_announcements(self):
|
async def remove_old_ufora_announcements(self):
|
||||||
"""Remove all announcements that are over 1 week old, once per day"""
|
"""Remove all announcements that are over 1 week old, once per day"""
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
await remove_old_announcements(session)
|
await remove_old_announcements(session)
|
||||||
|
|
||||||
@check_birthdays.error
|
@check_birthdays.error
|
||||||
|
|
|
@ -20,7 +20,7 @@ def timed_task(task: enums.TaskType):
|
||||||
async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
|
async def _wrapper(tasks_cog: Tasks, *args, **kwargs):
|
||||||
await func(tasks_cog, *args, **kwargs)
|
await func(tasks_cog, *args, **kwargs)
|
||||||
|
|
||||||
async with tasks_cog.client.db_session as session:
|
async with tasks_cog.client.postgres_session as session:
|
||||||
await set_last_task_execution_time(session, task)
|
await set_last_task_execution_time(session, task)
|
||||||
|
|
||||||
return _wrapper
|
return _wrapper
|
||||||
|
|
|
@ -2,13 +2,14 @@ import logging
|
||||||
import os
|
import os
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import motor.motor_asyncio
|
||||||
from aiohttp import ClientSession
|
from aiohttp import ClientSession
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
import settings
|
import settings
|
||||||
from database.crud import custom_commands
|
from database.crud import custom_commands
|
||||||
from database.engine import DBSession
|
from database.engine import DBSession, mongo_client
|
||||||
from database.utils.caches import CacheManager
|
from database.utils.caches import CacheManager
|
||||||
from didier.data.embeds.error_embed import create_error_embed
|
from didier.data.embeds.error_embed import create_error_embed
|
||||||
from didier.utils.discord.prefix import get_prefix
|
from didier.utils.discord.prefix import get_prefix
|
||||||
|
@ -45,10 +46,15 @@ class Didier(commands.Bot):
|
||||||
)
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def db_session(self) -> AsyncSession:
|
def postgres_session(self) -> AsyncSession:
|
||||||
"""Obtain a database session"""
|
"""Obtain a session for the PostgreSQL database"""
|
||||||
return DBSession()
|
return DBSession()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def mongo_db(self) -> motor.motor_asyncio.AsyncIOMotorDatabase:
|
||||||
|
"""Obtain a reference to the MongoDB database"""
|
||||||
|
return mongo_client[settings.MONGO_DB]
|
||||||
|
|
||||||
async def setup_hook(self) -> None:
|
async def setup_hook(self) -> None:
|
||||||
"""Do some initial setup
|
"""Do some initial setup
|
||||||
|
|
||||||
|
@ -60,7 +66,7 @@ class Didier(commands.Bot):
|
||||||
|
|
||||||
# Initialize caches
|
# Initialize caches
|
||||||
self.database_caches = CacheManager()
|
self.database_caches = CacheManager()
|
||||||
async with self.db_session as session:
|
async with self.postgres_session as session:
|
||||||
await self.database_caches.initialize_caches(session)
|
await self.database_caches.initialize_caches(session)
|
||||||
|
|
||||||
# Create aiohttp session
|
# Create aiohttp session
|
||||||
|
@ -153,7 +159,7 @@ class Didier(commands.Bot):
|
||||||
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
|
if not message.content.startswith(settings.DISCORD_CUSTOM_COMMAND_PREFIX):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async with self.db_session as session:
|
async with self.postgres_session as session:
|
||||||
# Remove the prefix
|
# Remove the prefix
|
||||||
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
|
content = message.content[len(settings.DISCORD_CUSTOM_COMMAND_PREFIX) :]
|
||||||
command = await custom_commands.get_command(session, content)
|
command = await custom_commands.get_command(session, content)
|
||||||
|
|
|
@ -27,7 +27,7 @@ class CreateCustomCommand(discord.ui.Modal, title="Create Custom Command"):
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
async def on_submit(self, interaction: discord.Interaction):
|
async def on_submit(self, interaction: discord.Interaction):
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
command = await create_command(session, str(self.name.value), str(self.response.value))
|
command = await create_command(session, str(self.name.value), str(self.response.value))
|
||||||
|
|
||||||
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
|
await interaction.response.send_message(f"Successfully created ``{command.name}``.", ephemeral=True)
|
||||||
|
@ -68,7 +68,7 @@ class EditCustomCommand(discord.ui.Modal, title="Edit Custom Command"):
|
||||||
name_field = typing.cast(discord.ui.TextInput, self.children[0])
|
name_field = typing.cast(discord.ui.TextInput, self.children[0])
|
||||||
response_field = typing.cast(discord.ui.TextInput, self.children[1])
|
response_field = typing.cast(discord.ui.TextInput, self.children[1])
|
||||||
|
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
await edit_command(session, self.original_name, name_field.value, response_field.value)
|
await edit_command(session, self.original_name, name_field.value, response_field.value)
|
||||||
|
|
||||||
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)
|
await interaction.response.send_message(f"Successfully edited ``{self.original_name}``.", ephemeral=True)
|
||||||
|
|
|
@ -26,7 +26,7 @@ class AddDadJoke(discord.ui.Modal, title="Add Dad Joke"):
|
||||||
|
|
||||||
@overrides
|
@overrides
|
||||||
async def on_submit(self, interaction: discord.Interaction):
|
async def on_submit(self, interaction: discord.Interaction):
|
||||||
async with self.client.db_session as session:
|
async with self.client.postgres_session as session:
|
||||||
joke = await add_dad_joke(session, str(self.name.value))
|
joke = await add_dad_joke(session, str(self.name.value))
|
||||||
|
|
||||||
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)
|
await interaction.response.send_message(f"Successfully added joke #{joke.dad_joke_id}", ephemeral=True)
|
||||||
|
|
|
@ -19,6 +19,7 @@ services:
|
||||||
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
|
- MONGO_INITDB_ROOT_USERNAME=${MONGO_USER:-root}
|
||||||
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
|
- MONGO_INITDB_ROOT_PASSWORD=${MONGO_PASS:-root}
|
||||||
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
|
- MONGO_INITDB_DATABASE=${MONGO_DB:-didier_dev}
|
||||||
|
command: [--auth]
|
||||||
ports:
|
ports:
|
||||||
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
|
- "${MONGO_PORT:-27017}:${MONGO_PORT:-27017}"
|
||||||
volumes:
|
volumes:
|
||||||
|
|
|
@ -7,6 +7,7 @@ git+https://github.com/Rapptz/discord.py
|
||||||
environs==9.5.0
|
environs==9.5.0
|
||||||
feedparser==6.0.10
|
feedparser==6.0.10
|
||||||
markdownify==0.11.2
|
markdownify==0.11.2
|
||||||
|
motor==3.0.0
|
||||||
overrides==6.1.0
|
overrides==6.1.0
|
||||||
pydantic==1.9.1
|
pydantic==1.9.1
|
||||||
python-dateutil==2.8.2
|
python-dateutil==2.8.2
|
||||||
|
|
|
@ -5,7 +5,7 @@ from unittest.mock import MagicMock
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from database.engine import engine
|
from database.engine import postgres_engine
|
||||||
from database.migrations import ensure_latest_migration, migrate
|
from database.migrations import ensure_latest_migration, migrate
|
||||||
from didier import Didier
|
from didier import Didier
|
||||||
|
|
||||||
|
@ -40,7 +40,7 @@ async def postgres(tables) -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
|
||||||
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
Rollbacks the transaction afterwards so that the future tests start with a clean database
|
||||||
"""
|
"""
|
||||||
connection = await engine.connect()
|
connection = await postgres_engine.connect()
|
||||||
transaction = await connection.begin()
|
transaction = await connection.begin()
|
||||||
session = AsyncSession(bind=connection, expire_on_commit=False)
|
session = AsyncSession(bind=connection, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue