diff --git a/dispatcher/backend/src/db/__init__.py b/dispatcher/backend/src/db/__init__.py index 69bad2e9..5f03486a 100644 --- a/dispatcher/backend/src/db/__init__.py +++ b/dispatcher/backend/src/db/__init__.py @@ -52,6 +52,20 @@ def inner(*args, **kwargs): return inner +def dbsession_manual(func): + """Decorator to create an SQLAlchemy ORM session object and wrap the function + inside the session. A `session` argument is automatically set. Transaction must + be managed by the developer (e.g. perform a commit / rollback). + """ + + def inner(*args, **kwargs): + with Session() as session: + kwargs["session"] = session + return func(*args, **kwargs) + + return inner + + def count_from_stmt(session: OrmSession, stmt: SelectBase) -> int: """Count all records returned by any statement `stmt` passed as parameter""" return session.execute( diff --git a/dispatcher/backend/src/routes/requested_tasks/requested_task.py b/dispatcher/backend/src/routes/requested_tasks/requested_task.py index c58080dd..6474e7b6 100644 --- a/dispatcher/backend/src/routes/requested_tasks/requested_task.py +++ b/dispatcher/backend/src/routes/requested_tasks/requested_task.py @@ -24,8 +24,8 @@ WorkerRequestedTaskSchema, ) from common.utils import task_event_handler -from db import count_from_stmt, dbsession -from errors.http import InvalidRequestJSON, TaskNotFound, WorkerNotFound +from db import count_from_stmt, dbsession, dbsession_manual +from errors.http import InvalidRequestJSON, TaskNotFound, WorkerNotFound, HTTPBase from routes import auth_info_if_supplied, authenticate, require_perm, url_uuid from routes.base import BaseRoute from routes.errors import NotFound @@ -208,7 +208,7 @@ class RequestedTasksForWorkers(BaseRoute): methods = ["GET"] @authenticate - @dbsession + @dbsession_manual def get(self, session: so.Session, token: AccessToken.Payload): """list of requested tasks to be retrieved by workers, auth-only""" @@ -237,8 +237,18 @@ def get(self, session: so.Session, token: AccessToken.Payload): f"IP changed from {worker.last_ip} to {worker_ip}" ) worker.last_ip = worker_ip - if USES_WORKERS_IPS_WHITELIST(): - record_ip_change(session=session, worker_name=worker_name) + # commit explicitely since we are not using an explicit transaction, + # and do it before calling Wasabi so that changes are propagated + # quickly and transaction is not blocking + session.commit() + if constants.USES_WORKERS_IPS_WHITELIST: + try: + record_ip_change(session=session, worker_name=worker_name) + except Exception: + raise HTTPBase( + status_code=HTTPStatus.SERVICE_UNAVAILABLE, + error="Recording IP changes failed", + ) request_args = WorkerRequestedTaskSchema().load(request_args)