Skip to content

Commit

Permalink
[BUGFIX] Pytrees: Handle empty shot vector (#6155)
Browse files Browse the repository at this point in the history
  • Loading branch information
brownj85 authored Aug 27, 2024
1 parent 9397ccb commit c34f247
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 1 deletion.
5 changes: 5 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@

<h3>Bug fixes 🐛</h3>

* Fix Pytree serialization of operators with empty shot vectors:
[(#6155)](https://github.com/PennyLaneAI/pennylane/pull/6155)

<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):

Jack Brown
3 changes: 3 additions & 0 deletions pennylane/pytrees/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def _wires_to_json(obj: Wires) -> JSON:

def _shots_to_json(obj: Shots) -> JSON:
"""JSON handler for serializing ``Shots``."""
if obj.total_shots is None:
return None

return obj.shot_vector


Expand Down
9 changes: 9 additions & 0 deletions tests/data/attributes/test_pytree.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import pytest

import pennylane as qml
from pennylane.data import Dataset, DatasetPyTree
from pennylane.pytrees.pytrees import (
_register_pytree_with_pennylane,
Expand Down Expand Up @@ -105,3 +106,11 @@ def test_bind_init(self):
attr = DatasetPyTree(bind=bind)

assert attr == value


@pytest.mark.parametrize("shots", [None, 1, [1, 2]])
def test_quantum_scripts(shots):
"""Test that ``QuantumScript`` can be serialized as Pytrees."""
script = qml.tape.QuantumScript([qml.X(0)], shots=shots)

assert qml.equal(DatasetPyTree(script).get_value(), script)
48 changes: 47 additions & 1 deletion tests/pytrees/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import pytest

import pennylane as qml
from pennylane.measurements import Shots
from pennylane.ops import PauliX, Prod, Sum
from pennylane.pytrees import PyTreeStructure, flatten, is_pytree, leaf, unflatten
from pennylane.pytrees.pytrees import (
Expand Down Expand Up @@ -118,6 +119,23 @@ def test_pytree_structure_dump(decode):
]


@pytest.mark.parametrize(
"shots, expect_metadata",
[
(Shots(), None),
(Shots(1), [[1, 1]]),
(Shots([1, 2]), [[1, 1], [2, 1]]),
],
)
def test_pytree_structure_dump_shots(shots, expect_metadata):
"""Test that ``pytree_structure_dump`` handles all forms of shots."""
_, struct = flatten(CustomNode([], {"shots": shots}))

flattened = pytree_structure_dump(struct)

assert json.loads(flattened) == ["test.CustomNode", {"shots": expect_metadata}, []]


def test_pytree_structure_dump_unserializable_metadata():
"""Test that a ``TypeError`` is raised if a Pytree has unserializable metadata."""
_, struct = flatten(CustomNode([1, 2, 4], {"operator": qml.PauliX(0)}))
Expand Down Expand Up @@ -190,9 +208,37 @@ def test_pytree_structure_load():
],
)
def test_pennylane_pytree_roundtrip(obj_in: Any):
"""Test that Pennylane Pytree objects are requal to themselves after
"""Test that Pennylane Pytree objects are equal to themselves after
a serialization roundtrip."""
data, struct = flatten(obj_in)
obj_out = unflatten(data, pytree_structure_load(pytree_structure_dump(struct)))

assert qml.equal(obj_in, obj_out)


@pytest.mark.parametrize(
"obj_in",
[
[
qml.tape.QuantumScript(
[qml.adjoint(qml.RX(0.1, wires=0))],
[qml.expval(2 * qml.X(0))],
trainable_params=[0, 1],
),
Prod(qml.X(0), qml.RX(0.1, wires=0), qml.X(1), id="id"),
Sum(
qml.Hermitian(H_ONE_QUBIT, 2),
qml.Hermitian(H_TWO_QUBITS, [0, 1]),
qml.PauliX(1),
qml.Identity("a"),
),
]
],
)
def test_pennylane_pytree_roundtrip_list(obj_in: Any):
"""Test that lists Pennylane Pytree objects are equal to themselves after
a serialization roundtrip."""
data, struct = flatten(obj_in)
obj_out = unflatten(data, pytree_structure_load(pytree_structure_dump(struct)))

assert all(qml.equal(in_, out) for in_, out in zip(obj_in, obj_out))

0 comments on commit c34f247

Please sign in to comment.