Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mid-circuit measurement program capture #6015

Merged
merged 33 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
54e097b
[skip ci] Prototyping MCM program capture
mudit2812 Jul 18, 2024
8e83958
Remove measurement_value updates
mudit2812 Jul 18, 2024
d76dfe8
Remove old test
mudit2812 Jul 18, 2024
c0bd8bb
Merge branch 'master' into capture-mcm1
mudit2812 Jul 18, 2024
30c659f
Add empty test file
mudit2812 Jul 18, 2024
3250dd1
New implementation: AbstractMCM which inherits ShapedArray
mudit2812 Jul 23, 2024
8b3f3a3
[skip ci] Skip CI
mudit2812 Jul 23, 2024
0645d3e
Merge branch 'master' into capture-mcm1
mudit2812 Jul 23, 2024
d71a22c
[skip ci] Skip CI
mudit2812 Jul 23, 2024
50af05d
[skip ci] Tidying up
mudit2812 Jul 23, 2024
6444fc4
[skip ci] Trying out new prototype
mudit2812 Jul 24, 2024
ee60e1e
Terminal measurements can have MCMs
mudit2812 Jul 26, 2024
f8154f4
Remove prototype implementation
mudit2812 Jul 26, 2024
8251f03
Fix capture module __init__.py
mudit2812 Jul 26, 2024
3dbe1f3
Update capture measurement tests
mudit2812 Jul 26, 2024
41426f2
Tidying up
mudit2812 Jul 29, 2024
13be228
Merge branch 'master' into capture-mcm1
mudit2812 Jul 29, 2024
4300ae9
Add test skeleton
mudit2812 Jul 29, 2024
6412f24
Added unit tests
mudit2812 Jul 29, 2024
57f0a44
Merge branch 'master' into capture-mcm1
mudit2812 Jul 29, 2024
1c79739
Finished adding tests
mudit2812 Jul 29, 2024
21e1819
Merge branch 'master' into capture-mcm1
mudit2812 Jul 29, 2024
1846f56
Linting
mudit2812 Jul 29, 2024
27c59fe
Merge branch 'master' into capture-mcm1
mudit2812 Jul 30, 2024
f57738e
Update MidMeasureMP abstract eval
mudit2812 Jul 30, 2024
7401bdf
Merge branch 'master' into capture-mcm1
mudit2812 Jul 31, 2024
2d513c3
Update tests
mudit2812 Jul 31, 2024
6c13265
Update bool aval to int
mudit2812 Jul 31, 2024
7d39390
Update dtype computation
mudit2812 Jul 31, 2024
ec65ed9
Merge branch 'master' into capture-mcm1
mudit2812 Jul 31, 2024
3ad7ba0
Merge branch 'master' into capture-mcm1
mudit2812 Jul 31, 2024
41b04ca
Remove duplicate changelog entry
mudit2812 Aug 1, 2024
2959d55
Merge branch 'master' into capture-mcm1
mudit2812 Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,15 @@

<h3>New features since last release</h3>

