Skip to content

Commit

Permalink
Commit Worker IP change to DB asap
Browse files Browse the repository at this point in the history
  • Loading branch information
benoit74 committed Nov 14, 2023
1 parent 45b05ee commit 7b534de
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 5 deletions.
14 changes: 14 additions & 0 deletions dispatcher/backend/src/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ def inner(*args, **kwargs):
return inner


def dbsession_manual(func):
"""Decorator to create an SQLAlchemy ORM session object and wrap the function
inside the session. A `session` argument is automatically set. Transaction must
be managed by the developer (e.g. perform a commit / rollback).
"""

def inner(*args, **kwargs):
with Session() as session:
kwargs["session"] = session
return func(*args, **kwargs)

return inner


def count_from_stmt(session: OrmSession, stmt: SelectBase) -> int:
"""Count all records returned by any statement `stmt` passed as parameter"""
return session.execute(
Expand Down
20 changes: 15 additions & 5 deletions dispatcher/backend/src/routes/requested_tasks/requested_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
WorkerRequestedTaskSchema,
)
from common.utils import task_event_handler
from db import count_from_stmt, dbsession
from errors.http import InvalidRequestJSON, TaskNotFound, WorkerNotFound
from db import count_from_stmt, dbsession, dbsession_manual
from errors.http import InvalidRequestJSON, TaskNotFound, WorkerNotFound, HTTPBase
from routes import auth_info_if_supplied, authenticate, require_perm, url_uuid
from routes.base import BaseRoute
from routes.errors import NotFound
Expand Down Expand Up @@ -208,7 +208,7 @@ class RequestedTasksForWorkers(BaseRoute):
methods = ["GET"]

@authenticate
@dbsession
@dbsession_manual
def get(self, session: so.Session, token: AccessToken.Payload):
"""list of requested tasks to be retrieved by workers, auth-only"""

Expand Down Expand Up @@ -237,8 +237,18 @@ def get(self, session: so.Session, token: AccessToken.Payload):
f"IP changed from {worker.last_ip} to {worker_ip}"
)
worker.last_ip = worker_ip
if USES_WORKERS_IPS_WHITELIST():
record_ip_change(session=session, worker_name=worker_name)
# commit explicitely since we are not using an explicit transaction,
# and do it before calling Wasabi so that changes are propagated
# quickly and transaction is not blocking
session.commit()
if constants.USES_WORKERS_IPS_WHITELIST:
try:
record_ip_change(session=session, worker_name=worker_name)
except Exception:
raise HTTPBase(
status_code=HTTPStatus.SERVICE_UNAVAILABLE,
error="Recording IP changes failed",
)

request_args = WorkerRequestedTaskSchema().load(request_args)

Expand Down

0 comments on commit 7b534de

Please sign in to comment.