Skip to content

Commit

Permalink
fix(multiprocessing): Reset pool if tasks are not completed (#315)
Browse files Browse the repository at this point in the history
If multiprocessing tasks are not completed within the timeout specified, we need to reset the pool to avoid state being carried over between assignments.
  • Loading branch information
lynnagara authored Dec 11, 2023
1 parent 17b75da commit 7e16346
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 10 deletions.
49 changes: 39 additions & 10 deletions arroyo/processing/strategies/run_task_with_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,11 @@ class MultiprocessingPool:
NOTE: The close() method must be called when shutting down the consumer.
The `close()` method is also called by `RunTaskWithMultiprocessing` when
there are uncompleted pending tasks to ensure no state is carried over.
The `maybe_create_pool` function is called on every assignment to ensure
the pool is re-created again, in case it was closed on the previous recovation.
:param num_processes: The number of processes to spawn.
:param initializer: A function to run at the beginning of each subprocess.
Expand All @@ -301,13 +306,22 @@ def __init__(
num_processes: int,
initializer: Optional[Callable[[], None]] = None,
) -> None:
self.__pool = Pool(
num_processes,
initializer=partial(parallel_worker_initializer, initializer),
context=multiprocessing.get_context("spawn"),
)
self.__num_processes = num_processes
self.__initializer = initializer
self.__pool: Optional[Pool] = None
self.__metrics = get_metrics()
self.maybe_create_pool()

def maybe_create_pool(self) -> None:
if self.__pool is None:
self.__metrics.increment(
"arroyo.strategies.run_task_with_multiprocessing.pool.create"
)
self.__pool = Pool(
self.__num_processes,
initializer=partial(parallel_worker_initializer, self.__initializer),
context=multiprocessing.get_context("spawn"),
)

@property
def num_processes(self) -> int:
Expand All @@ -318,13 +332,20 @@ def initializer(self) -> Optional[Callable[[], None]]:
return self.__initializer

def apply_async(self, *args: Any, **kwargs: Any) -> Any:
return self.__pool.apply_async(*args, **kwargs)
if self.__pool:
return self.__pool.apply_async(*args, **kwargs)
else:
raise RuntimeError("No pool available")

def close(self) -> None:
"""
Must be called manually when shutting down the consumer.
Also called from strategy.join() if there are pending futures in order
ensure state is completely cleaned up.
"""
self.__pool.terminate()
if self.__pool:
self.__pool.terminate()
self.__pool = None


class RunTaskWithMultiprocessing(
Expand Down Expand Up @@ -488,12 +509,13 @@ def __init__(
self.__max_input_block_size = max_input_block_size
self.__max_output_block_size = max_output_block_size

self.__shared_memory_manager = SharedMemoryManager()
self.__shared_memory_manager.start()

self.__pool = pool
self.__pool.maybe_create_pool()
num_processes = self.__pool.num_processes

self.__shared_memory_manager = SharedMemoryManager()
self.__shared_memory_manager.start()

self.__input_blocks = [
self.__shared_memory_manager.SharedMemory(
input_block_size or DEFAULT_INPUT_BLOCK_SIZE
Expand Down Expand Up @@ -817,6 +839,8 @@ def terminate(self) -> None:
logger.info("Terminating %r...", self.__pool)

logger.info("Shutting down %r...", self.__shared_memory_manager)
self.__pool.close()

self.__shared_memory_manager.shutdown()

logger.info("Terminating %r...", self.__next_step)
Expand All @@ -841,6 +865,11 @@ def join(self, timeout: Optional[float] = None) -> None:

logger.debug("Waiting for %s...", self.__pool)

# XXX: We need to recreate the pool if there are still pending futures, to avoid
# state from the previous assignment not being properly cleaned up.
if len(self.__processes):
self.__pool.close()

self.__shared_memory_manager.shutdown()

self.__next_step.close()
Expand Down
2 changes: 2 additions & 0 deletions arroyo/utils/metric_defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
# Gauge. Shows how many processes the multiprocessing strategy is
# configured with.
"arroyo.strategies.run_task_with_multiprocessing.processes",
# Counter. Incremented when the multiprocessing pool is created (or re-created).
"arroyo.strategies.run_task_with_multiprocessing.pool.create",
# Time (unitless) spent polling librdkafka for new messages.
"arroyo.consumer.poll.time",
# Time (unitless) spent in strategies (blocking in strategy.submit or
Expand Down
10 changes: 10 additions & 0 deletions tests/processing/strategies/test_run_task_with_multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,11 @@ def test_parallel_transform_step() -> None:
lambda: metrics.calls,
[],
[
IncrementCall(
name="arroyo.strategies.run_task_with_multiprocessing.pool.create",
value=1,
tags=None,
),
GaugeCall(
"arroyo.strategies.run_task_with_multiprocessing.batches_in_progress",
0.0,
Expand Down Expand Up @@ -394,6 +399,11 @@ def test_message_rejected_multiple() -> None:
]

assert TestingMetricsBackend.calls == [
IncrementCall(
name="arroyo.strategies.run_task_with_multiprocessing.pool.create",
value=1,
tags=None,
),
GaugeCall(
name="arroyo.strategies.run_task_with_multiprocessing.batches_in_progress",
value=0.0,
Expand Down

0 comments on commit 7e16346

Please sign in to comment.