Skip to content

Commit

Permalink
Suspend the processor with a synchronization event instead of SIGSTOP.
Browse files Browse the repository at this point in the history
Signed-off-by: rafa-be <[email protected]>
  • Loading branch information
rafa-be committed Oct 1, 2024
1 parent 151a82e commit ee16d60
Show file tree
Hide file tree
Showing 11 changed files with 78 additions and 7 deletions.
2 changes: 1 addition & 1 deletion scaler/about.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "1.8.3"
__version__ = "1.8.4"
3 changes: 3 additions & 0 deletions scaler/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(
death_timeout_seconds: int,
garbage_collect_interval_seconds: int,
trim_memory_threshold_bytes: int,
hard_processor_suspend: bool,
event_loop: str,
logging_paths: Tuple[str, ...],
logging_level: str,
Expand All @@ -35,6 +36,7 @@ def __init__(
self._death_timeout_seconds = death_timeout_seconds
self._garbage_collect_interval_seconds = garbage_collect_interval_seconds
self._trim_memory_threshold_bytes = trim_memory_threshold_bytes
self._hard_processor_suspend = hard_processor_suspend
self._event_loop = event_loop

self._logging_paths = logging_paths
Expand Down Expand Up @@ -76,6 +78,7 @@ def __start_workers_and_run_forever(self):
trim_memory_threshold_bytes=self._trim_memory_threshold_bytes,
task_timeout_seconds=self._task_timeout_seconds,
death_timeout_seconds=self._death_timeout_seconds,
hard_processor_suspend=self._hard_processor_suspend,
logging_paths=self._logging_paths,
logging_level=self._logging_level,
)
Expand Down
3 changes: 3 additions & 0 deletions scaler/cluster/combo.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES,
DEFAULT_WORKER_DEATH_TIMEOUT,
DEFAULT_WORKER_TIMEOUT_SECONDS,
DEFAULT_HARD_PROCESSOR_SUSPEND,
)
from scaler.utility.zmq_config import ZMQConfig

