Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support qml.sample() without specifying the observable #266

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions src/braket/pennylane_plugin/translation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pennylane as qml
from braket.aws import AwsDevice
from braket.circuits import FreeParameter, Gate, ResultType, gates, noises, observables
from braket.circuits.observables import Observable as BraketObservable
from braket.circuits.result_types import (
AdjointGradient,
DensityMatrix,
Expand Down Expand Up @@ -530,7 +531,7 @@
return AdjointGradient(observable=braket_observable, target=targets, parameters=parameters)


def translate_result_type(
def translate_result_type( # noqa: C901
measurement: MeasurementProcess, targets: list[int], supported_result_types: frozenset[str]
) -> Union[ResultType, tuple[ResultType, ...]]:
"""Translates a PennyLane ``MeasurementProcess`` into the corresponding Braket ``ResultType``.
Expand All @@ -547,6 +548,7 @@
then this will return a result type for each term.
"""
return_type = measurement.return_type
observable = measurement.obs

if return_type is ObservableReturnTypes.Probability:
return Probability(targets)
Expand All @@ -558,14 +560,21 @@
return DensityMatrix(targets)
raise NotImplementedError(f"Unsupported return type: {return_type}")

if isinstance(measurement.obs, (Hamiltonian, qml.Hamiltonian)):
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
if return_type is ObservableReturnTypes.Expectation:
return tuple(
Expectation(_translate_observable(term), term.wires) for term in measurement.obs.ops
Expectation(_translate_observable(term), term.wires) for term in observable.ops
)
raise NotImplementedError(f"Return type {return_type} unsupported for Hamiltonian")

braket_observable = _translate_observable(measurement.obs)
if return_type is ObservableReturnTypes.Sample and observable is None:
if isinstance(measurement, qml.measurements.SampleMeasurement):
return tuple(
Sample(BraketObservable.Z(), target) for target in targets or measurement.wires
)
raise NotImplementedError(f"Unsupported measurement type: {type(measurement)}")

Check warning on line 575 in src/braket/pennylane_plugin/translation.py

View check run for this annotation

Codecov / codecov/patch

src/braket/pennylane_plugin/translation.py#L575

Added line #L575 was not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of embedding this logic here, would it be possible to have _translate_observable return both an observable and targets, and have it fall back to these defaults if observable is none and/or targets is none? Might be slightly cleaner (and then maybe wouldn't require the C901 suppression above either, since the logic would be outside this function).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure about that. The target cannot be specified if an observable is specified, so the branching will be based on whether the observable is None. Therefore, I believe that the overall logic will remain the same, and it will only be divided into more functions.


braket_observable = _translate_observable(observable)
if return_type is ObservableReturnTypes.Expectation:
return Expectation(braket_observable, targets)
elif return_type is ObservableReturnTypes.Variance:
Expand Down Expand Up @@ -698,6 +707,14 @@
ag_result.value["gradient"][f"p_{i}"]
for i in sorted(key_indices)
]

if measurement.return_type is ObservableReturnTypes.Sample and observable is None:
if isinstance(measurement, qml.measurements.SampleMeasurement):
if targets:
return [m[targets] for m in braket_result.measurements]
return braket_result.measurements
raise NotImplementedError(f"Unsupported measurement type: {type(measurement)}")

Check warning on line 716 in src/braket/pennylane_plugin/translation.py

View check run for this annotation

Codecov / codecov/patch

src/braket/pennylane_plugin/translation.py#L715-L716

Added lines #L715 - L716 were not covered by tests
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same question here, would it be possible to embed this logic inside translate_result_type somehow?


translated = translate_result_type(measurement, targets, supported_result_types)
if isinstance(observable, (Hamiltonian, qml.Hamiltonian)):
coeffs, _ = observable.terms()
Expand Down
80 changes: 80 additions & 0 deletions test/integ_tests/test_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,86 @@
class TestSample:
"""Tests for the sample return type"""

def test_sample_default(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when no observable is specified
"""
dev = device(2)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 4, wires=0)
qml.CNOT(wires=[0, 1])
return qml.sample()

shot_vector = circuit()

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (shots, 2)

def test_sample_wires(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when only wires are specified
"""
dev = device(3)

@qml.qnode(dev)
def circuit():
qml.RX(np.pi / 4, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample(wires=[0, 2])

shot_vector = circuit()

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (shots, 2)

def test_sample_batch_default(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when no observable is specified and
the batch dimension is returned
"""
dev = device(3)

@qml.qnode(dev)
def circuit(a):
qml.RX(a, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample()

shot_vector = circuit([np.pi / 4, np.pi / 3])

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (2, shots, 3)

def test_sample_batch_wires(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values when only wires are specified
"""
dev = device(4)

@qml.qnode(dev)
def circuit(a):
qml.RX(a, wires=0)
qml.CNOT(wires=[0, 1])
qml.CNOT(wires=[1, 2])
return qml.sample(wires=[0, 2, 3])

shot_vector = circuit([np.pi / 4, np.pi / 3])

# The sample should only contain 1 and 0
assert shot_vector.dtype == np.dtype("int")
assert np.all((shot_vector == 0) | (shot_vector == 1))
assert shot_vector.shape == (2, shots, 3)

def test_sample_values(self, device, shots, tol):
"""Tests if the samples returned by sample have
the correct values
Expand Down
Loading