Skip to content

Commit

Permalink
Merge branch 'master' into remove-shadow-expval
Browse files Browse the repository at this point in the history
  • Loading branch information
albi3ro authored Nov 8, 2024
2 parents cf4e662 + 0d497ec commit 1d45225
Show file tree
Hide file tree
Showing 34 changed files with 962 additions and 197 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check_in_artifact.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ jobs:
run: |
git checkout stash -- .
git add ${{ inputs.artifact_save_path }}
git commit -m "Check in artifacts$COMMIT_MESSAGE_DESCRIPTION"
git commit --allow-empty -m "Check in artifacts$COMMIT_MESSAGE_DESCRIPTION"
git push -f --set-upstream origin "$HEAD_BRANCH_NAME"
# Create PR to master
Expand Down
Binary file added doc/_static/draw_mpl/per_wire_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added doc/_static/tape_mpl/per_wire_options.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 1 addition & 1 deletion doc/code/qml_drawer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ Currently Available Styles
+|pls|+|plw|+|skd|+
+-----+-----+-----+
+|sol|+|sod|+|def|+
+-----+-----+-----+
+-----+-----+-----+
11 changes: 11 additions & 0 deletions doc/introduction/interfaces.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,17 @@ a :class:`QNode <pennylane.QNode>`, e.g.,
If no interface is specified, PennyLane will automatically determine the interface based on provided arguments and keyword arguments.
See ``qml.workflow.SUPPORTED_INTERFACES`` for a list of all accepted interface strings.

.. warning::

``ComplexWarning`` messages may appear when running differentiable workflows involving both complex and float types, particularly
with certain interfaces. These warnings are common in backpropagation due to the nature of complex casting and do not
indicate an error in computation. If desired, you can suppress these warnings by adding the following code:

.. code-block:: python
import warnings
warnings.filterwarnings("ignore", category=np.ComplexWarning)
This will allow native numerical objects of the specified library (NumPy arrays, JAX arrays, Torch Tensors,
or TensorFlow Tensors) to be passed as parameters to the quantum circuit. It also makes
the gradients of the quantum circuit accessible to the classical library, enabling the
Expand Down
51 changes: 37 additions & 14 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Release 0.40.0-dev (development release)

<h3>New features since last release</h3>

