Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] qml.BasisRotation is not jax compatible #6019

Open
wants to merge 55 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 43 commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
76abf63
JAX compatible givens decompositions. Unit tests passing.
PabloAMC Jul 19, 2024
c17cf83
Converting to jax.numpy the basis rotation
PabloAMC Jul 22, 2024
cb663d3
Update basis_rotation.py
PabloAMC Jul 22, 2024
cb4d1f7
Update changelog-dev.md
PabloAMC Jul 22, 2024
11a4b27
Update givens_decomposition.py
PabloAMC Jul 22, 2024
d811833
Reverting from jnp. to qml.math.
PabloAMC Jul 22, 2024
7edc75a
Remove import numpy
PabloAMC Jul 22, 2024
d47d601
Unit test for jax jitting
PabloAMC Jul 22, 2024
ca95f7a
Removing whitespaces
PabloAMC Jul 22, 2024
eb77772
qml.math.copy and (not qml.math.is_abstract(unitary_matrix))
PabloAMC Jul 22, 2024
7d6c223
Update basis_rotation.py
PabloAMC Jul 22, 2024
38e3e68
branching due to .at
PabloAMC Jul 22, 2024
07816df
Branching functions with/without jax backend
PabloAMC Jul 22, 2024
cd76203
Update givens_decomposition.py
PabloAMC Jul 22, 2024
75e4f1e
Update test_basis_rotation.py
PabloAMC Jul 22, 2024
8bfba22
Update givens_decomposition.py
PabloAMC Jul 22, 2024
34d4257
Update givens_decomposition.py
PabloAMC Jul 22, 2024
379b630
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 22, 2024
96028a2
Update test_basis_rotation.py
PabloAMC Jul 22, 2024
bc34c6f
Update basis_rotation.py
PabloAMC Jul 22, 2024
2ccc9ab
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
PabloAMC Jul 22, 2024
57bfc8a
Update changelog-dev.md
PabloAMC Jul 22, 2024
650379c
Update changelog-dev.md
PabloAMC Jul 22, 2024
a563754
Checks for abstract arrays
PabloAMC Jul 22, 2024
e33bf03
make jax-jit friendly
obliviateandsurrender Jul 23, 2024
43686b9
Update basis_rotation.py
PabloAMC Jul 23, 2024
2b107c1
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
PabloAMC Jul 23, 2024
75edef1
Update givens_decomposition.py
PabloAMC Jul 23, 2024
2c7fffd
Revert "Update givens_decomposition.py"
PabloAMC Jul 23, 2024
2b789cc
Updating the documentation
PabloAMC Jul 23, 2024
a8c2d99
minor fixes
obliviateandsurrender Jul 23, 2024
182b173
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
obliviateandsurrender Jul 23, 2024
7ef90fb
assertion for workflow comparison
obliviateandsurrender Jul 23, 2024
e8351dd
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
obliviateandsurrender Jul 23, 2024
5cfabaf
Update changelog-dev.md
PabloAMC Jul 23, 2024
b825617
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 23, 2024
9010de1
Update doc/releases/changelog-dev.md
PabloAMC Jul 24, 2024
5541c89
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 24, 2024
62c4c23
Updating docstrings
PabloAMC Jul 29, 2024
9fb4d26
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 29, 2024
8b4eef5
attemp warning fix
obliviateandsurrender Jul 30, 2024
545be3e
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 30, 2024
eec5149
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 30, 2024
4064fa6
Update pennylane/qchem/givens_decomposition.py
PabloAMC Jul 31, 2024
be3df11
Update pennylane/qchem/givens_decomposition.py
PabloAMC Jul 31, 2024
a512ae3
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 31, 2024
7914f53
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Jul 31, 2024
8ebfebe
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Aug 2, 2024
7da2f88
Update givens_decomposition.py
PabloAMC Aug 2, 2024
db61a73
Update changelog-dev.md
PabloAMC Aug 2, 2024
1c752a1
Update doc/releases/changelog-dev.md
PabloAMC Aug 2, 2024
200481d
code cleanup
obliviateandsurrender Aug 5, 2024
af011ac
Merge branch '6004-bug-basisrotation-is-not-jax-compatible' of https:…
obliviateandsurrender Aug 5, 2024
24d13c7
Update qdrift.py
PabloAMC Aug 7, 2024
def0979
Merge branch 'master' into 6004-bug-basisrotation-is-not-jax-compatible
PabloAMC Sep 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,9 @@
* The `qubit_observable` function is modified to return an ascending wire order for molecular
Hamiltonians.
[(#5950)](https://github.com/PennyLaneAI/pennylane/pull/5950)

* `qml.BasisRotation` and `qml.qchem.givens_decomposition` are now jit compatible.
PabloAMC marked this conversation as resolved.
Show resolved Hide resolved
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)

* The `CNOT` operator no longer decomposes to itself. Instead, it raises a `qml.DecompositionUndefinedError`.
[(#6039)](https://github.com/PennyLaneAI/pennylane/pull/6039)
Expand Down Expand Up @@ -257,11 +260,13 @@
[(#5978)](https://github.com/PennyLaneAI/pennylane/pull/5978)

* `qml.AmplitudeEmbedding` has better support for features using low precision integer data types.
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)

* `qml.lie_closure` works with sums of Paulis.
[(#6023)](https://github.com/PennyLaneAI/pennylane/pull/6023)

* `qml.BasisRotation` works with qjit.
[(#6019)](https://github.com/PennyLaneAI/pennylane/pull/6019)
obliviateandsurrender marked this conversation as resolved.
Show resolved Hide resolved
PabloAMC marked this conversation as resolved.
Show resolved Hide resolved

<h3>Contributors ✍️</h3>

Expand All @@ -287,7 +292,8 @@ Christina Lee,
William Maxwell,
Vincent Michaud-Rioux,
Anurav Modak,
Pablo A. Moreno Casares,
Mudit Pandey,
Erik Schultheis,
nate stemen,
David Wierichs,
Nate Stemen,
David Wierichs.
99 changes: 62 additions & 37 deletions pennylane/qchem/givens_decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
This module contains the functions needed for performing basis transformations defined by a set of fermionic ladder operators.
"""

import numpy as np

import pennylane as qml


Expand All @@ -38,26 +36,28 @@ def _givens_matrix(a, b, left=True, tol=1e-8):
tol (float): determines tolerance limits for :math:`|a|` and :math:`|b|` under which they are considered as zero.

Returns:
np.ndarray (or tensor): Givens rotation matrix
tensor_like: Givens rotation matrix

"""
abs_a, abs_b = np.abs(a), np.abs(b)
if abs_a < tol:
cosine, sine, phase = 1.0, 0.0, 1.0
elif abs_b < tol:
cosine, sine, phase = 0.0, 1.0, 1.0
else:
hypot = np.hypot(abs_a, abs_b)
cosine = abs_b / hypot
sine = abs_a / hypot
phase = 1.0 * b / abs_b * a.conjugate() / abs_a
abs_a, abs_b, interface = qml.math.abs(a), qml.math.abs(b), qml.math.get_interface(a)
aprod = qml.math.nan_to_num(abs_b * abs_a)
hypot = qml.math.hypot(abs_a, abs_b)

cosine = qml.math.where(abs_a < tol, 1.0, qml.math.where(abs_b < tol, 0.0, abs_b / hypot))
sine = qml.math.where(abs_a < tol, 0.0, qml.math.where(abs_b < tol, 1.0, abs_a / hypot))
phase = qml.math.where(
abs_a < tol,
1.0,
qml.math.where(abs_b < tol, 1.0, (1.0 * b * qml.math.conj(a)) / (aprod + 1e-15)),
)

if left:
return np.array([[phase * cosine, -sine], [phase * sine, cosine]])
return qml.math.array([[phase * cosine, -sine], [phase * sine, cosine]], like=interface)

return np.array([[phase * sine, cosine], [-phase * cosine, sine]])
return qml.math.array([[phase * sine, cosine], [-phase * cosine, sine]], like=interface)


# pylint:disable = too-many-branches
def givens_decomposition(unitary):
r"""Decompose a unitary into a sequence of Givens rotation gates with phase shifts and a diagonal phase matrix.

Expand Down Expand Up @@ -110,7 +110,7 @@ def givens_decomposition(unitary):
unitary (tensor): unitary matrix on which decomposition will be performed

Returns:
(np.ndarray, list[(np.ndarray, tuple)]): diagonal elements of the phase matrix :math:`D` and Givens rotation matrix :math:`T` with their indices.
(tensor_like, list[(tensor_like, tuple)]): diagonal elements of the phase matrix :math:`D` and Givens rotation matrix :math:`T` with their indices.

Raises:
ValueError: if the provided matrix is not square.
Expand Down Expand Up @@ -147,43 +147,68 @@ def givens_decomposition(U):
# Update U = T(N+j-i-1, N+j-i) @ U

"""
interface = qml.math.get_deep_interface(unitary)
unitary = qml.math.copy(unitary) if interface == "jax" else qml.math.toarray(unitary).copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qml.math.copy is not general enough for all interfaces?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is not. 😅

M, N = qml.math.shape(unitary)

unitary, (M, N) = qml.math.toarray(unitary).copy(), unitary.shape
if M != N:
raise ValueError(f"The unitary matrix should be of shape NxN, got {unitary.shape}")

left_givens, right_givens = [], []
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary[:, indices] = unitary[:, indices] @ grot_mat.T
right_givens.append((grot_mat.conj(), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary[indices] = grot_mat @ unitary[indices]
left_givens.append((grot_mat, indices))
if interface == "jax":
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary = unitary.at[:, indices].set(unitary[:, indices] @ grot_mat.T)
PabloAMC marked this conversation as resolved.
Show resolved Hide resolved
right_givens.append((qml.math.conj(grot_mat), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary = unitary.at[indices, :].set(grot_mat @ unitary[indices, :])
PabloAMC marked this conversation as resolved.
Show resolved Hide resolved
left_givens.append((grot_mat, indices))
else:
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary[:, indices] = unitary[:, indices] @ grot_mat.T
right_givens.append((qml.math.conj(grot_mat), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary[indices] = grot_mat @ unitary[indices]
left_givens.append((grot_mat, indices))

nleft_givens = []
for grot_mat, (i, j) in reversed(left_givens):
sphase_mat = np.diag(np.diag(unitary)[[i, j]])
decomp_mat = grot_mat.conj().T @ sphase_mat
sphase_mat = qml.math.diag(qml.math.diag(unitary)[qml.math.array([i, j])])
decomp_mat = qml.math.conj(grot_mat).T @ sphase_mat
givens_mat = _givens_matrix(*decomp_mat[1, :].T)
nphase_mat = decomp_mat @ givens_mat.T

# check for T_{m,n}^{-1} x D = D x T.
if not np.allclose(nphase_mat @ givens_mat.conj(), decomp_mat): # pragma: no cover
if not qml.math.is_abstract(decomp_mat) and not qml.math.allclose(
nphase_mat @ qml.math.conj(givens_mat), decomp_mat
): # pragma: no cover
raise ValueError("Failed to shift phase transposition.")

unitary[i, i], unitary[j, j] = np.diag(nphase_mat)
nleft_givens.append((givens_mat.conj(), (i, j)))
if interface == "jax":
unitary = unitary.at[i, i].set(qml.math.diag(nphase_mat)[0])
unitary = unitary.at[j, j].set(qml.math.diag(nphase_mat)[1])
else:
unitary[i, i], unitary[j, j] = qml.math.diag(nphase_mat)
nleft_givens.append((qml.math.conj(givens_mat), (i, j)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing in qml.math can make this block general enough to work with any interface?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.at and .set is unique to JAX array since those are immutable and do not support item assignments. I did try making it work via qml.math but couldn't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the closest thing I could find in qml.math was this, but it doesn't completely capture our purpose and might cause performance degradation due to jax array re-initializing. So, I have separated the index setting out to a separate private method.


phases, ordered_rotations = np.diag(unitary), []
phases, ordered_rotations = qml.math.diag(unitary), []
for grot_mat, (i, j) in list(reversed(nleft_givens)) + list(reversed(right_givens)):
if not np.all(np.isreal(grot_mat[0, 1]) and np.isreal(grot_mat[1, 1])): # pragma: no cover
if not qml.math.is_abstract(grot_mat) and not qml.math.all(
qml.math.isreal(grot_mat[0, 1]) and qml.math.isreal(grot_mat[1, 1])
): # pragma: no cover
raise ValueError(f"Incorrect Givens Rotation encountered, {grot_mat}")
ordered_rotations.append((grot_mat, (i, j)))

Expand Down
33 changes: 19 additions & 14 deletions pennylane/templates/subroutines/basis_rotation.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR, but it seems that BasisRotation is not included in the Templates doc here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
It is present 🤔

Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
This module contains the template for performing basis transformation defined by a set of fermionic ladder operators.
"""

import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.qchem.givens_decomposition import givens_decomposition
Expand Down Expand Up @@ -110,15 +108,18 @@ def _primitive_bind_call(cls, wires, unitary_matrix, check=False, id=None):
return cls._primitive.bind(*wires, unitary_matrix, check=check, id=id)

def __init__(self, wires, unitary_matrix, check=False, id=None):
M, N = unitary_matrix.shape
M, N = qml.math.shape(unitary_matrix)

if M != N:
raise ValueError(
f"The unitary matrix should be of shape NxN, got {unitary_matrix.shape}"
f"The unitary matrix should be of shape NxN, got {qml.math.shape(unitary_matrix)}"
)

if check:
umat = qml.math.toarray(unitary_matrix)
if not np.allclose(umat @ umat.conj().T, np.eye(M, dtype=complex), atol=1e-6):
umat = qml.math.copy(unitary_matrix)
if not qml.math.allclose(
umat @ umat.conj().T, qml.math.eye(M, dtype=complex), atol=1e-6
):
raise ValueError("The provided transformation matrix should be unitary.")

if len(wires) < 2:
Expand Down Expand Up @@ -153,35 +154,39 @@ def compute_decomposition(
list[.Operator]: decomposition of the operator
"""

M, N = unitary_matrix.shape
M, N = qml.math.shape(unitary_matrix)
if M != N:
raise ValueError(
f"The unitary matrix should be of shape NxN, got {unitary_matrix.shape}"
)

if check:
umat = qml.math.toarray(unitary_matrix)
if not np.allclose(umat @ umat.conj().T, np.eye(M, dtype=complex), atol=1e-4):
raise ValueError("The provided transformation matrix should be unitary.")
umat = qml.math.copy(unitary_matrix)
if not qml.math.is_abstract(unitary_matrix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this check is needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prevents the below unitary check, which would be incompatible in a JIT-environment.

if not qml.math.allclose(
umat @ umat.conj().T, qml.math.eye(M, dtype=complex), atol=1e-4
):
raise ValueError("The provided transformation matrix should be unitary.")

if len(wires) < 2:
raise ValueError(f"This template requires at least two wires, got {len(wires)}")

op_list = []

phase_list, givens_list = givens_decomposition(unitary_matrix)

for idx, phase in enumerate(phase_list):
op_list.append(qml.PhaseShift(np.angle(phase), wires=wires[idx]))
op_list.append(qml.PhaseShift(qml.math.angle(phase), wires=wires[idx]))

for grot_mat, indices in givens_list:
theta = np.arccos(np.real(grot_mat[1, 1]))
phi = np.angle(grot_mat[0, 0])
theta = qml.math.arccos(qml.math.real(grot_mat[1, 1]))
phi = qml.math.angle(grot_mat[0, 0])

op_list.append(
qml.SingleExcitation(2 * theta, wires=[wires[indices[0]], wires[indices[1]]])
)

if not np.isclose(phi, 0.0):
if qml.math.is_abstract(phi) or not qml.math.isclose(phi, 0.0):
op_list.append(qml.PhaseShift(phi, wires=wires[indices[0]]))

return op_list
Expand Down
34 changes: 12 additions & 22 deletions tests/templates/test_subroutines/test_basis_rotation.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,12 @@ def test_id(self):
assert template.id == "a"


def circuit_template(unitary_matrix):
qml.BasisState(np.array([1, 1, 0]), wires=[0, 1, 2])
def circuit_template(unitary_matrix, check=False):
qml.BasisState(qml.math.array([1, 1, 0]), wires=[0, 1, 2])
qml.BasisRotation(
wires=range(3),
unitary_matrix=unitary_matrix,
check=check,
)
return qml.expval(qml.PauliZ(0) @ qml.PauliZ(1))

Expand Down Expand Up @@ -402,7 +403,7 @@ def test_autograd(self, tol):
assert np.allclose(grads, np.zeros_like(unitary_matrix, dtype=complex), atol=tol, rtol=0)

@pytest.mark.jax
def test_jax(self, tol):
def test_jax_jit(self, tol):
"""Test the jax interface."""

import jax
Expand All @@ -415,36 +416,25 @@ def test_jax(self, tol):
[-0.58608928 + 0.0j, 0.03902657 + 0.04633548j, -0.57220635 + 0.57044649j],
]
)
weights = jnp.array(
[
2.2707802713289267,
2.9355948424220206,
-1.4869222527726533,
1.2601662579297865,
2.3559705032936717,
1.1748572730890159,
2.2500537657656356,
-0.7251404204443089,
2.3577346350335198,
]
)

dev = qml.device("default.qubit", wires=3)

circuit = qml.QNode(circuit_template, dev)
circuit2 = qml.QNode(circuit_decomposed, dev)
circuit = jax.jit(qml.QNode(circuit_template, dev), static_argnames="check")
circuit2 = qml.QNode(circuit_template, dev)

res = circuit(unitary_matrix)
res2 = circuit2(weights)
assert qml.math.allclose(res, res2, atol=tol, rtol=0)
res2 = circuit2(unitary_matrix)
res3 = circuit2(qml.math.toarray(unitary_matrix))
assert jnp.allclose(res, res2, atol=tol, rtol=0)
assert qml.math.allclose(res, res3, atol=tol, rtol=0)
Comment on lines +428 to +429
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for checking both jax and autograd at the same time? Seems to me that assert qml.math.allclose should be enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, qml.math.allclose should be sufficient for both. I use jnp.allclose in the first one as I know both res and res2 must be JAX-arrays.


grad_fn = jax.grad(circuit)
grads = grad_fn(unitary_matrix)

grad_fn2 = jax.grad(circuit2)
grads2 = grad_fn2(weights)
grads2 = grad_fn2(unitary_matrix)

assert np.allclose(grads[0], grads2[0], atol=tol, rtol=0)
assert qml.math.allclose(grads[0], grads2[0], atol=tol, rtol=0)

@pytest.mark.tf
def test_tf(self, tol):
Expand Down
Loading