From ee16d60c96cbb4ad63ca04a0a6c0a72e36c18b57 Mon Sep 17 00:00:00 2001 From: rafa-be Date: Fri, 27 Sep 2024 17:00:40 +0200 Subject: [PATCH] Suspend the processor with a synchronization event instead of SIGSTOP. Signed-off-by: rafa-be --- scaler/about.py | 2 +- scaler/cluster/cluster.py | 3 ++ scaler/cluster/combo.py | 3 ++ scaler/entry_points/cluster.py | 13 +++++++++ scaler/io/config.py | 4 +++ scaler/worker/agent/processor/processor.py | 17 +++++++++-- scaler/worker/agent/processor_holder.py | 33 ++++++++++++++++++++-- scaler/worker/agent/processor_manager.py | 4 +++ scaler/worker/worker.py | 3 ++ tests/test_balance.py | 1 + tests/test_death_timeout.py | 2 +- 11 files changed, 78 insertions(+), 7 deletions(-) diff --git a/scaler/about.py b/scaler/about.py index a44132d..2294476 100644 --- a/scaler/about.py +++ b/scaler/about.py @@ -1 +1 @@ -__version__ = "1.8.3" +__version__ = "1.8.4" diff --git a/scaler/cluster/cluster.py b/scaler/cluster/cluster.py index 2303a17..1993dc4 100644 --- a/scaler/cluster/cluster.py +++ b/scaler/cluster/cluster.py @@ -20,6 +20,7 @@ def __init__( death_timeout_seconds: int, garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, + hard_processor_suspend: bool, event_loop: str, logging_paths: Tuple[str, ...], logging_level: str, @@ -35,6 +36,7 @@ def __init__( self._death_timeout_seconds = death_timeout_seconds self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes + self._hard_processor_suspend = hard_processor_suspend self._event_loop = event_loop self._logging_paths = logging_paths @@ -76,6 +78,7 @@ def __start_workers_and_run_forever(self): trim_memory_threshold_bytes=self._trim_memory_threshold_bytes, task_timeout_seconds=self._task_timeout_seconds, death_timeout_seconds=self._death_timeout_seconds, + hard_processor_suspend=self._hard_processor_suspend, logging_paths=self._logging_paths, logging_level=self._logging_level, ) diff --git a/scaler/cluster/combo.py b/scaler/cluster/combo.py index c01e918..d7af8ec 100644 --- a/scaler/cluster/combo.py +++ b/scaler/cluster/combo.py @@ -18,6 +18,7 @@ DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, DEFAULT_WORKER_DEATH_TIMEOUT, DEFAULT_WORKER_TIMEOUT_SECONDS, + DEFAULT_HARD_PROCESSOR_SUSPEND, ) from scaler.utility.zmq_config import ZMQConfig @@ -41,6 +42,7 @@ def __init__( garbage_collect_interval_seconds: int = DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS, trim_memory_threshold_bytes: int = DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, per_worker_queue_size: int = DEFAULT_PER_WORKER_QUEUE_SIZE, + hard_processor_suspend: bool = DEFAULT_HARD_PROCESSOR_SUSPEND, protected: bool = True, event_loop: str = "builtin", logging_paths: Tuple[str, ...] = ("/dev/stdout",), @@ -56,6 +58,7 @@ def __init__( death_timeout_seconds=death_timeout_seconds, garbage_collect_interval_seconds=garbage_collect_interval_seconds, trim_memory_threshold_bytes=trim_memory_threshold_bytes, + hard_processor_suspend=hard_processor_suspend, event_loop=event_loop, logging_paths=logging_paths, logging_level=logging_level, diff --git a/scaler/entry_points/cluster.py b/scaler/entry_points/cluster.py index 126d204..581465a 100644 --- a/scaler/entry_points/cluster.py +++ b/scaler/entry_points/cluster.py @@ -10,6 +10,7 @@ DEFAULT_TASK_TIMEOUT_SECONDS, DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, DEFAULT_WORKER_DEATH_TIMEOUT, + DEFAULT_HARD_PROCESSOR_SUSPEND, ) from scaler.utility.event_loop import EventLoopType, register_event_loop from scaler.utility.zmq_config import ZMQConfig @@ -66,6 +67,17 @@ def get_args(): parser.add_argument( "--io-threads", "-it", default=DEFAULT_IO_THREADS, help="specify number of io threads per worker" ) + parser.add_argument( + "--hard-processor-suspend", + "-hps", + action="store_true", + default=DEFAULT_HARD_PROCESSOR_SUSPEND, + help=( + "When set, suspends worker processors using the SIGTSTP signal instead of a synchronization event, " + "fully halting computation on suspended tasks. Note that this may cause some tasks to fail if they " + "do not support being paused at the OS level (e.g. tasks requiring active network connections)." + ) + ) parser.add_argument( "--log-hub-address", "-la", default=None, type=ZMQConfig.from_string, help="address for Worker send logs" ) @@ -119,6 +131,7 @@ def main(): garbage_collect_interval_seconds=args.garbage_collect_interval_seconds, trim_memory_threshold_bytes=args.trim_memory_threshold_bytes, death_timeout_seconds=args.death_timeout_seconds, + hard_processor_suspend=args.hard_processor_suspend, event_loop=args.event_loop, worker_io_threads=args.io_threads, logging_paths=args.logging_paths, diff --git a/scaler/io/config.py b/scaler/io/config.py index 5c55fd7..3934444 100644 --- a/scaler/io/config.py +++ b/scaler/io/config.py @@ -73,3 +73,7 @@ # the global client name for get log, right now, all client shared remote log, task need have client information to # deliver log DUMMY_CLIENT = b"dummy_client" + +# if true, suspended worker's processors will be actively suspended with a SIGTSTP signal, otherwise a synchronization +# event will be used. +DEFAULT_HARD_PROCESSOR_SUSPEND = False diff --git a/scaler/worker/agent/processor/processor.py b/scaler/worker/agent/processor/processor.py index a35152f..41cfdb2 100644 --- a/scaler/worker/agent/processor/processor.py +++ b/scaler/worker/agent/processor/processor.py @@ -5,6 +5,7 @@ import signal import uuid from contextvars import ContextVar, Token +from multiprocessing.synchronize import Event as EventType from typing import Callable, List, Optional, Tuple import tblib.pickling_support @@ -28,6 +29,8 @@ from scaler.utility.zmq_config import ZMQConfig from scaler.worker.agent.processor.object_cache import ObjectCache +SUSPEND_SIGNAL: signal.Signals = signal.SIGUSR1 + _current_processor: ContextVar[Optional["Processor"]] = ContextVar("_current_processor", default=None) @@ -36,6 +39,7 @@ def __init__( self, event_loop: str, address: ZMQConfig, + resume_event: Optional[EventType], garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, logging_paths: Tuple[str, ...], @@ -46,6 +50,8 @@ def __init__( self._event_loop = event_loop self._address = address + self._resume_event = resume_event + self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes self._logging_paths = logging_paths @@ -88,14 +94,21 @@ def __initialize(self): ) self._object_cache.start() - self.__register_signal() + self.__register_signals() - def __register_signal(self): + def __register_signals(self): signal.signal(signal.SIGTERM, self.__interrupt) + if self._resume_event is not None: + signal.signal(SUSPEND_SIGNAL, self.__suspend) + def __interrupt(self, *args): self._connector.close() # interrupts any blocking socket. + def __suspend(self, *args): + assert self._resume_event is not None + self._resume_event.wait() # stops any computation in the main thread until the event is triggered + def __run_forever(self): try: self._connector.send(ProcessorInitialized.new_msg()) diff --git a/scaler/worker/agent/processor_holder.py b/scaler/worker/agent/processor_holder.py index 8cd29cd..54be9e6 100644 --- a/scaler/worker/agent/processor_holder.py +++ b/scaler/worker/agent/processor_holder.py @@ -2,6 +2,7 @@ import logging import os import signal +from multiprocessing import Event from typing import Optional, Tuple import psutil @@ -9,7 +10,7 @@ from scaler.io.config import DEFAULT_PROCESSOR_KILL_DELAY_SECONDS from scaler.protocol.python.message import Task from scaler.utility.zmq_config import ZMQConfig -from scaler.worker.agent.processor.processor import Processor +from scaler.worker.agent.processor.processor import Processor, SUSPEND_SIGNAL class ProcessorHolder: @@ -19,6 +20,7 @@ def __init__( address: ZMQConfig, garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, + hard_suspend: bool, logging_paths: Tuple[str, ...], logging_level: str, ): @@ -27,9 +29,16 @@ def __init__( self._initialized = asyncio.Event() self._suspended = False + self._hard_suspend = hard_suspend + if hard_suspend: + self._resume_event = None + else: + self._resume_event = Event() + self._processor = Processor( event_loop=event_loop, address=address, + resume_event=self._resume_event, garbage_collect_interval_seconds=garbage_collect_interval_seconds, trim_memory_threshold_bytes=trim_memory_threshold_bytes, logging_paths=logging_paths, @@ -73,14 +82,32 @@ def suspend(self): assert self._task is not None assert self._suspended is False - os.kill(self.pid(), signal.SIGSTOP) # type: ignore + if self._hard_suspend: + os.kill(self.pid(), signal.SIGSTOP) + else: + # If we do not want to hardly suspend the processor's process (e.g. to keep network links alive), we request + # the process to wait on a synchronization event. That will stop the main thread while allowing the helper + # threads to continue running. + # + # See https://github.com/Citi/scaler/issues/14 + + assert self._resume_event is not None + self._resume_event.clear() + + os.kill(self.pid(), SUSPEND_SIGNAL) + self._suspended = True def resume(self): assert self._task is not None assert self._suspended is True - os.kill(self.pid(), signal.SIGCONT) # type: ignore + if self._hard_suspend: + os.kill(self.pid(), signal.SIGCONT) + else: + assert self._resume_event is not None + self._resume_event.set() + self._suspended = False def kill(self): diff --git a/scaler/worker/agent/processor_manager.py b/scaler/worker/agent/processor_manager.py index 53c5c80..5653df0 100644 --- a/scaler/worker/agent/processor_manager.py +++ b/scaler/worker/agent/processor_manager.py @@ -36,6 +36,7 @@ def __init__( io_threads: int, garbage_collect_interval_seconds: int, trim_memory_threshold_bytes: int, + hard_processor_suspend: bool, logging_paths: Tuple[str, ...], logging_level: str, ): @@ -43,8 +44,10 @@ def __init__( self._event_loop = event_loop self._io_threads = io_threads + self._garbage_collect_interval_seconds = garbage_collect_interval_seconds self._trim_memory_threshold_bytes = trim_memory_threshold_bytes + self._hard_processor_suspend = hard_processor_suspend self._logging_paths = logging_paths self._logging_level = logging_level @@ -245,6 +248,7 @@ def __start_new_processor(self): self._address, self._garbage_collect_interval_seconds, self._trim_memory_threshold_bytes, + self._hard_processor_suspend, self._logging_paths, self._logging_level, ) diff --git a/scaler/worker/worker.py b/scaler/worker/worker.py index 4de8d2b..e9ec4a7 100644 --- a/scaler/worker/worker.py +++ b/scaler/worker/worker.py @@ -42,6 +42,7 @@ def __init__( trim_memory_threshold_bytes: int, task_timeout_seconds: int, death_timeout_seconds: int, + hard_processor_suspend: bool, logging_paths: Tuple[str, ...], logging_level: str, ): @@ -57,6 +58,7 @@ def __init__( self._trim_memory_threshold_bytes = trim_memory_threshold_bytes self._task_timeout_seconds = task_timeout_seconds self._death_timeout_seconds = death_timeout_seconds + self._hard_processor_suspend = hard_processor_suspend self._logging_paths = logging_paths self._logging_level = logging_level @@ -98,6 +100,7 @@ def __initialize(self): io_threads=self._io_threads, garbage_collect_interval_seconds=self._garbage_collect_interval_seconds, trim_memory_threshold_bytes=self._trim_memory_threshold_bytes, + hard_processor_suspend=self._hard_processor_suspend, logging_paths=self._logging_paths, logging_level=self._logging_level, ) diff --git a/tests/test_balance.py b/tests/test_balance.py index 0900516..339c8e0 100644 --- a/tests/test_balance.py +++ b/tests/test_balance.py @@ -39,6 +39,7 @@ def test_balance(self): death_timeout_seconds=combo._cluster._death_timeout_seconds, garbage_collect_interval_seconds=combo._cluster._garbage_collect_interval_seconds, trim_memory_threshold_bytes=combo._cluster._trim_memory_threshold_bytes, + hard_processor_suspend=combo._cluster._hard_processor_suspend, event_loop=combo._cluster._event_loop, logging_paths=combo._cluster._logging_paths, logging_level=combo._cluster._logging_level, diff --git a/tests/test_death_timeout.py b/tests/test_death_timeout.py index 2a0fc12..39cdcb3 100644 --- a/tests/test_death_timeout.py +++ b/tests/test_death_timeout.py @@ -33,6 +33,7 @@ def test_no_scheduler(self): trim_memory_threshold_bytes=DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES, task_timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS, death_timeout_seconds=10, + hard_processor_suspend=False, event_loop="builtin", logging_paths=("/dev/stdout",), logging_level="INFO", @@ -58,7 +59,6 @@ def test_shutdown(self): # this is combo cluster, client only shutdown clusters, not scheduler, so scheduler need be shutdown also cluster.shutdown() - @unittest.skip("client timeout is currently not prevented on suspended processors") def test_no_timeout_if_suspended(self): """ Client and scheduler shouldn't timeout a client if it is running inside a suspended processor.