Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add generator checks and differentiability checks to assert_valid #6282

Merged
merged 47 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
47 commits
Select commit Hold shift + click to select a range
52e430b
Add check_generator to assert_valid
astralcai Sep 19, 2024
58179d5
make pylint happy
astralcai Sep 19, 2024
5e804f3
ooops
astralcai Sep 19, 2024
8f0531a
ooops * 2
astralcai Sep 19, 2024
76a2c72
ooops
astralcai Sep 19, 2024
90829a7
ooops
astralcai Sep 19, 2024
872ca86
fix test for evolution
astralcai Sep 19, 2024
7e3f5b0
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 20, 2024
b1b7ae2
Add check_differentiation
astralcai Sep 20, 2024
d8f389b
fix bugs
astralcai Sep 20, 2024
0e46fe4
fixes
astralcai Sep 20, 2024
8ab4f89
small change
astralcai Sep 24, 2024
71abf8b
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Sep 30, 2024
53eed70
revert special treatment
astralcai Sep 30, 2024
2f45948
skip test for stateprep
astralcai Sep 30, 2024
ef7fdfd
revert skip
astralcai Sep 30, 2024
e101951
revert change
astralcai Sep 30, 2024
1de730d
Merge branch 'master' into assert-valid
astralcai Sep 30, 2024
21abd17
add skip_differentiation
astralcai Oct 1, 2024
997f245
Merge branch 'master' into assert-valid
astralcai Oct 1, 2024
23db0c9
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Oct 1, 2024
3e6f5f9
fix black
astralcai Oct 1, 2024
9cf6cf6
skip differentiation for some tests
astralcai Oct 1, 2024
8f880ce
ooops missed one
astralcai Oct 1, 2024
517ad1a
adjust list of ops
astralcai Oct 1, 2024
6c5bb55
fix evolution
astralcai Oct 1, 2024
ec6374c
update tests
astralcai Oct 1, 2024
302695e
skip test_differentiation for all StatePrep
astralcai Oct 2, 2024
89e2eaf
minor fix
astralcai Oct 2, 2024
7a43049
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Oct 2, 2024
f1cbe39
Skip differentiation for a number of ops
astralcai Oct 2, 2024
19a42ae
fix bug in prepselprep
astralcai Oct 3, 2024
19afe24
add xfail to tests
astralcai Oct 3, 2024
7841fd7
Merge branch 'master' into assert-valid
astralcai Oct 3, 2024
04e9f7b
minor change
astralcai Oct 3, 2024
d9be06d
updates
astralcai Oct 3, 2024
ba158c5
add xfail
astralcai Oct 4, 2024
7f52d13
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Oct 4, 2024
3015173
ooops
astralcai Oct 4, 2024
33c966c
Merge branch 'master' into assert-valid
astralcai Oct 4, 2024
226ba2e
skip one more stateprep
astralcai Oct 4, 2024
b2e3494
apply suggestion from code review
astralcai Oct 8, 2024
0df5d77
Merge branch 'master' of https://github.com/PennyLaneAI/pennylane int…
astralcai Oct 8, 2024
d6a2de9
Merge branch 'master' into assert-valid
astralcai Oct 8, 2024
eed362e
add bad capture test
astralcai Oct 8, 2024
c84f291
Merge branch 'master' into assert-valid
astralcai Oct 8, 2024
5ff3b1e
ooops
astralcai Oct 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pennylane/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,7 +1461,7 @@ def generator(self): # pylint: disable=no-self-use
0.5 * Y(0) + Z(0) @ X(1)

The generator may also be provided in the form of a dense or sparse Hamiltonian
(using :class:`.Hermitian` and :class:`.SparseHamiltonian` respectively).
(using :class:`.Hamiltonian` and :class:`.SparseHamiltonian` respectively).

The default value to return is ``None``, indicating that the operation has
no defined generator.
Expand Down
64 changes: 63 additions & 1 deletion pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,25 @@ def _check_eigendecomposition(op):
assert qml.math.allclose(decomp_mat, original_mat), failure_comment


def _check_generator(op):
"""Checks that if an operator's has_generator property is True, it has a generator."""

