Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Users can capture conditionals using qml.cond #5999

Merged
merged 45 commits into from
Jul 31, 2024
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
5db098b
E.C. [ci skip]
PietropaoloFrisoni Jul 15, 2024
50ca82d
just doubts so far [ci skip]
PietropaoloFrisoni Jul 17, 2024
cf46ee1
providing `elifs` causes an error at this stage (I don't know why yet…
PietropaoloFrisoni Jul 17, 2024
60cdcd2
not skipping the CI
PietropaoloFrisoni Jul 17, 2024
c425896
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 17, 2024
cafd846
[ci skip] elifs are still not implemented correctly, but hopefully we…
PietropaoloFrisoni Jul 17, 2024
52b684f
The main problem right now is that all the `elifs` condition after th…
PietropaoloFrisoni Jul 18, 2024
e6af28a
Fixing multiple elifs issue [ci skip]
PietropaoloFrisoni Jul 18, 2024
656b4f7
Removing usage of jax.lax.cond [ci skip]
PietropaoloFrisoni Jul 19, 2024
0f81364
Undersanding how to handle dynamic tracer [ci skip]
PietropaoloFrisoni Jul 20, 2024
399c8be
Solved dynamic inconsistent behavior [ci skip]
PietropaoloFrisoni Jul 22, 2024
1d95f90
Improving code style and removing debug msgs [ci skip]
PietropaoloFrisoni Jul 22, 2024
f1de555
Improving code style and removing debug msgs [ci skip]
PietropaoloFrisoni Jul 22, 2024
d9b8422
TODO: add tests and check for more than one operator in the queue [ci…
PietropaoloFrisoni Jul 22, 2024
24ee3b4
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 22, 2024
91c3e0b
Adding tests for logic of `qml.cond` with captured enabled
PietropaoloFrisoni Jul 23, 2024
9317509
Re-naming file
PietropaoloFrisoni Jul 23, 2024
b92e674
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jul 23, 2024
2a7bbb3
removing useless for loop [ci skip]
PietropaoloFrisoni Jul 23, 2024
f95a18e
removing useless for loop [ci skip]
PietropaoloFrisoni Jul 23, 2024
fd76361
changed variable name [ci skip]
PietropaoloFrisoni Jul 23, 2024
f9ad42a
implemented abstract definition [ci skip]
PietropaoloFrisoni Jul 23, 2024
e8fef5d
Adding tests to catch errors
PietropaoloFrisoni Jul 24, 2024
4e9fe78
Removed import above skipping
PietropaoloFrisoni Jul 24, 2024
6a56972
Adding a few more tests [ci skip]
PietropaoloFrisoni Jul 24, 2024
9d50198
Adding more unit tests
PietropaoloFrisoni Jul 24, 2024
ab9e33b
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 24, 2024
61c059a
Adding test and re-naming file
PietropaoloFrisoni Jul 24, 2024
29b9c02
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 24, 2024
feadf30
Doc fix
PietropaoloFrisoni Jul 25, 2024
2cb2538
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 25, 2024
b8f80fc
Suggestions from code review
PietropaoloFrisoni Jul 25, 2024
c0c4c72
Docstring clarification
PietropaoloFrisoni Jul 26, 2024
4c57aa7
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jul 26, 2024
75687ba
Suggestions from code review
PietropaoloFrisoni Jul 26, 2024
4cee045
Adding test with multiple cond
PietropaoloFrisoni Jul 29, 2024
64ebe82
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jul 29, 2024
37862ee
Capturing jaxpr.consts and passing them as positional arguments
PietropaoloFrisoni Jul 29, 2024
121b885
Adding test to cover line (usual codecov stuff)
PietropaoloFrisoni Jul 29, 2024
6d8e2eb
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 29, 2024
ee7ea52
Codecov is buggy
PietropaoloFrisoni Jul 29, 2024
a9b7010
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
PietropaoloFrisoni Jul 30, 2024
d5e823b
Suggestions from code review
PietropaoloFrisoni Jul 30, 2024
046eab4
Adding more tests
PietropaoloFrisoni Jul 31, 2024
db95c9b
Merge branch 'master' into capture_qml_cond
PietropaoloFrisoni Jul 31, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@
[(#5919)](https://github.com/PennyLaneAI/pennylane/pull/5919)

* Applying `adjoint` and `ctrl` to a quantum function can now be captured into plxpr.
Furthermore, the `qml.cond` function can be captured into plxpr.
[(#5966)](https://github.com/PennyLaneAI/pennylane/pull/5966)
[(#5967)](https://github.com/PennyLaneAI/pennylane/pull/5967)
[(#5999)](https://github.com/PennyLaneAI/pennylane/pull/5999)

* Set operations are now supported by Wires.
[(#5983)](https://github.com/PennyLaneAI/pennylane/pull/5983)
Expand Down
2 changes: 1 addition & 1 deletion pennylane/math/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,7 @@ def get_interface(*values):
# contains autograd and another interface
warnings.warn(
f"Contains tensors of types {non_numpy_scipy_interfaces}; dispatch will prioritize "
"TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.",
"TensorFlow, PyTorch, and Jax over Autograd. Consider replacing Autograd with vanilla NumPy.",
UserWarning,
)

Expand Down
141 changes: 139 additions & 2 deletions pennylane/ops/op_math/condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,11 @@
"""
Contains the condition transform.
"""
import functools
from functools import wraps
from typing import Type
from typing import Callable, Optional, Type

import pennylane as qml
from pennylane import QueuingManager
from pennylane.compiler import compiler
from pennylane.operation import AnyWires, Operation, Operator
Expand Down Expand Up @@ -100,7 +102,7 @@ def adjoint(self):
return Conditional(self.meas_val, self.base.adjoint())


def cond(condition, true_fn, false_fn=None, elifs=()):
def cond(condition, true_fn: Callable, false_fn: Optional[Callable] = None, elifs=()):
"""Quantum-compatible if-else conditionals --- condition quantum operations
on parameters such as the results of mid-circuit qubit measurements.

Expand All @@ -120,12 +122,22 @@ def cond(condition, true_fn, false_fn=None, elifs=()):
apply the :func:`defer_measurements` transform.

.. note::

When used with :func:`~.qjit`, this function only supports
the Catalyst compiler. See :func:`catalyst.cond` for more details.

Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`,
as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`.

.. note::

When used with :func:`~.pennylane.capture.enabled`, this function allows for general
if-elif-else constructs. As with the JIT mode, all branches are captured,
with the executed branch determined at runtime.

Each branch can receive arguments, but the arguments must be JAX-compatible.
If a branch returns one or more variables, every other branch must return the same abstract values.

Args:
condition (Union[.MeasurementValue, bool]): a conditional expression involving a mid-circuit
measurement value (see :func:`.pennylane.measure`). This can only be of type ``bool`` when
Expand Down Expand Up @@ -364,6 +376,7 @@ def qnode(a, x, y, z):
>>> qnode(par, x, y, z)
tensor(-0.30922805, requires_grad=True)
"""

if active_jit := compiler.active_compiler():
available_eps = compiler.AvailableCompilers.names_entrypoints
ops_loader = available_eps[active_jit]["ops"].load()
Expand All @@ -379,6 +392,9 @@ def qnode(a, x, y, z):

return cond_func

if qml.capture.enabled():
return _capture_cond(condition, true_fn, false_fn, elifs)

if elifs:
raise ConditionalTransformError("'elif' branches are not supported in interpreted mode.")

Expand Down Expand Up @@ -430,3 +446,124 @@ def wrapper(*args, **kwargs):
)

return wrapper


def _validate_abstract_values(
outvals: list, expected_outvals: list, branch_type: str, index: int = None
) -> None:
"""Ensure the collected abstract values match the expected ones."""

if len(outvals) != len(expected_outvals):
raise ValueError(
f"Mismatch in number of output variables in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)}: "
f"{len(outvals)} vs {len(expected_outvals)}"
)

for i, (outval, expected_outval) in enumerate(zip(outvals, expected_outvals)):
if outval != expected_outval:
raise ValueError(
f"Mismatch in output abstract values in {branch_type} branch"
f"{'' if index is None else ' #' + str(index)} at position {i}: "
f"{outval} vs {expected_outval}"
)


@functools.lru_cache
def _get_cond_qfunc_prim():
"""Get the cond primitive for quantum functions."""

import jax # pylint: disable=import-outside-toplevel

cond_prim = jax.core.Primitive("cond")
cond_prim.multiple_results = True

@cond_prim.def_impl
def _(conditions, *args_and_consts, jaxpr_branches, n_consts_per_branch, n_args):

args = args_and_consts[:n_args]
consts_flat = args_and_consts[n_args:]

start = 0
for pred, jaxpr, n_consts in zip(conditions, jaxpr_branches, n_consts_per_branch):
consts = consts_flat[start : start + n_consts]
start += n_consts
if pred and jaxpr is not None:
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
return jax.core.eval_jaxpr(jaxpr.jaxpr, consts, *args)

return ()

@cond_prim.def_abstract_eval
def _(*_, jaxpr_branches, **__):

outvals_true = jaxpr_branches[0].out_avals

for idx, jaxpr_branch in enumerate(jaxpr_branches):
if idx == 0:
continue

if jaxpr_branch is None:
if outvals_true:
raise ValueError(
"The false branch must be provided if the true branch returns any variables"
)
# this is tested, but coverage does not pick it up
continue # pragma: no cover

outvals_branch = jaxpr_branch.out_avals
branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false"
_validate_abstract_values(outvals_branch, outvals_true, branch_type, idx - 1)

# We return the abstract values of the true branch since the abstract values
# of the other branches (if they exist) should be the same
return outvals_true

return cond_prim


def _capture_cond(condition, true_fn, false_fn=None, elifs=()) -> Callable:
"""Capture compatible way to apply conditionals."""

import jax # pylint: disable=import-outside-toplevel

cond_prim = _get_cond_qfunc_prim()

elifs = (elifs,) if len(elifs) > 0 and not isinstance(elifs[0], tuple) else elifs
dime10 marked this conversation as resolved.
Show resolved Hide resolved

@wraps(true_fn)
def new_wrapper(*args, **kwargs):

jaxpr_true = jax.make_jaxpr(functools.partial(true_fn, **kwargs))(*args)
jaxpr_false = (
jax.make_jaxpr(functools.partial(false_fn, **kwargs))(*args) if false_fn else None
)

# We extract each condition (or predicate) from the elifs argument list
# since these are traced by JAX and are passed as positional arguments to the primitive
elifs_conditions = []
jaxpr_elifs = []

for pred, elif_fn in elifs:
elifs_conditions.append(pred)
jaxpr_elifs.append(jax.make_jaxpr(functools.partial(elif_fn, **kwargs))(*args))

conditions = jax.numpy.array([condition, *elifs_conditions, True])

jaxpr_branches = [jaxpr_true, *jaxpr_elifs, jaxpr_false]
jaxpr_consts = [jaxpr.consts if jaxpr is not None else () for jaxpr in jaxpr_branches]

# We need to flatten the constants since JAX does not allow
# to pass lists as positional arguments
consts_flat = [const for sublist in jaxpr_consts for const in sublist]
n_consts_per_branch = [len(consts) for consts in jaxpr_consts]

return cond_prim.bind(
conditions,
*args,
*consts_flat,
jaxpr_branches=jaxpr_branches,
n_consts_per_branch=n_consts_per_branch,
n_args=len(args),
)

return new_wrapper
Loading
Loading