Skip to content

Commit

Permalink
Test support for capturing nested control flows (#6111)
Browse files Browse the repository at this point in the history
**Context:** Adds test for asserting correct support for capturing
nested control flows

**Description of the Change:** Adds new tests

**Benefits:**

**Possible Drawbacks:**  N/A

**Related GitHub Issues:** [sc-66776]
  • Loading branch information
obliviateandsurrender authored Aug 31, 2024
1 parent 255cbe2 commit 46e1036
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 0 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

<h3>Improvements 🛠</h3>

* Improve unit testing for capturing of nested control flows.
[(#6111)](https://github.com/PennyLaneAI/pennylane/pull/6111)

* Some custom primitives for the capture project can now be imported via
`from pennylane.capture.primitives import *`.
[(#6129)](https://github.com/PennyLaneAI/pennylane/pull/6129)
Expand All @@ -25,5 +28,6 @@

This release contains contributions from (in alphabetical order):

Utkarsh Azad
Jack Brown
Christina Lee
67 changes: 67 additions & 0 deletions tests/capture/test_capture_cond.py
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,73 @@ def f(*x):

assert np.allclose(res, expected, atol=atol, rtol=0), f"Expected {expected}, but got {res}"

@pytest.mark.parametrize("upper_bound, arg", [(3, [0.1, 0.3, 0.5]), (2, [2, 7, 12])])
def test_nested_cond_for_while_loop(self, upper_bound, arg):
"""Test that a nested control flows are correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

# Control flow for qml.conds
def true_fn(_):
@qml.for_loop(0, upper_bound, 1)
def loop_fn(i):
qml.Hadamard(wires=i)

loop_fn()

def elif_fn(arg):
qml.RY(arg**2, wires=[2])

def false_fn(arg):
qml.RY(-arg, wires=[2])

@qml.qnode(dev)
def circuit(upper_bound, arg):
qml.RY(-np.pi / 2, wires=[2])
m_0 = qml.measure(2)

# NOTE: qml.cond(m_0, qml.RX)(arg[1], wires=1) doesn't work
def rx_fn():
qml.RX(arg[1], wires=1)

qml.cond(m_0, rx_fn)()

def ry_fn():
qml.RY(arg[1] ** 3, wires=1)

# nested for loops.
# outer for loop updates x
@qml.for_loop(0, upper_bound, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)
m_1 = qml.measure(0)
# NOTE: qml.cond(m_0, qml.RY)(arg[1], wires=1) doesn't work
qml.cond(m_1, ry_fn)()

# inner while loop
@qml.while_loop(lambda j: j < upper_bound)
def inner(j):
qml.RZ(j, wires=0)
qml.RY(x**2, wires=0)
m_2 = qml.measure(0)
qml.cond(m_2, true_fn=true_fn, false_fn=false_fn, elifs=((m_1, elif_fn)))(
arg[0]
)
return j + 1

inner(i + 1)
return x + 0.1

loop_fn_returns(arg[2])

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, upper_bound, *arg)
assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}"


class TestPytree:
"""Test pytree support for cond."""
Expand Down
45 changes: 45 additions & 0 deletions tests/capture/test_capture_for_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,51 @@ def inner(j):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize(
"upper_bound, arg, expected", [(3, 0.5, 0.00223126), (2, 12, 0.2653001)]
)
def test_nested_for_and_while_loop(self, upper_bound, arg, expected):
"""Test that a nested for loop and while loop is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

@qml.qnode(dev)
def circuit(upper_bound, arg):

# for loop with dynamic bounds
@qml.for_loop(0, upper_bound, 1)
def loop_fn(i):
qml.Hadamard(wires=i)

# nested for-while loops.
@qml.for_loop(0, upper_bound, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)

# inner while loop
@qml.while_loop(lambda j: j < upper_bound)
def inner(j):
qml.RZ(j, wires=0)
qml.RY(x**2, wires=0)
return j + 1

inner(i + 1)

return x + 0.1

loop_fn()
loop_fn_returns(arg)

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
assert np.allclose(result, expected), f"Expected {expected}, but got {result}"

jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"


def test_pytree_inputs():
"""Test that for_loop works with pytree inputs and outputs."""
Expand Down
37 changes: 37 additions & 0 deletions tests/capture/test_capture_while_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,43 @@ def inner(j):
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(res_ev_jxpr, expected), f"Expected {expected}, but got {res_ev_jxpr}"

@pytest.mark.parametrize("upper_bound, arg", [(3, 0.5), (2, 12)])
def test_while_and_for_loop_nested(self, upper_bound, arg):
"""Test that a nested while and for loop is correctly captured into a jaxpr."""

dev = qml.device("default.qubit", wires=3)

def ry_fn(arg):
qml.RY(arg, wires=1)

@qml.qnode(dev)
def circuit(upper_bound, arg):

# while loop with dynamic bounds
@qml.while_loop(lambda i: i < upper_bound)
def loop_fn(i):
qml.Hadamard(wires=i)

@qml.for_loop(0, i, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)
m_0 = qml.measure(0)
qml.cond(m_0, ry_fn)(x)
return i + 1

loop_fn_returns(arg)
return i + 1

loop_fn(0)

return qml.expval(qml.Z(0))

args = [upper_bound, arg]
result = circuit(*args)
jaxpr = jax.make_jaxpr(circuit)(*args)
res_ev_jxpr = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *args)
assert np.allclose(result, res_ev_jxpr), f"Expected {result}, but got {res_ev_jxpr}"


def test_pytree_input_output():
"""Test that the while loop supports pytree input and output."""
Expand Down

0 comments on commit 46e1036

Please sign in to comment.