Skip to content

Commit

Permalink
move multiprocessing pre-processing to preprocess (#4425)
Browse files Browse the repository at this point in the history
* move multiprocessing pre-processing to preprocess

* add test for None case

* you knew it was a bad idea... 🙃

* changelog
  • Loading branch information
timmysilv authored Aug 3, 2023
1 parent cae8caf commit 109b3e8
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 88 deletions.
1 change: 1 addition & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
or not given, only the current process executes tapes. If you experience any
issue, say using JAX, TensorFlow, Torch, try setting `max_workers` to `None`.
[(#4319)](https://github.com/PennyLaneAI/pennylane/pull/4319)
[(#4425)](https://github.com/PennyLaneAI/pennylane/pull/4425)

* Transform Programs are now integrated with the `QNode`.
[(#4404)](https://github.com/PennyLaneAI/pennylane/pull/4404)
Expand Down
110 changes: 29 additions & 81 deletions pennylane/devices/experimental/default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
from numbers import Number
from typing import Union, Callable, Tuple, Optional, Sequence
import concurrent.futures
import os
import warnings
import numpy as np

from pennylane.tape import QuantumTape, QuantumScript
Expand All @@ -31,7 +29,11 @@
from . import Device
from .execution_config import ExecutionConfig, DefaultExecutionConfig
from ..qubit.simulate import simulate, get_final_state, measure_final_state
from ..qubit.preprocess import preprocess, validate_and_expand_adjoint
from ..qubit.preprocess import (
preprocess,
validate_and_expand_adjoint,
validate_multiprocessing_workers,
)
from ..qubit.adjoint_jacobian import adjoint_jacobian, adjoint_vjp, adjoint_jvp

Result_or_ResultBatch = Union[Result, ResultBatch]
Expand Down Expand Up @@ -170,7 +172,7 @@ def supports_derivatives(
# do once device accepts finite shots
if (
execution_config.gradient_method == "backprop"
and self._get_max_workers(execution_config) is None
and execution_config.device_options.get("max_workers", self._max_workers) is None
):
return True

Expand Down Expand Up @@ -211,6 +213,10 @@ def preprocess(
circuits = [circuits]
is_single_circuit = True

# prefer config over device value
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
self._validate_multiprocessing(max_workers, circuits)

batch, post_processing_fn, config = preprocess(circuits, execution_config=execution_config)

if is_single_circuit:
Expand Down Expand Up @@ -239,17 +245,16 @@ def execute(
self.tracker.update(batches=1, executions=len(circuits))
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
results = tuple(simulate(c, rng=self._rng, debugger=self._debugger) for c in circuits)
else:
self._validate_multiprocessing_circuits(circuits)
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))
_wrap_simulate = partial(simulate, debugger=None)
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
exec_map = executor.map(_wrap_simulate, vanilla_circuits, seeds)
results = tuple(circuit for circuit in exec_map)
results = tuple(exec_map)

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))
Expand All @@ -270,14 +275,14 @@ def compute_derivatives(
self.tracker.update(derivative_batches=1, derivatives=len(circuits))
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
res = tuple(adjoint_jacobian(circuit) for circuit in circuits)
else:
vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
exec_map = executor.map(adjoint_jacobian, vanilla_circuits)
res = tuple(circuit for circuit in exec_map)
res = tuple(exec_map)

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))
Expand All @@ -304,26 +309,22 @@ def execute_and_compute_derivatives(
)
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
results = tuple(
_adjoint_jac_wrapper(c, rng=self._rng, debugger=self._debugger) for c in circuits
)
results, jacs = tuple(zip(*results))
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
results = tuple(executor.map(_adjoint_jac_wrapper, vanilla_circuits, seeds))

results, jacs = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

results, jacs = tuple(zip(*results))
return (results[0], jacs[0]) if is_single_circuit else (results, jacs)

def supports_jvp(
Expand Down Expand Up @@ -361,7 +362,7 @@ def compute_jvp(
self.tracker.update(jvp_batches=1, jvps=len(circuits))
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
res = tuple(adjoint_jvp(circuit, tans) for circuit, tans in zip(circuits, tangents))
else:
Expand Down Expand Up @@ -394,16 +395,13 @@ def execute_and_compute_jvp(
)
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
results = tuple(
_adjoint_jvp_wrapper(c, t, rng=self._rng, debugger=self._debugger)
for c, t in zip(circuits, tangents)
)
results, jvps = tuple(zip(*results))
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

Expand All @@ -412,11 +410,10 @@ def execute_and_compute_jvp(
executor.map(_adjoint_jvp_wrapper, vanilla_circuits, tangents, seeds)
)

results, jvps = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

results, jvps = tuple(zip(*results))
return (results[0], jvps[0]) if is_single_circuit else (results, jvps)

def supports_vjp(
Expand Down Expand Up @@ -454,7 +451,7 @@ def compute_vjp(
self.tracker.update(vjp_batches=1, vjps=len(circuits))
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
res = tuple(adjoint_vjp(circuit, cots) for circuit, cots in zip(circuits, cotangents))
else:
Expand Down Expand Up @@ -487,16 +484,13 @@ def execute_and_compute_vjp(
)
self.tracker.record()

max_workers = self._get_max_workers(execution_config)
max_workers = execution_config.device_options.get("max_workers", self._max_workers)
if max_workers is None:
results = tuple(
_adjoint_vjp_wrapper(c, t, rng=self._rng, debugger=self._debugger)
for c, t in zip(circuits, cotangents)
)
results, vjps = tuple(zip(*results))
else:
self._validate_multiprocessing_circuits(circuits)

vanilla_circuits = [convert_to_numpy_parameters(c) for c in circuits]
seeds = self._rng.integers(2**31 - 1, size=len(vanilla_circuits))

Expand All @@ -505,33 +499,24 @@ def execute_and_compute_vjp(
executor.map(_adjoint_vjp_wrapper, vanilla_circuits, cotangents, seeds)
)

results, vjps = tuple(zip(*results))

# reset _rng to mimic serial behavior
self._rng = np.random.default_rng(self._rng.integers(2**31 - 1))

results, vjps = tuple(zip(*results))
return (results[0], vjps[0]) if is_single_circuit else (results, vjps)

# pylint: disable=missing-function-docstring
def _get_max_workers(self, execution_config=None):
max_workers = None
if (
execution_config
and execution_config.device_options
and "max_workers" in execution_config.device_options
):
max_workers = execution_config.device_options["max_workers"]
else:
max_workers = self._max_workers
_validate_multiprocessing_workers(max_workers)
return max_workers

def _validate_multiprocessing_circuits(self, circuits):
def _validate_multiprocessing(self, max_workers, circuits):
"""Make sure the tapes can be processed by a ProcessPoolExecutor instance.
Args:
max_workers (Union[int]): Maximal number of multiprocessing workers
circuits (QuantumTape_or_Batch): Quantum tapes
"""
if max_workers is None:
return

validate_multiprocessing_workers(max_workers)

if self._debugger and self._debugger.active:
raise DeviceError("Debugging with ``Snapshots`` is not available with multiprocessing.")

Expand All @@ -546,43 +531,6 @@ def _has_snapshot(circuit):
)


def _validate_multiprocessing_workers(max_workers):
"""Validates the number of workers for multiprocessing.
Checks that the CPU is not oversubscribed and warns user if it is,
making suggestions for the number of workers and/or the number of
threads per worker.
Args:
max_workers (int): Maximal number of multiprocessing workers
"""
if max_workers is None:
return
threads_per_proc = os.cpu_count() # all threads by default
varname = "OMP_NUM_THREADS"
varnames = ["MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "OMP_NUM_THREADS"]
for var in varnames:
if os.getenv(var): # pragma: no cover
varname = var
threads_per_proc = int(os.getenv(var))
break
num_threads = threads_per_proc * max_workers
num_cpu = os.cpu_count()
num_threads_suggest = max(1, os.cpu_count() // max_workers)
num_workers_suggest = max(1, os.cpu_count() // threads_per_proc)
if num_threads > num_cpu:
warnings.warn(
f"""The device requested {num_threads} threads ({max_workers} processes
times {threads_per_proc} threads per process), but the processor only has
{num_cpu} logical cores. The processor is likely oversubscribed, which may
lead to performance deterioration. Consider decreasing the number of processes,
setting the device or execution config argument `max_workers={num_workers_suggest}`
for example, or decreasing the number of threads per process by setting the
environment variable `{varname}={num_threads_suggest}`.""",
UserWarning,
)


def _adjoint_jac_wrapper(c, rng=None, debugger=None):
state, is_state_batched = get_final_state(c, debugger=debugger)
jac = adjoint_jacobian(c, state=state)
Expand Down
46 changes: 40 additions & 6 deletions pennylane/devices/qubit/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
that they are supported for execution by a device."""
# pylint: disable=protected-access
from dataclasses import replace
import os
from typing import Generator, Callable, Tuple, Union
import warnings
from functools import partial
Expand Down Expand Up @@ -66,10 +67,7 @@ def _accepted_operator(op: qml.operation.Operator) -> bool:
return False
if op.name == "GroverOperator" and len(op.wires) >= 13:
return False
if op.name == "Snapshot":
return True

return op.has_matrix
return op.name == "Snapshot" or op.has_matrix


def _accepted_adjoint_operator(op: qml.operation.Operator) -> bool:
Expand Down Expand Up @@ -98,6 +96,43 @@ def _operator_decomposition_gen(
#######################


def validate_multiprocessing_workers(max_workers):
"""Validates the number of workers for multiprocessing.
Checks that the CPU is not oversubscribed and warns user if it is,
making suggestions for the number of workers and/or the number of
threads per worker.
Args:
max_workers (int): Maximal number of multiprocessing workers
"""
if max_workers is None:
return
threads_per_proc = os.cpu_count() # all threads by default
varname = "OMP_NUM_THREADS"
varnames = ["MKL_NUM_THREADS", "OPENBLAS_NUM_THREADS", "OMP_NUM_THREADS"]
for var in varnames:
if os.getenv(var): # pragma: no cover
varname = var
threads_per_proc = int(os.getenv(var))
break
num_threads = threads_per_proc * max_workers
num_cpu = os.cpu_count()
num_threads_suggest = max(1, os.cpu_count() // max_workers)
num_workers_suggest = max(1, os.cpu_count() // threads_per_proc)
if num_threads > num_cpu:
warnings.warn(
f"""The device requested {num_threads} threads ({max_workers} processes
times {threads_per_proc} threads per process), but the processor only has
{num_cpu} logical cores. The processor is likely oversubscribed, which may
lead to performance deterioration. Consider decreasing the number of processes,
setting the device or execution config argument `max_workers={num_workers_suggest}`
for example, or decreasing the number of threads per process by setting the
environment variable `{varname}={num_threads_suggest}`.""",
UserWarning,
)


def validate_and_expand_adjoint(
circuit: qml.tape.QuantumTape,
) -> Union[qml.tape.QuantumTape, DeviceError]: # pylint: disable=protected-access
Expand Down Expand Up @@ -154,8 +189,7 @@ def validate_and_expand_adjoint(

measurements.append(m)

expanded_tape = qml.tape.QuantumScript(new_ops, measurements, prep, circuit.shots)
return expanded_tape
return qml.tape.QuantumScript(new_ops, measurements, prep, circuit.shots)


def validate_measurements(
Expand Down
2 changes: 1 addition & 1 deletion tests/devices/experimental/test_default_qubit_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_snapshot_multiprocessing_execute():
[qml.expval(qml.PauliX(0))],
)
with pytest.raises(RuntimeError, match="ProcessPoolExecutor cannot execute a QuantumScript"):
dev.execute(tape)
dev.preprocess(tape)


def test_snapshot_multiprocessing_qnode():
Expand Down
6 changes: 6 additions & 0 deletions tests/devices/qubit/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
preprocess,
validate_and_expand_adjoint,
validate_measurements,
validate_multiprocessing_workers,
)
from pennylane.devices.experimental import ExecutionConfig
from pennylane.measurements import MidMeasureMP, MeasurementValue
Expand Down Expand Up @@ -783,3 +784,8 @@ def test_preprocess_tape_for_adjoint(self):
qml.equal(o1, o2) for o1, o2 in zip(expanded_qs.measurements, expected_qs.measurements)
)
assert expanded_qs.trainable_params == expected_qs.trainable_params


def test_validate_multiprocessing_workers_None():
"""Test that validation does not fail when max_workers is None"""
validate_multiprocessing_workers(None)

0 comments on commit 109b3e8

Please sign in to comment.