if op.has_generator:
gen = op.generator()
assert isinstance(gen, qml.operation.Operator)
new_op = qml.exp(gen, 1j * op.data[0])
assert qml.math.allclose(
qml.matrix(op, wire_order=op.wires), qml.matrix(new_op, wire_order=op.wires)
)
else:
failure_comment = (
"If has_generator is False, the matrix method must raise a ``GeneratorUndefinedError``."
)
_assert_error_raised(
op.generator, qml.operation.GeneratorUndefinedError, failure_comment=failure_comment
)()


def _check_copy(op):
"""Check that copies and deep copies give identical objects."""
copied_op = copy.copy(op)
Expand Down Expand Up @@ -276,6 +295,39 @@ def _check_bind_new_parameters(op):
assert qml.math.allclose(d1, d2), failure_comment


def _check_differentiation(op):
"""Checks that the operator can be executed and differentiated correctly."""

if op.num_params == 0:
return

data, struct = qml.pytrees.flatten(op)

def circuit(*args):
qml.apply(qml.pytrees.unflatten(args, struct))
return qml.probs(wires=op.wires)

qnode_ref = qml.QNode(circuit, qml.device("default.qubit"), diff_method="backprop")
qnode_ps = qml.QNode(circuit, qml.device("default.qubit"), diff_method="parameter-shift")

params = [x if isinstance(x, int) else qml.numpy.array(x) for x in data]

ps = qml.jacobian(qnode_ps)(*params)
expected_bp = qml.jacobian(qnode_ref)(*params)

error_msg = (
"Parameter-shift does not produce the same Jacobian as with backpropagation. "
"This might be a bug, or it might be expected due to the mathematical nature "
"of backpropagation, in which case, this test can be skipped for this operator."
)

if isinstance(ps, tuple):
for actual, expected in zip(ps, expected_bp):
assert qml.math.allclose(actual, expected), error_msg
else:
assert qml.math.allclose(ps, expected_bp), error_msg


def _check_wires(op, skip_wire_mapping):
"""Check that wires are a ``Wires`` class and can be mapped."""
assert isinstance(op.wires, qml.wires.Wires), "wires must be a wires instance"
Expand All @@ -288,7 +340,12 @@ def _check_wires(op, skip_wire_mapping):
assert mapped_op.wires == new_wires, "wires must be mappable with map_wires"


