Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add update kwarg to QuantumScript.copy #6285

Merged
merged 22 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@

* The `diagonalize_measurements` transform now uses a more efficient method of diagonalization
when possible, based on the `pauli_rep` of the relevant observables.
[#6113](https://github.com/PennyLaneAI/pennylane/pull/6113/)
[(#6113)](https://github.com/PennyLaneAI/pennylane/pull/6113/)

* An `update` argument is added to `QuantumScript.copy` to make it easier to create
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
an updated version of a tape where some of `tape.operations`, `tape.measurements`,
`tape.shots`, and `tape.trainable_params` are modified while ensuring other attributes
are unchanged.
[(#6285)](https://github.com/PennyLaneAI/pennylane/pull/6285)

* The `Hermitian` operator now has a `compute_sparse_matrix` implementation.
[(#6225)](https://github.com/PennyLaneAI/pennylane/pull/6225)
Expand Down
52 changes: 44 additions & 8 deletions pennylane/tape/qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -834,25 +834,60 @@ def numeric_type(self) -> Union[type, tuple[type, ...]]:
# Transforms: QuantumScript to QuantumScript
# ========================================================

def copy(self, copy_operations: bool = False) -> "QuantumScript":
"""Returns a shallow copy of the quantum script.
def copy(
self, copy_operations: bool = False, update: Optional[Union[dict, bool]] = False
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
) -> "QuantumScript":
"""Returns a copy of the quantum script. If `update` was passed, the updated attributes are
modified, otherwise, all attributes match the original tape. The copy is a shallow copy if
`copy_operations` and `update` are both `False`.
lillian542 marked this conversation as resolved.
Show resolved Hide resolved

Args:
copy_operations (bool): If True, the operations are also shallow copied.
Otherwise, if False, the copied operations will simply be references
to the original operations; changing the parameters of one script will likewise
change the parameters of all copies.
update (dict): An optional dictionary to pass new operations, measurements, shots or
trainable_params with. These will be modified on the copied tape.

Returns:
QuantumScript : a shallow copy of the quantum script
QuantumScript : a copy of the quantum script. If `update` was passed, the updated attributes are modified.

**Example**
astralcai marked this conversation as resolved.
Show resolved Hide resolved

.. code-block:: python

tape = qml.tape.QuantumScript(
ops= [qml.X(0), qml.Y(1)],
measurements=[qml.expval(qml.Z(0))],
shots=2000)

new_tape = tape.copy(update={"measurements" :[qml.expval(qml.X(1))]})

>>> tape.measurements
[qml.expval(qml.Z(0)]

>>> new_tape.measurements
[qml.expval(qml.X(1))]

>>> new_tape.shots
Shots(total_shots=2000, shot_vector=(ShotCopies(2000 shots x 1),))
"""

if copy_operations:
if update:
for k in update:
if k not in ["operations", "measurements", "shots", "trainable_params"]:
raise TypeError(
f"{self.__class__}.copy() got an unexpected key '{k}' in update dict"
)
else:
update = {}

if copy_operations or update:
# Perform a shallow copy of all operations in the operation and measurement
# queues. The operations will continue to share data with the original script operations
# unless modified.
_ops = [copy.copy(op) for op in self.operations]
_measurements = [copy.copy(op) for op in self.measurements]
_ops = update.get("operations", [copy.copy(op) for op in self.operations])
_measurements = update.get("measurements", [copy.copy(op) for op in self.measurements])
else:
# Perform a shallow copy of the operation and measurement queues. The
# operations within the queues will be references to the original script operations;
Expand All @@ -865,9 +900,10 @@ def copy(self, copy_operations: bool = False) -> "QuantumScript":
new_qscript = self.__class__(
ops=_ops,
measurements=_measurements,
shots=self.shots,
trainable_params=list(self.trainable_params),
shots=update.get("shots", self.shots),
trainable_params=list(update.get("trainable_params", self.trainable_params)),
)

new_qscript._graph = None if copy_operations else self._graph
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
new_qscript._specs = None
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
new_qscript._batch_size = self._batch_size
lillian542 marked this conversation as resolved.
Show resolved Hide resolved
Expand Down
74 changes: 73 additions & 1 deletion tests/tape/test_qscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pennylane.operation import _UNSET_BATCH_SIZE
from pennylane.tape import QuantumScript

# pylint: disable=protected-access, unused-argument, too-few-public-methods
# pylint: disable=protected-access, unused-argument, too-few-public-methods, use-implicit-booleaness-not-comparison


class TestInitialization:
Expand Down Expand Up @@ -637,6 +637,78 @@ def test_deep_copy(self):
# to support PyTorch, which does not support deep copying of tensors
assert copied_qs.operations[0].data[0] is qs.operations[0].data[0]

@pytest.mark.parametrize("shots", [50, (1000, 2000), None])
def test_copy_update_shots(self, shots):
"""Test that copy with update dict behaves as expected for setting shots"""

ops = [qml.X("b"), qml.RX(1.2, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_tape = tape.copy(update={"shots": shots})
assert tape.shots == Shots(2500)
assert new_tape.shots == Shots(shots)

assert new_tape.operations == tape.operations == ops
assert new_tape.measurements == tape.measurements == [qml.counts()]
assert new_tape.trainable_params == tape.trainable_params == [1]

def test_copy_update_measurements(self):
"""Test that copy with update dict behaves as expected for setting measurements"""

ops = [qml.X("b"), qml.RX(1.2, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_measurements = [qml.expval(qml.X(0)), qml.sample()]
new_tape = tape.copy(update={"measurements": new_measurements})

assert tape.measurements == [qml.counts()]
assert new_tape.measurements == new_measurements

assert new_tape.operations == tape.operations == ops
assert new_tape.shots == tape.shots == Shots(2500)
assert new_tape.trainable_params == tape.trainable_params == [1]

def test_copy_update_operations(self):
"""Test that copy with update dict behaves as expected for setting operations"""

ops = [qml.X("b"), qml.RX(1.2, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_ops = [qml.X(0)]
new_tape = tape.copy(update={"operations": new_ops})

assert tape.operations == ops
assert new_tape.operations == new_ops

assert new_tape.measurements == tape.measurements == [qml.counts()]
assert new_tape.shots == tape.shots == Shots(2500)
assert new_tape.trainable_params == tape.trainable_params == [1]

def test_copy_update_trainable_params(self):
"""Test that copy with update dict behaves as expected for setting trainable parameters"""

ops = [qml.RX(1.23, "b"), qml.RX(4.56, "a")]
tape = QuantumScript(ops, measurements=[qml.counts()], shots=2500, trainable_params=[1])

new_tape = tape.copy(update={"trainable_params": [0]})

assert tape.trainable_params == [1]
assert tape.get_parameters() == [4.56]
assert new_tape.trainable_params == [0]
assert new_tape.get_parameters() == [1.23]

assert new_tape.operations == tape.operations == ops
assert new_tape.measurements == tape.measurements == [qml.counts()]
assert new_tape.shots == tape.shots == Shots(2500)

def test_copy_update_bad_key(self):
"""Test that an unrecognized key in update dict raises an error"""

tape = QuantumScript([qml.X(0)], [qml.counts()], shots=2500)

with pytest.raises(TypeError, match="got an unexpected key"):
_ = tape.copy(update={"bad_kwarg": 3})


def test_adjoint():
"""Tests taking the adjoint of a quantum script."""
Expand Down
Loading