Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Catalyst does not support QJIT-compiling a parameterized circuit with qml.FlipSign #1265

Open
joeycarter opened this issue Nov 1, 2024 · 0 comments
Labels
bug Something isn't working

Comments

@joeycarter
Copy link
Contributor

We discovered this issue when attempting to QJIT-compile a circuit implementing Grover's algorithm.

Consider the following PennyLane program that applies the qml.FlipSign operator:

import numpy as np
import pennylane as qml

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.FlipSign(basis_state, wires=wires)
    return qml.state()

basis_state = np.array([0., 0.])
state = circuit(basis_state)

As expected, the circuit flips the sign of the $|00\rangle$ basis state:

>>> print(state)
[-1.-0.j  0.+0.j  0.+0.j  0.+0.j]

When we attempt to QJIT-compile and execute this circuit, we get an error:

import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qjit
@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.FlipSign(basis_state, wires=wires)
    return qml.state()

basis_state = jnp.array([0., 0.])
state = circuit(basis_state)
Traceback (most recent call last):
...
  File ".../venv/lib/python3.12/site-packages/catalyst/device/decomposition.py", line 82, in catalyst_decomposer
    return op.decomposition()
           ^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/pennylane/operation.py", line 1337, in decomposition
    return self.compute_decomposition(
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/pennylane/templates/subroutines/flip_sign.py", line 144, in compute_decomposition
    if arr_bin[-1] == 0:
       ^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 712, in __bool__
    return self.aval._bool(self)
           ^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/jax/_src/core.py", line 1475, in error
    raise TracerBoolConversionError(arg)
jax.errors.TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

The error occurred in the FlipSign.compute_decomposition() method:

@staticmethod
def compute_decomposition(wires, arr_bin):
    op_list = []

    if arr_bin[-1] == 0:
        op_list.append(qml.X(wires[-1]))

    op_list.append(qml.ctrl(qml.Z(wires[-1]), control=wires[:-1], control_values=arr_bin[:-1]))

    if arr_bin[-1] == 0:
        op_list.append(qml.X(wires[-1]))

    return op_list

The problem is in statements like if arr_bin[-1] == 0, where in the jitted case, arr_bin is a traced JAX array that is being used in Python control flow, which is not allowed.

Compiling the circuit with AutoGraph, @qjit(autograph=True), gives the same error, because AutoGraph is disabled by default for any module in PennyLane. To try to get around this issue, we followed the Adding modules for Autograph conversion docs and tried the following, which results in a different error:

import jax.numpy as jnp
import pennylane as qml
from catalyst import qjit

NUM_QUBITS = 2

dev = qml.device("lightning.qubit", wires=NUM_QUBITS)

@qjit(autograph=True, autograph_include=["pennylane.templates.subroutines.flip_sign"])
@qml.qnode(dev)
def circuit(basis_state):
    wires = list(range(NUM_QUBITS))
    qml.FlipSign(basis_state, wires=wires)
    return qml.state()

basis_state = jnp.array([0.0, 0.0])
state = circuit(basis_state)
Traceback (most recent call last):
...
  File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 579, in converted_call
    return ag_converted_call(fn, args, kwargs, caller_fn_scope, options)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/malt/impl/api.py", line 380, in converted_call
    result = converted_f(*effective_args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/__autograph_generated_file_1_ucoey.py", line 35, in ag____call__
    ag__.if_stmt(ag__.converted_call(ag__.ld(enabled), (), None, fscope), if_body, else_body, get_state, set_state, ('do_return', 'retval_'), 2)
  File ".../venv/lib/python3.12/site-packages/catalyst/autograph/ag_primitives.py", line 132, in if_stmt
    results = functional_cond()
              ^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 736, in __call__
    return self._call_with_quantum_ctx(ctx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 662, in _call_with_quantum_ctx
    _assert_cond_result_structure([s.out_tree() for s in out_sigs])
  File ".../venv/lib/python3.12/site-packages/catalyst/api_extensions/control_flow.py", line 1319, in _assert_cond_result_structure
    raise TypeError(
TypeError: Conditional requires a consistent return structure across all branches! Got PyTreeDef((*, CustomNode(FlipSign[(Wires([0, 1]), (('n', (Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>, Traced<ShapedArray(float64[])>with<DynamicJaxprTrace(level=3/1)>)),))], []))) and PyTreeDef((*, *)).

The appropriate changes to Catalyst and/or PennyLane should be made to add support for the qml.FlipSign operator in QJIT-compiled circuits, where the basis-state input to qml.FlipSign is an input argument to the parameterized circuit.

@joeycarter joeycarter added the bug Something isn't working label Nov 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

1 participant