-
Notifications
You must be signed in to change notification settings - Fork 586
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] qml.BasisRotation
is not jax compatible
#6019
base: master
Are you sure you want to change the base?
Conversation
…//github.com/PennyLaneAI/pennylane into 6004-bug-basisrotation-is-not-jax-compatible
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #6019 +/- ##
==========================================
- Coverage 99.59% 99.58% -0.01%
==========================================
Files 443 443
Lines 42255 42274 +19
==========================================
+ Hits 42082 42097 +15
- Misses 173 177 +4 ☔ View full report in Codecov by Sentry. |
…//github.com/PennyLaneAI/pennylane into 6004-bug-basisrotation-is-not-jax-compatible
Thanks @obliviateandsurrender and @PabloAMC! @obliviateandsurrender something odd I notice: using this PR, the following works: dev = qml.device("lightning.qubit", wires=4)
@jax.jit
@qml.qnode(dev)
def f(U):
for i in range(4):
qml.Hadamard(i)
qml.BasisRotation.compute_decomposition(unitary_matrix=U, wires=[0, 1, 2, 3], check=False)
return qml.expval(qml.PauliZ(0))
from scipy.stats import unitary_group
U = jnp.array(unitary_group.rvs(4))
f(U) But it no longer works if you remove the |
Hey @josh146! I am not sure if it is a problem with the template cause it works with dev = qml.device("lightning.qubit", wires=4)
@jax.jit
@qml.qnode(dev, interface="jax-jit")
def f(U):
for i in range(4):
qml.Hadamard(i)
qml.BasisRotation(unitary_matrix=U, wires=[0, 1, 2, 3], check=False)
return qml.expval(qml.PauliZ(0))
U = jnp.array(unitary_group.rvs(4)) >>> f(U)
Array(-0.02033669, dtype=float64) |
Co-authored-by: Josh Izaac <[email protected]>
Co-authored-by: Josh Izaac <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @PabloAMC, left some comments. The failing test can be fixed by running black on givens_decomposition.py:
black -l 100 givens_decomposition.py
@@ -147,43 +147,68 @@ def givens_decomposition(U): | |||
# Update U = T(N+j-i-1, N+j-i) @ U | |||
|
|||
""" | |||
interface = qml.math.get_deep_interface(unitary) | |||
unitary = qml.math.copy(unitary) if interface == "jax" else qml.math.toarray(unitary).copy() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
qml.math.copy
is not general enough for all interfaces?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it is not. 😅
if interface == "jax": | ||
for i in range(1, N): | ||
if i % 2: | ||
for j in range(0, i): | ||
indices = [i - j - 1, i - j] | ||
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True) | ||
unitary = unitary.at[:, indices].set(unitary[:, indices] @ grot_mat.T, indices_are_sorted=True, unique_indices=True) | ||
right_givens.append((qml.math.conj(grot_mat), indices)) | ||
else: | ||
for j in range(1, i + 1): | ||
indices = [N + j - i - 2, N + j - i - 1] | ||
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False) | ||
unitary = unitary.at[indices, :].set(grot_mat @ unitary[indices, :], indices_are_sorted=True, unique_indices=True) | ||
left_givens.append((grot_mat, indices)) | ||
else: | ||
for i in range(1, N): | ||
if i % 2: | ||
for j in range(0, i): | ||
indices = [i - j - 1, i - j] | ||
grot_mat = _givens_matrix(*unitary[N - j - 1, indices].T, left=True) | ||
unitary[:, indices] = unitary[:, indices] @ grot_mat.T | ||
right_givens.append((qml.math.conj(grot_mat), indices)) | ||
else: | ||
for j in range(1, i + 1): | ||
indices = [N + j - i - 2, N + j - i - 1] | ||
grot_mat = _givens_matrix(*unitary[indices, j - 1], left=False) | ||
unitary[indices] = grot_mat @ unitary[indices] | ||
left_givens.append((grot_mat, indices)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This could be come more compact. There is a lot of repetition and the interface check can be just put before the lines that do not work with jax.
Also, is it possible to be general enough such that there is no need for checking the interface and writting interface-specific code blocks?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, is it possible to be general enough such that there is no need for checking the interface and writing interface-specific code blocks?
Nope. The best I could do was to separate out the interface checking to a separate function altogether. I was trying to prevent it since each if
condition check might lead to a tiny overhead but at least this way codebase becomes more manageable and modular.
if interface == "jax": | ||
unitary = unitary.at[i, i].set(qml.math.diag(nphase_mat)[0]) | ||
unitary = unitary.at[j, j].set(qml.math.diag(nphase_mat)[1]) | ||
else: | ||
unitary[i, i], unitary[j, j] = qml.math.diag(nphase_mat) | ||
nleft_givens.append((qml.math.conj(givens_mat), (i, j))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nothing in qml.math
can make this block general enough to work with any interface?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.at
and .set
is unique to JAX array since those are immutable and do not support item assignments. I did try making it work via qml.math
but couldn't.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So the closest thing I could find in qml.math
was this, but it doesn't completely capture our purpose and might cause performance degradation due to jax
array re-initializing. So, I have separated the index setting out to a separate private method.
if not np.allclose(umat @ umat.conj().T, np.eye(M, dtype=complex), atol=1e-4): | ||
raise ValueError("The provided transformation matrix should be unitary.") | ||
umat = qml.math.copy(unitary_matrix) | ||
if not qml.math.is_abstract(unitary_matrix): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this check is needed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This prevents the below unitary
check, which would be incompatible in a JIT-environment.
assert jnp.allclose(res, res2, atol=tol, rtol=0) | ||
assert qml.math.allclose(res, res3, atol=tol, rtol=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this for checking both jax
and autograd
at the same time? Seems to me that assert qml.math.allclose
should be enough?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, qml.math.allclose
should be sufficient for both. I use jnp.allclose
in the first one as I know both res
and res2
must be JAX-arrays.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not a blocker for this PR, but it seems that BasisRotation
is not included in the Templates doc here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Co-authored-by: soranjh <[email protected]>
…//github.com/PennyLaneAI/pennylane into 6004-bug-basisrotation-is-not-jax-compatible
qml.BasisRotation
is not jax compatible
Before submitting
Please complete the following checklist when submitting a PR:
All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to the
test directory!
All new functions and code must be clearly commented and documented.
If you do make documentation changes, make sure that the docs build and
render correctly by running
make docs
.Ensure that the test suite passes, by running
make test
.Add a new entry to the
doc/releases/changelog-dev.md
file, summarizing thechange, and including a link back to the PR.
The PennyLane source code conforms to
PEP8 standards.
We check all of our code against Pylint.
To lint modified files, simply
pip install pylint
, and thenrun
pylint pennylane/path/to/file.py
.When all the above are checked, delete everything above the dashed
line and fill in the pull request template.
Context:
qml.BasisRotation
is was not jit-compatible.Description of the Change: We modified all the numpy arrays to
jax.numpy
and ensure the tests were passingBenefits: The basis rotation is jittable and thus
qjit
compatible.Possible Drawbacks: jax numpy may be slower than basis numpy.
Related GitHub Issues: #6004