diff --git a/backend/json_encoder.py b/backend/json_encoder.py new file mode 100644 index 0000000..0eaeed8 --- /dev/null +++ b/backend/json_encoder.py @@ -0,0 +1,15 @@ +from datetime import date + +from quart.json import JSONEncoder + + +class CustomEncoder(JSONEncoder): + """ + Custom JSON encoder to handle things like date parsing + """ + def default(self, o): + # Return dates as ISO strings + if isinstance(o, date): + return o.isoformat() + + return super().default(o) diff --git a/backend/server.py b/backend/server.py index 1f190fe..7edd0c5 100644 --- a/backend/server.py +++ b/backend/server.py @@ -1,6 +1,7 @@ from werkzeug.exceptions import abort from .error_handlers import handler_404, handler_422 +from .json_encoder import CustomEncoder from .routes.dm import dm_blueprint from .routes.ping import ping_blueprint from .routes.stats import stats_blueprint @@ -15,6 +16,7 @@ app = Quart(__name__) # needs higher Python & Quart version app = cors(app, allow_origin="*") app.url_map.strict_slashes = False +app.json_encoder = CustomEncoder # Register blueprints app.register_blueprint(dm_blueprint) diff --git a/database/utils.py b/database/utils.py index cd4b05e..7fc089b 100644 --- a/database/utils.py +++ b/database/utils.py @@ -1,4 +1,5 @@ -from database import all_models, engine +from database import all_models, engine, Base +from typing import Dict def create_all_tables(): @@ -9,3 +10,16 @@ def create_all_tables(): """ for model in all_models: model.__table__.create(engine, checkfirst=True) + + +def row_to_dict(row: Base) -> Dict: + """ + Create a dictionary from a database entry + vars() or dict() can NOT be used as these add extra SQLAlchemy-related properties! + """ + d = {} + + for column in row.__table__.columns: + d[column.name] = getattr(row, column.name) + + return d diff --git a/functions/database/commands.py b/functions/database/commands.py index 9f63d68..6732806 100644 --- a/functions/database/commands.py +++ b/functions/database/commands.py @@ -2,8 +2,8 @@ from enum import IntEnum from typing import Dict, List from database.db import session +from database.utils import row_to_dict from database.models import CommandStats -from functions.database import utils from functions.stringFormatters import leading_zero as lz import time @@ -24,27 +24,16 @@ def _is_present(date: str) -> bool: """ Check if a given date is present in the database """ - connection = utils.connect() - cursor = connection.cursor() - - cursor.execute("SELECT * FROM command_stats WHERE day = %s", (date,)) - res = cursor.fetchone() - - if res: - return True - - return False + res = session.query(CommandStats).filter(CommandStats.day == date).scalar() + return res is not None def _add_date(date: str): """ Add a date into the db """ - connection = utils.connect() - cursor = connection.cursor() - - cursor.execute("INSERT INTO command_stats(day, commands, slash_commands, context_menus) VALUES (%s, 0, 0, 0)", (date,)) - connection.commit() + entry = CommandStats(day=date, commands=0, slash_commands=0, context_menus=0) + session.add(entry) def _update(date: str, inv: InvocationType): @@ -55,40 +44,22 @@ def _update(date: str, inv: InvocationType): if not _is_present(date): _add_date(date) - connection = utils.connect() - cursor = connection.cursor() + column_name: str = ["commands", "slash_commands", "context_menus"][inv.value] + session.query(CommandStats).filter(CommandStats.day == date)\ + .update({column_name: (getattr(CommandStats, column_name) + 1)}) - column_name = ["commands", "slash_commands", "context_menus"][inv.value] - - # String formatting is safe here because the input comes from above ^ - cursor.execute(f""" - UPDATE command_stats - SET {column_name} = {column_name} + 1 - WHERE day = %s - """, (date,)) - connection.commit() - - -def _get_all(): - """ - Get all rows - """ - connection = utils.connect() - cursor = connection.cursor() - - cursor.execute("SELECT * FROM command_stats") - return cursor.fetchall() + # Commit changes, including adding the date if necessary + session.commit() def query_command_stats() -> List[Dict]: + """ + Return all rows as dicts + """ stats = [] + instance: CommandStats for instance in session.query(CommandStats).order_by(CommandStats.day): - stats.append({ - "day": instance.day, - "commands": instance.commands, - "slash_commands": instance.slash_commands, - "context_menus": instance.context_menus - }) + stats.append(row_to_dict(instance)) return stats