Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

devices.qubit.measure uses csr_dot_products only when it is usable. #6278

Merged
merged 19 commits into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions pennylane/devices/qubit/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from scipy.sparse import csr_matrix

import pennylane as qml
from pennylane import math
from pennylane.measurements import (
ExpectationMP,
Expand Down Expand Up @@ -168,7 +169,7 @@ def sum_of_terms_method(
)


# pylint: disable=too-many-return-statements
# pylint: disable=too-many-return-statements,too-many-branches
def get_measurement_function(
measurementprocess: MeasurementProcess, state: TensorLike
) -> Callable[[MeasurementProcess, TensorLike], TensorLike]:
Expand All @@ -195,13 +196,31 @@ def get_measurement_function(

backprop_mode = math.get_interface(state, *measurementprocess.obs.data) != "numpy"
if isinstance(measurementprocess.obs, (Hamiltonian, LinearCombination)):
# need to work out thresholds for when its faster to use "backprop mode" measurements
return sum_of_terms_method if backprop_mode else csr_dot_products

# need to work out thresholds for when it's faster to use "backprop mode"
if backprop_mode:
return sum_of_terms_method

if not all(obs.has_sparse_matrix for obs in measurementprocess.obs.terms()[1]):
return sum_of_terms_method

# Hamiltonian.sparse_matrix raises a ValueError for this scenario.
if isinstance(measurementprocess.obs, Hamiltonian) and any(
any(len(o.wires) > 1 for o in qml.operation.Tensor(op).obs)
for op in measurementprocess.obs.ops
):
return sum_of_terms_method
astralcai marked this conversation as resolved.
Show resolved Hide resolved

return csr_dot_products

if isinstance(measurementprocess.obs, Sum):
if backprop_mode:
# always use sum_of_terms_method for Sum observables in backprop mode
return sum_of_terms_method

if not all(obs.has_sparse_matrix for obs in measurementprocess.obs):
return sum_of_terms_method

if (
measurementprocess.obs.has_overlapping_wires
and len(measurementprocess.obs.wires) > 7
Expand Down
12 changes: 12 additions & 0 deletions pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -870,6 +870,18 @@ def compute_sparse_matrix(
"""
raise SparseMatrixUndefinedError

# pylint: disable=no-self-argument, comparison-with-callable
@classproperty
def has_sparse_matrix(cls) -> bool:
r"""Bool: Whether the Operator returns a defined sparse matrix.

Note: Child classes may have this as an instance property instead of as a class property.
"""
return (
cls.compute_sparse_matrix != Operator.compute_sparse_matrix
or cls.sparse_matrix != Operator.sparse_matrix
)

def sparse_matrix(self, wire_order: Optional[WiresLike] = None) -> csr_matrix:
r"""Representation of the operator as a sparse matrix in the computational basis.

Expand Down
19 changes: 19 additions & 0 deletions pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from string import ascii_lowercase

import numpy as np
import scipy.sparse

import pennylane as qml
from pennylane.operation import EigvalsUndefinedError
Expand Down Expand Up @@ -112,6 +113,23 @@ def _check_matrix(op):
)()


def _check_sparse_matrix(op):
"""Check that if the operation says it has a sparse matrix, it does. Otherwise a ``SparseMatrixUndefinedError`` should be raised."""
if op.has_sparse_matrix:
mat = op.sparse_matrix()
assert isinstance(mat, scipy.sparse.csr_matrix), "matrix must be a TensorLike"
l = 2 ** len(op.wires)
failure_comment = f"matrix must be two dimensional with shape ({l}, {l})"
assert qml.math.shape(mat) == (l, l), failure_comment
else:
failure_comment = "If has_sparse_matrix is False, the matrix method must raise a ``SparseMatrixUndefinedError``."
_assert_error_raised(
op.sparse_matrix,
qml.operation.SparseMatrixUndefinedError,
failure_comment=failure_comment,
)()


def _check_matrix_matches_decomp(op):
"""Check that if both the matrix and decomposition are defined, they match."""
if op.has_matrix and op.has_decomposition:
Expand Down Expand Up @@ -332,5 +350,6 @@ def __init__(self, wires):
_check_decomposition(op, skip_wire_mapping)
_check_matrix(op)
_check_matrix_matches_decomp(op)
_check_sparse_matrix(op)
_check_eigendecomposition(op)
_check_capture(op)
5 changes: 5 additions & 0 deletions pennylane/ops/op_math/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,11 @@ def matrix(self, wire_order=None):

return moveaxis(conj(base_matrix), -2, -1)

# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_sparse_matrix(self) -> bool:
return self.base.has_sparse_matrix

# pylint: disable=arguments-differ
def sparse_matrix(self, wire_order=None, format="csr"):
base_matrix = self.base.sparse_matrix(wire_order=wire_order)
Expand Down
5 changes: 5 additions & 0 deletions pennylane/ops/op_math/pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,11 @@ def _matrix(scalar, mat):

return fractional_matrix_power(mat, scalar)

# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_sparse_matrix(self) -> bool:
return self.base.has_sparse_matrix and isinstance(self.z, int)

# pylint: disable=arguments-differ
@staticmethod
def compute_sparse_matrix(*params, base=None, z=0):
Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,10 @@ def sparse_matrix(self, wire_order=None):
full_mat = reduce(sparse_kron, mats)
return math.expand_matrix(full_mat, self.wires, wire_order=wire_order)

@property
def has_sparse_matrix(self):
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)

mudit2812 marked this conversation as resolved.
Show resolved Hide resolved
# pylint: disable=protected-access
@property
def _queue_category(self):
Expand Down
4 changes: 4 additions & 0 deletions pennylane/ops/op_math/sprod.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,10 @@ def sparse_matrix(self, wire_order=None):
mat.eliminate_zeros()
return mat

@property
def has_sparse_matrix(self):
return self.pauli_rep is not None or self.base.has_sparse_matrix

@property
def has_matrix(self):
"""Bool: Whether or not the Operator returns a defined matrix."""
Expand Down
5 changes: 5 additions & 0 deletions pennylane/ops/op_math/sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,11 @@ def matrix(self, wire_order=None):

return math.expand_matrix(reduced_mat, sum_wires, wire_order=wire_order)

# pylint: disable=arguments-renamed, invalid-overridden-method
@property
def has_sparse_matrix(self) -> bool:
return self.pauli_rep is not None or all(op.has_sparse_matrix for op in self)

def sparse_matrix(self, wire_order=None):
if self.pauli_rep: # Get the sparse matrix from the PauliSentence representation
return self.pauli_rep.to_mat(wire_order=wire_order or self.wires, format="csr")
Expand Down
44 changes: 44 additions & 0 deletions tests/devices/qubit/test_measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ def test_sample_based_observable(self):
_ = measure(qml.sample(wires=0), state)


@pytest.mark.unit
class TestMeasurementDispatch:
"""Test that get_measurement_function dispatchs to the correct place."""

Expand Down Expand Up @@ -96,6 +97,49 @@ def test_sum_sum_of_terms_when_backprop(self):
state = qml.numpy.zeros(2)
assert get_measurement_function(qml.expval(S), state) is sum_of_terms_method

@pytest.mark.usefixtures("use_legacy_opmath")
def test_hamiltonian_with_multi_wire_obs(self):
"""Check that a Hamiltonian with a multi-wire observable uses the sum of terms method."""

S = qml.Hamiltonian(
[0.5, 0.5],
[
qml.X(0),
qml.Hermitian(
np.array(
[
[0.5, 1.0j, 0.0, -3j],
[-1.0j, -1.1, 0.0, -0.1],
[0.0, 0.0, -0.9, 12.0],
[3j, -0.1, 12.0, 0.0],
]
),
wires=[0, 1],
),
],
)
state = np.zeros(2)
assert get_measurement_function(qml.expval(S), state) is sum_of_terms_method

def test_no_sparse_matrix(self):
"""Tests Hamiltonians/Sums containing observables that do not have a sparse matrix."""

class DummyOp(qml.operation.Operator): # pylint: disable=too-few-public-methods
num_wires = 1

S1 = qml.Hamiltonian([0.5, 0.5], [qml.X(0), DummyOp(wires=1)])
state = np.zeros(2)
assert get_measurement_function(qml.expval(S1), state) is sum_of_terms_method

S2 = qml.X(0) + DummyOp(wires=1)
assert get_measurement_function(qml.expval(S2), state) is sum_of_terms_method

S3 = 0.5 * qml.X(0) + 0.5 * DummyOp(wires=1)
assert get_measurement_function(qml.expval(S3), state) is sum_of_terms_method

S4 = qml.Y(0) + qml.X(0) @ DummyOp(wires=1)
assert get_measurement_function(qml.expval(S4), state) is sum_of_terms_method


class TestMeasurements:
@pytest.mark.parametrize(
Expand Down
Loading