diff --git a/Dockerfile b/Dockerfile
index d8a9d3ff..270c0869 100644
--- a/Dockerfile
+++ b/Dockerfile
@@ -5,11 +5,14 @@ ENV APP_HOME /app
ENV PROMPT_PROCESSING_DIR $APP_HOME
# Normally defined in the Kubernetes config.
ENV WORKER_RESTART_FREQ ${WORKER_RESTART_FREQ:-0}
+ENV WORKER_TIMEOUT ${WORKER_TIMEOUT:-0}
+ENV WORKER_GRACE_PERIOD ${WORKER_GRACE_PERIOD:-30}
ARG PORT
WORKDIR $APP_HOME
COPY python/activator activator/
COPY pipelines pipelines/
CMD source /opt/lsst/software/stack/loadLSST.bash \
&& setup lsst_distrib \
- && exec gunicorn --workers 1 --threads 1 --timeout 0 --max-requests $WORKER_RESTART_FREQ \
+ && exec gunicorn --workers 1 --threads 1 --timeout $WORKER_TIMEOUT --max-requests $WORKER_RESTART_FREQ \
+ --graceful-timeout $WORKER_GRACE_PERIOD \
--bind :$PORT activator.activator:app
diff --git a/python/activator/activator.py b/python/activator/activator.py
index aa5cd401..8ed6c263 100644
--- a/python/activator/activator.py
+++ b/python/activator/activator.py
@@ -21,12 +21,13 @@
__all__ = ["check_for_snap", "next_visit_handler"]
+import collections.abc
import json
import logging
import os
import sys
import time
-from typing import Optional, Tuple
+import signal
import uuid
import boto3
@@ -37,7 +38,7 @@
from werkzeug.exceptions import ServiceUnavailable
from .config import PipelinesConfig
-from .exception import NonRetriableError, RetriableError
+from .exception import GracefulShutdownInterrupt, NonRetriableError, RetriableError
from .logger import setup_usdf_logger
from .middleware_interface import get_central_butler, flush_local_repo, \
make_local_repo, make_local_cache, MiddlewareInterface
@@ -125,9 +126,58 @@ def find_local_repos(base_path):
sys.exit(3)
+def _graceful_shutdown(signum: int, stack_frame):
+ """Signal handler for cases where the service should gracefully shut down.
+
+ Parameters
+ ----------
+ signum : `int`
+ The signal received.
+ stack_frame : `frame` or `None`
+ The "current" stack frame.
+
+ Raises
+ ------
+ GracefulShutdownInterrupt
+ Raised unconditionally.
+ """
+ signame = signal.Signals(signum).name
+ _log.info("Signal %s detected, cleaning up and shutting down.", signame)
+ # TODO DM-45339: raising in signal handlers is dangerous; can we get a way
+ # for pipeline processing to check for interrupts?
+ raise GracefulShutdownInterrupt(f"Received signal {signame}.")
+
+
+def with_signal(signum: int,
+ handler: collections.abc.Callable | signal.Handlers,
+ ) -> collections.abc.Callable:
+ """A decorator that registers a signal handler for the duration of a
+ function call.
+
+ Parameters
+ ----------
+ signum : `int`
+ The signal for which to register a handler; see `signal.signal`.
+ handler : callable or `signal.Handlers`
+ The handler to register.
+ """
+ def decorator(func):
+ def wrapper(*args, **kwargs):
+ old_handler = signal.signal(signum, handler)
+ try:
+ return func(*args, **kwargs)
+ finally:
+ if old_handler is not None:
+ signal.signal(signum, old_handler)
+ else:
+ signal.signal(signum, signal.SIG_DFL)
+ return wrapper
+ return decorator
+
+
def check_for_snap(
instrument: str, group: int, snap: int, detector: int
-) -> Optional[str]:
+) -> str | None:
"""Search for new raw files matching a particular data ID.
The search is performed in the active image bucket.
@@ -266,7 +316,9 @@ def _try_export(mwi: MiddlewareInterface, exposures: set[int], log: logging.Logg
@app.route("/next-visit", methods=["POST"])
-def next_visit_handler() -> Tuple[str, int]:
+@with_signal(signal.SIGHUP, _graceful_shutdown)
+@with_signal(signal.SIGTERM, _graceful_shutdown)
+def next_visit_handler() -> tuple[str, int]:
"""A Flask view function for handling next-visit events.
Like all Flask handlers, this function accepts input through the
@@ -302,76 +354,83 @@ def next_visit_handler() -> Tuple[str, int]:
survey=expected_visit.survey,
detector=expected_visit.detector,
):
- expid_set = set()
-
- # Create a fresh MiddlewareInterface object to avoid accidental
- # "cross-talk" between different visits.
- mwi = MiddlewareInterface(central_butler,
- image_bucket,
- expected_visit,
- pre_pipelines,
- main_pipelines,
- skymap,
- local_repo.name,
- local_cache)
- # Copy calibrations for this detector/visit
- mwi.prep_butler()
-
- # expected_visit.nimages == 0 means "not known in advance"; keep listening until timeout
- expected_snaps = expected_visit.nimages if expected_visit.nimages else 100
- # Heuristic: take the upcoming script's duration and multiply by 2 to
- # include the currently executing script, then add time to transfer
- # the last image.
- timeout = expected_visit.duration * 2 + image_timeout
- # Check to see if any snaps have already arrived
- for snap in range(expected_snaps):
- oid = check_for_snap(
- expected_visit.instrument,
- expected_visit.groupId,
- snap,
- expected_visit.detector,
- )
- if oid:
- _log.debug("Found object %s already present", oid)
- exp_id = mwi.ingest_image(oid)
- expid_set.add(exp_id)
-
- _log.debug("Waiting for snaps...")
- start = time.time()
- while len(expid_set) < expected_snaps and time.time() - start < timeout:
- if startup_response:
- response = startup_response
- else:
- time_remaining = max(0.0, timeout - (time.time() - start))
- response = consumer.consume(num_messages=1, timeout=time_remaining + 1.0)
- end = time.time()
- messages = _filter_messages(response)
- response = []
- if len(messages) == 0 and end - start < timeout and not startup_response:
- _log.debug(f"Empty consume after {end - start}s.")
- continue
- startup_response = []
-
- # Not all notifications are for this group/detector
- for received in messages:
- for oid in _parse_bucket_notifications(received.value()):
- try:
- if is_path_consistent(oid, expected_visit):
- _log.debug("Received %r", oid)
- group_id = get_group_id_from_oid(oid)
- if group_id == expected_visit.groupId:
- # Ingest the snap
- exp_id = mwi.ingest_image(oid)
- expid_set.add(exp_id)
- except ValueError:
- _log.error(f"Failed to match object id '{oid}'")
- # Commits are per-group, so this can't interfere with other
- # workers. This may wipe messages associated with a next_visit
- # that will later be assigned to this worker, but those cases
- # should be caught by the "already arrived" check.
- consumer.commit(message=received)
- if len(expid_set) < expected_snaps:
- _log.warning(f"Timed out waiting for image after receiving exposures {expid_set}.")
+ try:
+ expid_set = set()
+
+ # Create a fresh MiddlewareInterface object to avoid accidental
+ # "cross-talk" between different visits.
+ mwi = MiddlewareInterface(central_butler,
+ image_bucket,
+ expected_visit,
+ pre_pipelines,
+ main_pipelines,
+ skymap,
+ local_repo.name,
+ local_cache)
+ # Copy calibrations for this detector/visit
+ mwi.prep_butler()
+
+ # expected_visit.nimages == 0 means "not known in advance"; keep listening until timeout
+ expected_snaps = expected_visit.nimages if expected_visit.nimages else 100
+ # Heuristic: take the upcoming script's duration and multiply by 2 to
+ # include the currently executing script, then add time to transfer
+ # the last image.
+ timeout = expected_visit.duration * 2 + image_timeout
+ # Check to see if any snaps have already arrived
+ for snap in range(expected_snaps):
+ oid = check_for_snap(
+ expected_visit.instrument,
+ expected_visit.groupId,
+ snap,
+ expected_visit.detector,
+ )
+ if oid:
+ _log.debug("Found object %s already present", oid)
+ exp_id = mwi.ingest_image(oid)
+ expid_set.add(exp_id)
+
+ _log.debug("Waiting for snaps...")
+ start = time.time()
+ while len(expid_set) < expected_snaps and time.time() - start < timeout:
+ if startup_response:
+ response = startup_response
+ else:
+ time_remaining = max(0.0, timeout - (time.time() - start))
+ response = consumer.consume(num_messages=1, timeout=time_remaining + 1.0)
+ end = time.time()
+ messages = _filter_messages(response)
+ response = []
+ if len(messages) == 0 and end - start < timeout and not startup_response:
+ _log.debug(f"Empty consume after {end - start}s.")
+ continue
+ startup_response = []
+
+ # Not all notifications are for this group/detector
+ for received in messages:
+ for oid in _parse_bucket_notifications(received.value()):
+ try:
+ if is_path_consistent(oid, expected_visit):
+ _log.debug("Received %r", oid)
+ group_id = get_group_id_from_oid(oid)
+ if group_id == expected_visit.groupId:
+ # Ingest the snap
+ exp_id = mwi.ingest_image(oid)
+ expid_set.add(exp_id)
+ except ValueError:
+ _log.error(f"Failed to match object id '{oid}'")
+ # Commits are per-group, so this can't interfere with other
+ # workers. This may wipe messages associated with a next_visit
+ # that will later be assigned to this worker, but those cases
+ # should be caught by the "already arrived" check.
+ consumer.commit(message=received)
+ if len(expid_set) < expected_snaps:
+ _log.warning(f"Timed out waiting for image after receiving exposures {expid_set}.")
+ except GracefulShutdownInterrupt as e:
+ _log.exception("Processing interrupted before pipeline execution")
+ # Do not export, to leave room for the next attempt
+ # Service unavailable is not quite right, but no better standard response
+ raise ServiceUnavailable(f"The server aborted processing, but it can be retried: {e}",
+ retry_after=10) from None
if expid_set:
with log_factory.add_context(exposures=expid_set):
@@ -392,7 +451,7 @@ def next_visit_handler() -> Tuple[str, int]:
from e
except RetriableError as e:
error = e.nested if e.nested else e
- _log.error("Processing failed: ", exc_info=error)
+ _log.error("Processing failed but can be retried: ", exc_info=error)
# Do not export, to leave room for the next attempt
# Service unavailable is not quite right, but no better standard response
raise ServiceUnavailable(f"A temporary error occurred during processing: {error}",
@@ -414,6 +473,12 @@ def next_visit_handler() -> Tuple[str, int]:
else:
_log.error("Timed out waiting for images.")
return "Timed out waiting for images", 500
+ except GracefulShutdownInterrupt:
+ # Safety net to minimize chance of interrupt propagating out of the worker.
+ # Ideally, this would be a Flask.errorhandler, but Flask ignores BaseExceptions.
+ _log.error("Service interrupted. Shutting down *without* syncing to the central repo.")
+ return "The worker was interrupted before it could complete the request. " \
+ "Retrying the request may not be safe.", 500
finally:
consumer.unsubscribe()
# Want to know when the handler exited for any reason.
@@ -421,7 +486,7 @@ def next_visit_handler() -> Tuple[str, int]:
@app.errorhandler(500)
-def server_error(e) -> Tuple[str, int]:
+def server_error(e) -> tuple[str, int]:
_log.exception("An error occurred during a request.")
return (
f"""
diff --git a/python/activator/exception.py b/python/activator/exception.py
index df25c6d7..5ed26f08 100644
--- a/python/activator/exception.py
+++ b/python/activator/exception.py
@@ -20,7 +20,7 @@
# along with this program. If not, see .
-__all__ = ["NonRetriableError", "RetriableError"]
+__all__ = ["NonRetriableError", "RetriableError", "GracefulShutdownInterrupt"]
class NonRetriableError(Exception):
@@ -72,3 +72,14 @@ def nested(self):
return self.__context__
else:
return None
+
+
+# See https://docs.python.org/3.11/library/exceptions.html#KeyboardInterrupt
+# for why interrupts should not subclass Exception.
+class GracefulShutdownInterrupt(BaseException):
+ """An interrupt indicating that the service should shut down gracefully.
+
+ Like all interrupts, ``GracefulShutdownInterrupt`` can be raised between
+ *any* two bytecode instructions, and handling it requires special care. See
+ `the Python docs `__.
+ """
diff --git a/python/activator/middleware_interface.py b/python/activator/middleware_interface.py
index f4ad6021..95a0b26d 100644
--- a/python/activator/middleware_interface.py
+++ b/python/activator/middleware_interface.py
@@ -53,7 +53,7 @@
from .caching import DatasetCache
from .config import PipelinesConfig
-from .exception import NonRetriableError
+from .exception import GracefulShutdownInterrupt, NonRetriableError, RetriableError
from .visit import FannedOutVisit
from .timer import enforce_schema, time_this_to_bundle
@@ -1318,6 +1318,31 @@ def _run_preprocessing(self) -> None:
label="preprocessing",
)
+ def _check_permanent_changes(self, where: str) -> bool:
+ """Test whether the APDB, alert stream, or other external state has
+ changed in a way that makes retries unsafe.
+
+ Parameters
+ ----------
+ where : `str`
+ A :ref:`Butler query string ` identifying the
+ current visit. The query should return exactly one visit.
+
+ Returns
+ ----------
+ changes : `bool`
+ `True` if changes have been made, `False` if retries are safe.
+ """
+ data_ids = set(self.butler.registry.queryDataIds(["instrument", "visit", "detector"], where=where))
+ if len(data_ids) == 1:
+ data_id = data_ids.pop()
+ apdb = lsst.dax.apdb.Apdb.from_uri(self._apdb_config)
+ return apdb.containsVisitDetector(data_id["visit"], self.visit.detector)
+ else:
+ # Don't know how this could happen, so won't try to handle it gracefully.
+ _log.warning("Unexpected visit ids: %s. Assuming APDB modified.", data_ids)
+ return True
+
def run_pipeline(self, exposure_ids: set[int]) -> None:
"""Process the received image(s).
@@ -1341,6 +1366,11 @@ def run_pipeline(self, exposure_ids: set[int]) -> None:
may have been left in a state that makes it unsafe to retry
failures. This exception is always chained to another exception
representing the original error.
+ RetriableError
+ Raised if the conditions for NonRetriableError are not met, *and*
+ the pipeline fails in a way that is expected to be transient. This
+ exception is always chained to another exception representing the
+ original error.
"""
# TODO: we want to define visits earlier, but we have to ingest a
# faked raw file and appropriate SSO data during prep (and then
@@ -1368,23 +1398,25 @@ def run_pipeline(self, exposure_ids: set[int]) -> None:
data_ids=where,
label="main",
)
- except Exception as e:
- state_changed = True # better safe than sorry
+ except GracefulShutdownInterrupt as e:
try:
- data_ids = set(self.butler.registry.queryDataIds(
- ["instrument", "visit", "detector"], where=where))
- if len(data_ids) == 1:
- data_id = list(data_ids)[0]
- apdb = lsst.dax.apdb.Apdb.from_uri(self._apdb_config)
- if not apdb.containsVisitDetector(data_id["visit"], self.visit.detector):
- state_changed = False
- else:
- # Don't know how this could happen, so won't try to handle it gracefully.
- _log.warning("Unexpected visit ids: %s. Assuming APDB modified.", data_ids)
+ state_changed = self._check_permanent_changes(where)
except Exception:
# Failure in registry or APDB queries
_log.exception("Could not determine APDB state, assuming modified.")
raise NonRetriableError("APDB potentially modified") from e
+ else:
+ if state_changed:
+ raise NonRetriableError("APDB modified") from e
+ else:
+ raise RetriableError("External interrupt") from e
+ except Exception as e:
+ try:
+ state_changed = self._check_permanent_changes(where)
+ except (Exception, GracefulShutdownInterrupt):
+ # Failure in registry or APDB queries
+ _log.exception("Could not determine APDB state, assuming modified.")
+ raise NonRetriableError("APDB potentially modified") from e
else:
if state_changed:
raise NonRetriableError("APDB modified") from e
diff --git a/tests/test_exception.py b/tests/test_exception.py
index c98541db..5e4d80c5 100644
--- a/tests/test_exception.py
+++ b/tests/test_exception.py
@@ -22,7 +22,7 @@
import unittest
-from activator.exception import NonRetriableError, RetriableError
+from activator.exception import GracefulShutdownInterrupt, NonRetriableError, RetriableError
class NonRetriableErrorTest(unittest.TestCase):
@@ -105,3 +105,20 @@ def test_raise_orphaned(self):
raise RetriableError("Cannot compute!") from None
except RetriableError as e:
self.assertIs(e.nested, None)
+
+
+class GracefulShutdownInterruptTest(unittest.TestCase):
+ def test_catchable(self):
+ try:
+ raise GracefulShutdownInterrupt("Last call!")
+ except GracefulShutdownInterrupt:
+ pass
+ else:
+ self.fail("Did not catch GracefulShutdownInterrupt.")
+
+ def test_uncatchable(self):
+ with self.assertRaises(GracefulShutdownInterrupt):
+ try:
+ raise GracefulShutdownInterrupt("Last call!")
+ except Exception:
+ pass # assertRaises should fail