Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] qml.BasisRotation is not jax compatible #6019

Open
wants to merge 55 commits into
base: master
Choose a base branch
from

Conversation

PabloAMC
Copy link
Contributor

@PabloAMC PabloAMC commented Jul 22, 2024

Before submitting

Please complete the following checklist when submitting a PR:

  • All new features must include a unit test.
    If you've fixed a bug or added code that should be tested, add a test to the
    test directory!

  • All new functions and code must be clearly commented and documented.
    If you do make documentation changes, make sure that the docs build and
    render correctly by running make docs.

  • Ensure that the test suite passes, by running make test.

  • Add a new entry to the doc/releases/changelog-dev.md file, summarizing the
    change, and including a link back to the PR.

  • The PennyLane source code conforms to
    PEP8 standards.
    We check all of our code against Pylint.
    To lint modified files, simply pip install pylint, and then
    run pylint pennylane/path/to/file.py.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


Context: qml.BasisRotation is was not jit-compatible.

Description of the Change: We modified all the numpy arrays to jax.numpy and ensure the tests were passing

Benefits: The basis rotation is jittable and thus qjit compatible.

Possible Drawbacks: jax numpy may be slower than basis numpy.

Related GitHub Issues: #6004

@PabloAMC PabloAMC linked an issue Jul 22, 2024 that may be closed by this pull request
Copy link

codecov bot commented Jul 23, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 99.58%. Comparing base (d0344b0) to head (def0979).
Report is 4 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #6019      +/-   ##
==========================================
- Coverage   99.59%   99.58%   -0.01%     
==========================================
  Files         443      443              
  Lines       42255    42274      +19     
==========================================
+ Hits        42082    42097      +15     
- Misses        173      177       +4     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

pennylane/qchem/givens_decomposition.py Outdated Show resolved Hide resolved
pennylane/qchem/givens_decomposition.py Outdated Show resolved Hide resolved
pennylane/qchem/givens_decomposition.py Outdated Show resolved Hide resolved
@josh146
Copy link
Member

josh146 commented Jul 26, 2024

Thanks @obliviateandsurrender and @PabloAMC!

@obliviateandsurrender something odd I notice: using this PR, the following works:

dev = qml.device("lightning.qubit", wires=4)

@jax.jit
@qml.qnode(dev)
def f(U):
    for i in range(4):
        qml.Hadamard(i)
    qml.BasisRotation.compute_decomposition(unitary_matrix=U, wires=[0, 1, 2, 3], check=False)
    return qml.expval(qml.PauliZ(0))

from scipy.stats import unitary_group
U = jnp.array(unitary_group.rvs(4))
f(U)

But it no longer works if you remove the compute_decomposition (I still get an error message ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.).

@obliviateandsurrender
Copy link
Contributor

obliviateandsurrender commented Jul 26, 2024

Hey @josh146! I am not sure if it is a problem with the template cause it works with default.qubit.
I suspect the error is coming from the convert_to_numpy_parameters transform being used in the lightning.qubit for under _execute_jvp function. I think the error comes from interface being recognized as jax instead of jax-jit for lightning.qubit. Manually setting the interface like below works for me -

dev = qml.device("lightning.qubit", wires=4)

@jax.jit
@qml.qnode(dev, interface="jax-jit")
def f(U):
    for i in range(4):
        qml.Hadamard(i)
    qml.BasisRotation(unitary_matrix=U, wires=[0, 1, 2, 3], check=False)
    return qml.expval(qml.PauliZ(0))
U = jnp.array(unitary_group.rvs(4))
>>> f(U)
Array(-0.02033669, dtype=float64)

Copy link
Contributor

@soranjh soranjh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @PabloAMC, left some comments. The failing test can be fixed by running black on givens_decomposition.py:

black -l 100 givens_decomposition.py

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
@@ -147,43 +147,68 @@ def givens_decomposition(U):
# Update U = T(N+j-i-1, N+j-i) @ U

"""
interface = qml.math.get_deep_interface(unitary)
unitary = qml.math.copy(unitary) if interface == "jax" else qml.math.toarray(unitary).copy()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

qml.math.copy is not general enough for all interfaces?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it is not. 😅

Comment on lines 158 to 185
if interface == "jax":
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary = unitary.at[:, indices].set(unitary[:, indices] @ grot_mat.T, indices_are_sorted=True, unique_indices=True)
right_givens.append((qml.math.conj(grot_mat), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary = unitary.at[indices, :].set(grot_mat @ unitary[indices, :], indices_are_sorted=True, unique_indices=True)
left_givens.append((grot_mat, indices))
else:
for i in range(1, N):
if i % 2:
for j in range(0, i):
indices = [i - j - 1, i - j]
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True)
unitary[:, indices] = unitary[:, indices] @ grot_mat.T
right_givens.append((qml.math.conj(grot_mat), indices))
else:
for j in range(1, i + 1):
indices = [N + j - i - 2, N + j - i - 1]
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False)
unitary[indices] = grot_mat @ unitary[indices]
left_givens.append((grot_mat, indices))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be come more compact. There is a lot of repetition and the interface check can be just put before the lines that do not work with jax.

Also, is it possible to be general enough such that there is no need for checking the interface and writting interface-specific code blocks?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, is it possible to be general enough such that there is no need for checking the interface and writing interface-specific code blocks?

Nope. The best I could do was to separate out the interface checking to a separate function altogether. I was trying to prevent it since each if condition check might lead to a tiny overhead but at least this way codebase becomes more manageable and modular.

Comment on lines 200 to 205
if interface == "jax":
unitary = unitary.at[i, i].set(qml.math.diag(nphase_mat)[0])
unitary = unitary.at[j, j].set(qml.math.diag(nphase_mat)[1])
else:
unitary[i, i], unitary[j, j] = qml.math.diag(nphase_mat)
nleft_givens.append((qml.math.conj(givens_mat), (i, j)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nothing in qml.math can make this block general enough to work with any interface?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.at and .set is unique to JAX array since those are immutable and do not support item assignments. I did try making it work via qml.math but couldn't.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the closest thing I could find in qml.math was this, but it doesn't completely capture our purpose and might cause performance degradation due to jax array re-initializing. So, I have separated the index setting out to a separate private method.

if not np.allclose(umat @ umat.conj().T, np.eye(M, dtype=complex), atol=1e-4):
raise ValueError("The provided transformation matrix should be unitary.")
umat = qml.math.copy(unitary_matrix)
if not qml.math.is_abstract(unitary_matrix):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this check is needed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This prevents the below unitary check, which would be incompatible in a JIT-environment.

Comment on lines +428 to +429
assert jnp.allclose(res, res2, atol=tol, rtol=0)
assert qml.math.allclose(res, res3, atol=tol, rtol=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this for checking both jax and autograd at the same time? Seems to me that assert qml.math.allclose should be enough?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, qml.math.allclose should be sufficient for both. I use jnp.allclose in the first one as I know both res and res2 must be JAX-arrays.

doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
doc/releases/changelog-dev.md Outdated Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a blocker for this PR, but it seems that BasisRotation is not included in the Templates doc here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
It is present 🤔

@Alex-Preciado Alex-Preciado changed the title 6004 bug basisrotation is not jax compatible [Bugfix] qml.BasisRotation is not jax compatible Aug 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[BUG] BasisRotation is not JAX compatible
5 participants