Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
astralcai committed Sep 20, 2024
1 parent d8f389b commit 0e46fe4
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions pennylane/ops/functions/assert_valid.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,14 @@ def _check_differentiation(op):
if op.num_params == 0:
return

if isinstance(op, qml.ops.qubit.BasisStateProjector):
return

data, struct = qml.pytrees.flatten(op)

def circuit(*args):
qml.apply(qml.pytrees.unflatten(args, struct))
return qml.probs()
return qml.probs(wires=op.wires)

qnode_ref = qml.QNode(circuit, device=qml.device("default.qubit"), diff_method="backprop")
qnode_ps = qml.QNode(
Expand All @@ -319,16 +322,16 @@ def circuit(*args):
params = [x if isinstance(x, int) else qml.numpy.array(x) for x in data]

ps = qml.jacobian(qnode_ps)(*params)
expected_ps = qml.jacobian(qnode_ref)(*params)
expected_bp = qml.jacobian(qnode_ref)(*params)

if isinstance(ps, tuple):
for actual, expected in zip(ps, expected_ps):
for actual, expected in zip(ps, expected_bp):
assert qml.math.allclose(
actual, expected
), "Backpropagation does not produce the expected Jacobian with this operator."
else:
assert qml.math.allclose(
ps, expected_ps
ps, expected_bp
), "Parameter shift does not produce the expected Jacobian with this operator."


Expand Down

0 comments on commit 0e46fe4

Please sign in to comment.