Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] qml.BasisRotation is not jax compatible #6019

Merged
merged 83 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from 48 commits
Commits
Show all changes
83 commits
Select commit Hold shift + click to select a range
76abf63
JAX compatible givens decompositions. Unit tests passing.
PabloAMC Jul 19, 2024
c17cf83
Converting to jax.numpy the basis rotation
PabloAMC Jul 22, 2024
cb663d3
Update basis_rotation.py
PabloAMC Jul 22, 2024
cb4d1f7
Update changelog-dev.md
PabloAMC Jul 22, 2024
11a4b27
Update givens_decomposition.py
PabloAMC Jul 22, 2024
d811833
Reverting from jnp. to qml.math.
PabloAMC Jul 22, 2024
7edc75a
Remove import numpy
PabloAMC Jul 22, 2024
d47d601
Unit test for jax jitting
PabloAMC Jul 22, 2024
ca95f7a
Removing whitespaces
PabloAMC Jul 22, 2024
eb77772
qml.math.copy and (not qml.math.is_abstract(unitary_matrix))
PabloAMC Jul 22, 2024
7d6c223
Update basis_rotation.py
PabloAMC Jul 22, 2024
38e3e68
branching due to .at
PabloAMC Jul 22, 2024
07816df
Branching functions with/without jax backend
PabloAMC Jul 22, 2024
cd76203
Update givens_decomposition.py
PabloAMC Jul 22, 2024
75e4f1e
Update test_basis_rotation.py
PabloAMC Jul 22, 2024
8bfba22
Update givens_decomposition.py
PabloAMC Jul 22, 2024
34d4257
Update givens_decomposition.py
PabloAMC Jul 22, 2024
379b630
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 22, 2024
96028a2
Update test_basis_rotation.py
PabloAMC Jul 22, 2024
bc34c6f
Update basis_rotation.py
PabloAMC Jul 22, 2024
2ccc9ab
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
PabloAMC Jul 22, 2024
57bfc8a
Update changelog-dev.md
PabloAMC Jul 22, 2024
650379c
Update changelog-dev.md
PabloAMC Jul 22, 2024
a563754
Checks for abstract arrays
PabloAMC Jul 22, 2024
e33bf03
make jax-jit friendly
obliviateandsurrender Jul 23, 2024
43686b9
Update basis_rotation.py
PabloAMC Jul 23, 2024
2b107c1
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
PabloAMC Jul 23, 2024
75edef1
Update givens_decomposition.py
PabloAMC Jul 23, 2024
2c7fffd
Revert "Update givens_decomposition.py"
PabloAMC Jul 23, 2024
2b789cc
Updating the documentation
PabloAMC Jul 23, 2024
a8c2d99
minor fixes
obliviateandsurrender Jul 23, 2024
182b173
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
obliviateandsurrender Jul 23, 2024
7ef90fb
assertion for workflow comparison
obliviateandsurrender Jul 23, 2024
e8351dd
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
obliviateandsurrender Jul 23, 2024
5cfabaf
Update changelog-dev.md
PabloAMC Jul 23, 2024
b825617
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 23, 2024
9010de1
Update doc/releases/changelog-dev.md
PabloAMC Jul 24, 2024
5541c89
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 24, 2024
62c4c23
Updating docstrings
PabloAMC Jul 29, 2024
9fb4d26
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 29, 2024
8b4eef5
attemp warning fix
obliviateandsurrender Jul 30, 2024
545be3e
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 30, 2024
eec5149
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 30, 2024
4064fa6
Update pennylane/qchem/givens_decomposition.py
PabloAMC Jul 31, 2024
be3df11
Update pennylane/qchem/givens_decomposition.py
PabloAMC Jul 31, 2024
a512ae3
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 31, 2024
7914f53
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 31, 2024
8ebfebe
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Aug 2, 2024
7da2f88
Update givens_decomposition.py
PabloAMC Aug 2, 2024
db61a73
Update changelog-dev.md
PabloAMC Aug 2, 2024
1c752a1
Update doc/releases/changelog-dev.md
PabloAMC Aug 2, 2024
200481d
code cleanup
obliviateandsurrender Aug 5, 2024
af011ac
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
obliviateandsurrender Aug 5, 2024
24d13c7
Update qdrift.py
PabloAMC Aug 7, 2024
def0979
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Sep 16, 2024
51c6325
reverting changelog
willjmax Oct 11, 2024
c87a513
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 11, 2024
b3b0bba
reverting qdrift changes
willjmax Oct 11, 2024
b5adf4d
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of github…
willjmax Oct 11, 2024
99ee770
cleanup givens changes
willjmax Oct 11, 2024
443017e
update changelog
willjmax Oct 11, 2024
8192fa0
minor changes to test
willjmax Oct 11, 2024
e4cc0d2
black
willjmax Oct 11, 2024
18dd017
fixing minor bug
willjmax Oct 11, 2024
a11d2db
Update pennylane/templates/subroutines/basis_rotation.py
obliviateandsurrender Oct 11, 2024
2dcaf6c
black
willjmax Oct 11, 2024
d30acc5
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 11, 2024
3e57915
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 11, 2024
80e6654
Apply suggestions from code review
obliviateandsurrender Oct 16, 2024
db88688
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 16, 2024
25c3d6d
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 16, 2024
2b01f24
addressing soran's comments
willjmax Oct 16, 2024
ae58b5d
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of github…
willjmax Oct 16, 2024
5a2e2ec
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 16, 2024
63f2ea7
test for _set_untiary_matrix
willjmax Oct 16, 2024
c459862
linting
willjmax Oct 16, 2024
e10576b
removing jax import from test
willjmax Oct 16, 2024
f9c445b
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 17, 2024
69acdf3
happy `pylint`
obliviateandsurrender Oct 18, 2024
1448756
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
obliviateandsurrender Oct 18, 2024
06a3bad
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
willjmax Oct 18, 2024
b7e20d8
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
obliviateandsurrender Oct 24, 2024
56595d9
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Nov 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@
* The `qubit_observable` function is modified to return an ascending wire order for molecular
Hamiltonians.
[(#5950)](https://github.com/PennyLaneAI/pennylane/pull/5950)

* `qml.BasisRotation` and `qml.qchem.givens_decomposition` are now jit compatible.
PabloAMC marked this conversation as resolved.
Show resolved Hide resolved
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)

* The `CNOT` operator no longer decomposes to itself. Instead, it raises a `qml.DecompositionUndefinedError`.
[(#6039)](https://github.com/PennyLaneAI/pennylane/pull/6039)
Expand Down Expand Up @@ -291,14 +294,18 @@
[(#5978)](https://github.com/PennyLaneAI/pennylane/pull/5978)

* `qml.AmplitudeEmbedding` has better support for features using low precision integer data types.
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)

* Jacobian shape is fixed for measurements with dimension in `qml.gradients.vjp.compute_vjp_single`.
[(5986)](https://github.com/PennyLaneAI/pennylane/pull/5986)

* `qml.lie_closure` works with sums of Paulis.
[(#6023)](https://github.com/PennyLaneAI/pennylane/pull/6023)

* `qml.BasisRotation` works with qjit.
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
PabloAMC marked this conversation as resolved.
Show resolved Hide resolved


<h3>Contributors ✍️</h3>

This release contains contributions from (in alphabetical order):
Expand All @@ -324,7 +331,8 @@ Christina Lee,
William Maxwell,
Vincent Michaud-Rioux,
Anurav Modak,
Pablo A. Moreno Casares,
Mudit Pandey,
Erik Schultheis,
nate stemen,
David Wierichs,
Nate Stemen,
David Wierichs.
99 changes: 62 additions & 37 deletions pennylane/qchem/givens_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
This module contains the functions needed for performing basis transformations defined by a set of fermionic ladder operators.
"""

import numpy as np

import pennylane as qml


Expand All @@ -38,26 +36,28 @@ def _givens_matrix(a, b, left=True, tol=1e-8):
tol (float): determines tolerance limits for :math:`|a|` and :math:`|b|` under which they are considered as zero.

Returns:
np.ndarray (or tensor): Givens rotation matrix
tensor_like: Givens rotation matrix

"""
abs_a, abs_b = np.abs(a), np.abs(b)
if abs_a < tol:
cosine, sine, phase = 1.0, 0.0, 1.0
elif abs_b < tol:
cosine, sine, phase = 0.0, 1.0, 1.0
else:
hypot = np.hypot(abs_a, abs_b)
cosine = abs_b / hypot
sine = abs_a / hypot
phase = 1.0 * b / abs_b * a.conjugate() / abs_a
abs_a, abs_b, interface = qml.math.abs(a), qml.math.abs(b), qml.math.get_interface(a)
willjmax marked this conversation as resolved.
Show resolved Hide resolved
aprod = qml.math.nan_to_num(abs_b * abs_a)
hypot = qml.math.hypot(abs_a, abs_b)

cosine = qml.math.where(abs_a < tol, 1.0, qml.math.where(abs_b < tol, 0.0, abs_b / hypot))
sine = qml.math.where(abs_a < tol, 0.0, qml.math.where(abs_b < tol, 1.0, abs_a / hypot))
phase = qml.math.where(
abs_a < tol,
1.0,
qml.math.where(abs_b < tol, 1.0, (1.0 * b * qml.math.conj(a)) / (aprod + 1e-15)),
)

if left:
return np.array([[phase * cosine, -sine], [phase * sine, cosine]])
return qml.math.array([[phase * cosine, -sine], [phase * sine, cosine]], like=interface)

return np.array([[phase * sine, cosine], [-phase * cosine, sine]])
return qml.math.array([[phase * sine, cosine], [-phase * cosine, sine]], like=interface)


# pylint:disable = too-many-branches
def givens_decomposition(unitary):
r"""Decompose a unitary into a sequence of Givens rotation gates with phase shifts and a diagonal phase matrix.

Expand Down Expand Up @@ -110,7 +110,7 @@ def givens_decomposition(unitary):
unitary (tensor): unitary matrix on which decomposition will be performed

Returns:
(np.ndarray, list[(np.ndarray, tuple)]): diagonal elements of the phase matrix :math:`D` and Givens rotation matrix :math:`T` with their indices.
(tensor_like, list[(tensor_like, tuple)]): diagonal elements of the phase matrix :math:`D` and Givens rotation matrix :math:`T` with their indices.
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved

Raises:
ValueError: if the provided matrix is not square.
Expand Down Expand Up @@ -147,43 +147,68 @@ def givens_decomposition(U):
# Update U = T(N+j-i-1, N+j-i) @ U

"""
interface = qml.math.get_deep_interface(unitary)
unitary = qml.math.copy(unitary) if interface == "jax" else qml.math.toarray(unitary).copy()
willjmax marked this conversation as resolved.
Show resolved Hide resolved
M, N = qml.math.shape(unitary)

unitary, (M, N) = qml.math.toarray(unitary).copy(), unitary.shape
if M != N:
raise ValueError(f"The unitary matrix should be of shape NxN, got {unitary.shape}")

left_givens, right_givens = [], []
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary[:, indices] = unitary[:, indices] @ grot_mat.T
right_givens.append((grot_mat.conj(), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary[indices] = grot_mat @ unitary[indices]
left_givens.append((grot_mat, indices))
if interface == "jax":
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary = unitary.at[:, indices].set(unitary[:, indices] @ grot_mat.T, indices_are_sorted=True, unique_indices=True)
right_givens.append((qml.math.conj(grot_mat), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary = unitary.at[indices, :].set(grot_mat @ unitary[indices, :], indices_are_sorted=True, unique_indices=True)
left_givens.append((grot_mat, indices))
else:
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary[:, indices] = unitary[:, indices] @ grot_mat.T
right_givens.append((qml.math.conj(grot_mat), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary[indices] = grot_mat @ unitary[indices]
left_givens.append((grot_mat, indices))
willjmax marked this conversation as resolved.
Show resolved Hide resolved

nleft_givens = []
for grot_mat, (i, j) in reversed(left_givens):
sphase_mat = np.diag(np.diag(unitary)[[i, j]])
decomp_mat = grot_mat.conj().T @ sphase_mat
sphase_mat = qml.math.diag(qml.math.diag(unitary)[qml.math.array([i, j])])
decomp_mat = qml.math.conj(grot_mat).T @ sphase_mat
givens_mat = _givens_matrix(*decomp_mat[1, :].T)
nphase_mat = decomp_mat @ givens_mat.T

# check for T_{m,n}^{-1} x D = D x T.
if not np.allclose(nphase_mat @ givens_mat.conj(), decomp_mat): # pragma: no cover
if not qml.math.is_abstract(decomp_mat) and not qml.math.allclose(
nphase_mat @ qml.math.conj(givens_mat), decomp_mat
): # pragma: no cover
raise ValueError("Failed to shift phase transposition.")

unitary[i, i], unitary[j, j] = np.diag(nphase_mat)
nleft_givens.append((givens_mat.conj(), (i, j)))
if interface == "jax":
unitary = unitary.at[i, i].set(qml.math.diag(nphase_mat)[0])
unitary = unitary.at[j, j].set(qml.math.diag(nphase_mat)[1])
else:
unitary[i, i], unitary[j, j] = qml.math.diag(nphase_mat)
nleft_givens.append((qml.math.conj(givens_mat), (i, j)))
willjmax marked this conversation as resolved.
Show resolved Hide resolved

phases, ordered_rotations = np.diag(unitary), []
phases, ordered_rotations = qml.math.diag(unitary), []
for grot_mat, (i, j) in list(reversed(nleft_givens)) + list(reversed(right_givens)):
if not np.all(np.isreal(grot_mat[0, 1]) and np.isreal(grot_mat[1, 1])): # pragma: no cover
if not qml.math.is_abstract(grot_mat) and not qml.math.all(
qml.math.isreal(grot_mat[0, 1]) and qml.math.isreal(grot_mat[1, 1])
): # pragma: no cover
raise ValueError(f"Incorrect Givens Rotation encountered, {grot_mat}")
ordered_rotations.append((grot_mat, (i, j)))

Expand Down
33 changes: 19 additions & 14 deletions pennylane/templates/subroutines/basis_rotation.py
willjmax marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
This module contains the template for performing basis transformation defined by a set of fermionic ladder operators.
"""

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.qchem.givens_decomposition import givens_decomposition
Expand Down Expand Up @@ -110,15 +108,18 @@ def _primitive_bind_call(cls, wires, unitary_matrix, check=False, id=None):
return cls._primitive.bind(*wires, unitary_matrix, check=check, id=id)

def __init__(self, wires, unitary_matrix, check=False, id=None):
M, N = unitary_matrix.shape
M, N = qml.math.shape(unitary_matrix)

if M != N:
raise ValueError(
f"The unitary matrix should be of shape NxN, got {unitary_matrix.shape}"
f"The unitary matrix should be of shape NxN, got {qml.math.shape(unitary_matrix)}"
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
)

if check:
umat = qml.math.toarray(unitary_matrix)
if not np.allclose(umat @ umat.conj().T, np.eye(M, dtype=complex), atol=1e-6):
umat = qml.math.copy(unitary_matrix)
if not qml.math.allclose(
umat @ umat.conj().T, qml.math.eye(M, dtype=complex), atol=1e-6
):
raise ValueError("The provided transformation matrix should be unitary.")

if len(wires) < 2:
Expand Down Expand Up @@ -153,35 +154,39 @@ def compute_decomposition(
list[.Operator]: decomposition of the operator
"""

M, N = unitary_matrix.shape
M, N = qml.math.shape(unitary_matrix)
if M != N:
raise ValueError(
f"The unitary matrix should be of shape NxN, got {unitary_matrix.shape}"
)

if check:
umat = qml.math.toarray(unitary_matrix)
if not np.allclose(umat @ umat.conj().T, np.eye(M, dtype=complex), atol=1e-4):
raise ValueError("The provided transformation matrix should be unitary.")
umat = qml.math.copy(unitary_matrix)
if not qml.math.is_abstract(unitary_matrix):
willjmax marked this conversation as resolved.
Show resolved Hide resolved
if not qml.math.allclose(
umat @ umat.conj().T, qml.math.eye(M, dtype=complex), atol=1e-4
):
raise ValueError("The provided transformation matrix should be unitary.")

if len(wires) < 2:
raise ValueError(f"This template requires at least two wires, got {len(wires)}")

op_list = []

phase_list, givens_list = givens_decomposition(unitary_matrix)

for idx, phase in enumerate(phase_list):
op_list.append(qml.PhaseShift(np.angle(phase), wires=wires[idx]))
op_list.append(qml.PhaseShift(qml.math.angle(phase), wires=wires[idx]))

for grot_mat, indices in givens_list:
theta = np.arccos(np.real(grot_mat[1, 1]))
phi = np.angle(grot_mat[0, 0])
theta = qml.math.arccos(qml.math.real(grot_mat[1, 1]))
phi = qml.math.angle(grot_mat[0, 0])

op_list.append(
qml.SingleExcitation(2 * theta, wires=[wires[indices[0]], wires[indices[1]]])
)

if not np.isclose(phi, 0.0):
if qml.math.is_abstract(phi) or not qml.math.isclose(phi, 0.0):
op_list.append(qml.PhaseShift(phi, wires=wires[indices[0]]))

return op_list
Expand Down
34 changes: 12 additions & 22 deletions tests/templates/test_subroutines/test_basis_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,12 @@ def test_id(self):
assert template.id == "a"


def circuit_template(unitary_matrix):
qml.BasisState(np.array([1, 1, 0]), wires=[0, 1, 2])
def circuit_template(unitary_matrix, check=False):
qml.BasisState(qml.math.array([1, 1, 0]), wires=[0, 1, 2])
qml.BasisRotation(
wires=range(3),
unitary_matrix=unitary_matrix,
check=check,
)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

Expand Down Expand Up @@ -402,7 +403,7 @@ def test_autograd(self, tol):
assert np.allclose(grads, np.zeros_like(unitary_matrix, dtype=complex), atol=tol, rtol=0)

@pytest.mark.jax
def test_jax(self, tol):
def test_jax_jit(self, tol):
"""Test the jax interface."""

import jax
Expand All @@ -415,36 +416,25 @@ def test_jax(self, tol):
[-0.58608928 + 0.0j, 0.03902657 + 0.04633548j, -0.57220635 + 0.57044649j],
]
)
weights = jnp.array(
[
2.2707802713289267,
2.9355948424220206,
-1.4869222527726533,
1.2601662579297865,
2.3559705032936717,
1.1748572730890159,
2.2500537657656356,
-0.7251404204443089,
2.3577346350335198,
]
)

dev = qml.device("default.qubit", wires=3)

circuit = qml.QNode(circuit_template, dev)
circuit2 = qml.QNode(circuit_decomposed, dev)
circuit = jax.jit(qml.QNode(circuit_template, dev), static_argnames="check")
circuit2 = qml.QNode(circuit_template, dev)
soranjh marked this conversation as resolved.
Show resolved Hide resolved

res = circuit(unitary_matrix)
res2 = circuit2(weights)
assert qml.math.allclose(res, res2, atol=tol, rtol=0)
res2 = circuit2(unitary_matrix)
res3 = circuit2(qml.math.toarray(unitary_matrix))
assert jnp.allclose(res, res2, atol=tol, rtol=0)
assert qml.math.allclose(res, res3, atol=tol, rtol=0)
willjmax marked this conversation as resolved.
Show resolved Hide resolved

grad_fn = jax.grad(circuit)
grads = grad_fn(unitary_matrix)

grad_fn2 = jax.grad(circuit2)
grads2 = grad_fn2(weights)
grads2 = grad_fn2(unitary_matrix)

assert np.allclose(grads[0], grads2[0], atol=tol, rtol=0)
assert qml.math.allclose(grads[0], grads2[0], atol=tol, rtol=0)

@pytest.mark.tf
def test_tf(self, tol):
Expand Down
Loading