diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 3af8b773be9..da4fa4994f5 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -71,8 +71,10 @@ [(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919) * Applying `adjoint` and `ctrl` to a quantum function can now be captured into plxpr. + Furthermore, the `qml.cond` function can be captured into plxpr. [(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966) [(#5967)](https://github.com/PennyLaneAI/pennylane/pull/5967) + [(#5999)](https://github.com/PennyLaneAI/pennylane/pull/5999) * Set operations are now supported by Wires. [(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983) diff --git a/pennylane/math/utils.py b/pennylane/math/utils.py index 9c42d037466..a03954b3746 100644 --- a/pennylane/math/utils.py +++ b/pennylane/math/utils.py @@ -227,7 +227,7 @@ def get_interface(*values): # contains autograd and another interface warnings.warn( f"Contains tensors of types {non_numpy_scipy_interfaces}; dispatch will prioritize " - "TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.", + "TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.", UserWarning, ) diff --git a/pennylane/ops/op_math/condition.py b/pennylane/ops/op_math/condition.py index d3da2814719..7df1bcb54f6 100644 --- a/pennylane/ops/op_math/condition.py +++ b/pennylane/ops/op_math/condition.py @@ -14,9 +14,11 @@ """ Contains the condition transform. """ +import functools from functools import wraps -from typing import Type +from typing import Callable, Optional, Type +import pennylane as qml from pennylane import QueuingManager from pennylane.compiler import compiler from pennylane.operation import AnyWires, Operation, Operator @@ -100,7 +102,7 @@ def adjoint(self): return Conditional(self.meas_val, self.base.adjoint()) -def cond(condition, true_fn, false_fn=None, elifs=()): +def cond(condition, true_fn: Callable, false_fn: Optional[Callable] = None, elifs=()): """Quantum-compatible if-else conditionals --- condition quantum operations on parameters such as the results of mid-circuit qubit measurements. @@ -120,12 +122,22 @@ def cond(condition, true_fn, false_fn=None, elifs=()): apply the :func:`defer_measurements` transform. .. note:: + When used with :func:`~.qjit`, this function only supports the Catalyst compiler. See :func:`catalyst.cond` for more details. Please see the Catalyst :doc:`quickstart guide `, as well as the :doc:`sharp bits and debugging tips `. + .. note:: + + When used with :func:`~.pennylane.capture.enabled`, this function allows for general + if-elif-else constructs. As with the JIT mode, all branches are captured, + with the executed branch determined at runtime. + + Each branch can receive arguments, but the arguments must be JAX-compatible. + If a branch returns one or more variables, every other branch must return the same abstract values. + Args: condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when @@ -364,6 +376,7 @@ def qnode(a, x, y, z): >>> qnode(par, x, y, z) tensor(-0.30922805, requires_grad=True) """ + if active_jit := compiler.active_compiler(): available_eps = compiler.AvailableCompilers.names_entrypoints ops_loader = available_eps[active_jit]["ops"].load() @@ -379,6 +392,9 @@ def qnode(a, x, y, z): return cond_func + if qml.capture.enabled(): + return _capture_cond(condition, true_fn, false_fn, elifs) + if elifs: raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.") @@ -430,3 +446,124 @@ def wrapper(*args, **kwargs): ) return wrapper + + +def _validate_abstract_values( + outvals: list, expected_outvals: list, branch_type: str, index: int = None +) -> None: + """Ensure the collected abstract values match the expected ones.""" + + if len(outvals) != len(expected_outvals): + raise ValueError( + f"Mismatch in number of output variables in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)}: " + f"{len(outvals)} vs {len(expected_outvals)}" + ) + + for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)): + if outval != expected_outval: + raise ValueError( + f"Mismatch in output abstract values in {branch_type} branch" + f"{'' if index is None else ' #' + str(index)} at position {i}: " + f"{outval} vs {expected_outval}" + ) + + +@functools.lru_cache +def _get_cond_qfunc_prim(): + """Get the cond primitive for quantum functions.""" + + import jax # pylint: disable=import-outside-toplevel + + cond_prim = jax.core.Primitive("cond") + cond_prim.multiple_results = True + + @cond_prim.def_impl + def _(conditions, *args_and_consts, jaxpr_branches, n_consts_per_branch, n_args): + + args = args_and_consts[:n_args] + consts_flat = args_and_consts[n_args:] + + start = 0 + for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch): + consts = consts_flat[start : start + n_consts] + start += n_consts + if pred and jaxpr is not None: + return jax.core.eval_jaxpr(jaxpr.jaxpr, consts, *args) + + return () + + @cond_prim.def_abstract_eval + def _(*_, jaxpr_branches, **__): + + outvals_true = jaxpr_branches[0].out_avals + + for idx, jaxpr_branch in enumerate(jaxpr_branches): + if idx == 0: + continue + + if jaxpr_branch is None: + if outvals_true: + raise ValueError( + "The false branch must be provided if the true branch returns any variables" + ) + # this is tested, but coverage does not pick it up + continue # pragma: no cover + + outvals_branch = jaxpr_branch.out_avals + branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false" + _validate_abstract_values(outvals_branch, outvals_true, branch_type, idx - 1) + + # We return the abstract values of the true branch since the abstract values + # of the other branches (if they exist) should be the same + return outvals_true + + return cond_prim + + +def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable: + """Capture compatible way to apply conditionals.""" + + import jax # pylint: disable=import-outside-toplevel + + cond_prim = _get_cond_qfunc_prim() + + elifs = (elifs,) if len(elifs) > 0 and not isinstance(elifs[0], tuple) else elifs + + @wraps(true_fn) + def new_wrapper(*args, **kwargs): + + jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args) + jaxpr_false = ( + jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None + ) + + # We extract each condition (or predicate) from the elifs argument list + # since these are traced by JAX and are passed as positional arguments to the primitive + elifs_conditions = [] + jaxpr_elifs = [] + + for pred, elif_fn in elifs: + elifs_conditions.append(pred) + jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args)) + + conditions = jax.numpy.array([condition, *elifs_conditions, True]) + + jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false] + jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches] + + # We need to flatten the constants since JAX does not allow + # to pass lists as positional arguments + consts_flat = [const for sublist in jaxpr_consts for const in sublist] + n_consts_per_branch = [len(consts) for consts in jaxpr_consts] + + return cond_prim.bind( + conditions, + *args, + *consts_flat, + jaxpr_branches=jaxpr_branches, + n_consts_per_branch=n_consts_per_branch, + n_args=len(args), + ) + + return new_wrapper diff --git a/tests/capture/test_capture_cond.py b/tests/capture/test_capture_cond.py new file mode 100644 index 00000000000..a6ddfa63dce --- /dev/null +++ b/tests/capture/test_capture_cond.py @@ -0,0 +1,520 @@ +# Copyright 2018-2024 Xanadu Quantum Technologies Inc. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Tests for capturing conditionals into jaxpr. +""" + +# pylint: disable=redefined-outer-name +# pylint: disable=no-self-use + +import numpy as np +import pytest + +import pennylane as qml +from pennylane.ops.op_math.condition import _capture_cond + +pytestmark = pytest.mark.jax + +jax = pytest.importorskip("jax") + + +@pytest.fixture(autouse=True) +def enable_disable_plxpr(): + """Enable and disable the PennyLane JAX capture context manager.""" + qml.capture.enable() + yield + qml.capture.disable() + + +@pytest.fixture +def testing_functions(): + """Returns a set of functions for testing.""" + + def true_fn(arg): + return 2 * arg + + def elif_fn1(arg): + return arg - 1 + + def elif_fn2(arg): + return arg - 2 + + def elif_fn3(arg): + return arg - 3 + + def elif_fn4(arg): + return arg - 4 + + def false_fn(arg): + return 3 * arg + + return true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 + + +class TestCond: + """Tests for conditional functions using qml.cond.""" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, 7), # Elif condition 3 + (-4, 10, 6), # Elif condition 4 + (0, 10, 30), # False condition + ], + ) + def test_cond_true_elifs_false(self, testing_functions, selector, arg, expected): + """Test the conditional with true, elifs, and false branches.""" + true_fn, false_fn, elif_fn1, elif_fn2, elif_fn3, elif_fn4 = testing_functions + + def test_func(pred): + return qml.cond( + pred > 0, + true_fn, + false_fn, + elifs=( + (pred == -1, elif_fn1), + (pred == -2, elif_fn2), + (pred == -3, elif_fn3), + (pred == -4, elif_fn4), + ), + ) + + result = test_func(selector)(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + jaxpr = jax.make_jaxpr(test_func(selector))(arg) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + # Note that this would fail by running the abstract evaluation of the jaxpr + # because the false branch must be provided if the true branch returns any variables. + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), # True condition + (-1, 10, 9), # Elif condition 1 + (-2, 10, 8), # Elif condition 2 + (-3, 10, ()), # No condition met + ], + ) + def test_cond_true_elifs(self, testing_functions, selector, arg, expected): + """Test the conditional with true and elifs branches.""" + true_fn, _, elif_fn1, elif_fn2, _, _ = testing_functions + + result = qml.cond( + selector > 0, + true_fn, + elifs=( + (selector == -1, elif_fn1), + (selector == -2, elif_fn2), + ), + )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), + (0, 10, 30), + ], + ) + def test_cond_true_false(self, testing_functions, selector, arg, expected): + """Test the conditional with true and false branches.""" + true_fn, false_fn, _, _, _, _ = testing_functions + + def test_func(pred): + return qml.cond( + pred > 0, + true_fn, + false_fn, + ) + + result = test_func(selector)(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + jaxpr = jax.make_jaxpr(test_func(selector))(arg) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + # Note that this would fail by running the abstract evaluation of the jaxpr + # because the false branch must be provided if the true branch returns any variables. + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, 10, 20), + (0, 10, ()), + ], + ) + def test_cond_true(self, testing_functions, selector, arg, expected): + """Test the conditional with only the true branch.""" + true_fn, _, _, _, _, _ = testing_functions + + result = qml.cond( + selector > 0, + true_fn, + )(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + @pytest.mark.parametrize( + "selector, arg, expected", + [ + (1, jax.numpy.array([2, 3]), 12), + (0, jax.numpy.array([2, 3]), 15), + ], + ) + def test_cond_with_jax_array(self, selector, arg, expected): + """Test the conditional with array arguments.""" + + def true_fn(jax_array): + return jax_array[0] * jax_array[1] * 2.0 + + def false_fn(jax_array): + return jax_array[0] * jax_array[1] * 2.5 + + def test_func(pred): + return qml.cond( + pred > 0, + true_fn, + false_fn, + ) + + result = test_func(selector)(arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + jaxpr = jax.make_jaxpr(test_func(selector))(arg) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, arg) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + +class TestCondReturns: + """Tests for validating the number and types of output variables in conditional functions.""" + + @pytest.mark.parametrize( + "true_fn, false_fn, expected_error, match", + [ + ( + lambda x: (x + 1, x + 2), + lambda x: None, + ValueError, + r"Mismatch in number of output variables", + ), + ( + lambda x: (x + 1, x + 2), + lambda x: (x + 1,), + ValueError, + r"Mismatch in number of output variables", + ), + ( + lambda x: (x + 1, x + 2), + lambda x: (x + 1, x + 2.0), + ValueError, + r"Mismatch in output abstract values", + ), + ], + ) + def test_validate_mismatches(self, true_fn, false_fn, expected_error, match): + """Test mismatch in number and type of output variables.""" + with pytest.raises(expected_error, match=match): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + def test_validate_number_of_output_variables(self): + """Test mismatch in number of output variables.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1 + + with pytest.raises(ValueError, match=r"Mismatch in number of output variables"): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + def test_validate_output_variable_types(self): + """Test mismatch in output variable types.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1, x + 2.0 + + with pytest.raises(ValueError, match=r"Mismatch in output abstract values"): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn))(jax.numpy.array(1)) + + def test_validate_no_false_branch_with_return(self): + """Test no false branch provided with return variables.""" + + def true_fn(x): + return x + 1, x + 2 + + with pytest.raises( + ValueError, + match=r"The false branch must be provided if the true branch returns any variables", + ): + jax.make_jaxpr(_capture_cond(True, true_fn))(jax.numpy.array(1)) + + def test_validate_no_false_branch_with_return_2(self): + """Test no false branch provided with return variables.""" + + def true_fn(x): + return x + 1, x + 2 + + def elif_fn(x): + return x + 1, x + 2 + + with pytest.raises( + ValueError, + match=r"The false branch must be provided if the true branch returns any variables", + ): + jax.make_jaxpr(_capture_cond(True, true_fn, false_fn=None, elifs=(False, elif_fn)))( + jax.numpy.array(1) + ) + + def test_validate_elif_branches(self): + """Test elif branch mismatches.""" + + def true_fn(x): + return x + 1, x + 2 + + def false_fn(x): + return x + 1, x + 2 + + def elif_fn1(x): + return x + 1, x + 2 + + def elif_fn2(x): + return x + 1, x + 2.0 + + def elif_fn3(x): + return x + 1 + + with pytest.raises( + ValueError, match=r"Mismatch in output abstract values in elif branch #1" + ): + jax.make_jaxpr( + _capture_cond(False, true_fn, false_fn, [(True, elif_fn1), (False, elif_fn2)]) + )(jax.numpy.array(1)) + + with pytest.raises( + ValueError, match=r"Mismatch in number of output variables in elif branch #0" + ): + jax.make_jaxpr(_capture_cond(False, true_fn, false_fn, [(True, elif_fn3)]))( + jax.numpy.array(1) + ) + + +dev = qml.device("default.qubit", wires=3) + + +@qml.qnode(dev) +def circuit(pred): + """Quantum circuit with only a true branch.""" + + def true_fn(): + qml.RX(0.1, wires=0) + + qml.cond(pred > 0, true_fn)() + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def circuit_branches(pred, arg1, arg2): + """Quantum circuit with conditional branches.""" + + qml.RX(0.10, wires=0) + + def true_fn(arg1, arg2): + qml.RY(arg1, wires=0) + qml.RX(arg2, wires=0) + qml.RZ(arg1, wires=0) + + def false_fn(arg1, arg2): + qml.RX(arg1, wires=0) + qml.RX(arg2, wires=0) + + def elif_fn1(arg1, arg2): + qml.RZ(arg2, wires=0) + qml.RX(arg1, wires=0) + + qml.cond(pred > 0, true_fn, false_fn, elifs=(pred == -1, elif_fn1))(arg1, arg2) + qml.RX(0.10, wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def circuit_with_returned_operator(pred, arg1, arg2): + """Quantum circuit with conditional branches that return operators.""" + + qml.RX(0.10, wires=0) + + def true_fn(arg1, arg2): + qml.RY(arg1, wires=0) + return 7, 4.6, qml.RY(arg2, wires=0), True + + def false_fn(arg1, arg2): + qml.RZ(arg2, wires=0) + return 2, 2.2, qml.RZ(arg1, wires=0), False + + qml.cond(pred > 0, true_fn, false_fn)(arg1, arg2) + qml.RX(0.10, wires=0) + return qml.expval(qml.PauliZ(wires=0)) + + +@qml.qnode(dev) +def circuit_multiple_cond(tmp_pred, tmp_arg): + """Quantum circuit with multiple dynamic conditional branches.""" + + dyn_pred_1 = tmp_pred > 0 + arg = tmp_arg + + def true_fn_1(arg): + return True, qml.RX(arg, wires=0) + + # pylint: disable=unused-argument + def false_fn_1(arg): + return False, qml.RY(0.1, wires=0) + + def true_fn_2(arg): + return qml.RX(arg, wires=0) + + # pylint: disable=unused-argument + def false_fn_2(arg): + return qml.RY(0.1, wires=0) + + [dyn_pred_2, _] = qml.cond(dyn_pred_1, true_fn_1, false_fn_1, elifs=())(arg) + qml.cond(dyn_pred_2, true_fn_2, false_fn_2, elifs=())(arg) + return qml.expval(qml.Z(0)) + + +@qml.qnode(dev) +def circuit_with_consts(pred, arg): + """Quantum circuit with jaxpr constants.""" + + # these are captured as consts + arg1 = arg + arg2 = arg + 0.2 + arg3 = arg + 0.3 + arg4 = arg + 0.4 + arg5 = arg + 0.5 + arg6 = arg + 0.6 + + def true_fn(): + qml.RX(arg1, 0) + + def false_fn(): + qml.RX(arg2, 0) + qml.RX(arg3, 0) + + def elif_fn1(): + qml.RX(arg4, 0) + qml.RX(arg5, 0) + qml.RX(arg6, 0) + + qml.cond(pred > 0, true_fn, false_fn, elifs=((pred == 0, elif_fn1),))() + + return qml.expval(qml.Z(0)) + + +class TestCondCircuits: + """Tests for conditional quantum circuits.""" + + @pytest.mark.parametrize( + "pred, expected", + [ + (1, 0.99500417), # RX(0.1) + (0, 1.0), # No operation + ], + ) + def test_circuit(self, pred, expected): + """Test circuit with only a true branch.""" + result = circuit(pred) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + args = [pred] + jaxpr = jax.make_jaxpr(circuit)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + @pytest.mark.parametrize( + "pred, arg1, arg2, expected", + [ + (1, 0.5, 0.6, 0.63340907), # RX(0.10) -> RY(0.5) -> RX(0.6) -> RZ(0.5) -> RX(0.10) + (0, 0.5, 0.6, 0.26749883), # RX(0.10) -> RX(0.5) -> RX(0.6) -> RX(0.10) + (-1, 0.5, 0.6, 0.77468805), # RX(0.10) -> RZ(0.6) -> RX(0.5) -> RX(0.10) + ], + ) + def test_circuit_branches(self, pred, arg1, arg2, expected): + """Test circuit with true, false, and elif branches.""" + result = circuit_branches(pred, arg1, arg2) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + args = [pred, arg1, arg2] + jaxpr = jax.make_jaxpr(circuit_branches)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + @pytest.mark.parametrize( + "pred, arg1, arg2, expected", + [ + (1, 0.5, 0.6, 0.43910855), # RX(0.10) -> RY(0.5) -> RY(0.6) -> RX(0.10) + (0, 0.5, 0.6, 0.98551243), # RX(0.10) -> RZ(0.6) -> RX(0.5) -> RX(0.10) + ], + ) + def test_circuit_with_returned_operator(self, pred, arg1, arg2, expected): + """Test circuit with returned operators in the branches.""" + result = circuit_with_returned_operator(pred, arg1, arg2) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + args = [pred, arg1, arg2] + jaxpr = jax.make_jaxpr(circuit_with_returned_operator)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + @pytest.mark.parametrize( + "tmp_pred, tmp_arg, expected", + [ + (1, 0.5, 0.54030231), # RX(0.5) -> RX(0.5) + (-1, 0.5, 0.98006658), # RY(0.1) -> RY(0.1) + ], + ) + def test_circuit_multiple_cond(self, tmp_pred, tmp_arg, expected): + """Test circuit with returned operators in the branches.""" + result = circuit_multiple_cond(tmp_pred, tmp_arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + args = [tmp_pred, tmp_arg] + jaxpr = jax.make_jaxpr(circuit_multiple_cond)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}" + + @pytest.mark.parametrize( + "pred, arg, expected", + [ + (1, 0.5, 0.87758256), # RX(0.5) + (-1, 0.5, 0.0707372), # RX(0.7) -> RX(0.8) + (0, 0.5, -0.9899925), # RX(0.9) -> RX(1.0) -> RX(1.1) + ], + ) + def test_circuit_consts(self, pred, arg, expected): + """Test circuit with jaxpr constants.""" + result = circuit_with_consts(pred, arg) + assert np.allclose(result, expected), f"Expected {expected}, but got {result}" + + args = [pred, arg] + jaxpr = jax.make_jaxpr(circuit_with_consts)(*args) + res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args) + assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"