diff --git a/Dockerfile b/Dockerfile index d8a9d3ff..270c0869 100644 --- a/Dockerfile +++ b/Dockerfile @@ -5,11 +5,14 @@ ENV APP_HOME /app ENV PROMPT_PROCESSING_DIR $APP_HOME # Normally defined in the Kubernetes config. ENV WORKER_RESTART_FREQ ${WORKER_RESTART_FREQ:-0} +ENV WORKER_TIMEOUT ${WORKER_TIMEOUT:-0} +ENV WORKER_GRACE_PERIOD ${WORKER_GRACE_PERIOD:-30} ARG PORT WORKDIR $APP_HOME COPY python/activator activator/ COPY pipelines pipelines/ CMD source /opt/lsst/software/stack/loadLSST.bash \ && setup lsst_distrib \ - && exec gunicorn --workers 1 --threads 1 --timeout 0 --max-requests $WORKER_RESTART_FREQ \ + && exec gunicorn --workers 1 --threads 1 --timeout $WORKER_TIMEOUT --max-requests $WORKER_RESTART_FREQ \ + --graceful-timeout $WORKER_GRACE_PERIOD \ --bind :$PORT activator.activator:app diff --git a/python/activator/activator.py b/python/activator/activator.py index aa5cd401..8ed6c263 100644 --- a/python/activator/activator.py +++ b/python/activator/activator.py @@ -21,12 +21,13 @@ __all__ = ["check_for_snap", "next_visit_handler"] +import collections.abc import json import logging import os import sys import time -from typing import Optional, Tuple +import signal import uuid import boto3 @@ -37,7 +38,7 @@ from werkzeug.exceptions import ServiceUnavailable from .config import PipelinesConfig -from .exception import NonRetriableError, RetriableError +from .exception import GracefulShutdownInterrupt, NonRetriableError, RetriableError from .logger import setup_usdf_logger from .middleware_interface import get_central_butler, flush_local_repo, \ make_local_repo, make_local_cache, MiddlewareInterface @@ -125,9 +126,58 @@ def find_local_repos(base_path): sys.exit(3) +def _graceful_shutdown(signum: int, stack_frame): + """Signal handler for cases where the service should gracefully shut down. + + Parameters + ---------- + signum : `int` + The signal received. + stack_frame : `frame` or `None` + The "current" stack frame. + + Raises + ------ + GracefulShutdownInterrupt + Raised unconditionally. + """ + signame = signal.Signals(signum).name + _log.info("Signal %s detected, cleaning up and shutting down.", signame) + # TODO DM-45339: raising in signal handlers is dangerous; can we get a way + # for pipeline processing to check for interrupts? + raise GracefulShutdownInterrupt(f"Received signal {signame}.") + + +def with_signal(signum: int, + handler: collections.abc.Callable | signal.Handlers, + ) -> collections.abc.Callable: + """A decorator that registers a signal handler for the duration of a + function call. + + Parameters + ---------- + signum : `int` + The signal for which to register a handler; see `signal.signal`. + handler : callable or `signal.Handlers` + The handler to register. + """ + def decorator(func): + def wrapper(*args, **kwargs): + old_handler = signal.signal(signum, handler) + try: + return func(*args, **kwargs) + finally: + if old_handler is not None: + signal.signal(signum, old_handler) + else: + signal.signal(signum, signal.SIG_DFL) + return wrapper + return decorator + + def check_for_snap( instrument: str, group: int, snap: int, detector: int -) -> Optional[str]: +) -> str | None: """Search for new raw files matching a particular data ID. The search is performed in the active image bucket. @@ -266,7 +316,9 @@ def _try_export(mwi: MiddlewareInterface, exposures: set[int], log: logging.Logg @app.route("/next-visit", methods=["POST"]) -def next_visit_handler() -> Tuple[str, int]: +@with_signal(signal.SIGHUP, _graceful_shutdown) +@with_signal(signal.SIGTERM, _graceful_shutdown) +def next_visit_handler() -> tuple[str, int]: """A Flask view function for handling next-visit events. Like all Flask handlers, this function accepts input through the @@ -302,76 +354,83 @@ def next_visit_handler() -> Tuple[str, int]: survey=expected_visit.survey, detector=expected_visit.detector, ): - expid_set = set() - - # Create a fresh MiddlewareInterface object to avoid accidental - # "cross-talk" between different visits. - mwi = MiddlewareInterface(central_butler, - image_bucket, - expected_visit, - pre_pipelines, - main_pipelines, - skymap, - local_repo.name, - local_cache) - # Copy calibrations for this detector/visit - mwi.prep_butler() - - # expected_visit.nimages == 0 means "not known in advance"; keep listening until timeout - expected_snaps = expected_visit.nimages if expected_visit.nimages else 100 - # Heuristic: take the upcoming script's duration and multiply by 2 to - # include the currently executing script, then add time to transfer - # the last image. - timeout = expected_visit.duration * 2 + image_timeout - # Check to see if any snaps have already arrived - for snap in range(expected_snaps): - oid = check_for_snap( - expected_visit.instrument, - expected_visit.groupId, - snap, - expected_visit.detector, - ) - if oid: - _log.debug("Found object %s already present", oid) - exp_id = mwi.ingest_image(oid) - expid_set.add(exp_id) - - _log.debug("Waiting for snaps...") - start = time.time() - while len(expid_set) < expected_snaps and time.time() - start < timeout: - if startup_response: - response = startup_response - else: - time_remaining = max(0.0, timeout - (time.time() - start)) - response = consumer.consume(num_messages=1, timeout=time_remaining + 1.0) - end = time.time() - messages = _filter_messages(response) - response = [] - if len(messages) == 0 and end - start < timeout and not startup_response: - _log.debug(f"Empty consume after {end - start}s.") - continue - startup_response = [] - - # Not all notifications are for this group/detector - for received in messages: - for oid in _parse_bucket_notifications(received.value()): - try: - if is_path_consistent(oid, expected_visit): - _log.debug("Received %r", oid) - group_id = get_group_id_from_oid(oid) - if group_id == expected_visit.groupId: - # Ingest the snap - exp_id = mwi.ingest_image(oid) - expid_set.add(exp_id) - except ValueError: - _log.error(f"Failed to match object id '{oid}'") - # Commits are per-group, so this can't interfere with other - # workers. This may wipe messages associated with a next_visit - # that will later be assigned to this worker, but those cases - # should be caught by the "already arrived" check. - consumer.commit(message=received) - if len(expid_set) < expected_snaps: - _log.warning(f"Timed out waiting for image after receiving exposures {expid_set}.") + try: + expid_set = set() + + # Create a fresh MiddlewareInterface object to avoid accidental + # "cross-talk" between different visits. + mwi = MiddlewareInterface(central_butler, + image_bucket, + expected_visit, + pre_pipelines, + main_pipelines, + skymap, + local_repo.name, + local_cache) + # Copy calibrations for this detector/visit + mwi.prep_butler() + + # expected_visit.nimages == 0 means "not known in advance"; keep listening until timeout + expected_snaps = expected_visit.nimages if expected_visit.nimages else 100 + # Heuristic: take the upcoming script's duration and multiply by 2 to + # include the currently executing script, then add time to transfer + # the last image. + timeout = expected_visit.duration * 2 + image_timeout + # Check to see if any snaps have already arrived + for snap in range(expected_snaps): + oid = check_for_snap( + expected_visit.instrument, + expected_visit.groupId, + snap, + expected_visit.detector, + ) + if oid: + _log.debug("Found object %s already present", oid) + exp_id = mwi.ingest_image(oid) + expid_set.add(exp_id) + + _log.debug("Waiting for snaps...") + start = time.time() + while len(expid_set) < expected_snaps and time.time() - start < timeout: + if startup_response: + response = startup_response + else: + time_remaining = max(0.0, timeout - (time.time() - start)) + response = consumer.consume(num_messages=1, timeout=time_remaining + 1.0) + end = time.time() + messages = _filter_messages(response) + response = [] + if len(messages) == 0 and end - start < timeout and not startup_response: + _log.debug(f"Empty consume after {end - start}s.") + continue + startup_response = [] + + # Not all notifications are for this group/detector + for received in messages: + for oid in _parse_bucket_notifications(received.value()): + try: + if is_path_consistent(oid, expected_visit): + _log.debug("Received %r", oid) + group_id = get_group_id_from_oid(oid) + if group_id == expected_visit.groupId: + # Ingest the snap + exp_id = mwi.ingest_image(oid) + expid_set.add(exp_id) + except ValueError: + _log.error(f"Failed to match object id '{oid}'") + # Commits are per-group, so this can't interfere with other + # workers. This may wipe messages associated with a next_visit + # that will later be assigned to this worker, but those cases + # should be caught by the "already arrived" check. + consumer.commit(message=received) + if len(expid_set) < expected_snaps: + _log.warning(f"Timed out waiting for image after receiving exposures {expid_set}.") + except GracefulShutdownInterrupt as e: + _log.exception("Processing interrupted before pipeline execution") + # Do not export, to leave room for the next attempt + # Service unavailable is not quite right, but no better standard response + raise ServiceUnavailable(f"The server aborted processing, but it can be retried: {e}", + retry_after=10) from None if expid_set: with log_factory.add_context(exposures=expid_set): @@ -392,7 +451,7 @@ def next_visit_handler() -> Tuple[str, int]: from e except RetriableError as e: error = e.nested if e.nested else e - _log.error("Processing failed: ", exc_info=error) + _log.error("Processing failed but can be retried: ", exc_info=error) # Do not export, to leave room for the next attempt # Service unavailable is not quite right, but no better standard response raise ServiceUnavailable(f"A temporary error occurred during processing: {error}", @@ -414,6 +473,12 @@ def next_visit_handler() -> Tuple[str, int]: else: _log.error("Timed out waiting for images.") return "Timed out waiting for images", 500 + except GracefulShutdownInterrupt: + # Safety net to minimize chance of interrupt propagating out of the worker. + # Ideally, this would be a Flask.errorhandler, but Flask ignores BaseExceptions. + _log.error("Service interrupted. Shutting down *without* syncing to the central repo.") + return "The worker was interrupted before it could complete the request. " \ + "Retrying the request may not be safe.", 500 finally: consumer.unsubscribe() # Want to know when the handler exited for any reason. @@ -421,7 +486,7 @@ def next_visit_handler() -> Tuple[str, int]: @app.errorhandler(500) -def server_error(e) -> Tuple[str, int]: +def server_error(e) -> tuple[str, int]: _log.exception("An error occurred during a request.") return ( f""" diff --git a/python/activator/exception.py b/python/activator/exception.py index df25c6d7..5ed26f08 100644 --- a/python/activator/exception.py +++ b/python/activator/exception.py @@ -20,7 +20,7 @@ # along with this program. If not, see . -__all__ = ["NonRetriableError", "RetriableError"] +__all__ = ["NonRetriableError", "RetriableError", "GracefulShutdownInterrupt"] class NonRetriableError(Exception): @@ -72,3 +72,14 @@ def nested(self): return self.__context__ else: return None + + +# See https://docs.python.org/3.11/library/exceptions.html#KeyboardInterrupt +# for why interrupts should not subclass Exception. +class GracefulShutdownInterrupt(BaseException): + """An interrupt indicating that the service should shut down gracefully. + + Like all interrupts, ``GracefulShutdownInterrupt`` can be raised between + *any* two bytecode instructions, and handling it requires special care. See + `the Python docs `__. + """ diff --git a/python/activator/middleware_interface.py b/python/activator/middleware_interface.py index f4ad6021..95a0b26d 100644 --- a/python/activator/middleware_interface.py +++ b/python/activator/middleware_interface.py @@ -53,7 +53,7 @@ from .caching import DatasetCache from .config import PipelinesConfig -from .exception import NonRetriableError +from .exception import GracefulShutdownInterrupt, NonRetriableError, RetriableError from .visit import FannedOutVisit from .timer import enforce_schema, time_this_to_bundle @@ -1318,6 +1318,31 @@ def _run_preprocessing(self) -> None: label="preprocessing", ) + def _check_permanent_changes(self, where: str) -> bool: + """Test whether the APDB, alert stream, or other external state has + changed in a way that makes retries unsafe. + + Parameters + ---------- + where : `str` + A :ref:`Butler query string ` identifying the + current visit. The query should return exactly one visit. + + Returns + ---------- + changes : `bool` + `True` if changes have been made, `False` if retries are safe. + """ + data_ids = set(self.butler.registry.queryDataIds(["instrument", "visit", "detector"], where=where)) + if len(data_ids) == 1: + data_id = data_ids.pop() + apdb = lsst.dax.apdb.Apdb.from_uri(self._apdb_config) + return apdb.containsVisitDetector(data_id["visit"], self.visit.detector) + else: + # Don't know how this could happen, so won't try to handle it gracefully. + _log.warning("Unexpected visit ids: %s. Assuming APDB modified.", data_ids) + return True + def run_pipeline(self, exposure_ids: set[int]) -> None: """Process the received image(s). @@ -1341,6 +1366,11 @@ def run_pipeline(self, exposure_ids: set[int]) -> None: may have been left in a state that makes it unsafe to retry failures. This exception is always chained to another exception representing the original error. + RetriableError + Raised if the conditions for NonRetriableError are not met, *and* + the pipeline fails in a way that is expected to be transient. This + exception is always chained to another exception representing the + original error. """ # TODO: we want to define visits earlier, but we have to ingest a # faked raw file and appropriate SSO data during prep (and then @@ -1368,23 +1398,25 @@ def run_pipeline(self, exposure_ids: set[int]) -> None: data_ids=where, label="main", ) - except Exception as e: - state_changed = True # better safe than sorry + except GracefulShutdownInterrupt as e: try: - data_ids = set(self.butler.registry.queryDataIds( - ["instrument", "visit", "detector"], where=where)) - if len(data_ids) == 1: - data_id = list(data_ids)[0] - apdb = lsst.dax.apdb.Apdb.from_uri(self._apdb_config) - if not apdb.containsVisitDetector(data_id["visit"], self.visit.detector): - state_changed = False - else: - # Don't know how this could happen, so won't try to handle it gracefully. - _log.warning("Unexpected visit ids: %s. Assuming APDB modified.", data_ids) + state_changed = self._check_permanent_changes(where) except Exception: # Failure in registry or APDB queries _log.exception("Could not determine APDB state, assuming modified.") raise NonRetriableError("APDB potentially modified") from e + else: + if state_changed: + raise NonRetriableError("APDB modified") from e + else: + raise RetriableError("External interrupt") from e + except Exception as e: + try: + state_changed = self._check_permanent_changes(where) + except (Exception, GracefulShutdownInterrupt): + # Failure in registry or APDB queries + _log.exception("Could not determine APDB state, assuming modified.") + raise NonRetriableError("APDB potentially modified") from e else: if state_changed: raise NonRetriableError("APDB modified") from e diff --git a/tests/test_exception.py b/tests/test_exception.py index c98541db..5e4d80c5 100644 --- a/tests/test_exception.py +++ b/tests/test_exception.py @@ -22,7 +22,7 @@ import unittest -from activator.exception import NonRetriableError, RetriableError +from activator.exception import GracefulShutdownInterrupt, NonRetriableError, RetriableError class NonRetriableErrorTest(unittest.TestCase): @@ -105,3 +105,20 @@ def test_raise_orphaned(self): raise RetriableError("Cannot compute!") from None except RetriableError as e: self.assertIs(e.nested, None) + + +class GracefulShutdownInterruptTest(unittest.TestCase): + def test_catchable(self): + try: + raise GracefulShutdownInterrupt("Last call!") + except GracefulShutdownInterrupt: + pass + else: + self.fail("Did not catch GracefulShutdownInterrupt.") + + def test_uncatchable(self): + with self.assertRaises(GracefulShutdownInterrupt): + try: + raise GracefulShutdownInterrupt("Last call!") + except Exception: + pass # assertRaises should fail