Skip to content

Commit

Permalink
Merge branch 'master' into ad/move-devs-to-devices-module
Browse files Browse the repository at this point in the history
  • Loading branch information
Shiro-Raven authored Jul 24, 2024
2 parents b1843b2 + be355e0 commit 98ba139
Show file tree
Hide file tree
Showing 3 changed files with 252 additions and 17 deletions.
46 changes: 46 additions & 0 deletions doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,51 @@
* Molecules and Hamiltonians can now be constructed for all the elements present in the periodic table.
[(#5821)](https://github.com/PennyLaneAI/pennylane/pull/5821)

* `qml.for_loop` and `qml.while_loop` now fallback to standard Python control
flow if `@qjit` is not present, allowing the same code to work with and without
`@qjit` without any rewrites.
[(#6014)](https://github.com/PennyLaneAI/pennylane/pull/6014)

```python
dev = qml.device("lightning.qubit", wires=3)

@qml.qnode(dev)
def circuit(x, n):

@qml.for_loop(0, n, 1)
def init_state(i):
qml.Hadamard(wires=i)

init_state()

@qml.for_loop(0, n, 1)
def apply_operations(i, x):
qml.RX(x, wires=i)

@qml.for_loop(i + 1, n, 1)
def inner(j):
qml.CRY(x**2, [i, j])

inner()
return jnp.sin(x)

apply_operations(x)
return qml.probs()
```

```pycon
>>> print(qml.draw(circuit)(0.5, 3))
0: ──H──RX(0.50)─╭●────────╭●──────────────────────────────────────┤ Probs
1: ──H───────────╰RY(0.25)─│──────────RX(0.48)─╭●──────────────────┤ Probs
2: ──H─────────────────────╰RY(0.25)───────────╰RY(0.23)──RX(0.46)─┤ Probs
>>> circuit(0.5, 3)
array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666,
0.11917543, 0.08942104, 0.21545687])
>>> qml.qjit(circuit)(0.5, 3)
Array([0.125 , 0.125 , 0.09949758, 0.15050242, 0.07594666,
0.11917543, 0.08942104, 0.21545687], dtype=float64)
```

* The `qubit_observable` function is modified to return an ascending wire order for molecular
Hamiltonians.
[(#5950)](https://github.com/PennyLaneAI/pennylane/pull/5950)
Expand Down Expand Up @@ -172,6 +217,7 @@ Lillian M. A. Frederiksen,
Pietropaolo Frisoni,
Emiliano Godinez,
Renke Huang,
Josh Izaac,
Soran Jahangiri,
Christina Lee,
Austin Huang,
Expand Down
133 changes: 116 additions & 17 deletions pennylane/compiler/qjit_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""QJIT compatible quantum and compilation operations API"""
from collections.abc import Callable

from .compiler import (
AvailableCompilers,
Expand Down Expand Up @@ -377,20 +378,54 @@ def loop_rx(x):
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.while_loop(cond_fn)

raise CompileError("There is no active compiler package.") # pragma: no cover
# if there is no active compiler, simply interpret the while loop
# via the Python interpretor.
def _decorator(body_fn: Callable) -> Callable:
"""Transform that will call the input ``body_fn`` until the closure variable ``cond_fn`` is met.
Args:
body_fn (Callable):
def for_loop(lower_bound, upper_bound, step):
"""A :func:`~.qjit` compatible for-loop for PennyLane programs.
Closure Variables:
cond_fn (Callable):
.. note::
Returns:
Callable: a callable with the same signature as ``body_fn`` and ``cond_fn``.
"""
return WhileLoopCallable(cond_fn, body_fn)

This function only supports the Catalyst compiler. See
:func:`catalyst.for_loop` for more details.
return _decorator

Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`,
as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`
page for an overview of the differences between Catalyst and PennyLane.

class WhileLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a while loop. This class
when called with an initial state will execute the while
loop via the Python interpreter.
Args:
cond_fn (Callable): the condition function in the while loop
body_fn (Callable): the function that is executed within the while loop
"""

def __init__(self, cond_fn, body_fn):
self.cond_fn = cond_fn
self.body_fn = body_fn

def __call__(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

while self.cond_fn(*args):
fn_res = self.body_fn(*args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

return fn_res


def for_loop(lower_bound, upper_bound, step):
"""A :func:`~.qjit` compatible for-loop for PennyLane programs. When
used without :func:`~.qjit`, this function will fall back to a standard
Python for loop.
This decorator provides a functional version of the traditional
for-loop, similar to `jax.cond.fori_loop <https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.fori_loop.html>`__.
Expand Down Expand Up @@ -430,19 +465,14 @@ def for_loop(lower_bound, upper_bound, step, loop_fn, *args):
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.
Raises:
CompileError: if the compiler is not installed
.. seealso:: :func:`~.while_loop`, :func:`~.qjit`
**Example**
.. code-block:: python
dev = qml.device("lightning.qubit", wires=1)
@qml.qjit
@qml.qnode(dev)
def circuit(n: int, x: float):
Expand All @@ -457,15 +487,84 @@ def loop_rx(i, x):
# apply the for loop
final_x = loop_rx(x)
return qml.expval(qml.Z(0)), final_x
return qml.expval(qml.Z(0))
>>> circuit(7, 1.6)
(array(0.97926626), array(0.55395718))
array(0.97926626)
``for_loop`` is also :func:`~.qjit` compatible; when used with the
:func:`~.qjit` decorator, the for loop will not be unrolled, and instead
will be captured as-is during compilation and executed during runtime:
>>> qml.qjit(circuit)(7, 1.6)
Array(0.97926626, dtype=float64)
.. note::
Please see the Catalyst :doc:`quickstart guide <catalyst:dev/quick_start>`,
as well as the :doc:`sharp bits and debugging tips <catalyst:dev/sharp_bits>`
page for an overview of using quantum just-in-time compilation.
"""

if active_jit := active_compiler():
compilers = AvailableCompilers.names_entrypoints
ops_loader = compilers[active_jit]["ops"].load()
return ops_loader.for_loop(lower_bound, upper_bound, step)

raise CompileError("There is no active compiler package.") # pragma: no cover
# if there is no active compiler, simply interpret the for loop
# via the Python interpretor.
def _decorator(body_fn):
"""Transform that will call the input ``body_fn`` within a for loop defined by the closure variables lower_bound, upper_bound, and step.
Args:
body_fn (Callable): The function called within the for loop. Note that the loop body
function must always have the iteration index as its first
argument, which can be used arbitrarily inside the loop body. As the value of the index
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.
Closure Variables:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration
Returns:
Callable: a callable with the same signature as ``body_fn``
"""
return ForLoopCallable(lower_bound, upper_bound, step, body_fn)

return _decorator


class ForLoopCallable: # pylint:disable=too-few-public-methods
"""Base class to represent a for loop. This class
when called with an initial state will execute the while
loop via the Python interpreter.
Args:
lower_bound (int): starting value of the iteration index
upper_bound (int): (exclusive) upper bound of the iteration index
step (int): increment applied to the iteration index at the end of each iteration
body_fn (Callable): The function called within the for loop. Note that the loop body
function must always have the iteration index as its first
argument, which can be used arbitrarily inside the loop body. As the value of the index
across iterations is handled automatically by the provided loop bounds, it must not be
returned from the function.
"""

def __init__(self, lower_bound, upper_bound, step, body_fn):
self.lower_bound = lower_bound
self.upper_bound = upper_bound
self.step = step
self.body_fn = body_fn

def __call__(self, *init_state):
args = init_state
fn_res = args if len(args) > 1 else args[0] if len(args) == 1 else None

for i in range(self.lower_bound, self.upper_bound, self.step):
fn_res = self.body_fn(i, *args)
args = fn_res if len(args) > 1 else (fn_res,) if len(args) == 1 else ()

return fn_res
90 changes: 90 additions & 0 deletions tests/test_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,45 @@ def inner(j):
assert circuit(5, 6) == 30 # 5 * 6
assert circuit(4, 7) == 28 # 4 * 7

def test_while_loop_python_fallback(self):
"""Test that qml.while_loop fallsback to
Python without qjit"""

def f(n, m):
@qml.while_loop(lambda i, _: i < n)
def outer(i, sm):
@qml.while_loop(lambda j: j < m)
def inner(j):
return j + 1

return i + 1, sm + inner(0)

return outer(0, 0)[1]

assert f(5, 6) == 30 # 5 * 6
assert f(4, 7) == 28 # 4 * 7

def test_fallback_while_loop_qnode(self):
"""Test that qml.while_loop inside a qnode fallsback to
Python without qjit"""
dev = qml.device("lightning.qubit", wires=1)

@qml.qnode(dev)
def circuit(n):
@qml.while_loop(lambda v: v[0] < v[1])
def loop(v):
qml.PauliX(wires=0)
return v[0] + 1, v[1]

loop((0, n))
return qml.expval(qml.PauliZ(0))

assert jnp.allclose(circuit(1), -1.0)

res = circuit.tape.operations
expected = [qml.PauliX(0) for i in range(4)]
_ = [qml.assert_equal(i, j) for i, j in zip(res, expected)]

def test_dynamic_wires_for_loops(self):
"""Test for loops with iteration index-dependant wires."""
dev = qml.device("lightning.qubit", wires=6)
Expand Down Expand Up @@ -405,6 +444,57 @@ def inner(j):

assert jnp.allclose(circuit(4), jnp.eye(2**4)[0])

def test_for_loop_python_fallback(self):
"""Test that qml.for_loop fallsback to Python
interpretation if Catalyst is not available"""
dev = qml.device("lightning.qubit", wires=3)

@qml.qnode(dev)
def circuit(x, n):

# for loop with dynamic bounds
@qml.for_loop(0, n, 1)
def loop_fn(i):
qml.Hadamard(wires=i)

# nested for loops.
# outer for loop updates x
@qml.for_loop(0, n, 1)
def loop_fn_returns(i, x):
qml.RX(x, wires=i)

# inner for loop
@qml.for_loop(i + 1, n, 1)
def inner(j):
qml.CRY(x**2, [i, j])

inner()

return x + 0.1

loop_fn()
loop_fn_returns(x)

return qml.expval(qml.PauliZ(0))

x = 0.5
assert jnp.allclose(circuit(x, 3), qml.qjit(circuit)(x, 3))

res = circuit.tape.operations
expected = [
qml.Hadamard(wires=[0]),
qml.Hadamard(wires=[1]),
qml.Hadamard(wires=[2]),
qml.RX(0.5, wires=[0]),
qml.CRY(0.25, wires=[0, 1]),
qml.CRY(0.25, wires=[0, 2]),
qml.RX(0.6, wires=[1]),
qml.CRY(0.36, wires=[1, 2]),
qml.RX(0.7, wires=[2]),
]

_ = [qml.assert_equal(i, j) for i, j in zip(res, expected)]

def test_cond(self):
"""Test condition with simple true_fn"""
dev = qml.device("lightning.qubit", wires=1)
Expand Down

0 comments on commit 98ba139

Please sign in to comment.