def assert_valid(op: qml.operation.Operator, skip_pickle=False, skip_wire_mapping=False) -> None:
def assert_valid(
op: qml.operation.Operator,
skip_pickle=False,
skip_wire_mapping=False,
skip_differentiation=False,
) -> None:
"""Runs basic validation checks on an :class:`~.operation.Operator` to make
sure it has been correctly defined.

Expand All @@ -298,6 +355,8 @@ def assert_valid(op: qml.operation.Operator, skip_pickle=False, skip_wire_mappin
Keyword Args:
skip_pickle=False : If ``True``, pickling tests are not run. Set to ``True`` when
testing a locally defined operator, as pickle cannot handle local objects
skip_differentiation: If ``True``, differentiation tests are not run. Set to `True` when
the operator is parametrized but not differentiable.

**Examples:**

Expand Down Expand Up @@ -352,4 +411,7 @@ def __init__(self, wires):
_check_matrix_matches_decomp(op)
_check_sparse_matrix(op)
_check_eigendecomposition(op)
_check_generator(op)
if not skip_differentiation:
_check_differentiation(op)
_check_capture(op)
2 changes: 1 addition & 1 deletion pennylane/ops/op_math/evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def generator(self):
f"The operator coefficient {self.coeff} is not imaginary; the expected format is exp(-ixG)."
f"The generator is not defined."
)
return self.base
return -1 * self.base

astralcai marked this conversation as resolved.
Show resolved Hide resolved
def __copy__(self):
copied = super().__copy__()
Expand Down
1 change: 0 additions & 1 deletion pennylane/ops/op_math/exp.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,6 @@ def __init__(self, base, coeff=1, num_steps=None, id=None):
super().__init__(base, scalar=coeff, id=id)
self.grad_recipe = [None]
self.num_steps = num_steps

self.hyperparameters["num_steps"] = num_steps

def __repr__(self):
Expand Down
3 changes: 2 additions & 1 deletion pennylane/ops/op_math/prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def circuit(weights):

_op_symbol = "@"
_math_op = math.prod
grad_method = None

@property
def is_hermitian(self):
Expand Down Expand Up @@ -359,7 +360,7 @@ def arithmetic_depth(self) -> int:
def _build_pauli_rep(self):
"""PauliSentence representation of the Product of operations."""
if all(operand_pauli_reps := [op.pauli_rep for op in self.operands]):
return reduce(lambda a, b: a @ b, operand_pauli_reps)
return reduce(lambda a, b: a @ b, operand_pauli_reps) if operand_pauli_reps else None
return None

def _simplify_factors(self, factors: tuple[Operator]) -> tuple[complex, Operator]:
Expand Down
1 change: 1 addition & 0 deletions pennylane/templates/subroutines/prepselprep.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
def _get_new_terms(lcu):
"""Compute a new sum of unitaries with positive coefficients"""
coeffs, ops = lcu.terms()
coeffs = qml.math.stack(coeffs)
angles = qml.math.angle(coeffs)
new_ops = []

Expand Down
79 changes: 46 additions & 33 deletions tests/ops/functions/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,43 +35,55 @@
from pennylane.ops.op_math.pow import PowObs, PowOperation, PowOpObs

_INSTANCES_TO_TEST = [
qml.sum(qml.PauliX(0), qml.PauliZ(0)),
qml.sum(qml.X(0), qml.X(0), qml.Z(0), qml.Z(0)),
qml.BasisState([1], wires=[0]),
qml.ControlledQubitUnitary(np.eye(2), control_wires=1, wires=0),
qml.QubitChannel([np.array([[1, 0], [0, 0.8]]), np.array([[0, 0.6], [0, 0]])], wires=0),
qml.MultiControlledX(wires=[0, 1]),
qml.Projector([1], 0), # the state-vector version is already tested
qml.SpecialUnitary([1, 1, 1], 0),
qml.IntegerComparator(1, wires=[0, 1]),
qml.PauliRot(1.1, "X", wires=[0]),
qml.StatePrep([0, 1], 0),
qml.PCPhase(0.27, dim=2, wires=[0, 1]),
qml.BlockEncode([[0.1, 0.2], [0.3, 0.4]], wires=[0, 1]),
qml.adjoint(qml.PauliX(0)),
qml.adjoint(qml.RX(1.1, 0)),
Tensor(qml.PauliX(0), qml.PauliX(1)),
qml.ops.LinearCombination([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)]),
qml.s_prod(1.1, qml.RX(1.1, 0)),
qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)),
qml.ctrl(qml.RX(1.1, 0), 1),
qml.exp(qml.PauliX(0), 1.1),
qml.pow(qml.IsingXX(1.1, [0, 1]), 2.5),
qml.ops.Evolution(qml.PauliX(0), 5.2),
qml.QutritBasisState([1, 2, 0], wires=[0, 1, 2]),
qml.resource.FirstQuantization(1, 2, 1),
qml.prod(qml.RX(1.1, 0), qml.RY(2.2, 0), qml.RZ(3.3, 1)),
qml.Snapshot(measurement=qml.expval(qml.Z(0)), tag="hi"),
qml.Snapshot(tag="tag"),
(qml.sum(qml.PauliX(0), qml.PauliZ(0)), {}),
(qml.sum(qml.X(0), qml.X(0), qml.Z(0), qml.Z(0)), {}),
(qml.BasisState([1], wires=[0]), {"skip_differentiation": True}),
(
qml.ControlledQubitUnitary(np.eye(2), control_wires=1, wires=0),
{"skip_differentiation": True},
),
(
qml.QubitChannel([np.array([[1, 0], [0, 0.8]]), np.array([[0, 0.6], [0, 0]])], wires=0),
{"skip_differentiation": True},
),
(qml.MultiControlledX(wires=[0, 1]), {}),
(qml.Projector([1], 0), {"skip_differentiation": True}),
(qml.Projector([1, 0], 0), {"skip_differentiation": True}),
(qml.DiagonalQubitUnitary([1, 1, 1, 1], wires=[0, 1]), {"skip_differentiation": True}),
(qml.QubitUnitary(np.eye(2), wires=[0]), {"skip_differentiation": True}),
(qml.SpecialUnitary([1, 1, 1], 0), {"skip_differentiation": True}),
(qml.IntegerComparator(1, wires=[0, 1]), {"skip_differentiation": True}),
(qml.PauliRot(1.1, "X", wires=[0]), {}),
(qml.StatePrep([0, 1], 0), {"skip_differentiation": True}),
(qml.PCPhase(0.27, dim=2, wires=[0, 1]), {}),
(qml.BlockEncode([[0.1, 0.2], [0.3, 0.4]], wires=[0, 1]), {"skip_differentiation": True}),
(qml.adjoint(qml.PauliX(0)), {}),
(qml.adjoint(qml.RX(1.1, 0)), {}),
(Tensor(qml.PauliX(0), qml.PauliX(1)), {}),
(qml.ops.LinearCombination([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)]), {}),
(qml.s_prod(1.1, qml.RX(1.1, 0)), {}),
(qml.prod(qml.PauliX(0), qml.PauliY(1), qml.PauliZ(0)), {}),
(qml.ctrl(qml.RX(1.1, 0), 1), {}),
(qml.exp(qml.PauliX(0), 1.1), {}),
(qml.pow(qml.IsingXX(1.1, [0, 1]), 2.5), {}),
(qml.ops.Evolution(qml.PauliX(0), 5.2), {}),
(qml.QutritBasisState([1, 2, 0], wires=[0, 1, 2]), {"skip_differentiation": True}),
(qml.resource.FirstQuantization(1, 2, 1), {}),
(qml.prod(qml.RX(1.1, 0), qml.RY(2.2, 0), qml.RZ(3.3, 1)), {}),
(qml.Snapshot(measurement=qml.expval(qml.Z(0)), tag="hi"), {}),
(qml.Snapshot(tag="tag"), {}),
]
"""Valid operator instances that could not be auto-generated."""

with warnings.catch_warnings():
warnings.filterwarnings("ignore", "qml.ops.Hamiltonian uses", qml.PennyLaneDeprecationWarning)
_INSTANCES_TO_TEST.append(
qml.operation.convert_to_legacy_H(
qml.Hamiltonian([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)])
),
(
qml.operation.convert_to_legacy_H(
qml.Hamiltonian([1.1, 2.2], [qml.PauliX(0), qml.PauliZ(0)])
),
{},
)
)


Expand Down Expand Up @@ -130,6 +142,7 @@
Operation,
Observable,
Channel,
qml.ops.Projector,
qml.ops.SymbolicOp,
qml.ops.ScalarSymbolicOp,
qml.ops.Pow,
Expand Down Expand Up @@ -164,7 +177,7 @@ def get_all_classes(c):
_CLASSES_TO_TEST = (
set(get_all_classes(Operator))
- {i[1] for i in getmembers(qml.templates) if isclass(i[1]) and issubclass(i[1], Operator)}
- {type(op) for op in _INSTANCES_TO_TEST}
- {type(op) for (op, _) in _INSTANCES_TO_TEST}
- {type(op) for (op, _) in _INSTANCES_TO_FAIL}
)
"""All operators, except those tested manually, abstract/meta classes, and templates."""
Expand All @@ -176,7 +189,7 @@ def class_to_validate(request):


@pytest.fixture(params=_INSTANCES_TO_TEST)
def valid_instance(request):
def valid_instance_and_kwargs(request):
yield request.param


Expand Down
30 changes: 27 additions & 3 deletions tests/ops/functions/test_assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import pennylane as qml
from pennylane.operation import Operator
from pennylane.ops.functions import assert_valid
from pennylane.ops.functions.assert_valid import _check_capture


class TestDecompositionErrors:
Expand Down Expand Up @@ -303,6 +304,28 @@ def _unflatten(cls, data, _):
assert_valid(op, skip_pickle=True)


@pytest.mark.jax
def test_bad_capture():
"""Tests that the correct error is raised when something goes wrong with program capture."""

class MyBadOp(qml.operation.Operator):

def _flatten(self):
return (self.hyperparameters["target_op"], self.data[0]), ()

@classmethod
def _unflatten(cls, data, metadata):
return cls(*data)

def __init__(self, target_op, val):
super().__init__(val, wires=target_op.wires)
self.hyperparameters["target_op"] = target_op

op = MyBadOp(qml.X(0), 2)
with pytest.raises(ValueError, match=r"The capture of the operation into jaxpr failed"):
_check_capture(op)


def test_data_is_tuple():
"""Check that the data property is a tuple."""

Expand Down Expand Up @@ -376,13 +399,14 @@ def test_generated_list_of_ops(class_to_validate, str_wires):


@pytest.mark.jax
def test_explicit_list_of_ops(valid_instance):
def test_explicit_list_of_ops(valid_instance_and_kwargs):
"""Test the validity of operators that could not be auto-generated."""
valid_instance, kwargs = valid_instance_and_kwargs
if valid_instance.name == "Hamiltonian":
with qml.operation.disable_new_opmath_cm(warn=False):
assert_valid(valid_instance)
assert_valid(valid_instance, **kwargs)
else:
assert_valid(valid_instance)
assert_valid(valid_instance, **kwargs)


@pytest.mark.jax
Expand Down
4 changes: 2 additions & 2 deletions tests/ops/op_math/test_evolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_has_generator_false(self):

def test_generator(self):
U = Evolution(qml.PauliX(0), 3)
assert U.base == U.generator()
assert U.generator() == -1 * U.base

@pytest.mark.usefixtures("legacy_opmath_only")
def test_num_params_for_parametric_base_legacy_opmath(self):
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_generator_not_observable_class(self, base):
"""Test that qml.generator will return generator if it is_hermitian, but is not a subclass of Observable"""
op = Evolution(base, 1)
gen, c = qml.generator(op)
qml.assert_equal(gen if c == 1 else qml.s_prod(c, gen), base)
qml.assert_equal(gen if c == 1 else qml.s_prod(c, gen), -1 * base)

def test_generator_error_if_not_hermitian(self):
"""Tests that an error is raised if the generator is not hermitian."""
Expand Down
2 changes: 1 addition & 1 deletion tests/ops/qubit/test_observables.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,7 +539,7 @@ def test_basisstate_projector(self):
second_projector = qml.Projector(basis_state, wires)
qml.assert_equal(second_projector, basis_state_projector)

qml.ops.functions.assert_valid(basis_state_projector)
qml.ops.functions.assert_valid(basis_state_projector, skip_differentiation=True)

def test_statevector_projector(self):
"""Test that we obtain a _StateVectorProjector when input is a state vector."""
Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_amplitude.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_standard_validity():

op = qml.AmplitudeEmbedding(features=FEATURES[0], wires=range(2))

qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True)


class TestDecomposition:
Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def test_standard_validity():
"""Check the operation using the assert_valid function."""
op = qml.AngleEmbedding(features=[1, 2, 3], wires=range(3), rotation="Z")
op = qml.AngleEmbedding(features=[1.0, 2.0, 3.0], wires=range(3), rotation="Z")
qml.ops.functions.assert_valid(op)


Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def test_standard_validity():
"""Check the operation using the assert_valid function."""
wires = qml.wires.Wires((0, 1, 2))
op = qml.BasisEmbedding(features=np.array([1, 1, 1]), wires=wires)
qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True)


# pylint: disable=protected-access
Expand Down
4 changes: 2 additions & 2 deletions tests/templates/test_embeddings/test_displacement_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,9 @@

def test_standard_validity():
"""Check the operation using the assert_valid function."""
feature_vector = [1, 2, 3]
feature_vector = [1.0, 2.0, 3.0]
op = qml.DisplacementEmbedding(features=feature_vector, wires=range(3), method="phase", c=0.5)
qml.ops.functions.assert_valid(op)
qml.ops.functions.assert_valid(op, skip_differentiation=True) # Skip because it's CV op.


def test_flatten_unflatten_methods():
Expand Down
2 changes: 1 addition & 1 deletion tests/templates/test_embeddings/test_iqp_emb.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

def test_standard_validity():
"""Check the operation using the assert_valid function."""
features = (0, 1, 2)
features = (0.0, 1.0, 2.0)

op = qml.IQPEmbedding(features, wires=(0, 1, 2))
qml.ops.functions.assert_valid(op)
Expand Down
Loading
Loading