Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

devices.qubit.measure uses csr_dot_products only when it is usable. #6278

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 21 additions & 3 deletions pennylane/devices/qubit/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from scipy.sparse import csr_matrix

import pennylane as qml
from pennylane import math
from pennylane.measurements import (
ExpectationMP,
Expand Down Expand Up @@ -168,7 +169,7 @@ def sum_of_terms_method(
)


# pylint: disable=too-many-return-statements
# pylint: disable=too-many-return-statements,too-many-branches
def get_measurement_function(
measurementprocess: MeasurementProcess, state: TensorLike
) -> Callable[[MeasurementProcess, TensorLike], TensorLike]:
Expand All @@ -195,13 +196,30 @@ def get_measurement_function(

backprop_mode = math.get_interface(state, *measurementprocess.obs.data) != "numpy"
if isinstance(measurementprocess.obs, (Hamiltonian, LinearCombination)):
# need to work out thresholds for when its faster to use "backprop mode" measurements
return sum_of_terms_method if backprop_mode else csr_dot_products

# need to work out thresholds for when it's faster to use "backprop mode"
if backprop_mode:
return sum_of_terms_method

if not all(obs.has_sparse_matrix for obs in measurementprocess.obs.terms()[1]):
return sum_of_terms_method

if isinstance(measurementprocess.obs, Hamiltonian) and any(
any(len(o.wires) > 1 for o in qml.operation.Tensor(op).obs)
for op in measurementprocess.obs.ops
):
return sum_of_terms_method
astralcai marked this conversation as resolved.
Show resolved Hide resolved

return csr_dot_products

if isinstance(measurementprocess.obs, Sum):
if backprop_mode:
# always use sum_of_terms_method for Sum observables in backprop mode
return sum_of_terms_method

if not all(obs.has_sparse_matrix for obs in measurementprocess.obs):
return sum_of_terms_method

if (
measurementprocess.obs.has_overlapping_wires
and len(measurementprocess.obs.wires) > 7
Expand Down
12 changes: 12 additions & 0 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,18 @@ def compute_sparse_matrix(
"""
raise SparseMatrixUndefinedError

# pylint: disable=no-self-argument, comparison-with-callable
@classproperty
def has_sparse_matrix(cls) -> bool:
r"""Bool: Whether the Operator returns a defined sparse matrix.

Note: Child classes may have this as an instance property instead of as a class property.
"""
return (
cls.compute_sparse_matrix != Operator.compute_sparse_matrix
or cls.sparse_matrix != Operator.sparse_matrix
)

def sparse_matrix(self, wire_order: Optional[WiresLike] = None) -> csr_matrix:
r"""Representation of the operator as a sparse matrix in the computational basis.

Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def sparse_matrix(self, wire_order=None):
full_mat = reduce(sparse_kron, mats)
return math.expand_matrix(full_mat, self.wires, wire_order=wire_order)

@property
def has_sparse_matrix(self):
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)

mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
# pylint: disable=protected-access
@property
def _queue_category(self):
Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ def sparse_matrix(self, wire_order=None):
mat.eliminate_zeros()
return mat

@property
def has_sparse_matrix(self):
return self.pauli_rep is not None or self.base.has_sparse_matrix

@property
def has_matrix(self):
"""Bool: Whether or not the Operator returns a defined matrix."""
Expand Down
44 changes: 44 additions & 0 deletions tests/devices/qubit/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_sample_based_observable(self):
_ = measure(qml.sample(wires=0), state)


@pytest.mark.unit
class TestMeasurementDispatch:
"""Test that get_measurement_function dispatchs to the correct place."""

Expand Down Expand Up @@ -96,6 +97,49 @@ def test_sum_sum_of_terms_when_backprop(self):
state = qml.numpy.zeros(2)
assert get_measurement_function(qml.expval(S), state) is sum_of_terms_method

@pytest.mark.usefixtures("use_legacy_opmath")
def test_hamiltonian_with_multi_wire_obs(self):
"""Check that a Hamiltonian with a multi-wire observable uses the sum of terms method."""

S = qml.Hamiltonian(
[0.5, 0.5],
[
qml.X(0),
qml.Hermitian(
np.array(
[
[0.5, 1.0j, 0.0, -3j],
[-1.0j, -1.1, 0.0, -0.1],
[0.0, 0.0, -0.9, 12.0],
[3j, -0.1, 12.0, 0.0],
]
),
wires=[0, 1],
),
],
)
state = np.zeros(2)
assert get_measurement_function(qml.expval(S), state) is sum_of_terms_method

def test_no_sparse_matrix(self):
"""Tests that Hamiltonians/Sums containing observables that does not have sparse matrix."""
astralcai marked this conversation as resolved.
Show resolved Hide resolved

class DummyOp(qml.operation.Operator): # pylint: disable=too-few-public-methods
num_wires = 1

S1 = qml.Hamiltonian([0.5, 0.5], [qml.X(0), DummyOp(wires=1)])
state = np.zeros(2)
assert get_measurement_function(qml.expval(S1), state) is sum_of_terms_method

S2 = qml.X(0) + DummyOp(wires=1)
assert get_measurement_function(qml.expval(S2), state) is sum_of_terms_method

S3 = 0.5 * qml.X(0) + 0.5 * DummyOp(wires=1)
assert get_measurement_function(qml.expval(S3), state) is sum_of_terms_method

S4 = qml.Y(0) + qml.X(0) @ DummyOp(wires=1)
assert get_measurement_function(qml.expval(S4), state) is sum_of_terms_method


class TestMeasurements:
@pytest.mark.parametrize(
Expand Down
Loading