From 5db098bdb2ad9ea41c296c692af02b2de9f3d9ef Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 15 Jul 2024 15:31:24 -0400 Subject: [PATCH 01/34] E.C. [ci skip] From 50ca82db94d9bffbed7b87ed5fa70b1da0e7e0dc Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 16 Jul 2024 22:24:50 -0400 Subject: [PATCH 02/34] just doubts so far [ci skip] --- pennylane/ops/op_math/condition.py | 149 +++++++++++++++++++++++++++++ 1 file changed, 149 insertions(+) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d3da2814719..e94632cf939 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -17,6 +17,7 @@ from functools import wraps from typing import Type +import pennylane as qml from pennylane import QueuingManager from pennylane.compiler import compiler from pennylane.operation import AnyWires, Operation, Operator @@ -100,6 +101,20 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) +from typing import Callable, Optional, overload, Union, Tuple + +from pennylane.measurements import MeasurementValue + +import functools + + +@overload +def cond( + condition: Union[MeasurementValue, bool], + true_fn: Callable, + false_fn: Optional[Callable] = None, + elifs: Tuple[Tuple[bool, Callable], ...] = (), +) -> Callable: ... def cond(condition, true_fn, false_fn=None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -364,6 +379,11 @@ def qnode(a, x, y, z): >>> qnode(par, x, y, z) tensor(-0.30922805, requires_grad=True) """ + + print( + f"cond function called with condition={condition}, true_fn={true_fn}, false_fn={false_fn}, elifs={elifs}" + ) + if active_jit := compiler.active_compiler(): available_eps = compiler.AvailableCompilers.names_entrypoints ops_loader = available_eps[active_jit]["ops"].load() @@ -379,6 +399,11 @@ def qnode(a, x, y, z): return cond_func + # This will not be the final place for this logic, but it is a start) + if qml.capture.enabled(): + print("Capture mode for cond") + return _capture_cond(condition, true_fn, false_fn, elifs) + if elifs: raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.") @@ -430,3 +455,127 @@ def wrapper(*args, **kwargs): ) return wrapper + + +@functools.lru_cache # only create the first time requested +def _get_cond_qfunc_prim(): + # if capture is enabled, jax should be installed + import jax # pylint: disable=import-outside-toplevel + + AbstractOperator = qml.capture.AbstractOperator + + cond_prim = jax.core.Primitive("cond") + cond_prim.multiple_results = True + + @cond_prim.def_impl + def _(*args, n_elif, jaxpr_true, jaxpr_false, jaxprs_elif, condition): + def run_jaxpr(jaxpr, *args): + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + + def true_branch(args): + return run_jaxpr(jaxpr_true, *args) + + def false_branch(args): + for cond, jaxpr in jaxprs_elif: + + def elif_branch(args): + return run_jaxpr(jaxpr, *args) + + args = jax.lax.cond(cond, elif_branch, lambda x: x, args) + if jaxpr_false: + return run_jaxpr(jaxpr_false, *args) + return args + + return jax.lax.cond(condition, true_branch, false_branch, args) + + def _is_queued_outvar(outvars): + if not outvars: + return False + return isinstance(outvars[0].aval, AbstractOperator) and isinstance( + outvars[0], jax.core.DropVar + ) + + @cond_prim.def_abstract_eval + def _abstract(*args, **kwargs): + return [qml.capture.AbstractOperator()] + + return cond_prim + + +# Vogliamo catturare la funzione 'true_fn', e probabilmente passare 'condition' e 'false_fn' come argomenti. +def _capture_cond(condition, true_fn, false_fn, elifs) -> Callable: + """Capture compatible way to apply conditionally a ....""" + # note that this logic is tested in `tests/capture/test_... + + print("Capture mode for cond") + + import jax # pylint: disable=import-outside-toplevel + + cond_prim = _get_cond_qfunc_prim() + + @wraps(true_fn) + def new_wrapper(*args, **kwargs): + jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) + jaxpr_false = ( + jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None + ) + jaxprs_elif = [ + (cond_val, jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) + for cond_val, elif_fn in elifs + ] + return cond_prim.bind( + *args, + condition=condition, + n_elif=len(elifs), + jaxpr_true=jaxpr_true, + jaxpr_false=jaxpr_false, + jaxprs_elif=jaxprs_elif, + ) + + return new_wrapper + + +def _cond(condition, true_fn, false_fn): + + # We assume that the callable is an operation or a quantum function + with_meas_err = ( + "Only quantum functions that contain no measurements can be applied conditionally." + ) + + @wraps(true_fn) + def wrapper(*args, **kwargs): + # We assume that the callable is a quantum function + + recorded_ops = [a for a in args if isinstance(a, Operator)] + [ + k for k in kwargs.values() if isinstance(k, Operator) + ] + + # This will dequeue all operators passed in as arguments to the qfunc that is + # being conditioned. These are queued incorrectly due to be fully constructed + # before the wrapper function is called. + if recorded_ops and QueuingManager.recording(): + for op in recorded_ops: + QueuingManager.remove(op) + + # 1. Apply true_fn conditionally + qscript = make_qscript(true_fn)(*args, **kwargs) + + if qscript.measurements: + raise ConditionalTransformError(with_meas_err) + + for op in qscript.operations: + Conditional(condition, op) + + if false_fn is not None: + # 2. Apply false_fn conditionally + else_qscript = make_qscript(false_fn)(*args, **kwargs) + + if else_qscript.measurements: + raise ConditionalTransformError(with_meas_err) + + inverted_condition = ~condition + + for op in else_qscript.operations: + Conditional(inverted_condition, op) + + return wrapper From cf46ee108ff52e8cb9f6b8961f0c298578e922f1 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 17 Jul 2024 15:42:52 -0400 Subject: [PATCH 03/34] providing `elifs` causes an error at this stage (I don't know why yet) [ci skip] --- pennylane/ops/op_math/condition.py | 114 +++++++++-------------------- 1 file changed, 34 insertions(+), 80 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index e94632cf939..a59064f3060 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -399,7 +399,8 @@ def qnode(a, x, y, z): return cond_func - # This will not be the final place for this logic, but it is a start) + # This will not be the final place for this logic, but it is a start + # TODO: providing `elifs` raises an error at this stage if qml.capture.enabled(): print("Capture mode for cond") return _capture_cond(condition, true_fn, false_fn, elifs) @@ -462,13 +463,19 @@ def _get_cond_qfunc_prim(): # if capture is enabled, jax should be installed import jax # pylint: disable=import-outside-toplevel - AbstractOperator = qml.capture.AbstractOperator + print("Creating the cond primitive (executed only once)") cond_prim = jax.core.Primitive("cond") cond_prim.multiple_results = True @cond_prim.def_impl - def _(*args, n_elif, jaxpr_true, jaxpr_false, jaxprs_elif, condition): + def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): + + print("We are in the cond primitive definition implementation") + print( + f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}, jaxpr_elifs={jaxpr_elifs}" + ) + def run_jaxpr(jaxpr, *args): return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) @@ -476,106 +483,53 @@ def true_branch(args): return run_jaxpr(jaxpr_true, *args) def false_branch(args): - for cond, jaxpr in jaxprs_elif: - - def elif_branch(args): - return run_jaxpr(jaxpr, *args) - - args = jax.lax.cond(cond, elif_branch, lambda x: x, args) - if jaxpr_false: + if not jaxpr_elifs: return run_jaxpr(jaxpr_false, *args) - return args + else: + pred, elif_jaxpr, rest_jaxpr_elifs = jaxpr_elifs[0] + return jax.lax.cond( + pred, lambda y: run_jaxpr(elif_jaxpr, *y), lambda y: false_branch(y), args + ) return jax.lax.cond(condition, true_branch, false_branch, args) - def _is_queued_outvar(outvars): - if not outvars: - return False - return isinstance(outvars[0].aval, AbstractOperator) and isinstance( - outvars[0], jax.core.DropVar - ) - @cond_prim.def_abstract_eval - def _abstract(*args, **kwargs): - return [qml.capture.AbstractOperator()] + def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): + print("We are in the cond primitive abstract evaluation") + print( + f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}" + ) + out_avals = jaxpr_true.out_avals + return out_avals return cond_prim -# Vogliamo catturare la funzione 'true_fn', e probabilmente passare 'condition' e 'false_fn' come argomenti. -def _capture_cond(condition, true_fn, false_fn, elifs) -> Callable: - """Capture compatible way to apply conditionally a ....""" - # note that this logic is tested in `tests/capture/test_... +def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: + """Capture compatible way to apply conditionals.""" + # TODO: implement tests print("Capture mode for cond") + print(f"condition={condition}, true_fn={true_fn}, false_fn={false_fn}, elifs={elifs}") + import jax # pylint: disable=import-outside-toplevel cond_prim = _get_cond_qfunc_prim() @wraps(true_fn) def new_wrapper(*args, **kwargs): - jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) - jaxpr_false = ( - jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None - ) - jaxprs_elif = [ - (cond_val, jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) - for cond_val, elif_fn in elifs - ] + jaxpr_true = jax.make_jaxpr(true_fn)(*args) + jaxpr_false = jax.make_jaxpr(false_fn)(*args) if false_fn else jaxpr_true + + jaxpr_elifs = [(cond, jax.make_jaxpr(fn)(*args), []) for cond, fn in elifs] + return cond_prim.bind( *args, condition=condition, - n_elif=len(elifs), jaxpr_true=jaxpr_true, jaxpr_false=jaxpr_false, - jaxprs_elif=jaxprs_elif, + jaxpr_elifs=jaxpr_elifs, ) return new_wrapper - - -def _cond(condition, true_fn, false_fn): - - # We assume that the callable is an operation or a quantum function - with_meas_err = ( - "Only quantum functions that contain no measurements can be applied conditionally." - ) - - @wraps(true_fn) - def wrapper(*args, **kwargs): - # We assume that the callable is a quantum function - - recorded_ops = [a for a in args if isinstance(a, Operator)] + [ - k for k in kwargs.values() if isinstance(k, Operator) - ] - - # This will dequeue all operators passed in as arguments to the qfunc that is - # being conditioned. These are queued incorrectly due to be fully constructed - # before the wrapper function is called. - if recorded_ops and QueuingManager.recording(): - for op in recorded_ops: - QueuingManager.remove(op) - - # 1. Apply true_fn conditionally - qscript = make_qscript(true_fn)(*args, **kwargs) - - if qscript.measurements: - raise ConditionalTransformError(with_meas_err) - - for op in qscript.operations: - Conditional(condition, op) - - if false_fn is not None: - # 2. Apply false_fn conditionally - else_qscript = make_qscript(false_fn)(*args, **kwargs) - - if else_qscript.measurements: - raise ConditionalTransformError(with_meas_err) - - inverted_condition = ~condition - - for op in else_qscript.operations: - Conditional(inverted_condition, op) - - return wrapper From 60cdcd278278758aa6a1899ed46c414eab787884 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 17 Jul 2024 15:43:08 -0400 Subject: [PATCH 04/34] not skipping the CI From cafd846ed252aca2c5ac892fbdbd345f13f5d736 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 17 Jul 2024 16:55:14 -0400 Subject: [PATCH 05/34] [ci skip] elifs are still not implemented correctly, but hopefully we are making progresses... [ci skip] --- pennylane/ops/op_math/condition.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index a59064f3060..b415834e343 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -476,6 +476,8 @@ def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}, jaxpr_elifs={jaxpr_elifs}" ) + # I don't think this function is necessary for the requirement of this epic + def run_jaxpr(jaxpr, *args): return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) @@ -517,12 +519,33 @@ def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: cond_prim = _get_cond_qfunc_prim() + def handle_elifs(elifs): + if len(elifs) == 2 and isinstance(elifs[0], bool) and callable(elifs[1]): + # Case 2: single (condition, fn) tuple + return [(elifs[0], elifs[1])] + else: + # Case 1: tuple of (condition, fn) tuples + return list(elifs) + + elifs = handle_elifs(elifs) + + print(f"elifs={elifs}") + @wraps(true_fn) def new_wrapper(*args, **kwargs): jaxpr_true = jax.make_jaxpr(true_fn)(*args) jaxpr_false = jax.make_jaxpr(false_fn)(*args) if false_fn else jaxpr_true - jaxpr_elifs = [(cond, jax.make_jaxpr(fn)(*args), []) for cond, fn in elifs] + # TODO: find a better way to distinguish the 2 cases + if len(elifs) == 2 and callable(elifs[1]): + print("elifs caso singolo") + jaxpr_elifs = [(elifs[0], jax.make_jaxpr(elifs[1])(*args), [])] + + else: + print("elifs caso multiplo") + jaxpr_elifs = [(cond, jax.make_jaxpr(fn)(*args), []) for cond, fn in elifs] + + print(f"jaxpr_elifs={jaxpr_elifs}") return cond_prim.bind( *args, From 52b684f10678862263cfb53ea7e621d92e0eccc1 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 17 Jul 2024 21:10:31 -0400 Subject: [PATCH 06/34] The main problem right now is that all the `elifs` condition after the second one are ignored... --- pennylane/ops/op_math/condition.py | 63 ++++++++++++++++++------------ 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index b415834e343..95d2ac326d3 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -14,8 +14,9 @@ """ Contains the condition transform. """ +import functools from functools import wraps -from typing import Type +from typing import Callable, Optional, Tuple, Type, overload import pennylane as qml from pennylane import QueuingManager @@ -101,19 +102,12 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -from typing import Callable, Optional, overload, Union, Tuple - -from pennylane.measurements import MeasurementValue - -import functools - - @overload def cond( - condition: Union[MeasurementValue, bool], + condition: bool, true_fn: Callable, false_fn: Optional[Callable] = None, - elifs: Tuple[Tuple[bool, Callable], ...] = (), + elifs: Optional[Tuple[Tuple[bool, Callable], ...]] = (), ) -> Callable: ... def cond(condition, true_fn, false_fn=None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations @@ -473,33 +467,43 @@ def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive definition implementation") print( - f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}, jaxpr_elifs={jaxpr_elifs}" + f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}, \njaxpr_elifs={jaxpr_elifs}" ) - # I don't think this function is necessary for the requirement of this epic - def run_jaxpr(jaxpr, *args): return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) def true_branch(args): + print("We are in the true branch") return run_jaxpr(jaxpr_true, *args) + # pylint: disable=unused-variable def false_branch(args): + print("We are in the false branch") if not jaxpr_elifs: return run_jaxpr(jaxpr_false, *args) - else: - pred, elif_jaxpr, rest_jaxpr_elifs = jaxpr_elifs[0] - return jax.lax.cond( - pred, lambda y: run_jaxpr(elif_jaxpr, *y), lambda y: false_branch(y), args - ) - return jax.lax.cond(condition, true_branch, false_branch, args) + def elif_branch(args, jaxpr_elifs): + print("We are in the elif branch") + print(f"jaxpr_elifs={jaxpr_elifs}") + if not jaxpr_elifs: + return run_jaxpr(jaxpr_false, *args) + + pred, jaxpr_elif, rest_jaxpr_elifs = jaxpr_elifs[0] + return jax.lax.cond( + pred, + lambda y: run_jaxpr(jaxpr_elif, *y), + lambda y: elif_branch(y, rest_jaxpr_elifs), + args, + ) + + return jax.lax.cond(condition, true_branch, lambda y: elif_branch(y, jaxpr_elifs), args) @cond_prim.def_abstract_eval def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive abstract evaluation") print( - f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}" + f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}" ) out_avals = jaxpr_true.out_avals return out_avals @@ -521,16 +525,15 @@ def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: def handle_elifs(elifs): if len(elifs) == 2 and isinstance(elifs[0], bool) and callable(elifs[1]): - # Case 2: single (condition, fn) tuple return [(elifs[0], elifs[1])] - else: - # Case 1: tuple of (condition, fn) tuples - return list(elifs) + return list(elifs) elifs = handle_elifs(elifs) print(f"elifs={elifs}") + # pylint: disable=unused-argument + # pylint: disable=unused-variable @wraps(true_fn) def new_wrapper(*args, **kwargs): jaxpr_true = jax.make_jaxpr(true_fn)(*args) @@ -538,13 +541,21 @@ def new_wrapper(*args, **kwargs): # TODO: find a better way to distinguish the 2 cases if len(elifs) == 2 and callable(elifs[1]): - print("elifs caso singolo") + print("elifs single case") + # this is the case where we only have one elif, like: + # elifs=((x == 1, elif_fn)) jaxpr_elifs = [(elifs[0], jax.make_jaxpr(elifs[1])(*args), [])] else: - print("elifs caso multiplo") + print("elifs multiple case") + # this is the case where we have multiple elifs, like: + # elifs=((x == 1, elif_fn), (x == 2, elif_fn2)) jaxpr_elifs = [(cond, jax.make_jaxpr(fn)(*args), []) for cond, fn in elifs] + # Create a nested structure for jaxpr_elifs + for i in range(len(jaxpr_elifs) - 1): + jaxpr_elifs[i] = (jaxpr_elifs[i][0], jaxpr_elifs[i][1], jaxpr_elifs[i + 1 :]) + print(f"jaxpr_elifs={jaxpr_elifs}") return cond_prim.bind( From e6af28a8cc95eee4c2b24c675729007f36df9a56 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 18 Jul 2024 10:52:28 -0400 Subject: [PATCH 07/34] Fixing multiple elifs issue [ci skip] --- pennylane/ops/op_math/condition.py | 53 ++++++++---------------------- 1 file changed, 13 insertions(+), 40 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 95d2ac326d3..d430f91e724 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -467,7 +467,7 @@ def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive definition implementation") print( - f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}, \njaxpr_elifs={jaxpr_elifs}" + f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}, \njaxpr_elifs={jaxpr_elifs}\n" ) def run_jaxpr(jaxpr, *args): @@ -477,19 +477,15 @@ def true_branch(args): print("We are in the true branch") return run_jaxpr(jaxpr_true, *args) - # pylint: disable=unused-variable - def false_branch(args): - print("We are in the false branch") - if not jaxpr_elifs: - return run_jaxpr(jaxpr_false, *args) - def elif_branch(args, jaxpr_elifs): print("We are in the elif branch") print(f"jaxpr_elifs={jaxpr_elifs}") if not jaxpr_elifs: return run_jaxpr(jaxpr_false, *args) - pred, jaxpr_elif, rest_jaxpr_elifs = jaxpr_elifs[0] + pred, jaxpr_elif = jaxpr_elifs[0] + rest_jaxpr_elifs = jaxpr_elifs[1:] + return jax.lax.cond( pred, lambda y: run_jaxpr(jaxpr_elif, *y), @@ -497,14 +493,18 @@ def elif_branch(args, jaxpr_elifs): args, ) - return jax.lax.cond(condition, true_branch, lambda y: elif_branch(y, jaxpr_elifs), args) + def false_branch(args): + print("We are in the false branch") + if not jaxpr_elifs: + return run_jaxpr(jaxpr_false, *args) + return elif_branch(args, jaxpr_elifs) + + return jax.lax.cond(condition, true_branch, false_branch, args) @cond_prim.def_abstract_eval def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive abstract evaluation") - print( - f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}" - ) + out_avals = jaxpr_true.out_avals return out_avals @@ -517,44 +517,17 @@ def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: print("Capture mode for cond") - print(f"condition={condition}, true_fn={true_fn}, false_fn={false_fn}, elifs={elifs}") - import jax # pylint: disable=import-outside-toplevel cond_prim = _get_cond_qfunc_prim() - def handle_elifs(elifs): - if len(elifs) == 2 and isinstance(elifs[0], bool) and callable(elifs[1]): - return [(elifs[0], elifs[1])] - return list(elifs) - - elifs = handle_elifs(elifs) - - print(f"elifs={elifs}") - # pylint: disable=unused-argument # pylint: disable=unused-variable @wraps(true_fn) def new_wrapper(*args, **kwargs): jaxpr_true = jax.make_jaxpr(true_fn)(*args) jaxpr_false = jax.make_jaxpr(false_fn)(*args) if false_fn else jaxpr_true - - # TODO: find a better way to distinguish the 2 cases - if len(elifs) == 2 and callable(elifs[1]): - print("elifs single case") - # this is the case where we only have one elif, like: - # elifs=((x == 1, elif_fn)) - jaxpr_elifs = [(elifs[0], jax.make_jaxpr(elifs[1])(*args), [])] - - else: - print("elifs multiple case") - # this is the case where we have multiple elifs, like: - # elifs=((x == 1, elif_fn), (x == 2, elif_fn2)) - jaxpr_elifs = [(cond, jax.make_jaxpr(fn)(*args), []) for cond, fn in elifs] - - # Create a nested structure for jaxpr_elifs - for i in range(len(jaxpr_elifs) - 1): - jaxpr_elifs[i] = (jaxpr_elifs[i][0], jaxpr_elifs[i][1], jaxpr_elifs[i + 1 :]) + jaxpr_elifs = [(cond, jax.make_jaxpr(elif_fn)(*args)) for cond, elif_fn in elifs] print(f"jaxpr_elifs={jaxpr_elifs}") From 656b4f7ef9aa45cea81da6c476acefce85ca3fdf Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 19 Jul 2024 09:26:46 -0400 Subject: [PATCH 08/34] Removing usage of jax.lax.cond [ci skip] --- pennylane/ops/op_math/condition.py | 39 +++++++++++++++++++++++------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d430f91e724..800ebb0983d 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -462,6 +462,9 @@ def _get_cond_qfunc_prim(): cond_prim = jax.core.Primitive("cond") cond_prim.multiple_results = True + # The AbstractOperator class is defined in the PennyLane capture module. Don't worry about it. + AbstractOperator = qml.capture.AbstractOperator + @cond_prim.def_impl def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): @@ -486,12 +489,10 @@ def elif_branch(args, jaxpr_elifs): pred, jaxpr_elif = jaxpr_elifs[0] rest_jaxpr_elifs = jaxpr_elifs[1:] - return jax.lax.cond( - pred, - lambda y: run_jaxpr(jaxpr_elif, *y), - lambda y: elif_branch(y, rest_jaxpr_elifs), - args, - ) + if pred: + return run_jaxpr(jaxpr_elif, *args) + else: + return elif_branch(args, rest_jaxpr_elifs) def false_branch(args): print("We are in the false branch") @@ -499,14 +500,34 @@ def false_branch(args): return run_jaxpr(jaxpr_false, *args) return elif_branch(args, jaxpr_elifs) - return jax.lax.cond(condition, true_branch, false_branch, args) + if condition: + return true_branch(args) + else: + return false_branch(args) + + def _is_queued_outvar(outvars): + if not outvars: + return False + return isinstance(outvars[0].aval, AbstractOperator) and isinstance( + outvars[0], jax.core.DropVar + ) @cond_prim.def_abstract_eval def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive abstract evaluation") - out_avals = jaxpr_true.out_avals - return out_avals + outvars = [AbstractOperator() for eqn in jaxpr_true.eqns if _is_queued_outvar(eqn.outvars)] + + # operators that are not dropped var because they are returned + outvars += [ + AbstractOperator() + for aval in jaxpr_true.out_avals + if isinstance(aval, AbstractOperator) + ] + return outvars + + # out_avals = jaxpr_true.out_avals + # return out_avals return cond_prim From 0f81364161852e12fee9a333b40f1e1ce5ecf8d3 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Sat, 20 Jul 2024 12:47:19 -0400 Subject: [PATCH 09/34] Undersanding how to handle dynamic tracer [ci skip] --- pennylane/ops/op_math/condition.py | 38 +++++++++++++++++++++++------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 800ebb0983d..bd9b0be630f 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -474,7 +474,22 @@ def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): ) def run_jaxpr(jaxpr, *args): - return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + + print(f"Running jaxpr: {jaxpr}") + print(f"jaxpr.eqns={jaxpr.eqns}") + print(f"jaxpr.out_avals={jaxpr.out_avals}") + + out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + print(f"Jaxpr evaluation result: {out}") + + if not isinstance(out, tuple): + out = (out,) + + for outvar in out: + if isinstance(outvar, Operator): + QueuingManager.append(outvar) + + return out def true_branch(args): print("We are in the true branch") @@ -500,10 +515,14 @@ def false_branch(args): return run_jaxpr(jaxpr_false, *args) return elif_branch(args, jaxpr_elifs) - if condition: - return true_branch(args) + if isinstance(condition, bool): + if condition: + return true_branch(args) + else: + return false_branch(args) else: - return false_branch(args) + # understand what to do with the condition + pass def _is_queued_outvar(outvars): if not outvars: @@ -515,6 +534,9 @@ def _is_queued_outvar(outvars): @cond_prim.def_abstract_eval def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive abstract evaluation") + print( + f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}, jaxpr_elifs={jaxpr_elifs}" + ) outvars = [AbstractOperator() for eqn in jaxpr_true.eqns if _is_queued_outvar(eqn.outvars)] @@ -526,9 +548,6 @@ def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): ] return outvars - # out_avals = jaxpr_true.out_avals - # return out_avals - return cond_prim @@ -536,8 +555,6 @@ def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: """Capture compatible way to apply conditionals.""" # TODO: implement tests - print("Capture mode for cond") - import jax # pylint: disable=import-outside-toplevel cond_prim = _get_cond_qfunc_prim() @@ -546,10 +563,13 @@ def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: # pylint: disable=unused-variable @wraps(true_fn) def new_wrapper(*args, **kwargs): + jaxpr_true = jax.make_jaxpr(true_fn)(*args) jaxpr_false = jax.make_jaxpr(false_fn)(*args) if false_fn else jaxpr_true jaxpr_elifs = [(cond, jax.make_jaxpr(elif_fn)(*args)) for cond, elif_fn in elifs] + print(f"jaxpr_true={jaxpr_true}") + print(f"jaxpr_false={jaxpr_false}") print(f"jaxpr_elifs={jaxpr_elifs}") return cond_prim.bind( From 399c8be28c6bf7126b92cb9f475181b20f286508 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 22 Jul 2024 16:23:44 -0400 Subject: [PATCH 10/34] Solved dynamic inconsistent behavior [ci skip] --- pennylane/ops/op_math/condition.py | 86 +++++++++++++++++------------- 1 file changed, 50 insertions(+), 36 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index bd9b0be630f..d365b46dfb2 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -466,21 +466,23 @@ def _get_cond_qfunc_prim(): AbstractOperator = qml.capture.AbstractOperator @cond_prim.def_impl - def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): + def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive definition implementation") - print( - f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}, \njaxpr_elifs={jaxpr_elifs}\n" - ) + # print( + # f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}, \njaxpr_elifs={jaxpr_elifs}\n" + # ) + + # print(f"elifs_conditions={elifs_conditions}") def run_jaxpr(jaxpr, *args): - print(f"Running jaxpr: {jaxpr}") - print(f"jaxpr.eqns={jaxpr.eqns}") - print(f"jaxpr.out_avals={jaxpr.out_avals}") + # print(f"Running jaxpr: {jaxpr}") + # print(f"jaxpr.eqns={jaxpr.eqns}") + # print(f"jaxpr.out_avals={jaxpr.out_avals}") out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - print(f"Jaxpr evaluation result: {out}") + # print(f"Jaxpr evaluation result: {out}") if not isinstance(out, tuple): out = (out,) @@ -495,34 +497,34 @@ def true_branch(args): print("We are in the true branch") return run_jaxpr(jaxpr_true, *args) - def elif_branch(args, jaxpr_elifs): + def elif_branch(args, elifs_conditions, jaxpr_elifs): print("We are in the elif branch") print(f"jaxpr_elifs={jaxpr_elifs}") if not jaxpr_elifs: return run_jaxpr(jaxpr_false, *args) - pred, jaxpr_elif = jaxpr_elifs[0] + pred = elifs_conditions[0] + rest_preds = elifs_conditions[1:] + + jaxpr_elif = jaxpr_elifs[0] rest_jaxpr_elifs = jaxpr_elifs[1:] if pred: return run_jaxpr(jaxpr_elif, *args) else: - return elif_branch(args, rest_jaxpr_elifs) + return elif_branch(args, rest_preds, rest_jaxpr_elifs) def false_branch(args): print("We are in the false branch") if not jaxpr_elifs: return run_jaxpr(jaxpr_false, *args) - return elif_branch(args, jaxpr_elifs) + return elif_branch(args, elifs_conditions, jaxpr_elifs) - if isinstance(condition, bool): - if condition: - return true_branch(args) - else: - return false_branch(args) + if condition: + return true_branch(args) else: - # understand what to do with the condition - pass + # if elifs_conditions + return false_branch(args) def _is_queued_outvar(outvars): if not outvars: @@ -532,26 +534,24 @@ def _is_queued_outvar(outvars): ) @cond_prim.def_abstract_eval - def _(*args, condition, jaxpr_true, jaxpr_false, jaxpr_elifs): + def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): print("We are in the cond primitive abstract evaluation") print( f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}, jaxpr_elifs={jaxpr_elifs}" ) - outvars = [AbstractOperator() for eqn in jaxpr_true.eqns if _is_queued_outvar(eqn.outvars)] + def collect_outvars(jaxpr): + return [AbstractOperator() for eqn in jaxpr.eqns if _is_queued_outvar(eqn.outvars)] + [ + AbstractOperator() for aval in jaxpr.out_avals if isinstance(aval, AbstractOperator) + ] - # operators that are not dropped var because they are returned - outvars += [ - AbstractOperator() - for aval in jaxpr_true.out_avals - if isinstance(aval, AbstractOperator) - ] - return outvars + outvars_true = collect_outvars(jaxpr_true) + return outvars_true return cond_prim -def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: +def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: """Capture compatible way to apply conditionals.""" # TODO: implement tests @@ -564,17 +564,31 @@ def _capture_cond(condition, true_fn, false_fn, elifs=()) -> Callable: @wraps(true_fn) def new_wrapper(*args, **kwargs): - jaxpr_true = jax.make_jaxpr(true_fn)(*args) - jaxpr_false = jax.make_jaxpr(false_fn)(*args) if false_fn else jaxpr_true - jaxpr_elifs = [(cond, jax.make_jaxpr(elif_fn)(*args)) for cond, elif_fn in elifs] + # Each predicate in the elifs argument is traced by JAX + elifs_conditions = ( + jax.numpy.array([cond for cond, _ in elifs]) if elifs else jax.numpy.empty(0) + ) + + jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) + jaxpr_false = ( + ((jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None)) + if false_fn + else None + ) + jaxpr_elifs = ( + ([jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args) for _, elif_fn in elifs]) + if elifs + else () + ) - print(f"jaxpr_true={jaxpr_true}") - print(f"jaxpr_false={jaxpr_false}") - print(f"jaxpr_elifs={jaxpr_elifs}") + # print(f"jaxpr_true={jaxpr_true}") + # print(f"jaxpr_false={jaxpr_false}") + # print(f"jaxpr_elifs={jaxpr_elifs}") return cond_prim.bind( + condition, + elifs_conditions, *args, - condition=condition, jaxpr_true=jaxpr_true, jaxpr_false=jaxpr_false, jaxpr_elifs=jaxpr_elifs, From 1d95f90661422ed3ac537087fba93dbc1e67e490 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 22 Jul 2024 18:25:40 -0400 Subject: [PATCH 11/34] Improving code style and removing debug msgs [ci skip] --- pennylane/ops/op_math/condition.py | 48 ++++++++---------------------- 1 file changed, 12 insertions(+), 36 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d365b46dfb2..0783d272e63 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -457,8 +457,6 @@ def _get_cond_qfunc_prim(): # if capture is enabled, jax should be installed import jax # pylint: disable=import-outside-toplevel - print("Creating the cond primitive (executed only once)") - cond_prim = jax.core.Primitive("cond") cond_prim.multiple_results = True @@ -468,21 +466,9 @@ def _get_cond_qfunc_prim(): @cond_prim.def_impl def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): - print("We are in the cond primitive definition implementation") - # print( - # f"args={args}, \ncondition={condition}, \njaxpr_true={jaxpr_true}, \njaxpr_false={jaxpr_false}, \njaxpr_elifs={jaxpr_elifs}\n" - # ) - - # print(f"elifs_conditions={elifs_conditions}") - def run_jaxpr(jaxpr, *args): - # print(f"Running jaxpr: {jaxpr}") - # print(f"jaxpr.eqns={jaxpr.eqns}") - # print(f"jaxpr.out_avals={jaxpr.out_avals}") - out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - # print(f"Jaxpr evaluation result: {out}") if not isinstance(out, tuple): out = (out,) @@ -494,14 +480,12 @@ def run_jaxpr(jaxpr, *args): return out def true_branch(args): - print("We are in the true branch") return run_jaxpr(jaxpr_true, *args) def elif_branch(args, elifs_conditions, jaxpr_elifs): - print("We are in the elif branch") - print(f"jaxpr_elifs={jaxpr_elifs}") + if not jaxpr_elifs: - return run_jaxpr(jaxpr_false, *args) + return false_branch(args) pred = elifs_conditions[0] rest_preds = elifs_conditions[1:] @@ -511,20 +495,19 @@ def elif_branch(args, elifs_conditions, jaxpr_elifs): if pred: return run_jaxpr(jaxpr_elif, *args) - else: - return elif_branch(args, rest_preds, rest_jaxpr_elifs) + + return elif_branch(args, rest_preds, rest_jaxpr_elifs) def false_branch(args): - print("We are in the false branch") - if not jaxpr_elifs: + if jaxpr_false is not None: return run_jaxpr(jaxpr_false, *args) - return elif_branch(args, elifs_conditions, jaxpr_elifs) + return () if condition: return true_branch(args) - else: - # if elifs_conditions - return false_branch(args) + if elifs_conditions.size > 0: + return elif_branch(args, elifs_conditions, jaxpr_elifs) + return false_branch(args) def _is_queued_outvar(outvars): if not outvars: @@ -534,11 +517,7 @@ def _is_queued_outvar(outvars): ) @cond_prim.def_abstract_eval - def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): - print("We are in the cond primitive abstract evaluation") - print( - f"args={args}, condition={condition}, jaxpr_true={jaxpr_true}, jaxpr_false={jaxpr_false}, jaxpr_elifs={jaxpr_elifs}" - ) + def _(*_, jaxpr_true, **__): def collect_outvars(jaxpr): return [AbstractOperator() for eqn in jaxpr.eqns if _is_queued_outvar(eqn.outvars)] + [ @@ -564,7 +543,8 @@ def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: @wraps(true_fn) def new_wrapper(*args, **kwargs): - # Each predicate in the elifs argument is traced by JAX + # We extract each predicate from the elifs list + # since these are traced by JAX and should be passed as positional arguments elifs_conditions = ( jax.numpy.array([cond for cond, _ in elifs]) if elifs else jax.numpy.empty(0) ) @@ -581,10 +561,6 @@ def new_wrapper(*args, **kwargs): else () ) - # print(f"jaxpr_true={jaxpr_true}") - # print(f"jaxpr_false={jaxpr_false}") - # print(f"jaxpr_elifs={jaxpr_elifs}") - return cond_prim.bind( condition, elifs_conditions, From f1de555bf062ea5889db66c757a51c92fd9b97e9 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 22 Jul 2024 18:28:48 -0400 Subject: [PATCH 12/34] Improving code style and removing debug msgs [ci skip] --- pennylane/ops/op_math/condition.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 0783d272e63..35a09dce7be 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -16,7 +16,7 @@ """ import functools from functools import wraps -from typing import Callable, Optional, Tuple, Type, overload +from typing import Callable, Type import pennylane as qml from pennylane import QueuingManager @@ -102,13 +102,6 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -@overload -def cond( - condition: bool, - true_fn: Callable, - false_fn: Optional[Callable] = None, - elifs: Optional[Tuple[Tuple[bool, Callable], ...]] = (), -) -> Callable: ... def cond(condition, true_fn, false_fn=None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -374,10 +367,6 @@ def qnode(a, x, y, z): tensor(-0.30922805, requires_grad=True) """ - print( - f"cond function called with condition={condition}, true_fn={true_fn}, false_fn={false_fn}, elifs={elifs}" - ) - if active_jit := compiler.active_compiler(): available_eps = compiler.AvailableCompilers.names_entrypoints ops_loader = available_eps[active_jit]["ops"].load() @@ -393,8 +382,6 @@ def qnode(a, x, y, z): return cond_func - # This will not be the final place for this logic, but it is a start - # TODO: providing `elifs` raises an error at this stage if qml.capture.enabled(): print("Capture mode for cond") return _capture_cond(condition, true_fn, false_fn, elifs) @@ -460,7 +447,6 @@ def _get_cond_qfunc_prim(): cond_prim = jax.core.Primitive("cond") cond_prim.multiple_results = True - # The AbstractOperator class is defined in the PennyLane capture module. Don't worry about it. AbstractOperator = qml.capture.AbstractOperator @cond_prim.def_impl From d9b8422c296592c00ba5e36c9b4da1aa6920681d Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 22 Jul 2024 19:10:12 -0400 Subject: [PATCH 13/34] TODO: add tests and check for more than one operator in the queue [ci skip] --- pennylane/ops/op_math/condition.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 35a09dce7be..41f11524553 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -491,8 +491,10 @@ def false_branch(args): if condition: return true_branch(args) + if elifs_conditions.size > 0: return elif_branch(args, elifs_conditions, jaxpr_elifs) + return false_branch(args) def _is_queued_outvar(outvars): @@ -518,26 +520,26 @@ def collect_outvars(jaxpr): def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: """Capture compatible way to apply conditionals.""" - # TODO: implement tests import jax # pylint: disable=import-outside-toplevel cond_prim = _get_cond_qfunc_prim() - # pylint: disable=unused-argument - # pylint: disable=unused-variable + elifs = (elifs,) if len(elifs) > 0 and not isinstance(elifs[0], tuple) else elifs + @wraps(true_fn) def new_wrapper(*args, **kwargs): # We extract each predicate from the elifs list - # since these are traced by JAX and should be passed as positional arguments + # since these are traced by JAX and are passed as positional arguments + elifs_conditions = ( jax.numpy.array([cond for cond, _ in elifs]) if elifs else jax.numpy.empty(0) ) jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) jaxpr_false = ( - ((jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None)) + (jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None) if false_fn else None ) From 91c3e0b1e5c9e8477e02a92b559296af0074b785 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 23 Jul 2024 11:03:05 -0400 Subject: [PATCH 14/34] Adding tests for logic of `qml.cond` with captured enabled --- pennylane/ops/op_math/condition.py | 1 - tests/capture/test_conditionals.py | 180 +++++++++++++++++++++++++++++ 2 files changed, 180 insertions(+), 1 deletion(-) create mode 100644 tests/capture/test_conditionals.py diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 41f11524553..51ea974f548 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -532,7 +532,6 @@ def new_wrapper(*args, **kwargs): # We extract each predicate from the elifs list # since these are traced by JAX and are passed as positional arguments - elifs_conditions = ( jax.numpy.array([cond for cond, _ in elifs]) if elifs else jax.numpy.empty(0) ) diff --git a/tests/capture/test_conditionals.py b/tests/capture/test_conditionals.py new file mode 100644 index 00000000000..84cd3712ca6 --- /dev/null +++ b/tests/capture/test_conditionals.py @@ -0,0 +1,180 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for capturing conditionals into jaxpr. +""" +import numpy as np + +# pylint: disable=protected-access +import pytest + +import pennylane as qml + +pytestmark = pytest.mark.jax + +jax = pytest.importorskip("jax") + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + """Enable and disable the PennyLane JAX capture context manager.""" + qml.capture.enable() + yield + qml.capture.disable() + + +def cond_true_elifs_false(selector, arg): + """A function with conditional containing true, elifs, and false branches.""" + + def true_fn(arg): + return 2 * arg + + def elif_fn1(arg): + return arg - 1 + + def elif_fn2(arg): + return arg - 2 + + def elif_fn3(arg): + return arg - 3 + + def elif_fn4(arg): + return arg - 4 + + def false_fn(arg): + return 3 * arg + + return qml.cond( + selector > 0, + true_fn, + false_fn, + elifs=( + (selector == -1, elif_fn1), + (selector == -2, elif_fn2), + (selector == -3, elif_fn3), + (selector == -4, elif_fn4), + ), + )(arg) + + +def cond_true_elifs(selector, arg): + """A function with conditional containing true and elifs branches.""" + + def true_fn(arg): + return 2 * arg + + def elif_fn1(arg): + return arg - 1 + + def elif_fn2(arg): + return arg - 2 + + return qml.cond( + selector > 0, + true_fn, + elifs=( + (selector == -1, elif_fn1), + (selector == -2, elif_fn2), + ), + )(arg) + + +def cond_true_false(selector, arg): + """A function with conditional containing true and false branches.""" + + def true_fn(arg): + return 2 * arg + + def false_fn(arg): + return 3 * arg + + return qml.cond( + selector > 0, + true_fn, + false_fn, + )(arg) + + +def cond_true(selector, arg): + """A function with conditional containing only the true branch.""" + + def true_fn(arg): + return 2 * arg + + return qml.cond( + selector > 0, + true_fn, + )(arg) + +# pylint: disable=no-self-use +class TestCond: + """Tests for capturing conditional statements.""" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, 7), # Elif condition 3 + (-4, 10, 6), # Elif condition 4 + (0, 10, 30), # False condition + ], + ) + def test_cond_true_elifs_false(self, selector, arg, expected): + """Test the conditional with true, elifs, and false branches.""" + + result = cond_true_elifs_false(selector, arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, ()), # No condition met + ], + ) + def test_cond_true_elifs(self, selector, arg, expected): + """Test the conditional with true and elifs branches.""" + + result = cond_true_elifs(selector, arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (0, 10, 30), # False condition + ], + ) + def test_cond_true_false(self, selector, arg, expected): + """Test the conditional with true and false branches.""" + + result = cond_true_false(selector, arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (0, 10, ()), # No condition met + ], + ) + def test_cond_true(self, selector, arg, expected): + """Test the conditional with only the true branch.""" + + result = cond_true(selector, arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" From 9317509db0cfc237ddde2ed02a2a3e9f7dcca2ea Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 23 Jul 2024 11:16:17 -0400 Subject: [PATCH 15/34] Re-naming file --- .../{test_conditionals.py => test_capture_conditionals.py} | 1 + 1 file changed, 1 insertion(+) rename tests/capture/{test_conditionals.py => test_capture_conditionals.py} (99%) diff --git a/tests/capture/test_conditionals.py b/tests/capture/test_capture_conditionals.py similarity index 99% rename from tests/capture/test_conditionals.py rename to tests/capture/test_capture_conditionals.py index 84cd3712ca6..de73719a0bd 100644 --- a/tests/capture/test_conditionals.py +++ b/tests/capture/test_capture_conditionals.py @@ -117,6 +117,7 @@ def true_fn(arg): true_fn, )(arg) + # pylint: disable=no-self-use class TestCond: """Tests for capturing conditional statements.""" From 2a7bbb3e269bab40652fa65af6cd50cb7e77a7fa Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 23 Jul 2024 12:31:33 -0400 Subject: [PATCH 16/34] removing useless for loop [ci skip] --- pennylane/ops/op_math/condition.py | 30 ++++++++++++------------------ 1 file changed, 12 insertions(+), 18 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 51ea974f548..0672e85b4d6 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -455,9 +455,8 @@ def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): def run_jaxpr(jaxpr, *args): out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - - if not isinstance(out, tuple): - out = (out,) + print("out", out) + out = (out,) if not isinstance(out, tuple) else out for outvar in out: if isinstance(outvar, Operator): @@ -469,19 +468,14 @@ def true_branch(args): return run_jaxpr(jaxpr_true, *args) def elif_branch(args, elifs_conditions, jaxpr_elifs): - if not jaxpr_elifs: return false_branch(args) - pred = elifs_conditions[0] rest_preds = elifs_conditions[1:] - jaxpr_elif = jaxpr_elifs[0] rest_jaxpr_elifs = jaxpr_elifs[1:] - if pred: return run_jaxpr(jaxpr_elif, *args) - return elif_branch(args, rest_preds, rest_jaxpr_elifs) def false_branch(args): @@ -491,10 +485,8 @@ def false_branch(args): if condition: return true_branch(args) - if elifs_conditions.size > 0: return elif_branch(args, elifs_conditions, jaxpr_elifs) - return false_branch(args) def _is_queued_outvar(outvars): @@ -530,10 +522,17 @@ def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: @wraps(true_fn) def new_wrapper(*args, **kwargs): - # We extract each predicate from the elifs list - # since these are traced by JAX and are passed as positional arguments + # We extract each condition (or predicate) from the elifs argument list + # since these are traced by JAX and are passed as positional arguments to the cond primitive + elifs_conditions = [] + jaxpr_elifs = [] + + for cond, elif_fn in elifs: + elifs_conditions.append(cond) + jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) + elifs_conditions = ( - jax.numpy.array([cond for cond, _ in elifs]) if elifs else jax.numpy.empty(0) + jax.numpy.array(elifs_conditions) if elifs_conditions else jax.numpy.empty(0) ) jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) @@ -542,11 +541,6 @@ def new_wrapper(*args, **kwargs): if false_fn else None ) - jaxpr_elifs = ( - ([jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args) for _, elif_fn in elifs]) - if elifs - else () - ) return cond_prim.bind( condition, From f95a18e9bae827db68f8b434dd618ae404adb266 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 23 Jul 2024 12:31:51 -0400 Subject: [PATCH 17/34] removing useless for loop [ci skip] --- pennylane/ops/op_math/condition.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 0672e85b4d6..08d022edc84 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -455,7 +455,6 @@ def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): def run_jaxpr(jaxpr, *args): out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - print("out", out) out = (out,) if not isinstance(out, tuple) else out for outvar in out: From fd763615d00f754fab815f6ac7e6ff4f98046ed8 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 23 Jul 2024 12:36:58 -0400 Subject: [PATCH 18/34] changed variable name [ci skip] --- pennylane/ops/op_math/condition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 08d022edc84..6c4675afa8a 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -526,8 +526,8 @@ def new_wrapper(*args, **kwargs): elifs_conditions = [] jaxpr_elifs = [] - for cond, elif_fn in elifs: - elifs_conditions.append(cond) + for pred, elif_fn in elifs: + elifs_conditions.append(pred) jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) elifs_conditions = ( From f9ad42a52cef0f4bf4d8a99e2c563290a02f86df Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 23 Jul 2024 17:42:00 -0400 Subject: [PATCH 19/34] implemented abstract definition [ci skip] --- pennylane/ops/op_math/condition.py | 73 +++++++++++++++++++----------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 6c4675afa8a..921acc9e804 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -447,15 +447,12 @@ def _get_cond_qfunc_prim(): cond_prim = jax.core.Primitive("cond") cond_prim.multiple_results = True - AbstractOperator = qml.capture.AbstractOperator - @cond_prim.def_impl def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): def run_jaxpr(jaxpr, *args): out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - out = (out,) if not isinstance(out, tuple) else out for outvar in out: if isinstance(outvar, Operator): @@ -488,23 +485,47 @@ def false_branch(args): return elif_branch(args, elifs_conditions, jaxpr_elifs) return false_branch(args) - def _is_queued_outvar(outvars): - if not outvars: - return False - return isinstance(outvars[0].aval, AbstractOperator) and isinstance( - outvars[0], jax.core.DropVar - ) - @cond_prim.def_abstract_eval - def _(*_, jaxpr_true, **__): - - def collect_outvars(jaxpr): - return [AbstractOperator() for eqn in jaxpr.eqns if _is_queued_outvar(eqn.outvars)] + [ - AbstractOperator() for aval in jaxpr.out_avals if isinstance(aval, AbstractOperator) - ] - - outvars_true = collect_outvars(jaxpr_true) - return outvars_true + def _(*_, jaxpr_true, jaxpr_false, jaxpr_elifs): + + # We check that the return values in each branch (true, and possibly false and elifs) + # have the same abstract values (length, type, and value). + # The error messages are detailed to help debugging + def validate_abstract_values( + outvals: list, expected_outvals: list, branch_type: str, index: int = None + ) -> None: + """Ensure the collected abstract values match the expected ones.""" + + assert len(outvals) == len(expected_outvals), ( + f"Mismatch in number of output variables in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)}: " + f"{len(outvals)} vs {len(expected_outvals)}" + ) + for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): + assert isinstance(outval, type(expected_outval)), ( + f"Mismatch in output variable types in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)} at position {i}: " + f"{type(outval)} vs {type(expected_outval)}" + ) + assert outval == expected_outval, ( + f"Mismatch in output variable values in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)} at position {i}: " + f"{outval} vs {expected_outval}" + ) + + outvals_true = jaxpr_true.out_avals + outvals_false = jaxpr_false.out_avals if jaxpr_false is not None else [] + + for idx, jaxpr_elif in enumerate(jaxpr_elifs): + outvals_elif = jaxpr_elif.out_avals + validate_abstract_values(outvals_elif, outvals_true, "elif", idx) + + if outvals_false: + validate_abstract_values(outvals_false, outvals_true, "false") + + # We return the abstract values of the true branch since the abstract values + # of the false and elif branches (if they exist) should be the same + return outvals_true return cond_prim @@ -521,6 +542,13 @@ def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: @wraps(true_fn) def new_wrapper(*args, **kwargs): + jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) + jaxpr_false = ( + (jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None) + if false_fn + else None + ) + # We extract each condition (or predicate) from the elifs argument list # since these are traced by JAX and are passed as positional arguments to the cond primitive elifs_conditions = [] @@ -534,13 +562,6 @@ def new_wrapper(*args, **kwargs): jax.numpy.array(elifs_conditions) if elifs_conditions else jax.numpy.empty(0) ) - jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) - jaxpr_false = ( - (jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None) - if false_fn - else None - ) - return cond_prim.bind( condition, elifs_conditions, From e8fef5d758b1af1dc7bc2464a5ca0e05f6080048 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 24 Jul 2024 10:06:32 -0400 Subject: [PATCH 20/34] Adding tests to catch errors --- pennylane/ops/op_math/condition.py | 9 +- tests/capture/test_capture_conditionals.py | 235 +++++++++++++-------- 2 files changed, 148 insertions(+), 96 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 921acc9e804..70215f364bf 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -454,6 +454,8 @@ def run_jaxpr(jaxpr, *args): out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + # If the branch returns an Operator, we append it to the QueuingManager + # so that it is applied to the circuit for outvar in out: if isinstance(outvar, Operator): QueuingManager.append(outvar) @@ -502,13 +504,8 @@ def validate_abstract_values( f"{len(outvals)} vs {len(expected_outvals)}" ) for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): - assert isinstance(outval, type(expected_outval)), ( - f"Mismatch in output variable types in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)} at position {i}: " - f"{type(outval)} vs {type(expected_outval)}" - ) assert outval == expected_outval, ( - f"Mismatch in output variable values in {branch_type} branch" + f"Mismatch in output abstract values in {branch_type} branch" f"{'' if index is None else ' #' + str(index)} at position {i}: " f"{outval} vs {expected_outval}" ) diff --git a/tests/capture/test_capture_conditionals.py b/tests/capture/test_capture_conditionals.py index de73719a0bd..39e9f7c4075 100644 --- a/tests/capture/test_capture_conditionals.py +++ b/tests/capture/test_capture_conditionals.py @@ -14,12 +14,15 @@ """ Tests for capturing conditionals into jaxpr. """ -import numpy as np -# pylint: disable=protected-access +# pylint: disable=redefined-outer-name + +import jax.numpy as jnp +import numpy as np import pytest import pennylane as qml +from pennylane.ops.op_math.condition import _capture_cond pytestmark = pytest.mark.jax @@ -34,8 +37,9 @@ def enable_disable_plxpr(): qml.capture.disable() -def cond_true_elifs_false(selector, arg): - """A function with conditional containing true, elifs, and false branches.""" +@pytest.fixture +def testing_functions(): + """Returns a set of functions for testing.""" def true_fn(arg): return 2 * arg @@ -55,7 +59,26 @@ def elif_fn4(arg): def false_fn(arg): return 3 * arg - return qml.cond( + return true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 + + +@pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, 7), # Elif condition 3 + (-4, 10, 6), # Elif condition 4 + (0, 10, 30), # False condition + ], +) +def test_cond_true_elifs_false(testing_functions, selector, arg, expected): + """Test the conditional with true, elifs, and false branches.""" + + true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 = testing_functions + + result = qml.cond( selector > 0, true_fn, false_fn, @@ -66,21 +89,24 @@ def false_fn(arg): (selector == -4, elif_fn4), ), )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" -def cond_true_elifs(selector, arg): - """A function with conditional containing true and elifs branches.""" - - def true_fn(arg): - return 2 * arg - - def elif_fn1(arg): - return arg - 1 +@pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, ()), # No condition met + ], +) +def test_cond_true_elifs(testing_functions, selector, arg, expected): + """Test the conditional with true and elifs branches.""" - def elif_fn2(arg): - return arg - 2 + true_fn, _, elif_fn1, elif_fn2, _, _ = testing_functions - return qml.cond( + result = qml.cond( selector > 0, true_fn, elifs=( @@ -88,94 +114,123 @@ def elif_fn2(arg): (selector == -2, elif_fn2), ), )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" -def cond_true_false(selector, arg): - """A function with conditional containing true and false branches.""" +@pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (0, 10, 30), # False condition + ], +) +def test_cond_true_false(testing_functions, selector, arg, expected): + """Test the conditional with true and false branches.""" - def true_fn(arg): - return 2 * arg + true_fn, false_fn, _, _, _, _ = testing_functions - def false_fn(arg): - return 3 * arg - - return qml.cond( + result = qml.cond( selector > 0, true_fn, false_fn, )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" -def cond_true(selector, arg): - """A function with conditional containing only the true branch.""" +@pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (0, 10, ()), # No condition met + ], +) +def test_cond_true(testing_functions, selector, arg, expected): + """Test the conditional with only the true branch.""" - def true_fn(arg): - return 2 * arg + true_fn, _, _, _, _, _ = testing_functions - return qml.cond( + result = qml.cond( selector > 0, true_fn, )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" -# pylint: disable=no-self-use -class TestCond: - """Tests for capturing conditional statements.""" - - @pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (-1, 10, 9), # Elif condition 1 - (-2, 10, 8), # Elif condition 2 - (-3, 10, 7), # Elif condition 3 - (-4, 10, 6), # Elif condition 4 - (0, 10, 30), # False condition - ], - ) - def test_cond_true_elifs_false(self, selector, arg, expected): - """Test the conditional with true, elifs, and false branches.""" - - result = cond_true_elifs_false(selector, arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - @pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (-1, 10, 9), # Elif condition 1 - (-2, 10, 8), # Elif condition 2 - (-3, 10, ()), # No condition met - ], - ) - def test_cond_true_elifs(self, selector, arg, expected): - """Test the conditional with true and elifs branches.""" - - result = cond_true_elifs(selector, arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - @pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (0, 10, 30), # False condition - ], - ) - def test_cond_true_false(self, selector, arg, expected): - """Test the conditional with true and false branches.""" - - result = cond_true_false(selector, arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - @pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (0, 10, ()), # No condition met - ], - ) - def test_cond_true(self, selector, arg, expected): - """Test the conditional with only the true branch.""" - - result = cond_true(selector, arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" +def test_validate_number_of_output_variables(): + """Test mismatch in number of output variables.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1 + + with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1)) + + +def test_validate_output_variable_types(): + """Test mismatch in output variable types.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1, x + 2.0 + + with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1)) + + +def test_validate_elif_branches(): + """Test elif branch mismatches.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1, x + 2 + + def elif_fn1(x): + return x + 1, x + 2 + + def elif_fn2(x): + return x + 1, x + 2.0 # Type mismatch + + def elif_fn3(x): + return x + 1 # Length mismatch + + with pytest.raises( + AssertionError, match=r"Mismatch in output abstract values in elif branch #1" + ): + jax.make_jaxpr( + _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) + )(jnp.array(1)) + + with pytest.raises( + AssertionError, match=r"Mismatch in number of output variables in elif branch #0" + ): + jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))(jnp.array(1)) + + +@pytest.mark.parametrize( + "true_fn, false_fn, expected_error, match", + [ + ( + lambda x: (x + 1, x + 2), + lambda x: (x + 1), + AssertionError, + r"Mismatch in number of output variables", + ), + ( + lambda x: (x + 1, x + 2), + lambda x: (x + 1, x + 2.0), + AssertionError, + r"Mismatch in output abstract values", + ), + ], +) +def test_validate_mismatches(true_fn, false_fn, expected_error, match): + """Test mismatch in number and type of output variables.""" + with pytest.raises(expected_error, match=match): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1)) From 4e9fe78e904bf1d920138ed0062ce0f4bd5a46c6 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 24 Jul 2024 10:13:50 -0400 Subject: [PATCH 21/34] Removed import above skipping --- pennylane/ops/op_math/condition.py | 2 +- tests/capture/test_capture_conditionals.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 70215f364bf..7413dbcb3a7 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -491,7 +491,7 @@ def false_branch(args): def _(*_, jaxpr_true, jaxpr_false, jaxpr_elifs): # We check that the return values in each branch (true, and possibly false and elifs) - # have the same abstract values (length, type, and value). + # have the same abstract values. # The error messages are detailed to help debugging def validate_abstract_values( outvals: list, expected_outvals: list, branch_type: str, index: int = None diff --git a/tests/capture/test_capture_conditionals.py b/tests/capture/test_capture_conditionals.py index 39e9f7c4075..07ded688505 100644 --- a/tests/capture/test_capture_conditionals.py +++ b/tests/capture/test_capture_conditionals.py @@ -17,7 +17,6 @@ # pylint: disable=redefined-outer-name -import jax.numpy as jnp import numpy as np import pytest @@ -166,7 +165,7 @@ def false_fn(x): return x + 1 with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1)) + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_output_variable_types(): @@ -179,7 +178,7 @@ def false_fn(x): return x + 1, x + 2.0 with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1)) + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_elif_branches(): @@ -205,12 +204,14 @@ def elif_fn3(x): ): jax.make_jaxpr( _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) - )(jnp.array(1)) + )(jax.numpy.array(1)) with pytest.raises( AssertionError, match=r"Mismatch in number of output variables in elif branch #0" ): - jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))(jnp.array(1)) + jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( + jax.numpy.array(1) + ) @pytest.mark.parametrize( @@ -233,4 +234,4 @@ def elif_fn3(x): def test_validate_mismatches(true_fn, false_fn, expected_error, match): """Test mismatch in number and type of output variables.""" with pytest.raises(expected_error, match=match): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jnp.array(1)) + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) From 6a569728e93abc5cea7d14704cd8084210f05286 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 24 Jul 2024 12:14:07 -0400 Subject: [PATCH 22/34] Adding a few more tests [ci skip] --- pennylane/ops/op_math/condition.py | 4 +- tests/capture/test_capture_conditionals.py | 91 +++++++++++++++++++++- 2 files changed, 91 insertions(+), 4 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 7413dbcb3a7..e0ec5441f36 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -454,8 +454,8 @@ def run_jaxpr(jaxpr, *args): out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - # If the branch returns an Operator, we append it to the QueuingManager - # so that it is applied to the circuit + # If the branch returns one or more Operators, we append them to the QueuingManager + # so that they are applied to the quantum circuit for outvar in out: if isinstance(outvar, Operator): QueuingManager.append(outvar) diff --git a/tests/capture/test_capture_conditionals.py b/tests/capture/test_capture_conditionals.py index 07ded688505..cad8e031ea9 100644 --- a/tests/capture/test_capture_conditionals.py +++ b/tests/capture/test_capture_conditionals.py @@ -194,10 +194,10 @@ def elif_fn1(x): return x + 1, x + 2 def elif_fn2(x): - return x + 1, x + 2.0 # Type mismatch + return x + 1, x + 2.0 def elif_fn3(x): - return x + 1 # Length mismatch + return x + 1 with pytest.raises( AssertionError, match=r"Mismatch in output abstract values in elif branch #1" @@ -235,3 +235,90 @@ def test_validate_mismatches(true_fn, false_fn, expected_error, match): """Test mismatch in number and type of output variables.""" with pytest.raises(expected_error, match=match): jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + +dev = qml.device("default.qubit", wires=3) + + +@qml.qnode(dev) +def circuit(pred, arg1, arg2): + + qml.RX(0.10, wires=0) + + def true_fn(arg1, arg2): + qml.RX(arg1, wires=0) + qml.RX(arg2, wires=0) + qml.RX(arg1, wires=0) + + def false_fn(arg1, arg2): + qml.RX(arg2, wires=0) + + def elif_fn1(arg1, arg2): + qml.RX(arg1, wires=0) + + qml.cond( + pred > 0, + true_fn, + false_fn, + elifs=((pred == -1, elif_fn1)), + )(arg1, arg2) + + qml.RX(0.10, wires=0) + + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def reference_circuit_true(arg1, arg2): + qml.RX(0.10, wires=0) + qml.RX(arg1, wires=0) + qml.RX(arg2, wires=0) + qml.RX(arg1, wires=0) + qml.RX(0.10, wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def reference_circuit_false(arg2): + qml.RX(0.10, wires=0) + qml.RX(arg2, wires=0) + qml.RX(0.10, wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def reference_circuit_elif1(arg1): + qml.RX(0.10, wires=0) + qml.RX(arg1, wires=0) + qml.RX(0.10, wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def reference_circuit_no_branch(): + qml.RX(0.10, wires=0) + qml.RX(0.10, wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + +class TestQuantumConditionals: + @pytest.mark.parametrize( + "pred, arg1, arg2", + [ + (1, 0.5, 1.0), + (0, 0.5, 1.0), + (-1, 0.5, 1.0), + (-2, 0.5, 1.0), + ], + ) + def test_conditional_branches(self, pred, arg1, arg2): + result = circuit(pred, arg1, arg2) + + if pred > 0: + expected_result = reference_circuit_true(arg1, arg2) + elif pred == -1: + expected_result = reference_circuit_elif1(arg1) + else: + expected_result = reference_circuit_false(arg2) + + assert np.allclose(result, expected_result), f"Expected {expected_result}, but got {result}" From 9d50198cef5083706a7dd906357a531ba7c0aacc Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 24 Jul 2024 14:32:20 -0400 Subject: [PATCH 23/34] Adding more unit tests --- pennylane/math/utils.py | 2 +- pennylane/ops/op_math/condition.py | 15 +- tests/capture/test_capture_conditionals.py | 431 ++++++++++----------- 3 files changed, 223 insertions(+), 225 deletions(-) diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 9c42d037466..a03954b3746 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -227,7 +227,7 @@ def get_interface(*values): # contains autograd and another interface warnings.warn( f"Contains tensors of types {non_numpy_scipy_interfaces}; dispatch will prioritize " - "TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.", + "TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.", UserWarning, ) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index e0ec5441f36..58e487344f4 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -114,6 +114,13 @@ def cond(condition, true_fn, false_fn=None, elifs=()): will be captured by Catalyst, the just-in-time (JIT) compiler, with the executed branch determined at runtime. For more details, please see :func:`catalyst.cond`. + When used with `qml.capture.enabled()` equal to ``True``, this function allows + for general if-elif-else constructs. As with the JIT mode, all branches will be + captured, with the executed branch determined at runtime. Each branch can receive parameters. + However, the function cannot branch on mid-circuit measurements. + If a branch returns one or more variables, every other branch must return the same abstract values. + If a branch returns one or more operators, these will be appended to the QueuingManager. + .. note:: With the Python interpreter, support for :func:`~.cond` @@ -511,15 +518,15 @@ def validate_abstract_values( ) outvals_true = jaxpr_true.out_avals - outvals_false = jaxpr_false.out_avals if jaxpr_false is not None else [] + + if jaxpr_false is not None: + outvals_false = jaxpr_false.out_avals + validate_abstract_values(outvals_false, outvals_true, "false") for idx, jaxpr_elif in enumerate(jaxpr_elifs): outvals_elif = jaxpr_elif.out_avals validate_abstract_values(outvals_elif, outvals_true, "elif", idx) - if outvals_false: - validate_abstract_values(outvals_false, outvals_true, "false") - # We return the abstract values of the true branch since the abstract values # of the false and elif branches (if they exist) should be the same return outvals_true diff --git a/tests/capture/test_capture_conditionals.py b/tests/capture/test_capture_conditionals.py index cad8e031ea9..77843f64191 100644 --- a/tests/capture/test_capture_conditionals.py +++ b/tests/capture/test_capture_conditionals.py @@ -16,6 +16,7 @@ """ # pylint: disable=redefined-outer-name +# pylint: disable=no-self-use import numpy as np import pytest @@ -61,180 +62,182 @@ def false_fn(arg): return true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 -@pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (-1, 10, 9), # Elif condition 1 - (-2, 10, 8), # Elif condition 2 - (-3, 10, 7), # Elif condition 3 - (-4, 10, 6), # Elif condition 4 - (0, 10, 30), # False condition - ], -) -def test_cond_true_elifs_false(testing_functions, selector, arg, expected): - """Test the conditional with true, elifs, and false branches.""" - - true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 = testing_functions - - result = qml.cond( - selector > 0, - true_fn, - false_fn, - elifs=( - (selector == -1, elif_fn1), - (selector == -2, elif_fn2), - (selector == -3, elif_fn3), - (selector == -4, elif_fn4), - ), - )(arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - -@pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (-1, 10, 9), # Elif condition 1 - (-2, 10, 8), # Elif condition 2 - (-3, 10, ()), # No condition met - ], -) -def test_cond_true_elifs(testing_functions, selector, arg, expected): - """Test the conditional with true and elifs branches.""" - - true_fn, _, elif_fn1, elif_fn2, _, _ = testing_functions - - result = qml.cond( - selector > 0, - true_fn, - elifs=( - (selector == -1, elif_fn1), - (selector == -2, elif_fn2), - ), - )(arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - -@pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (0, 10, 30), # False condition - ], -) -def test_cond_true_false(testing_functions, selector, arg, expected): - """Test the conditional with true and false branches.""" - - true_fn, false_fn, _, _, _, _ = testing_functions - - result = qml.cond( - selector > 0, - true_fn, - false_fn, - )(arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - -@pytest.mark.parametrize( - "selector, arg, expected", - [ - (1, 10, 20), # True condition - (0, 10, ()), # No condition met - ], -) -def test_cond_true(testing_functions, selector, arg, expected): - """Test the conditional with only the true branch.""" - - true_fn, _, _, _, _, _ = testing_functions - - result = qml.cond( - selector > 0, - true_fn, - )(arg) - assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - - -def test_validate_number_of_output_variables(): - """Test mismatch in number of output variables.""" - - def true_fn(x): - return x + 1, x + 2 - - def false_fn(x): - return x + 1 - - with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) - - -def test_validate_output_variable_types(): - """Test mismatch in output variable types.""" - - def true_fn(x): - return x + 1, x + 2 - - def false_fn(x): - return x + 1, x + 2.0 - - with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) - - -def test_validate_elif_branches(): - """Test elif branch mismatches.""" - - def true_fn(x): - return x + 1, x + 2 - - def false_fn(x): - return x + 1, x + 2 - - def elif_fn1(x): - return x + 1, x + 2 - - def elif_fn2(x): - return x + 1, x + 2.0 - - def elif_fn3(x): - return x + 1 - - with pytest.raises( - AssertionError, match=r"Mismatch in output abstract values in elif branch #1" - ): - jax.make_jaxpr( - _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) - )(jax.numpy.array(1)) - - with pytest.raises( - AssertionError, match=r"Mismatch in number of output variables in elif branch #0" - ): - jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( - jax.numpy.array(1) - ) +class TestCond: + """Tests for conditional functions using qml.cond.""" + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, 7), # Elif condition 3 + (-4, 10, 6), # Elif condition 4 + (0, 10, 30), # False condition + ], + ) + def test_cond_true_elifs_false(self, testing_functions, selector, arg, expected): + """Test the conditional with true, elifs, and false branches.""" + true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 = testing_functions + + result = qml.cond( + selector > 0, + true_fn, + false_fn, + elifs=( + (selector == -1, elif_fn1), + (selector == -2, elif_fn2), + (selector == -3, elif_fn3), + (selector == -4, elif_fn4), + ), + )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, ()), # No condition met + ], + ) + def test_cond_true_elifs(self, testing_functions, selector, arg, expected): + """Test the conditional with true and elifs branches.""" + true_fn, _, elif_fn1, elif_fn2, _, _ = testing_functions + + result = qml.cond( + selector > 0, + true_fn, + elifs=( + (selector == -1, elif_fn1), + (selector == -2, elif_fn2), + ), + )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (0, 10, 30), # False condition + ], + ) + def test_cond_true_false(self, testing_functions, selector, arg, expected): + """Test the conditional with true and false branches.""" + true_fn, false_fn, _, _, _, _ = testing_functions + + result = qml.cond( + selector > 0, + true_fn, + false_fn, + )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (0, 10, ()), # No condition met + ], + ) + def test_cond_true(self, testing_functions, selector, arg, expected): + """Test the conditional with only the true branch.""" + true_fn, _, _, _, _, _ = testing_functions + + result = qml.cond( + selector > 0, + true_fn, + )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + +class TestCondReturns: + """Tests for validating the number and types of output variables in conditional functions.""" + + @pytest.mark.parametrize( + "true_fn, false_fn, expected_error, match", + [ + ( + lambda x: (x + 1, x + 2), + lambda x: None, + AssertionError, + r"Mismatch in number of output variables", + ), + ( + lambda x: (x + 1, x + 2), + lambda x: (x + 1,), + AssertionError, + r"Mismatch in number of output variables", + ), + ( + lambda x: (x + 1, x + 2), + lambda x: (x + 1, x + 2.0), + AssertionError, + r"Mismatch in output abstract values", + ), + ], + ) + def test_validate_mismatches(self, true_fn, false_fn, expected_error, match): + """Test mismatch in number and type of output variables.""" + with pytest.raises(expected_error, match=match): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + def test_validate_number_of_output_variables(self): + """Test mismatch in number of output variables.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1 + + with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + def test_validate_output_variable_types(self): + """Test mismatch in output variable types.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1, x + 2.0 -@pytest.mark.parametrize( - "true_fn, false_fn, expected_error, match", - [ - ( - lambda x: (x + 1, x + 2), - lambda x: (x + 1), - AssertionError, - r"Mismatch in number of output variables", - ), - ( - lambda x: (x + 1, x + 2), - lambda x: (x + 1, x + 2.0), - AssertionError, - r"Mismatch in output abstract values", - ), - ], -) -def test_validate_mismatches(true_fn, false_fn, expected_error, match): - """Test mismatch in number and type of output variables.""" - with pytest.raises(expected_error, match=match): - jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + def test_validate_elif_branches(self): + """Test elif branch mismatches.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1, x + 2 + + def elif_fn1(x): + return x + 1, x + 2 + + def elif_fn2(x): + return x + 1, x + 2.0 + + def elif_fn3(x): + return x + 1 + + with pytest.raises( + AssertionError, match=r"Mismatch in output abstract values in elif branch #1" + ): + jax.make_jaxpr( + _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) + )(jax.numpy.array(1)) + + with pytest.raises( + AssertionError, match=r"Mismatch in number of output variables in elif branch #0" + ): + jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( + jax.numpy.array(1) + ) dev = qml.device("default.qubit", wires=3) @@ -242,83 +245,71 @@ def test_validate_mismatches(true_fn, false_fn, expected_error, match): @qml.qnode(dev) def circuit(pred, arg1, arg2): + """Quantum circuit with conditional branches.""" qml.RX(0.10, wires=0) def true_fn(arg1, arg2): - qml.RX(arg1, wires=0) + qml.RY(arg1, wires=0) qml.RX(arg2, wires=0) - qml.RX(arg1, wires=0) + qml.RZ(arg1, wires=0) def false_fn(arg1, arg2): + qml.RX(arg1, wires=0) qml.RX(arg2, wires=0) def elif_fn1(arg1, arg2): + qml.RZ(arg2, wires=0) qml.RX(arg1, wires=0) - qml.cond( - pred > 0, - true_fn, - false_fn, - elifs=((pred == -1, elif_fn1)), - )(arg1, arg2) - + qml.cond(pred > 0, true_fn, false_fn, elifs=(pred == -1, elif_fn1))(arg1, arg2) qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) @qml.qnode(dev) -def reference_circuit_true(arg1, arg2): - qml.RX(0.10, wires=0) - qml.RX(arg1, wires=0) - qml.RX(arg2, wires=0) - qml.RX(arg1, wires=0) - qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) +def circuit_with_returned_operator(pred, arg1, arg2): + """Quantum circuit with conditional branches that return operators.""" - -@qml.qnode(dev) -def reference_circuit_false(arg2): qml.RX(0.10, wires=0) - qml.RX(arg2, wires=0) - qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) - -@qml.qnode(dev) -def reference_circuit_elif1(arg1): - qml.RX(0.10, wires=0) - qml.RX(arg1, wires=0) - qml.RX(0.10, wires=0) - return qml.expval(qml.PauliZ(wires=0)) + def true_fn(arg1, arg2): + qml.RY(arg1, wires=0) + return 7, 4.6, qml.RY(arg2, wires=0), True + def false_fn(arg1, arg2): + qml.RZ(arg2, wires=0) + return 2, 2.2, qml.RZ(arg1, wires=0), False -@qml.qnode(dev) -def reference_circuit_no_branch(): - qml.RX(0.10, wires=0) + qml.cond(pred > 0, true_fn, false_fn)(arg1, arg2) qml.RX(0.10, wires=0) return qml.expval(qml.PauliZ(wires=0)) -class TestQuantumConditionals: +class TestCondCircuits: + """Tests for conditional quantum circuits.""" + @pytest.mark.parametrize( - "pred, arg1, arg2", + "pred, arg1, arg2, expected", [ - (1, 0.5, 1.0), - (0, 0.5, 1.0), - (-1, 0.5, 1.0), - (-2, 0.5, 1.0), + (1, 0.5, 0.6, 0.63340907), # RX(0.10) -> RY(0.5) -> RX(0.6) -> RZ(0.5) -> RX(0.10) + (0, 0.5, 0.6, 0.26749883), # RX(0.10) -> RX(0.5) -> RX(0.6) -> RX(0.10) + (-1, 0.5, 0.6, 0.77468805), # RX(0.10) -> RZ(0.6) -> RX(0.5) -> RX(0.10) ], ) - def test_conditional_branches(self, pred, arg1, arg2): + def test_circuit(self, pred, arg1, arg2, expected): + """Test circuit with true, false, and elif branches.""" result = circuit(pred, arg1, arg2) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" - if pred > 0: - expected_result = reference_circuit_true(arg1, arg2) - elif pred == -1: - expected_result = reference_circuit_elif1(arg1) - else: - expected_result = reference_circuit_false(arg2) - - assert np.allclose(result, expected_result), f"Expected {expected_result}, but got {result}" + @pytest.mark.parametrize( + "pred, arg1, arg2, expected", + [ + (1, 0.5, 0.6, 0.43910855), # RX(0.10) -> RY(0.5) -> RY(0.6) -> RX(0.10) + (0, 0.5, 0.6, 0.98551243), # RX(0.10) -> RZ(0.6) -> RX(0.5) -> RX(0.10) + ], + ) + def test_circuit_with_returned_operator(self, pred, arg1, arg2, expected): + """Test circuit with returned operators in the branches.""" + result = circuit_with_returned_operator(pred, arg1, arg2) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" From 61c059ad8d78b4263abe9324890382a1463e3d73 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 24 Jul 2024 17:19:55 -0400 Subject: [PATCH 24/34] Adding test and re-naming file --- doc/releases/changelog-dev.md | 3 +++ pennylane/ops/op_math/condition.py | 14 +++++----- ...e_conditionals.py => test_capture_cond.py} | 27 ++++++++++++++++--- 3 files changed, 34 insertions(+), 10 deletions(-) rename tests/capture/{test_capture_conditionals.py => test_capture_cond.py} (92%) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index f67840f549b..e2dce45dd17 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,6 +4,9 @@

New features since last release

+* The `qml.cond` function can be captured into plxpr. + [(#5999)](https://github.com/PennyLaneAI/pennylane/pull/5999) + * Resolved the bug in `qml.ThermalRelaxationError` where there was a typo from `tq` to `tg`. [(#5988)](https://github.com/PennyLaneAI/pennylane/issues/5988) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 58e487344f4..df14da2b0bf 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -102,7 +102,7 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -def cond(condition, true_fn, false_fn=None, elifs=()): +def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -114,10 +114,12 @@ def cond(condition, true_fn, false_fn=None, elifs=()): will be captured by Catalyst, the just-in-time (JIT) compiler, with the executed branch determined at runtime. For more details, please see :func:`catalyst.cond`. - When used with `qml.capture.enabled()` equal to ``True``, this function allows - for general if-elif-else constructs. As with the JIT mode, all branches will be - captured, with the executed branch determined at runtime. Each branch can receive parameters. + When used with `qml.capture.enabled()`, this function allows for general + if-elif-else constructs. As with the JIT mode, all branches will be captured, + with the executed branch determined at runtime. However, the function cannot branch on mid-circuit measurements. + Each branch can receive arguments, but the arguments must be the same for all branches. + Both the arguments and the branches must be JAX-compatible. If a branch returns one or more variables, every other branch must return the same abstract values. If a branch returns one or more operators, these will be appended to the QueuingManager. @@ -138,13 +140,13 @@ def cond(condition, true_fn, false_fn=None, elifs=()): Args: condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when - decorated by :func:`~.qjit`. + decorated by :func:`~.qjit` or when using :func:`~.qml.capture.enabled()`. true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True`` false_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Can only - be used when decorated by :func:`~.qjit`. + be used when decorated by :func:`~.qjit` or when using :func:`~.qml.capture.enabled()`. Returns: function: A new function that applies the conditional equivalent of ``true_fn``. The returned diff --git a/tests/capture/test_capture_conditionals.py b/tests/capture/test_capture_cond.py similarity index 92% rename from tests/capture/test_capture_conditionals.py rename to tests/capture/test_capture_cond.py index 77843f64191..934f37c4ae3 100644 --- a/tests/capture/test_capture_conditionals.py +++ b/tests/capture/test_capture_cond.py @@ -119,8 +119,8 @@ def test_cond_true_elifs(self, testing_functions, selector, arg, expected): @pytest.mark.parametrize( "selector, arg, expected", [ - (1, 10, 20), # True condition - (0, 10, 30), # False condition + (1, 10, 20), + (0, 10, 30), ], ) def test_cond_true_false(self, testing_functions, selector, arg, expected): @@ -137,8 +137,8 @@ def test_cond_true_false(self, testing_functions, selector, arg, expected): @pytest.mark.parametrize( "selector, arg, expected", [ - (1, 10, 20), # True condition - (0, 10, ()), # No condition met + (1, 10, 20), + (0, 10, ()), ], ) def test_cond_true(self, testing_functions, selector, arg, expected): @@ -151,6 +151,25 @@ def test_cond_true(self, testing_functions, selector, arg, expected): )(arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, jax.numpy.array([2, 3]), 12), + (0, jax.numpy.array([2, 3]), 15), + ], + ) + def test_cond_with_jax_array(self, selector, arg, expected): + """Test the conditional with array arguments.""" + + def true_fn(jax_array): + return jax_array[0] * jax_array[1] * 2 + + def false_fn(jax_array): + return jax_array[0] * jax_array[1] * 2.5 + + result = qml.cond(selector > 0, true_fn, false_fn)(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + class TestCondReturns: """Tests for validating the number and types of output variables in conditional functions.""" From feadf305e1b1d36ab7c319018d9819229e86cd6a Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Thu, 25 Jul 2024 09:03:03 -0400 Subject: [PATCH 25/34] Doc fix --- pennylane/ops/op_math/condition.py | 23 ++++++++--------------- 1 file changed, 8 insertions(+), 15 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index df14da2b0bf..f12d629f162 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -114,7 +114,7 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): will be captured by Catalyst, the just-in-time (JIT) compiler, with the executed branch determined at runtime. For more details, please see :func:`catalyst.cond`. - When used with `qml.capture.enabled()`, this function allows for general + When used with :func:`~.pennylane.capture.enabled`, this function allows for general if-elif-else constructs. As with the JIT mode, all branches will be captured, with the executed branch determined at runtime. However, the function cannot branch on mid-circuit measurements. @@ -140,13 +140,13 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): Args: condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when - decorated by :func:`~.qjit` or when using :func:`~.qml.capture.enabled()`. + decorated by :func:`~.qjit` or when using :func:`~.pennylane.capture.enabled`. true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True`` false_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Can only - be used when decorated by :func:`~.qjit` or when using :func:`~.qml.capture.enabled()`. + be used when decorated by :func:`~.qjit` or when using :func:`~.pennylane.capture.enabled`. Returns: function: A new function that applies the conditional equivalent of ``true_fn``. The returned @@ -448,9 +448,11 @@ def wrapper(*args, **kwargs): return wrapper -@functools.lru_cache # only create the first time requested +@functools.lru_cache def _get_cond_qfunc_prim(): - # if capture is enabled, jax should be installed + """Get the cond primitive for quantum functions.""" + + # JAX should be installed if capture is enabled import jax # pylint: disable=import-outside-toplevel cond_prim = jax.core.Primitive("cond") @@ -460,16 +462,7 @@ def _get_cond_qfunc_prim(): def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): def run_jaxpr(jaxpr, *args): - - out = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - - # If the branch returns one or more Operators, we append them to the QueuingManager - # so that they are applied to the quantum circuit - for outvar in out: - if isinstance(outvar, Operator): - QueuingManager.append(outvar) - - return out + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) def true_branch(args): return run_jaxpr(jaxpr_true, *args) From b8f80fc851d0f73b18a37e864ffd8ec25324b93c Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Thu, 25 Jul 2024 13:32:53 -0400 Subject: [PATCH 26/34] Suggestions from code review --- pennylane/ops/op_math/condition.py | 47 ++++++++++++++++++------------ tests/capture/test_capture_cond.py | 26 ++++++++++++----- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index f12d629f162..6be1de601e6 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -121,7 +121,7 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): Each branch can receive arguments, but the arguments must be the same for all branches. Both the arguments and the branches must be JAX-compatible. If a branch returns one or more variables, every other branch must return the same abstract values. - If a branch returns one or more operators, these will be appended to the QueuingManager. + If a branch returns one or more operators, these will be applied to the circuit. .. note:: @@ -392,7 +392,6 @@ def qnode(a, x, y, z): return cond_func if qml.capture.enabled(): - print("Capture mode for cond") return _capture_cond(condition, true_fn, false_fn, elifs) if elifs: @@ -469,7 +468,7 @@ def true_branch(args): def elif_branch(args, elifs_conditions, jaxpr_elifs): if not jaxpr_elifs: - return false_branch(args) + return None pred = elifs_conditions[0] rest_preds = elifs_conditions[1:] jaxpr_elif = jaxpr_elifs[0] @@ -485,9 +484,12 @@ def false_branch(args): if condition: return true_branch(args) - if elifs_conditions.size > 0: - return elif_branch(args, elifs_conditions, jaxpr_elifs) - return false_branch(args) + + elif_branch_out = ( + elif_branch(args, elifs_conditions, jaxpr_elifs) if elifs_conditions.size > 0 else None + ) + + return false_branch(args) if elif_branch_out is None else elif_branch_out @cond_prim.def_abstract_eval def _(*_, jaxpr_true, jaxpr_false, jaxpr_elifs): @@ -500,24 +502,33 @@ def validate_abstract_values( ) -> None: """Ensure the collected abstract values match the expected ones.""" - assert len(outvals) == len(expected_outvals), ( - f"Mismatch in number of output variables in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)}: " - f"{len(outvals)} vs {len(expected_outvals)}" - ) - for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): - assert outval == expected_outval, ( - f"Mismatch in output abstract values in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)} at position {i}: " - f"{outval} vs {expected_outval}" + if len(outvals) != len(expected_outvals): + raise ValueError( + f"Mismatch in number of output variables in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)}: " + f"{len(outvals)} vs {len(expected_outvals)}" ) + for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): + if outval != expected_outval: + raise ValueError( + f"Mismatch in output abstract values in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)} at position {i}: " + f"{outval} vs {expected_outval}" + ) + outvals_true = jaxpr_true.out_avals if jaxpr_false is not None: outvals_false = jaxpr_false.out_avals validate_abstract_values(outvals_false, outvals_true, "false") + else: + if outvals_true is not None: + raise ValueError( + "The false branch must be provided if the true branch returns any variables" + ) + for idx, jaxpr_elif in enumerate(jaxpr_elifs): outvals_elif = jaxpr_elif.out_avals validate_abstract_values(outvals_elif, outvals_true, "elif", idx) @@ -543,9 +554,7 @@ def new_wrapper(*args, **kwargs): jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) jaxpr_false = ( - (jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None) - if false_fn - else None + jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None ) # We extract each condition (or predicate) from the elifs argument list diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 934f37c4ae3..6cec5e68246 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -180,19 +180,19 @@ class TestCondReturns: ( lambda x: (x + 1, x + 2), lambda x: None, - AssertionError, + ValueError, r"Mismatch in number of output variables", ), ( lambda x: (x + 1, x + 2), lambda x: (x + 1,), - AssertionError, + ValueError, r"Mismatch in number of output variables", ), ( lambda x: (x + 1, x + 2), lambda x: (x + 1, x + 2.0), - AssertionError, + ValueError, r"Mismatch in output abstract values", ), ], @@ -211,7 +211,7 @@ def true_fn(x): def false_fn(x): return x + 1 - with pytest.raises(AssertionError, match=r"Mismatch in number of output variables"): + with pytest.raises(ValueError, match=r"Mismatch in number of output variables"): jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) def test_validate_output_variable_types(self): @@ -223,9 +223,21 @@ def true_fn(x): def false_fn(x): return x + 1, x + 2.0 - with pytest.raises(AssertionError, match=r"Mismatch in output abstract values"): + with pytest.raises(ValueError, match=r"Mismatch in output abstract values"): jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + def test_validate_no_false_branch_with_return(self): + """Test no false branch provided with return variables.""" + + def true_fn(x): + return x + 1, x + 2 + + with pytest.raises( + ValueError, + match=r"The false branch must be provided if the true branch returns any variables", + ): + jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1)) + def test_validate_elif_branches(self): """Test elif branch mismatches.""" @@ -245,14 +257,14 @@ def elif_fn3(x): return x + 1 with pytest.raises( - AssertionError, match=r"Mismatch in output abstract values in elif branch #1" + ValueError, match=r"Mismatch in output abstract values in elif branch #1" ): jax.make_jaxpr( _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) )(jax.numpy.array(1)) with pytest.raises( - AssertionError, match=r"Mismatch in number of output variables in elif branch #0" + ValueError, match=r"Mismatch in number of output variables in elif branch #0" ): jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( jax.numpy.array(1) From c0c4c72785820d3bc8731205dae1c322bc2ec452 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 26 Jul 2024 09:40:43 -0400 Subject: [PATCH 27/34] Docstring clarification --- pennylane/ops/op_math/condition.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 6be1de601e6..60c9fd3434b 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -115,13 +115,14 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): branch determined at runtime. For more details, please see :func:`catalyst.cond`. When used with :func:`~.pennylane.capture.enabled`, this function allows for general - if-elif-else constructs. As with the JIT mode, all branches will be captured, + if-elif-else constructs. As with the JIT mode, all branches are captured, with the executed branch determined at runtime. However, the function cannot branch on mid-circuit measurements. Each branch can receive arguments, but the arguments must be the same for all branches. Both the arguments and the branches must be JAX-compatible. If a branch returns one or more variables, every other branch must return the same abstract values. - If a branch returns one or more operators, these will be applied to the circuit. + If used inside a quantum function, operators in the branch executed + at runtime are applied to the circuit, even if they are not explicitly returned. .. note:: From 75687ba933cc7ce736542968ae0ebb5520933072 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Fri, 26 Jul 2024 14:32:42 -0400 Subject: [PATCH 28/34] Suggestions from code review --- pennylane/ops/op_math/condition.py | 144 +++++++++++------------------ tests/capture/test_capture_cond.py | 17 ++++ 2 files changed, 73 insertions(+), 88 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 60c9fd3434b..aa10801f195 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -16,7 +16,7 @@ """ import functools from functools import wraps -from typing import Callable, Type +from typing import Callable, Optional, Type import pennylane as qml from pennylane import QueuingManager @@ -102,7 +102,7 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): +def cond(condition, true_fn: Callable, false_fn: Optional[Callable] = None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -114,16 +114,6 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): will be captured by Catalyst, the just-in-time (JIT) compiler, with the executed branch determined at runtime. For more details, please see :func:`catalyst.cond`. - When used with :func:`~.pennylane.capture.enabled`, this function allows for general - if-elif-else constructs. As with the JIT mode, all branches are captured, - with the executed branch determined at runtime. - However, the function cannot branch on mid-circuit measurements. - Each branch can receive arguments, but the arguments must be the same for all branches. - Both the arguments and the branches must be JAX-compatible. - If a branch returns one or more variables, every other branch must return the same abstract values. - If used inside a quantum function, operators in the branch executed - at runtime are applied to the circuit, even if they are not explicitly returned. - .. note:: With the Python interpreter, support for :func:`~.cond` @@ -132,22 +122,32 @@ def cond(condition, true_fn: Callable, false_fn: Callable = None, elifs=()): apply the :func:`defer_measurements` transform. .. note:: + When used with :func:`~.qjit`, this function only supports the Catalyst compiler. See :func:`catalyst.cond` for more details. Please see the Catalyst :doc:`quickstart guide `, as well as the :doc:`sharp bits and debugging tips `. + .. note:: + + When used with :func:`~.pennylane.capture.enabled`, this function allows for general + if-elif-else constructs. As with the JIT mode, all branches are captured, + with the executed branch determined at runtime. + + Each branch can receive arguments, but the arguments must be JAX-compatible. + If a branch returns one or more variables, every other branch must return the same abstract values. + Args: condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when - decorated by :func:`~.qjit` or when using :func:`~.pennylane.capture.enabled`. + decorated by :func:`~.qjit`. true_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``True`` false_fn (callable): The quantum function or PennyLane operation to apply if ``condition`` is ``False`` elifs (List(Tuple(bool, callable))): A list of (bool, elif_fn) clauses. Can only - be used when decorated by :func:`~.qjit` or when using :func:`~.pennylane.capture.enabled`. + be used when decorated by :func:`~.qjit`. Returns: function: A new function that applies the conditional equivalent of ``true_fn``. The returned @@ -448,94 +448,66 @@ def wrapper(*args, **kwargs): return wrapper +def _validate_abstract_values( + outvals: list, expected_outvals: list, branch_type: str, index: int = None +) -> None: + """Ensure the collected abstract values match the expected ones.""" + + if len(outvals) != len(expected_outvals): + raise ValueError( + f"Mismatch in number of output variables in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)}: " + f"{len(outvals)} vs {len(expected_outvals)}" + ) + + for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): + if outval != expected_outval: + raise ValueError( + f"Mismatch in output abstract values in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)} at position {i}: " + f"{outval} vs {expected_outval}" + ) + + @functools.lru_cache def _get_cond_qfunc_prim(): """Get the cond primitive for quantum functions.""" - # JAX should be installed if capture is enabled import jax # pylint: disable=import-outside-toplevel cond_prim = jax.core.Primitive("cond") cond_prim.multiple_results = True @cond_prim.def_impl - def _(condition, elifs_conditions, *args, jaxpr_true, jaxpr_false, jaxpr_elifs): - - def run_jaxpr(jaxpr, *args): - return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - - def true_branch(args): - return run_jaxpr(jaxpr_true, *args) - - def elif_branch(args, elifs_conditions, jaxpr_elifs): - if not jaxpr_elifs: - return None - pred = elifs_conditions[0] - rest_preds = elifs_conditions[1:] - jaxpr_elif = jaxpr_elifs[0] - rest_jaxpr_elifs = jaxpr_elifs[1:] - if pred: - return run_jaxpr(jaxpr_elif, *args) - return elif_branch(args, rest_preds, rest_jaxpr_elifs) - - def false_branch(args): - if jaxpr_false is not None: - return run_jaxpr(jaxpr_false, *args) - return () - - if condition: - return true_branch(args) - - elif_branch_out = ( - elif_branch(args, elifs_conditions, jaxpr_elifs) if elifs_conditions.size > 0 else None - ) + def _(conditions, *args, jaxpr_branches): - return false_branch(args) if elif_branch_out is None else elif_branch_out + for pred, jaxpr in zip(conditions, jaxpr_branches): + if pred and jaxpr is not None: + return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) - @cond_prim.def_abstract_eval - def _(*_, jaxpr_true, jaxpr_false, jaxpr_elifs): + return () - # We check that the return values in each branch (true, and possibly false and elifs) - # have the same abstract values. - # The error messages are detailed to help debugging - def validate_abstract_values( - outvals: list, expected_outvals: list, branch_type: str, index: int = None - ) -> None: - """Ensure the collected abstract values match the expected ones.""" - - if len(outvals) != len(expected_outvals): - raise ValueError( - f"Mismatch in number of output variables in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)}: " - f"{len(outvals)} vs {len(expected_outvals)}" - ) - - for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): - if outval != expected_outval: - raise ValueError( - f"Mismatch in output abstract values in {branch_type} branch" - f"{'' if index is None else ' #' + str(index)} at position {i}: " - f"{outval} vs {expected_outval}" - ) + @cond_prim.def_abstract_eval + def _(*_, jaxpr_branches): - outvals_true = jaxpr_true.out_avals + # Index 0 corresponds to the true branch + outvals_true = jaxpr_branches[0].out_avals - if jaxpr_false is not None: - outvals_false = jaxpr_false.out_avals - validate_abstract_values(outvals_false, outvals_true, "false") + for idx, jaxpr_branch in enumerate(jaxpr_branches): + if idx == 0: + continue - else: - if outvals_true is not None: + if outvals_true and jaxpr_branch is None: raise ValueError( "The false branch must be provided if the true branch returns any variables" ) - for idx, jaxpr_elif in enumerate(jaxpr_elifs): - outvals_elif = jaxpr_elif.out_avals - validate_abstract_values(outvals_elif, outvals_true, "elif", idx) + outvals_branch = jaxpr_branch.out_avals + branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false" + _validate_abstract_values(outvals_branch, outvals_true, branch_type, idx - 1) # We return the abstract values of the true branch since the abstract values - # of the false and elif branches (if they exist) should be the same + # of the other branches (if they exist) should be the same return outvals_true return cond_prim @@ -567,17 +539,13 @@ def new_wrapper(*args, **kwargs): elifs_conditions.append(pred) jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) - elifs_conditions = ( - jax.numpy.array(elifs_conditions) if elifs_conditions else jax.numpy.empty(0) - ) + jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] + conditions = jax.numpy.array([condition, *elifs_conditions, True]) return cond_prim.bind( - condition, - elifs_conditions, + conditions, *args, - jaxpr_true=jaxpr_true, - jaxpr_false=jaxpr_false, - jaxpr_elifs=jaxpr_elifs, + jaxpr_branches=jaxpr_branches, ) return new_wrapper diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 6cec5e68246..88d7785af8b 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -238,6 +238,23 @@ def true_fn(x): ): jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1)) + def test_validate_no_false_branch_with_return_2(self): + """Test no false branch provided with return variables.""" + + def true_fn(x): + return x + 1, x + 2 + + def elif_fn(x): + return x + 1, x + 2 + + with pytest.raises( + ValueError, + match=r"The false branch must be provided if the true branch returns any variables", + ): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn=None, elifs=(False, elif_fn)))( + jax.numpy.array(1) + ) + def test_validate_elif_branches(self): """Test elif branch mismatches.""" From 4cee0457e9e8ef6a955cdb4ac995258b2b62273c Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 29 Jul 2024 10:39:43 -0400 Subject: [PATCH 29/34] Adding test with multiple cond --- tests/capture/test_capture_cond.py | 38 ++++++++++++++++++++++++++++++ 1 file changed, 38 insertions(+) diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 88d7785af8b..31c24f89754 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -334,6 +334,32 @@ def false_fn(arg1, arg2): return qml.expval(qml.PauliZ(wires=0)) +@qml.qnode(dev) +def circuit_multiple_cond(tmp_pred, tmp_arg): + """Quantum circuit with multiple dynamic conditional branches.""" + + dyn_pred_1 = tmp_pred > 0 + arg = tmp_arg + + def true_fn_1(arg): + return True, qml.RX(arg, wires=0) + + # pylint: disable=unused-argument + def false_fn_1(arg): + return False, qml.RY(0.1, wires=0) + + def true_fn_2(arg): + return qml.RX(arg, wires=0) + + # pylint: disable=unused-argument + def false_fn_2(arg): + return qml.RY(0.1, wires=0) + + [dyn_pred_2, _] = qml.cond(dyn_pred_1, true_fn_1, false_fn_1, elifs=())(arg) + qml.cond(dyn_pred_2, true_fn_2, false_fn_2, elifs=())(arg) + return qml.expval(qml.Z(0)) + + class TestCondCircuits: """Tests for conditional quantum circuits.""" @@ -361,3 +387,15 @@ def test_circuit_with_returned_operator(self, pred, arg1, arg2, expected): """Test circuit with returned operators in the branches.""" result = circuit_with_returned_operator(pred, arg1, arg2) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "tmp_pred, tmp_arg, expected", + [ + (1, 0.5, 0.54030231), # RX(0.5) -> RX(0.5) + (-1, 0.5, 0.98006658), # RY(0.1) -> RY(0.1) + ], + ) + def test_circuit_multiple_cond(self, tmp_pred, tmp_arg, expected): + """Test circuit with returned operators in the branches.""" + result = circuit_multiple_cond(tmp_pred, tmp_arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" From 37862eefe05e55de1aac8d1287ef60d9e9bbd9af Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 29 Jul 2024 17:04:26 -0400 Subject: [PATCH 30/34] Capturing jaxpr.consts and passing them as positional arguments --- pennylane/ops/op_math/condition.py | 42 ++++++++++++++++++++++-------- tests/capture/test_capture_cond.py | 42 ++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 11 deletions(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index aa10801f195..286d5a457c5 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -479,28 +479,38 @@ def _get_cond_qfunc_prim(): cond_prim.multiple_results = True @cond_prim.def_impl - def _(conditions, *args, jaxpr_branches): + def _(conditions, *args_and_consts, jaxpr_branches, n_consts_per_branch, n_args): - for pred, jaxpr in zip(conditions, jaxpr_branches): + args = args_and_consts[:n_args] + consts_flat = args_and_consts[n_args:] + + consts_per_branch = [] + start = 0 + for n in n_consts_per_branch: + consts_per_branch.append(consts_flat[start : start + n]) + start += n + + for pred, jaxpr, consts in zip(conditions, jaxpr_branches, consts_per_branch): if pred and jaxpr is not None: - return jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + return jax.core.eval_jaxpr(jaxpr.jaxpr, consts, *args) return () @cond_prim.def_abstract_eval - def _(*_, jaxpr_branches): + def _(*_, jaxpr_branches, **__): - # Index 0 corresponds to the true branch outvals_true = jaxpr_branches[0].out_avals for idx, jaxpr_branch in enumerate(jaxpr_branches): if idx == 0: continue - if outvals_true and jaxpr_branch is None: - raise ValueError( - "The false branch must be provided if the true branch returns any variables" - ) + if jaxpr_branch is None: + if outvals_true: + raise ValueError( + "The false branch must be provided if the true branch returns any variables" + ) + continue outvals_branch = jaxpr_branch.out_avals branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false" @@ -531,7 +541,7 @@ def new_wrapper(*args, **kwargs): ) # We extract each condition (or predicate) from the elifs argument list - # since these are traced by JAX and are passed as positional arguments to the cond primitive + # since these are traced by JAX and are passed as positional arguments to the primitive elifs_conditions = [] jaxpr_elifs = [] @@ -539,13 +549,23 @@ def new_wrapper(*args, **kwargs): elifs_conditions.append(pred) jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) - jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] conditions = jax.numpy.array([condition, *elifs_conditions, True]) + jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] + jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches] + + # We need to flatten the constants since JAX does not allow + # to pass lists as positional arguments + consts_flat = [const for sublist in jaxpr_consts for const in sublist] + n_consts_per_branch = [len(consts) for consts in jaxpr_consts] + return cond_prim.bind( conditions, *args, + *consts_flat, jaxpr_branches=jaxpr_branches, + n_consts_per_branch=n_consts_per_branch, + n_args=len(args), ) return new_wrapper diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 31c24f89754..546ef4b3e18 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -360,6 +360,35 @@ def false_fn_2(arg): return qml.expval(qml.Z(0)) +@qml.qnode(dev) +def circuit_with_consts(pred, arg): + """Quantum circuit with jaxpr constants.""" + + # these are captured as consts + arg1 = arg + arg2 = arg + 0.2 + arg3 = arg + 0.3 + arg4 = arg + 0.4 + arg5 = arg + 0.5 + arg6 = arg + 0.6 + + def true_fn(): + qml.RX(arg1, 0) + + def false_fn(): + qml.RX(arg2, 0) + qml.RX(arg3, 0) + + def elif_fn1(): + qml.RX(arg4, 0) + qml.RX(arg5, 0) + qml.RX(arg6, 0) + + qml.cond(pred > 0, true_fn, false_fn, elifs=((pred == 0, elif_fn1),))() + + return qml.expval(qml.Z(0)) + + class TestCondCircuits: """Tests for conditional quantum circuits.""" @@ -399,3 +428,16 @@ def test_circuit_multiple_cond(self, tmp_pred, tmp_arg, expected): """Test circuit with returned operators in the branches.""" result = circuit_multiple_cond(tmp_pred, tmp_arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "pred, arg, expected", + [ + (1, 0.5, 0.87758256), # RX(0.5) + (-1, 0.5, 0.0707372), # RX(0.7) -> RX(0.8) + (0, 0.5, -0.9899925), # RX(0.9) -> RX(1.0) -> RX(1.1) + ], + ) + def test_circuit_consts(self, pred, arg, expected): + """Test circuit with jaxpr constants.""" + result = circuit_with_consts(pred, arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" From 121b885482f9dc91b96e274db49a2eeeab00b0ae Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Mon, 29 Jul 2024 17:56:25 -0400 Subject: [PATCH 31/34] Adding test to cover line (usual codecov stuff) --- tests/capture/test_capture_cond.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 546ef4b3e18..bc191edf8fc 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -292,7 +292,18 @@ def elif_fn3(x): @qml.qnode(dev) -def circuit(pred, arg1, arg2): +def circuit(pred): + """Quantum circuit with only a true branch.""" + + def true_fn(): + qml.RX(0.1, wires=0) + + qml.cond(pred > 0, true_fn)() + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def circuit_branches(pred, arg1, arg2): """Quantum circuit with conditional branches.""" qml.RX(0.10, wires=0) @@ -392,6 +403,18 @@ def elif_fn1(): class TestCondCircuits: """Tests for conditional quantum circuits.""" + @pytest.mark.parametrize( + "pred, expected", + [ + (1, 0.99500417), # RX(0.1) + (0, 1.0), # No operation + ], + ) + def test_circuit(self, pred, expected): + """Test circuit with only a true branch.""" + result = circuit(pred) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + @pytest.mark.parametrize( "pred, arg1, arg2, expected", [ @@ -400,9 +423,9 @@ class TestCondCircuits: (-1, 0.5, 0.6, 0.77468805), # RX(0.10) -> RZ(0.6) -> RX(0.5) -> RX(0.10) ], ) - def test_circuit(self, pred, arg1, arg2, expected): + def test_circuit_branches(self, pred, arg1, arg2, expected): """Test circuit with true, false, and elif branches.""" - result = circuit(pred, arg1, arg2) + result = circuit_branches(pred, arg1, arg2) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" @pytest.mark.parametrize( From ee7ea5232bcfb914822d77db63f8f526b30ee1ab Mon Sep 17 00:00:00 2001 From: Pietropaolo Frisoni Date: Mon, 29 Jul 2024 18:55:26 -0400 Subject: [PATCH 32/34] Codecov is buggy --- pennylane/ops/op_math/condition.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index 286d5a457c5..e1ef20c79bb 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -510,7 +510,8 @@ def _(*_, jaxpr_branches, **__): raise ValueError( "The false branch must be provided if the true branch returns any variables" ) - continue + # this is tested, but coverage does not pick it up + continue # pragma: no cover outvals_branch = jaxpr_branch.out_avals branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false" From d5e823bfe499665b0cd5eba0966c8f273323e518 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Tue, 30 Jul 2024 15:37:05 -0400 Subject: [PATCH 33/34] Suggestions from code review --- doc/releases/changelog-dev.md | 5 ++--- pennylane/ops/op_math/condition.py | 9 +++------ tests/capture/test_capture_cond.py | 20 ++++++++++++++++++++ 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 47861aae9e2..494ce3234ac 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -4,9 +4,6 @@

New features since last release

-* The `qml.cond` function can be captured into plxpr. - [(#5999)](https://github.com/PennyLaneAI/pennylane/pull/5999) - * A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP` classes, allowing for more efficient handling of quantum density matrices, particularly with batch processing support. This method simplifies the calculation of probabilities from quantum states @@ -74,8 +71,10 @@ [(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919) * Applying `adjoint` and `ctrl` to a quantum function can now be captured into plxpr. + Furthermore, the `qml.cond` function can be captured into plxpr. [(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966) [(#5967)](https://github.com/PennyLaneAI/pennylane/pull/5967) + [(#5999)](https://github.com/PennyLaneAI/pennylane/pull/5999) * Set operations are now supported by Wires. [(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index e1ef20c79bb..7df1bcb54f6 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -484,13 +484,10 @@ def _(conditions, *args_and_consts, jaxpr_branches, n_consts_per_branch, n_args) args = args_and_consts[:n_args] consts_flat = args_and_consts[n_args:] - consts_per_branch = [] start = 0 - for n in n_consts_per_branch: - consts_per_branch.append(consts_flat[start : start + n]) - start += n - - for pred, jaxpr, consts in zip(conditions, jaxpr_branches, consts_per_branch): + for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch): + consts = consts_flat[start : start + n_consts] + start += n_consts if pred and jaxpr is not None: return jax.core.eval_jaxpr(jaxpr.jaxpr, consts, *args) diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index bc191edf8fc..3b9b778b604 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -428,6 +428,11 @@ def test_circuit_branches(self, pred, arg1, arg2, expected): result = circuit_branches(pred, arg1, arg2) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + args = [pred, arg1, arg2] + jaxpr = jax.make_jaxpr(circuit_branches)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize( "pred, arg1, arg2, expected", [ @@ -440,6 +445,11 @@ def test_circuit_with_returned_operator(self, pred, arg1, arg2, expected): result = circuit_with_returned_operator(pred, arg1, arg2) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + args = [pred, arg1, arg2] + jaxpr = jax.make_jaxpr(circuit_with_returned_operator)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize( "tmp_pred, tmp_arg, expected", [ @@ -452,6 +462,11 @@ def test_circuit_multiple_cond(self, tmp_pred, tmp_arg, expected): result = circuit_multiple_cond(tmp_pred, tmp_arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + args = [tmp_pred, tmp_arg] + jaxpr = jax.make_jaxpr(circuit_multiple_cond)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize( "pred, arg, expected", [ @@ -464,3 +479,8 @@ def test_circuit_consts(self, pred, arg, expected): """Test circuit with jaxpr constants.""" result = circuit_with_consts(pred, arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + args = [pred, arg] + jaxpr = jax.make_jaxpr(circuit_with_consts)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" From 046eab4b67fac055723dc9347210307e99947b14 Mon Sep 17 00:00:00 2001 From: PietropaoloFrisoni Date: Wed, 31 Jul 2024 10:24:47 -0400 Subject: [PATCH 34/34] Adding more tests --- tests/capture/test_capture_cond.py | 70 ++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 18 deletions(-) diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py index 3b9b778b604..a6ddfa63dce 100644 --- a/tests/capture/test_capture_cond.py +++ b/tests/capture/test_capture_cond.py @@ -80,19 +80,28 @@ def test_cond_true_elifs_false(self, testing_functions, selector, arg, expected) """Test the conditional with true, elifs, and false branches.""" true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 = testing_functions - result = qml.cond( - selector > 0, - true_fn, - false_fn, - elifs=( - (selector == -1, elif_fn1), - (selector == -2, elif_fn2), - (selector == -3, elif_fn3), - (selector == -4, elif_fn4), - ), - )(arg) + def test_func(pred): + return qml.cond( + pred > 0, + true_fn, + false_fn, + elifs=( + (pred == -1, elif_fn1), + (pred == -2, elif_fn2), + (pred == -3, elif_fn3), + (pred == -4, elif_fn4), + ), + ) + + result = test_func(selector)(arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + jaxpr = jax.make_jaxpr(test_func(selector))(arg) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + # Note that this would fail by running the abstract evaluation of the jaxpr + # because the false branch must be provided if the true branch returns any variables. @pytest.mark.parametrize( "selector, arg, expected", [ @@ -127,13 +136,22 @@ def test_cond_true_false(self, testing_functions, selector, arg, expected): """Test the conditional with true and false branches.""" true_fn, false_fn, _, _, _, _ = testing_functions - result = qml.cond( - selector > 0, - true_fn, - false_fn, - )(arg) + def test_func(pred): + return qml.cond( + pred > 0, + true_fn, + false_fn, + ) + + result = test_func(selector)(arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + jaxpr = jax.make_jaxpr(test_func(selector))(arg) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + # Note that this would fail by running the abstract evaluation of the jaxpr + # because the false branch must be provided if the true branch returns any variables. @pytest.mark.parametrize( "selector, arg, expected", [ @@ -162,14 +180,25 @@ def test_cond_with_jax_array(self, selector, arg, expected): """Test the conditional with array arguments.""" def true_fn(jax_array): - return jax_array[0] * jax_array[1] * 2 + return jax_array[0] * jax_array[1] * 2.0 def false_fn(jax_array): return jax_array[0] * jax_array[1] * 2.5 - result = qml.cond(selector > 0, true_fn, false_fn)(arg) + def test_func(pred): + return qml.cond( + pred > 0, + true_fn, + false_fn, + ) + + result = test_func(selector)(arg) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + jaxpr = jax.make_jaxpr(test_func(selector))(arg) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + class TestCondReturns: """Tests for validating the number and types of output variables in conditional functions.""" @@ -415,6 +444,11 @@ def test_circuit(self, pred, expected): result = circuit(pred) assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + args = [pred] + jaxpr = jax.make_jaxpr(circuit)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + @pytest.mark.parametrize( "pred, arg1, arg2, expected", [