Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] BasisRotation is not JAX compatible #6004

Open
josh146 opened this issue Jul 17, 2024 · 2 comments · May be fixed by #6019
Open

[BUG] BasisRotation is not JAX compatible #6004

josh146 opened this issue Jul 17, 2024 · 2 comments · May be fixed by #6019
Labels
bug 🐛 Something isn't working

Comments

@josh146
Copy link
Member

josh146 commented Jul 17, 2024

Expected behavior

qml.BasisRotation is not supported with jax.jit or qml.qjit.

Actual behavior

dev = qml.device("lightning.qubit", wires=4)

@jax.jit
@qml.qnode(dev)
def f(U):
    for i in range(4):
        qml.Hadamard(i)
    qml.BasisRotation(unitary_matrix=U, wires=[0, 1, 2, 3], check=False)
    return qml.expval(qml.PauliZ(0))
>>> from scipy.stats import unitary_group
>>> U = unitary_group.rvs(4)
>>> f(U)
[/usr/local/lib/python3.10/dist-packages/pennylane/operation.py](https://localhost:8080/#) in decomposition(self)
   1318             list[Operator]: decomposition of the operator
   1319         """
-> 1320         return self.compute_decomposition(
   1321             *self.parameters, wires=self.wires, **self.hyperparameters
   1322         )

[/usr/local/lib/python3.10/dist-packages/pennylane/templates/subroutines/basis_rotation.py](https://localhost:8080/#) in compute_decomposition(wires, unitary_matrix, check)
    169 
    170         op_list = []
--> 171         phase_list, givens_list = givens_decomposition(unitary_matrix)
    172 
    173         for idx, phase in enumerate(phase_list):

[/usr/local/lib/python3.10/dist-packages/pennylane/qchem/givens_decomposition.py](https://localhost:8080/#) in givens_decomposition(unitary)
    149     """
    150 
--> 151     unitary, (M, N) = qml.math.toarray(unitary).copy(), unitary.shape
    152     if M != N:
    153         raise ValueError(f"The unitary matrix should be of shape NxN, got {unitary.shape}")

[/usr/local/lib/python3.10/dist-packages/autoray/autoray.py](https://localhost:8080/#) in do(fn, like, *args, **kwargs)
     79     backend = _choose_backend(fn, args, kwargs, like=like)
     80     func = get_lib_fn(backend, fn)
---> 81     return func(*args, **kwargs)
     82 
     83 

[/usr/local/lib/python3.10/dist-packages/pennylane/math/single_dispatch.py](https://localhost:8080/#) in _to_numpy_jax(x)
    783         return np.array(getattr(x, "val", x))
    784     except TracerArrayConversionError as e:
--> 785         raise ValueError(
    786             "Converting a JAX array to a NumPy array not supported when using the JAX JIT."
    787         ) from e

ValueError: Converting a JAX array to a NumPy array not supported when using the JAX JIT.

Additional information

This is occurring because the decomposition for qml.BasisRotation calls qml.qchem.givens_decomposition function, which is not JAX compatible:

  • The unitary matrix U is being converted to a NumPy array and copied
  • The copied unitary matrix is being updated in place.
  • NumPy functions (rather than qml.math functions) are used
  • Exceptions based on value are being raised.
@josh146 josh146 added the bug 🐛 Something isn't working label Jul 17, 2024
@josh146 josh146 changed the title [BUG] [BUG] BasisRotation is not JAX compatible Jul 17, 2024
@trbromley
Copy link
Contributor

@josh146 what priority would you assign to this? Perhaps a P1?

@isaacdevlugt
Copy link
Contributor

@trbromley there are a few other related bugs (I think?)

#6006
#6007
#6008

@KetpuntoG is working on simplifying our stateprep suite based on this epic: https://app.shortcut.com/xanaduai/epic/66499?group_by=none&vc_group_by=day&ct_workflow=all&cf_workflow=500000005. There might be some dragons he runs into when completing this work related to these bugs. Maybe best to wait for him to get started on the work and see if these bugs are blockers, then we assign PX?

@PabloAMC PabloAMC linked a pull request Jul 22, 2024 that will close this issue
5 tasks
@PabloAMC PabloAMC linked a pull request Jul 22, 2024 that will close this issue
5 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug 🐛 Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants