Skip to content

Commit

Permalink
Fix jax enable x64 (#5960)
Browse files Browse the repository at this point in the history
### Before submitting

Please complete the following checklist when submitting a PR:

- [ ] All new features must include a unit test.
If you've fixed a bug or added code that should be tested, add a test to
the
      test directory!

- [ ] All new functions and code must be clearly commented and
documented.
If you do make documentation changes, make sure that the docs build and
      render correctly by running `make docs`.

- [ ] Ensure that the test suite passes, by running `make test`.

- [ ] Add a new entry to the `doc/releases/changelog-dev.md` file,
summarizing the
      change, and including a link back to the PR.

- [ ] The PennyLane source code conforms to
      [PEP8 standards](https://www.python.org/dev/peps/pep-0008/).
We check all of our code against [Pylint](https://www.pylint.org/).
      To lint modified files, simply `pip install pylint`, and then
      run `pylint pennylane/path/to/file.py`.

When all the above are checked, delete everything above the dashed
line and fill in the pull request template.


------------------------------------------------------------------------------------------------------------

**Context:**
@dwierichs reported the following 
```
import jax
import pennylane as qml

x = jax.numpy.array(0.4, dtype=jax.numpy.float32)

@qml.qnode(qml.device("default.qubit", shots=10))
def node(x):
    qml.RX(x, 0)
    qml.measure(0)
    return qml.expval(qml.Z(0))

print(jax.config.jax_enable_x64)
>>> False # Expected because it was never activated
out = node(x)
print(jax.config.jax_enable_x64)
>>> True # Not expected.
```

**Description of the Change:**
`qml.compiler.active` first checks whether Catalyst is imported at all
to avoid changing `jax_enable_x64` on module initialization.

**Benefits:**

**Possible Drawbacks:**

**Related GitHub Issues:**
[sc-67969]

---------

Co-authored-by: Astral Cai <[email protected]>
  • Loading branch information
vincentmr and astralcai authored Jul 5, 2024
1 parent add20fc commit 9a60574
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 1 deletion.
3 changes: 3 additions & 0 deletions doc/releases/changelog-0.37.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,9 @@ Stay tuned for an in-depth demonstration on using this feature with real-world e

<h3>Bug fixes 🐛</h3>

* `qml.compiler.active` first checks whether Catalyst is imported at all to avoid changing `jax_enable_x64` on module initialization.
[(#5960)](https://github.com/PennyLaneAI/pennylane/pull/5960)

* The `__invert__` dunder method of the `MeasurementValue` class uses an array-valued function.
[(#5955)](https://github.com/PennyLaneAI/pennylane/pull/5955)

Expand Down
4 changes: 3 additions & 1 deletion pennylane/compiler/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import dataclasses
import re
import sys
from collections import defaultdict
from importlib import metadata, reload
from sys import version_info
Expand Down Expand Up @@ -206,6 +207,8 @@ def circuit(phi, theta):
"""

for name, eps in AvailableCompilers.names_entrypoints.items():
if name not in sys.modules:
continue
tracer_loader = eps["context"].load()
if tracer_loader.is_tracing():
return name
Expand Down Expand Up @@ -245,5 +248,4 @@ def circuit(phi, theta):
>>> qml.qjit(circuit)(np.pi, np.pi / 2)
-1.0
"""

return active_compiler() is not None
8 changes: 8 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,14 @@ def circuit(phi, theta):
assert jnp.allclose(circuit(jnp.pi, jnp.pi / 2), 1.0)
assert jnp.allclose(qml.qjit(circuit)(jnp.pi, jnp.pi / 2), -1.0)

@pytest.mark.parametrize("jax_enable_x64", [False, True])
def test_jax_enable_x64(self, jax_enable_x64):
"""Test whether `qml.compiler.active` changes `jax_enable_x64`."""
jax.config.update("jax_enable_x64", jax_enable_x64)
assert jax.config.jax_enable_x64 is jax_enable_x64
qml.compiler.active()
assert jax.config.jax_enable_x64 is jax_enable_x64

def test_qjit_circuit(self):
"""Test JIT compilation of a circuit with 2-qubit"""
dev = qml.device("lightning.qubit", wires=2)
Expand Down

0 comments on commit 9a60574

Please sign in to comment.