Expand All @@ -41,6 +42,7 @@ def __init__(
garbage_collect_interval_seconds: int = DEFAULT_GARBAGE_COLLECT_INTERVAL_SECONDS,
trim_memory_threshold_bytes: int = DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES,
per_worker_queue_size: int = DEFAULT_PER_WORKER_QUEUE_SIZE,
hard_processor_suspend: bool = DEFAULT_HARD_PROCESSOR_SUSPEND,
protected: bool = True,
event_loop: str = "builtin",
logging_paths: Tuple[str, ...] = ("/dev/stdout",),
Expand All @@ -56,6 +58,7 @@ def __init__(
death_timeout_seconds=death_timeout_seconds,
garbage_collect_interval_seconds=garbage_collect_interval_seconds,
trim_memory_threshold_bytes=trim_memory_threshold_bytes,
hard_processor_suspend=hard_processor_suspend,
event_loop=event_loop,
logging_paths=logging_paths,
logging_level=logging_level,
Expand Down
13 changes: 13 additions & 0 deletions scaler/entry_points/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
DEFAULT_TASK_TIMEOUT_SECONDS,
DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES,
DEFAULT_WORKER_DEATH_TIMEOUT,
DEFAULT_HARD_PROCESSOR_SUSPEND,
)
from scaler.utility.event_loop import EventLoopType, register_event_loop
from scaler.utility.zmq_config import ZMQConfig
Expand Down Expand Up @@ -66,6 +67,17 @@ def get_args():
parser.add_argument(
"--io-threads", "-it", default=DEFAULT_IO_THREADS, help="specify number of io threads per worker"
)
parser.add_argument(
"--hard-processor-suspend",
"-hps",
action="store_true",
default=DEFAULT_HARD_PROCESSOR_SUSPEND,
help=(
"When set, suspends worker processors using the SIGTSTP signal instead of a synchronization event, "
"fully halting computation on suspended tasks. Note that this may cause some tasks to fail if they "
"do not support being paused at the OS level (e.g. tasks requiring active network connections)."
)
)
parser.add_argument(
"--log-hub-address", "-la", default=None, type=ZMQConfig.from_string, help="address for Worker send logs"
)
Expand Down Expand Up @@ -119,6 +131,7 @@ def main():
garbage_collect_interval_seconds=args.garbage_collect_interval_seconds,
trim_memory_threshold_bytes=args.trim_memory_threshold_bytes,
death_timeout_seconds=args.death_timeout_seconds,
hard_processor_suspend=args.hard_processor_suspend,
event_loop=args.event_loop,
worker_io_threads=args.io_threads,
logging_paths=args.logging_paths,
Expand Down
4 changes: 4 additions & 0 deletions scaler/io/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,3 +73,7 @@
# the global client name for get log, right now, all client shared remote log, task need have client information to
# deliver log
DUMMY_CLIENT = b"dummy_client"

# if true, suspended worker's processors will be actively suspended with a SIGTSTP signal, otherwise a synchronization
# event will be used.
DEFAULT_HARD_PROCESSOR_SUSPEND = False
17 changes: 15 additions & 2 deletions scaler/worker/agent/processor/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import signal
import uuid
from contextvars import ContextVar, Token
from multiprocessing.synchronize import Event as EventType
from typing import Callable, List, Optional, Tuple

import tblib.pickling_support
Expand All @@ -28,6 +29,8 @@
from scaler.utility.zmq_config import ZMQConfig
from scaler.worker.agent.processor.object_cache import ObjectCache

SUSPEND_SIGNAL: signal.Signals = signal.SIGUSR1

_current_processor: ContextVar[Optional["Processor"]] = ContextVar("_current_processor", default=None)


Expand All @@ -36,6 +39,7 @@ def __init__(
self,
event_loop: str,
address: ZMQConfig,
resume_event: Optional[EventType],
garbage_collect_interval_seconds: int,
trim_memory_threshold_bytes: int,
logging_paths: Tuple[str, ...],
Expand All @@ -46,6 +50,8 @@ def __init__(
self._event_loop = event_loop
self._address = address

self._resume_event = resume_event

self._garbage_collect_interval_seconds = garbage_collect_interval_seconds
self._trim_memory_threshold_bytes = trim_memory_threshold_bytes
self._logging_paths = logging_paths
Expand Down Expand Up @@ -88,14 +94,21 @@ def __initialize(self):
)
self._object_cache.start()

self.__register_signal()
self.__register_signals()

def __register_signal(self):
def __register_signals(self):
signal.signal(signal.SIGTERM, self.__interrupt)

if self._resume_event is not None:
signal.signal(SUSPEND_SIGNAL, self.__suspend)

def __interrupt(self, *args):
self._connector.close() # interrupts any blocking socket.

def __suspend(self, *args):
assert self._resume_event is not None
self._resume_event.wait() # stops any computation in the main thread until the event is triggered

def __run_forever(self):
try:
self._connector.send(ProcessorInitialized.new_msg())
Expand Down
33 changes: 30 additions & 3 deletions scaler/worker/agent/processor_holder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,15 @@
import logging
import os
import signal
from multiprocessing import Event
from typing import Optional, Tuple

import psutil

from scaler.io.config import DEFAULT_PROCESSOR_KILL_DELAY_SECONDS
from scaler.protocol.python.message import Task
from scaler.utility.zmq_config import ZMQConfig
from scaler.worker.agent.processor.processor import Processor
from scaler.worker.agent.processor.processor import Processor, SUSPEND_SIGNAL


class ProcessorHolder:
Expand All @@ -19,6 +20,7 @@ def __init__(
address: ZMQConfig,
garbage_collect_interval_seconds: int,
trim_memory_threshold_bytes: int,
hard_suspend: bool,
logging_paths: Tuple[str, ...],
logging_level: str,
):
Expand All @@ -27,9 +29,16 @@ def __init__(
self._initialized = asyncio.Event()
self._suspended = False

self._hard_suspend = hard_suspend
if hard_suspend:
self._resume_event = None
else:
self._resume_event = Event()

self._processor = Processor(
event_loop=event_loop,
address=address,
resume_event=self._resume_event,
garbage_collect_interval_seconds=garbage_collect_interval_seconds,
trim_memory_threshold_bytes=trim_memory_threshold_bytes,
logging_paths=logging_paths,
Expand Down Expand Up @@ -73,14 +82,32 @@ def suspend(self):
assert self._task is not None
assert self._suspended is False

os.kill(self.pid(), signal.SIGSTOP) # type: ignore
if self._hard_suspend:
os.kill(self.pid(), signal.SIGSTOP)
else:
# If we do not want to hardly suspend the processor's process (e.g. to keep network links alive), we request
# the process to wait on a synchronization event. That will stop the main thread while allowing the helper
# threads to continue running.
#
# See https://github.com/Citi/scaler/issues/14

assert self._resume_event is not None
self._resume_event.clear()

os.kill(self.pid(), SUSPEND_SIGNAL)

self._suspended = True

def resume(self):
assert self._task is not None
assert self._suspended is True

os.kill(self.pid(), signal.SIGCONT) # type: ignore
if self._hard_suspend:
os.kill(self.pid(), signal.SIGCONT)
else:
assert self._resume_event is not None
self._resume_event.set()

self._suspended = False

def kill(self):
Expand Down
4 changes: 4 additions & 0 deletions scaler/worker/agent/processor_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,18 @@ def __init__(
io_threads: int,
garbage_collect_interval_seconds: int,
trim_memory_threshold_bytes: int,
hard_processor_suspend: bool,
logging_paths: Tuple[str, ...],
logging_level: str,
):
tblib.pickling_support.install()

self._event_loop = event_loop
self._io_threads = io_threads

self._garbage_collect_interval_seconds = garbage_collect_interval_seconds
self._trim_memory_threshold_bytes = trim_memory_threshold_bytes
self._hard_processor_suspend = hard_processor_suspend
self._logging_paths = logging_paths
self._logging_level = logging_level

Expand Down Expand Up @@ -245,6 +248,7 @@ def __start_new_processor(self):
self._address,
self._garbage_collect_interval_seconds,
self._trim_memory_threshold_bytes,
self._hard_processor_suspend,
self._logging_paths,
self._logging_level,
)
Expand Down
3 changes: 3 additions & 0 deletions scaler/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
trim_memory_threshold_bytes: int,
task_timeout_seconds: int,
death_timeout_seconds: int,
hard_processor_suspend: bool,
logging_paths: Tuple[str, ...],
logging_level: str,
):
Expand All @@ -57,6 +58,7 @@ def __init__(
self._trim_memory_threshold_bytes = trim_memory_threshold_bytes
self._task_timeout_seconds = task_timeout_seconds
self._death_timeout_seconds = death_timeout_seconds
self._hard_processor_suspend = hard_processor_suspend

self._logging_paths = logging_paths
self._logging_level = logging_level
Expand Down Expand Up @@ -98,6 +100,7 @@ def __initialize(self):
io_threads=self._io_threads,
garbage_collect_interval_seconds=self._garbage_collect_interval_seconds,
trim_memory_threshold_bytes=self._trim_memory_threshold_bytes,
hard_processor_suspend=self._hard_processor_suspend,
logging_paths=self._logging_paths,
logging_level=self._logging_level,
)
Expand Down
1 change: 1 addition & 0 deletions tests/test_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_balance(self):
death_timeout_seconds=combo._cluster._death_timeout_seconds,
garbage_collect_interval_seconds=combo._cluster._garbage_collect_interval_seconds,
trim_memory_threshold_bytes=combo._cluster._trim_memory_threshold_bytes,
hard_processor_suspend=combo._cluster._hard_processor_suspend,
event_loop=combo._cluster._event_loop,
logging_paths=combo._cluster._logging_paths,
logging_level=combo._cluster._logging_level,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_death_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_no_scheduler(self):
trim_memory_threshold_bytes=DEFAULT_TRIM_MEMORY_THRESHOLD_BYTES,
task_timeout_seconds=DEFAULT_TASK_TIMEOUT_SECONDS,
death_timeout_seconds=10,
hard_processor_suspend=False,
event_loop="builtin",
logging_paths=("/dev/stdout",),
logging_level="INFO",
Expand All @@ -58,7 +59,6 @@ def test_shutdown(self):
# this is combo cluster, client only shutdown clusters, not scheduler, so scheduler need be shutdown also
cluster.shutdown()

@unittest.skip("client timeout is currently not prevented on suspended processors")
def test_no_timeout_if_suspended(self):
"""
Client and scheduler shouldn't timeout a client if it is running inside a suspended processor.
Expand Down

0 comments on commit ee16d60

Please sign in to comment.