* A `DeviceCapabilities` data class is defined to contain all capabilities of the device's execution interface (i.e. its implementation of `Device.execute`). A TOML file can be used to define the capabilities of a device, and it can be loaded into a `DeviceCapabilities` object.
[(#6407)](https://github.com/PennyLaneAI/pennylane/pull/6407)

Expand All @@ -15,25 +15,38 @@
True
```

<h4>New API for Qubit Mixed</h4>

* Added `qml.devices.qubit_mixed` module for mixed-state qubit device support [(#6379)](https://github.com/PennyLaneAI/pennylane/pull/6379). This module introduces an `apply_operation` helper function that features:


* Two density matrix contraction methods using `einsum` and `tensordot`

* Optimized handling of special cases including: Diagonal operators, Identity operators, CX (controlled-X), Multi-controlled X gates, Grover operators

* Added submodule 'initialize_state' featuring a `create_initial_state` function for initializing a density matrix from `qml.StatePrep` operations or `qml.QubitDensityMatrix` operations.
[(#6503)](https://github.com/PennyLaneAI/pennylane/pull/6503)

<h3>Improvements 🛠</h3>

<h4>Other Improvements</h4>
* Added support for the `wire_options` dictionary to customize wire line formatting in `qml.draw_mpl` circuit
visualizations, allowing global and per-wire customization with options like `color`, `linestyle`, and `linewidth`.
[(#6486)](https://github.com/PennyLaneAI/pennylane/pull/6486)

<h4>Capturing and representing hybrid programs</h4>

* Added `qml.devices.qubit_mixed` module for mixed-state qubit device support. This module introduces:
- A new API for mixed-state operations
- An `apply_operation` helper function featuring:
- Two density matrix contraction methods using `einsum` and `tensordot`
- Optimized handling of special cases including:
- Diagonal operators
- Identity operators
- CX (controlled-X)
- Multi-controlled X gates
- Grover operators
[(#6379)](https://github.com/PennyLaneAI/pennylane/pull/6379)
* `jax.vmap` can be captured with `qml.capture.make_plxpr` and is compatible with quantum circuits.
[(#6349)](https://github.com/PennyLaneAI/pennylane/pull/6349)

<h4>Other Improvements</h4>

* `qml.BasisRotation` template is now JIT compatible.
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)

* The Jaxpr primitives for `for_loop`, `while_loop` and `cond` now store slices instead of
numbers of args.
[(#6521)](https://github.com/PennyLaneAI/pennylane/pull/6521)

* Expand `ExecutionConfig.gradient_method` to store `TransformDispatcher` type.
[(#6455)](https://github.com/PennyLaneAI/pennylane/pull/6455)

Expand All @@ -47,11 +60,21 @@

<h3>Documentation 📝</h3>

* Add a warning message to Gradients and training documentation about ComplexWarnings
[(#6543)](https://github.com/PennyLaneAI/pennylane/pull/6543)

<h3>Bug fixes 🐛</h3>

* Fixed `Identity.__repr__` to return correct wires list.
[(#6506)](https://github.com/PennyLaneAI/pennylane/pull/6506)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Shiwen An
Astral Cai,
Andrija Paurevic
Yushao Chen,
Pietropaolo Frisoni,
Andrija Paurevic,
Justin Pickering
27 changes: 15 additions & 12 deletions pennylane/capture/capture_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,28 +25,31 @@


@lru_cache
def create_non_jvp_primitive():
"""Create a primitive type ``NonJVPPrimitive``, which binds to JAX's JVPTrace
like a standard Python function and otherwise behaves like jax.core.Primitive.
def create_non_interpreted_prim():
"""Create a primitive type ``NonInterpPrimitive``, which binds to JAX's JVPTrace
and BatchTrace objects like a standard Python function and otherwise behaves like jax.core.Primitive.
"""

if not has_jax: # pragma: no cover
return None

# pylint: disable=too-few-public-methods
class NonJVPPrimitive(jax.core.Primitive):
class NonInterpPrimitive(jax.core.Primitive):
"""A subclass to JAX's Primitive that works like a Python function
when evaluating JVPTracers."""
when evaluating JVPTracers and BatchTracers."""

def bind_with_trace(self, trace, args, params):
"""Bind the ``NonJVPPrimitive`` with a trace. If the trace is a ``JVPTrace``,
binding falls back to a standard Python function call. Otherwise, the
bind call of JAX's standard Primitive is used."""
if isinstance(trace, jax.interpreters.ad.JVPTrace):
"""Bind the ``NonInterpPrimitive`` with a trace.
If the trace is a ``JVPTrace``or a ``BatchTrace``, binding falls back to a standard Python function call.
Otherwise, the bind call of JAX's standard Primitive is used."""
if isinstance(
trace, (jax.interpreters.ad.JVPTrace, jax.interpreters.batching.BatchTrace)
):
return self.impl(*args, **params)
return super().bind_with_trace(trace, args, params)

return NonJVPPrimitive
return NonInterpPrimitive


@lru_cache
Expand All @@ -57,7 +60,7 @@ def _get_grad_prim():
if not has_jax: # pragma: no cover
return None

grad_prim = create_non_jvp_primitive()("grad")
grad_prim = create_non_interpreted_prim()("grad")
grad_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init

# pylint: disable=too-many-arguments
Expand Down Expand Up @@ -89,7 +92,7 @@ def _get_jacobian_prim():
"""Create a primitive for Jacobian computations.
This primitive is used when capturing ``qml.jacobian``.
"""
jacobian_prim = create_non_jvp_primitive()("jacobian")
jacobian_prim = create_non_interpreted_prim()("jacobian")
jacobian_prim.multiple_results = True # pylint: disable=attribute-defined-outside-init

# pylint: disable=too-many-arguments
Expand Down
4 changes: 2 additions & 2 deletions pennylane/capture/capture_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import pennylane as qml

from .capture_diff import create_non_jvp_primitive
from .capture_diff import create_non_interpreted_prim

has_jax = True
try:
Expand Down Expand Up @@ -103,7 +103,7 @@ def create_operator_primitive(
if not has_jax:
return None

primitive = create_non_jvp_primitive()(operator_type.__name__)
primitive = create_non_interpreted_prim()(operator_type.__name__)

@primitive.def_impl
def _(*args, **kwargs):
Expand Down
130 changes: 121 additions & 9 deletions pennylane/capture/capture_qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,65 @@
"""
This submodule defines a capture compatible call to QNodes.
"""

from copy import copy
from dataclasses import asdict
from functools import lru_cache, partial
from numbers import Number
from warnings import warn

import pennylane as qml
from pennylane.typing import TensorLike

from .flatfn import FlatFn

has_jax = True
try:
import jax
from jax.interpreters import ad
from jax.interpreters import ad, batching

except ImportError:
has_jax = False


def _get_shapes_for(*measurements, shots=None, num_device_wires=0):
def _is_scalar_tensor(arg) -> bool:
"""Check if an argument is a scalar tensor-like object or a numeric scalar."""

if isinstance(arg, Number):
return True

if isinstance(arg, TensorLike):

if arg.size == 0:
raise ValueError("Empty tensors are not supported with jax.vmap.")

if arg.shape == ():
return True

if len(arg.shape) > 1:
raise ValueError(
"One argument has more than one dimension. "
"Currently, only single-dimension batching is supported."
)

return False


def _get_batch_shape(args, batch_dims):
"""Calculate the batch shape for the given arguments and batch dimensions."""

if batch_dims is None:
return ()

input_shapes = [
(arg.shape[batch_dim],) for arg, batch_dim in zip(args, batch_dims) if batch_dim is not None
]

return jax.lax.broadcast_shapes(*input_shapes)


def _get_shapes_for(*measurements, shots=None, num_device_wires=0, batch_shape=()):
"""Calculate the abstract output shapes for the given measurements."""

if jax.config.jax_enable_x64: # pylint: disable=no-member
dtype_map = {
float: jax.numpy.float64,
Expand All @@ -53,7 +93,8 @@ def _get_shapes_for(*measurements, shots=None, num_device_wires=0):
for s in shots:
for m in measurements:
shape, dtype = m.aval.abstract_eval(shots=s, num_device_wires=num_device_wires)
shapes.append(jax.core.ShapedArray(shape, dtype_map.get(dtype, dtype)))
shapes.append(jax.core.ShapedArray(batch_shape + shape, dtype_map.get(dtype, dtype)))

return shapes


Expand All @@ -66,21 +107,88 @@ def _get_qnode_prim():

# pylint: disable=too-many-arguments
@qnode_prim.def_impl
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):
consts = args[:n_consts]
args = args[n_consts:]
non_const_args = args[n_consts:]

def qfunc(*inner_args):
return jax.core.eval_jaxpr(qfunc_jaxpr, consts, *inner_args)

qnode = qml.QNode(qfunc, device, **qnode_kwargs)
return qnode._impl_call(*args, shots=shots) # pylint: disable=protected-access

if batch_dims is not None:
# pylint: disable=protected-access
return jax.vmap(partial(qnode._impl_call, shots=shots), batch_dims)(*non_const_args)

# pylint: disable=protected-access
return qnode._impl_call(*non_const_args, shots=shots)

# pylint: disable=unused-argument
@qnode_prim.def_abstract_eval
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts):
def _(*args, qnode, shots, device, qnode_kwargs, qfunc_jaxpr, n_consts, batch_dims=None):

mps = qfunc_jaxpr.outvars
return _get_shapes_for(*mps, shots=shots, num_device_wires=len(device.wires))

return _get_shapes_for(
*mps,
shots=shots,
num_device_wires=len(device.wires),
batch_shape=_get_batch_shape(args[n_consts:], batch_dims),
)

def _qnode_batching_rule(
batched_args,
batch_dims,
qnode,
shots,
device,
qnode_kwargs,
qfunc_jaxpr,
n_consts,
):
"""
Batching rule for the ``qnode`` primitive.
This rule exploits the parameter broadcasting feature of the QNode to vectorize the circuit execution.
"""

for i, (arg, batch_dim) in enumerate(zip(batched_args, batch_dims)):

if _is_scalar_tensor(arg):
continue

# Regardless of their shape, jax.vmap treats constants as scalars
# by automatically inserting `None` as the batch dimension.
if i < n_consts:
raise ValueError(
f"Constant argument at index {i} is not scalar. ",
"Only scalar constants are currently supported with jax.vmap.",
)

# To resolve this, we need to add more properties to the AbstractOperator
# class to indicate which operators support batching and check them here
if arg.size > 1 and batch_dim is None:
warn(
f"Argument at index {i} has more than 1 element but is not batched. "
"This may lead to unintended behavior or wrong results if the argument is provided "
"using parameter broadcasting to a quantum operation that supports batching.",
UserWarning,
)

result = qnode_prim.bind(
*batched_args,
shots=shots,
qnode=qnode,
device=device,
qnode_kwargs=qnode_kwargs,
qfunc_jaxpr=qfunc_jaxpr,
n_consts=n_consts,
batch_dims=batch_dims[n_consts:],
)

# The batch dimension is at the front (axis 0) for all elements in the result.
# JAX doesn't expose `out_axes` in the batching rule.
return result, (0,) * len(result)

def make_zero(tan, arg):
return jax.lax.zeros_like_array(arg) if isinstance(tan, ad.Zero) else tan
Expand All @@ -91,6 +199,8 @@ def _qnode_jvp(args, tangents, **impl_kwargs):

ad.primitive_jvps[qnode_prim] = _qnode_jvp

batching.primitive_batchers[qnode_prim] = _qnode_batching_rule

return qnode_prim


Expand Down Expand Up @@ -173,12 +283,14 @@ def f(x):

flat_fn = FlatFn(qfunc)
qfunc_jaxpr = jax.make_jaxpr(flat_fn)(*args)

execute_kwargs = copy(qnode.execute_kwargs)
mcm_config = asdict(execute_kwargs.pop("mcm_config"))
qnode_kwargs = {"diff_method": qnode.diff_method, **execute_kwargs, **mcm_config}
qnode_prim = _get_qnode_prim()

flat_args = jax.tree_util.tree_leaves(args)

res = qnode_prim.bind(
*qfunc_jaxpr.consts,
*flat_args,
Expand Down
Loading

0 comments on commit 1d45225

Please sign in to comment.