From 4c187d1df5f66b5de2d03baddccabff74cf8e6be Mon Sep 17 00:00:00 2001 From: SaintShit Date: Wed, 13 Sep 2023 16:06:53 +0330 Subject: [PATCH] fix: calculation problems of recording usages --- app/__init__.py | 2 +- app/jobs/record_usages.py | 84 +++++++++++++++------------------------ 2 files changed, 32 insertions(+), 54 deletions(-) diff --git a/app/__init__.py b/app/__init__.py index 2b39afa1..c0017bda 100755 --- a/app/__init__.py +++ b/app/__init__.py @@ -21,7 +21,7 @@ redoc_url='/redoc' if DOCS else None ) app.openapi = custom_openapi(app) -scheduler = BackgroundScheduler({'apscheduler.job_defaults.max_instances': 5}, timezone='UTC') +scheduler = BackgroundScheduler({'apscheduler.job_defaults.max_instances': 10}, timezone='UTC') logger = logging.getLogger('uvicorn.error') app.add_middleware( CORSMiddleware, diff --git a/app/jobs/record_usages.py b/app/jobs/record_usages.py index 081e9d4f..bfadbb50 100644 --- a/app/jobs/record_usages.py +++ b/app/jobs/record_usages.py @@ -5,7 +5,7 @@ from typing import Union from pymysql.err import OperationalError -from sqlalchemy import and_, bindparam, insert, select, update +from sqlalchemy import and_, bindparam, insert, select, sql, update from app import scheduler, xray from app.db import GetDB @@ -15,6 +15,30 @@ from xray_api import exc as xray_exc +def safe_execute(db, stmt, params=None): + if db.bind.name == 'mysql': + if isinstance(stmt, sql.dml.Insert): + stmt = stmt.prefix_with('IGNORE') + + tries = 0 + done = False + while not done: + try: + db.execute(stmt, params) + db.commit() + done = True + except OperationalError as err: + if err.args[0] == 1213 and tries < 3: # Deadlock + db.rollback() + tries += 1 + continue + raise err + + else: + db.execute(stmt, params) + db.commit() + + def record_user_stats(params: list, node_id: Union[int, None]): if not params: return @@ -41,9 +65,7 @@ def record_user_stats(params: list, node_id: Union[int, None]): node_id=node_id, used_traffic=0 ) - if db.bind.name == 'mysql': - stmt = stmt.prefix_with('IGNORE') - db.execute(stmt, [{'uid': uid} for uid in uids_to_insert]) + safe_execute(db, stmt, [{'uid': uid} for uid in uids_to_insert]) # record stmt = update(NodeUserUsage) \ @@ -51,20 +73,7 @@ def record_user_stats(params: list, node_id: Union[int, None]): .where(and_(NodeUserUsage.user_id == bindparam('uid'), NodeUserUsage.node_id == node_id, NodeUserUsage.created_at == created_at)) - - tries = 0 - done = False - while not done: - try: - db.execute(stmt, params) - db.commit() - done = True - except OperationalError as err: - if err.args[0] == 1213 and tries < 3: # Deadlock - db.rollback() - tries += 1 - continue - raise err + safe_execute(db, stmt, params) def record_node_stats(params: dict, node_id: Union[int, None]): @@ -81,31 +90,14 @@ def record_node_stats(params: dict, node_id: Union[int, None]): notfound = db.execute(select_stmt).first() is None if notfound: stmt = insert(NodeUsage).values(created_at=created_at, node_id=node_id, uplink=0, downlink=0) - if db.bind.name == 'mysql': - stmt = stmt.prefix_with('IGNORE') - db.execute(stmt) + safe_execute(db, stmt) # record stmt = update(NodeUsage). \ values(uplink=NodeUsage.uplink + bindparam('up'), downlink=NodeUsage.downlink + bindparam('down')). \ where(and_(NodeUsage.node_id == node_id, NodeUsage.created_at == created_at)) - db.execute(stmt, params) - - # commit changes - tries = 0 - done = False - while not done: - try: - db.execute(stmt, params) - db.commit() - done = True - except OperationalError as err: - if err.args[0] == 1213 and tries < 3: # Deadlock - db.rollback() - tries += 1 - continue - raise err + safe_execute(db, stmt, params) def get_users_stats(api: XRayAPI): @@ -151,20 +143,7 @@ def record_user_usages(): stmt = update(User). \ where(User.id == bindparam('uid')). \ values(used_traffic=User.used_traffic + bindparam('value')) - - tries = 0 - done = False - while not done: - try: - db.execute(stmt, users_usage) - db.commit() - done = True - except OperationalError as err: - if err.args[0] == 1213 and tries < 3: # Deadlock - db.rollback() - tries += 1 - continue - raise err + safe_execute(db, stmt, users_usage) if DISABLE_RECORDING_NODE_USAGE: return @@ -198,8 +177,7 @@ def record_node_usages(): uplink=System.uplink + total_up, downlink=System.downlink + total_down ) - db.execute(stmt) - db.commit() + safe_execute(db, stmt) if DISABLE_RECORDING_NODE_USAGE: return