* Mid-circuit measurements can now be captured with `qml.capture` enabled.
[(#6015)](https://github.com/PennyLaneAI/pennylane/pull/6015)

PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
* A new method `process_density_matrix` has been added to the `ProbabilityMP` and `DensityMatrixMP`
classes, allowing for more efficient handling of quantum density matrices, particularly with batch
processing support. This method simplifies the calculation of probabilities from quantum states
represented as density matrices.
[(#5830)](https://github.com/PennyLaneAI/pennylane/pull/5830)

* Resolved the bug in `qml.ThermalRelaxationError` where there was a typo from `tq` to `tg`.
[(#5988)](https://github.com/PennyLaneAI/pennylane/issues/5988)

* The `qml.PrepSelPrep` template is added. The template implements a block-encoding of a linear
combination of unitaries.
[(#5756)](https://github.com/PennyLaneAI/pennylane/pull/5756)
Expand Down Expand Up @@ -174,6 +174,9 @@

<h4>Community contributions 🥳</h4>

* Resolved the bug in `qml.ThermalRelaxationError` where there was a typo from `tq` to `tg`.
[(#5988)](https://github.com/PennyLaneAI/pennylane/issues/5988)

PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
* `DefaultQutritMixed` readout error has been added using parameters `readout_relaxation_probs` and
`readout_misclassification_probs` on the `default.qutrit.mixed` device. These parameters add a `~.QutritAmplitudeDamping` and a `~.TritFlip` channel, respectively,
after measurement diagonalization. The amplitude damping error represents the potential for
Expand Down
9 changes: 2 additions & 7 deletions pennylane/capture/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,13 +266,8 @@ def create_measurement_mcm_primitive(
primitive = jax.core.Primitive(name + "_mcm")

@primitive.def_impl
def _(*mcms, **kwargs):
raise NotImplementedError(
"mcm measurements do not currently have a concrete implementation"
)
# need to figure out how to convert a jaxpr int into a measurement value, and pass
# that measurment value here.
# return type.__call__(measurement_type, obs=mcms, **kwargs)
def _(*mcms, single_mcm=True, **kwargs):
return type.__call__(measurement_type, obs=mcms[0] if single_mcm else mcms, **kwargs)

abstract_type = _get_abstract_measurement()

Expand Down
4 changes: 2 additions & 2 deletions pennylane/measurements/measurements.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,8 +172,8 @@ def _primitive_bind_call(cls, obs=None, wires=None, eigvals=None, id=None, **kwa
):
return cls._obs_primitive.bind(obs, **kwargs)
if isinstance(obs, (list, tuple)):
return cls._mcm_primitive.bind(*obs, **kwargs) # iterable of mcms
return cls._mcm_primitive.bind(obs, **kwargs) # single mcm
return cls._mcm_primitive.bind(*obs, single_mcm=False, **kwargs) # iterable of mcms
return cls._mcm_primitive.bind(obs, single_mcm=True, **kwargs) # single mcm

# pylint: disable=unused-argument
@classmethod
Expand Down
51 changes: 43 additions & 8 deletions pennylane/measurements/mid_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
This module contains the qml.measure measurement.
"""
import uuid
from functools import lru_cache
from typing import Generic, Hashable, Optional, TypeVar, Union

import pennylane as qml
Expand Down Expand Up @@ -209,22 +210,56 @@ def func(x):
samples, leading to unexpected or incorrect results.

"""
wire = Wires(wires)
if len(wire) > 1:
if qml.capture.enabled():
primitive = _create_mid_measure_primitive()
return primitive.bind(wires, reset=reset, postselect=postselect)

return _measure_impl(wires, reset=reset, postselect=postselect)


def _measure_impl(
wires: Union[Hashable, Wires], reset: Optional[bool] = False, postselect: Optional[int] = None
):
"""Concrete implementation of qml.measure"""
wires = Wires(wires)
if len(wires) > 1:
raise qml.QuantumFunctionError(
"Only a single qubit can be measured in the middle of the circuit"
)

# Create a UUID and a map between MP and MV to support serialization
measurement_id = str(uuid.uuid4())[:8]
mp = MidMeasureMP(wires=wire, reset=reset, postselect=postselect, id=measurement_id)
if qml.capture.enabled():
raise NotImplementedError(
"Capture cannot currently handle classical output from mid circuit measurements."
)
mp = MidMeasureMP(wires=wires, reset=reset, postselect=postselect, id=measurement_id)
return MeasurementValue([mp], processing_fn=lambda v: v)


@lru_cache
def _create_mid_measure_primitive():
"""Create a primitive corresponding to an mid-circuit measurement type.

Called when using :func:`~pennylane.measure`.

Returns:
jax.core.Primitive: A new jax primitive corresponding to a mid-circuit
measurement.

"""
import jax # pylint: disable=import-outside-toplevel

mid_measure_p = jax.core.Primitive("measure")

@mid_measure_p.def_impl
def _(wires, reset=False, postselect=None):
return _measure_impl(wires, reset=reset, postselect=postselect)

@mid_measure_p.def_abstract_eval
def _(*_, **__):
dtype = jax.numpy.int64 if jax.config.jax_enable_x64 else jax.numpy.int32
return jax.core.ShapedArray((), dtype)

return mid_measure_p


T = TypeVar("T")


Expand Down Expand Up @@ -266,7 +301,7 @@ def __init__(
@classmethod
def _primitive_bind_call(cls, wires=None, reset=False, postselect=None, id=None):
wires = () if wires is None else wires
return cls._wires_primitive.bind(*wires, reset=reset, postselect=postselect)
return cls._wires_primitive.bind(*wires, reset=reset, postselect=postselect, id=id)

@classmethod
def _abstract_eval(
Expand Down
6 changes: 5 additions & 1 deletion pennylane/workflow/qnode.py
Original file line number Diff line number Diff line change
Expand Up @@ -1169,7 +1169,11 @@ def _impl_call(self, *args, **kwargs) -> qml.typing.Result:

old_interface = self.interface
if old_interface == "auto":
interface = qml.math.get_interface(*args, *list(kwargs.values()))
interface = (
"jax"
if qml.capture.enabled()
else qml.math.get_interface(*args, *list(kwargs.values()))
)
PietropaoloFrisoni marked this conversation as resolved.
Show resolved Hide resolved
self._interface = INTERFACE_MAP[interface]

if self._qfunc_uses_shots_arg:
Expand Down
Loading
Loading