Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add tests for exponential extrapolation #953

Merged
merged 9 commits into from
Jul 30, 2024
27 changes: 21 additions & 6 deletions doc/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,20 @@

```

* Exponential extrapolation is now a supported method of extrapolation when using `mitigate_with_zne`.
[(#953)](https://github.com/PennyLaneAI/catalyst/pull/953)

This new functionality fits the data from noise-scaled circuits with an exponential function,
and returns the zero-noise value. This functionality is available through the pennylane module
as follows
```py
from pennylane.transforms import exponential_extrapolate

catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=exponential_extrapolate
)
```

<h3>Improvements</h3>

* Catalyst is now compatible with Enzyme `v0.0.130`
Expand Down Expand Up @@ -182,6 +196,12 @@
* Support for TOML files in Schema 1 has been disabled.
[(#960)](https://github.com/PennyLaneAI/catalyst/pull/960)

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
function is valid. Keyword arguments can be passed to this function using the
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

<h3>Bug fixes</h3>

* Static arguments can now be passed through a QNode when specified
Expand Down Expand Up @@ -288,6 +308,7 @@ Mehrdad Malekmohammadi,
Romain Moyard,
Erick Ochoa,
Mudit Pandey,
nate stemen,
Raul Torres,
Tzung-Han Juang,
Paul Haochen Wang,
Expand Down Expand Up @@ -803,12 +824,6 @@ Paul Haochen Wang,

<h3>Breaking changes</h3>

* The `mitigate_with_zne` function no longer accepts a `degree` parameter for polynomial fitting
and instead accepts a callable to perform extrapolation. Any qjit-compatible extrapolation
function is valid. Keyword arguments can be passed to this function using the
`extrapolate_kwargs` keyword argument in `mitigate_with_zne`.
[(#806)](https://github.com/PennyLaneAI/catalyst/pull/806)

* Binary distributions for Linux are now based on `manylinux_2_28` instead of `manylinux_2014`.
As a result, Catalyst will only be compatible on systems with `glibc` versions `2.28` and above
(e.g., Ubuntu 20.04 and above).
Expand Down
74 changes: 59 additions & 15 deletions frontend/test/pytest/test_mitigation.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,26 @@
import numpy as np
import pennylane as qml
import pytest
from pennylane.transforms import exponential_extrapolate

import catalyst
from catalyst.api_extensions.error_mitigation import polynomial_extrapolation
natestemen marked this conversation as resolved.
Show resolved Hide resolved

quadratic_extrapolation = polynomial_extrapolation(2)


def skip_if_exponential_extrapolation_unstable(circuit_param, extrapolation_func):
"""skip test if exponential extrapolation will be unstable"""
if circuit_param < 0.3 and extrapolation_func == exponential_extrapolate:
pytest.skip("Exponential extrapolation unstable in this region.")


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_single_measurement(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_single_measurement(params, extrapolation):
"""Test that without noise the same results are returned for single measurements."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -42,15 +52,18 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_multiple_measurements(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_multiple_measurements(params, extrapolation):
"""Test that without noise the same results are returned for multiple measurements"""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -65,7 +78,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

assert np.allclose(mitigated_qnode(params), circuit(params))
Expand Down Expand Up @@ -121,7 +134,8 @@ def mitigated_function(args):
mitigated_function(0.1)


def test_dtype_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_dtype_error(extrapolation):
"""Test that an error is raised when multiple results do not have the same dtype."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -137,7 +151,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand All @@ -146,7 +160,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)


def test_dtype_not_float_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_dtype_not_float_error(extrapolation):
"""Test that an error is raised when results are not float."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -162,7 +177,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand All @@ -171,7 +186,8 @@ def mitigated_qnode(args):
mitigated_qnode(0.1)


def test_shape_error():
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_shape_error(extrapolation):
"""Test that an error is raised when results have shape."""
dev = qml.device("lightning.qubit", wires=2)

Expand All @@ -187,7 +203,7 @@ def circuit(x):
@catalyst.qjit
def mitigated_qnode(args):
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

with pytest.raises(
Expand Down Expand Up @@ -229,8 +245,11 @@ def mitigated_qnode():


@pytest.mark.parametrize("params", [0.1, 0.2, 0.3, 0.4, 0.5])
def test_zne_usage_patterns(params):
@pytest.mark.parametrize("extrapolation", [quadratic_extrapolation, exponential_extrapolate])
def test_zne_usage_patterns(params, extrapolation):
"""Test usage patterns of catalyst.zne."""
skip_if_exponential_extrapolation_unstable(params, extrapolation)

dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
Expand All @@ -245,13 +264,13 @@ def fn(x):
@catalyst.qjit
def mitigated_qnode_fn_as_argument(args):
return catalyst.mitigate_with_zne(
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
fn, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(args)

@catalyst.qjit
def mitigated_qnode_partial(args):
return catalyst.mitigate_with_zne(
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=quadratic_extrapolation
scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=extrapolation
)(fn)(args)

assert np.allclose(mitigated_qnode_fn_as_argument(params), fn(params))
Expand All @@ -271,13 +290,13 @@ def circuit():
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

def jax_extrap(scale_factors, results):
def jax_extrapolation(scale_factors, results):
return jax.numpy.polyfit(scale_factors, results, 2)[-1]

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrap
circuit, scale_factors=jax.numpy.array([1, 2, 3]), extrapolate=jax_extrapolation
)()

assert np.allclose(mitigated_qnode(), circuit())
Expand Down Expand Up @@ -308,5 +327,30 @@ def mitigated_qnode():
assert np.allclose(mitigated_qnode(), circuit())


def test_exponential_extrapolation_with_kwargs():
"""test mitigate_with_zne with keyword arguments for exponential extrapolation function"""
dev = qml.device("lightning.qubit", wires=2)

@qml.qnode(device=dev)
def circuit():
qml.Hadamard(wires=0)
qml.RZ(0.1, wires=0)
qml.RZ(0.2, wires=0)
qml.CNOT(wires=[1, 0])
qml.Hadamard(wires=1)
return qml.expval(qml.PauliY(wires=0))

@catalyst.qjit
def mitigated_qnode():
return catalyst.mitigate_with_zne(
circuit,
scale_factors=jax.numpy.array([1, 2, 3]),
extrapolate=qml.transforms.exponential_extrapolate,
extrapolate_kwargs={"asymptote": 3},
)()

assert np.allclose(mitigated_qnode(), circuit())


if __name__ == "__main__":
pytest.main(["-x", __file__])
Loading