Skip to content

Commit

Permalink
Upgrade and generalise basis state preparation (#6021)
Browse files Browse the repository at this point in the history
This PR complete part of this story:
[[sc-68521](https://app.shortcut.com/xanaduai/story/68521)]

Goal: `BasisEmbedding` is an alias of `BasisState`. This way, we don't
have duplicate code that does the same thing.
In unifying this, I have had to modify some tests due to:
- `BasisEmbedding` and `BasisState` throw errors such as "incorrect
length" with different messages. Now it will always be the same. (test
modified for this reason: `test_default_qubit_legacy.py`,
`test_default_qubit_tf.py`
`test_default_qubit_torch.py`, `test_state_prep.py`,
`test_all_singles_doubles.py` and test_uccsd`)

- In `BasisEmbedding`, errors were thrown in `__init__` while in
BasisState in `state_vector`. Now they are unified in `__init__`. For
this reason, there were tests where the operator was not initialized
correctly but no error was thrown since `state_vector` was not being
called but now they are detected. To correct this, I have modified the
tests: `test_qscript.py`, `test_state_prep.py`,

- Now `BasisState` does not decompose `BasisStatePreparation` since we
are going to deprecate it. This causes the number of gates after
expanding to be affected. In this case I had to modify some test in
`test_tape.py`.

This PR also solves:

- [issue 6008](#6008)
- [issue 6007](#6007)
- [issue 6006](#6006)

---------

Co-authored-by: Isaac De Vlugt <[email protected]>
Co-authored-by: soranjh <[email protected]>
Co-authored-by: Utkarsh <[email protected]>
  • Loading branch information
4 people authored Aug 21, 2024
1 parent 8bd226c commit f00c924
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 168 deletions.
5 changes: 4 additions & 1 deletion doc/releases/changelog-dev.md
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,10 @@
[(#5978)](https://github.com/PennyLaneAI/pennylane/pull/5978)

* `qml.AmplitudeEmbedding` has better support for features using low precision integer data types.
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)
[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969)

* `qml.BasisState` and `qml.BasisEmbedding` now works with jax.jit, lightning.qubit and give the correct decomposition.
[(#6021)](https://github.com/PennyLaneAI/pennylane/pull/6021)

* Jacobian shape is fixed for measurements with dimension in `qml.gradients.vjp.compute_vjp_single`.
[(5986)](https://github.com/PennyLaneAI/pennylane/pull/5986)
Expand Down
99 changes: 73 additions & 26 deletions pennylane/ops/qubit/state_preparation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@

import numpy as np

import pennylane as qml
from pennylane import math
from pennylane.operation import AnyWires, Operation, Operator, StatePrepBase
from pennylane.templates.state_preparations import BasisStatePreparation, MottonenStatePreparation
from pennylane.templates.state_preparations import MottonenStatePreparation
from pennylane.typing import TensorLike
from pennylane.wires import WireError, Wires, WiresLike

Expand All @@ -33,14 +34,14 @@


class BasisState(StatePrepBase):
r"""BasisState(n, wires)
r"""BasisState(state, wires)
Prepares a single computational basis state.
**Details:**
* Number of wires: Any (the operation can act on any number of wires)
* Number of parameters: 1
* Gradient recipe: None (integer parameters not supported)
* Gradient recipe: None
.. note::
Expand All @@ -54,9 +55,8 @@ class BasisState(StatePrepBase):
as :math:`U|0\rangle = |\psi\rangle`
Args:
n (array): prepares the basis state :math:`\ket{n}`, where ``n`` is an
array of integers from the set :math:`\{0, 1\}`, i.e.,
if ``n = np.array([0, 1, 0])``, prepares the state :math:`|010\rangle`.
state (tensor_like): binary input of shape ``(len(wires), )``, e.g., for ``state=np.array([0, 1, 0])`` or ``state=2`` (binary 010), the quantum system will be prepared in state :math:`|010 \rangle`.
wires (Sequence[int] or int): the wire(s) the operation acts on
id (str): custom label given to an operator instance,
can be useful for some applications where the instance has to be identified.
Expand All @@ -72,15 +72,51 @@ class BasisState(StatePrepBase):
[0.+0.j 0.+0.j 0.+0.j 1.+0.j]
"""

num_wires = AnyWires
num_params = 1
"""int: Number of trainable parameters that the operator depends on."""
def __init__(self, state, wires, id=None):

ndim_params = (1,)
"""int: Number of dimensions per trainable parameter of the operator."""
if isinstance(state, list):
state = qml.math.stack(state)

tracing = qml.math.is_abstract(state)

if not qml.math.shape(state):
if not tracing and state >= 2 ** len(wires):
raise ValueError(
f"Integer state must be < {2 ** len(wires)} to have a feasible binary representation, got {state}"
)
bin = 2 ** math.arange(len(wires))[::-1]
state = qml.math.where((state & bin) > 0, 1, 0)

wires = Wires(wires)
shape = qml.math.shape(state)

if len(shape) != 1:
raise ValueError(f"State must be one-dimensional; got shape {shape}.")

n_states = shape[0]
if n_states != len(wires):
raise ValueError(
f"State must be of length {len(wires)}; got length {n_states} (state={state})."
)

if not tracing:
state_list = list(qml.math.toarray(state))
if not set(state_list).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {state_list}")

super().__init__(state, wires=wires, id=id)

def _flatten(self):
state = self.parameters[0]
state = tuple(state) if isinstance(state, list) else state
return (state,), (self.wires,)

@classmethod
def _unflatten(cls, data, metadata) -> "BasisState":
return cls(data[0], wires=metadata[0])

@staticmethod
def compute_decomposition(n: TensorLike, wires: WiresLike) -> list[Operator]:
def compute_decomposition(state: TensorLike, wires: WiresLike) -> list[Operator]:
r"""Representation of the operator as a product of other operators (static method). :
.. math:: O = O_1 O_2 \dots O_n.
Expand All @@ -89,8 +125,7 @@ def compute_decomposition(n: TensorLike, wires: WiresLike) -> list[Operator]:
.. seealso:: :meth:`~.BasisState.decomposition`.
Args:
n (array): prepares the basis state :math:`\ket{n}`, where ``n`` is an
array of integers from the set :math:`\{0, 1\}`
state (array): the basis state to be prepared
wires (Iterable, Wires): the wire(s) the operation acts on
Returns:
Expand All @@ -99,33 +134,45 @@ def compute_decomposition(n: TensorLike, wires: WiresLike) -> list[Operator]:
**Example:**
>>> qml.BasisState.compute_decomposition([1,0], wires=(0,1))
[BasisStatePreparation([1, 0], wires=[0, 1])]
[X(0)]
"""
return [BasisStatePreparation(n, wires)]

if not qml.math.is_abstract(state):
return [qml.X(wire) for wire, basis in zip(wires, state) if basis == 1]

op_list = []
for wire, basis in zip(wires, state):
op_list.append(qml.PhaseShift(basis * np.pi / 2, wire))
op_list.append(qml.RX(basis * np.pi, wire))
op_list.append(qml.PhaseShift(basis * np.pi / 2, wire))

return op_list

def state_vector(self, wire_order: Optional[WiresLike] = None) -> TensorLike:
"""Returns a statevector of shape ``(2,) * num_wires``."""
prep_vals = self.parameters[0]
if any(i not in [0, 1] for i in prep_vals):
raise ValueError("BasisState parameter must consist of 0 or 1 integers.")

if (num_wires := len(self.wires)) != len(prep_vals):
raise ValueError("BasisState parameter and wires must be of equal length.")
prep_vals_int = math.cast(self.parameters[0], int)

prep_vals = math.cast(prep_vals, int)
if wire_order is None:
indices = prep_vals
indices = prep_vals_int
num_wires = len(indices)
else:
if not Wires(wire_order).contains_wires(self.wires):
raise WireError("Custom wire_order must contain all BasisState wires")
num_wires = len(wire_order)
indices = [0] * num_wires
for base_wire_label, value in zip(self.wires, prep_vals):
for base_wire_label, value in zip(self.wires, prep_vals_int):
indices[wire_order.index(base_wire_label)] = value

ket = np.zeros((2,) * num_wires)
ket[tuple(indices)] = 1
if qml.math.get_interface(prep_vals_int) == "jax":
ket = math.array(math.zeros((2,) * num_wires), like="jax")
ket = ket.at[tuple(indices)].set(1)

else:
ket = math.zeros((2,) * num_wires)
ket[tuple(indices)] = 1

return math.convert_like(ket, prep_vals)


Expand Down
103 changes: 7 additions & 96 deletions pennylane/templates/embeddings/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,14 @@
Contains the BasisEmbedding template.
"""
# pylint: disable-msg=too-many-branches,too-many-arguments,protected-access
import numpy as np

import pennylane as qml
from pennylane.operation import AnyWires, Operation
from pennylane.wires import Wires
from pennylane.ops.qubit.state_preparation import BasisState


class BasisEmbedding(Operation):
class BasisEmbedding(BasisState):
r"""Encodes :math:`n` binary features into a basis state of :math:`n` qubits.
For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 10), the
For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 010), the
quantum system will be prepared in state :math:`|010 \rangle`.
.. warning::
Expand All @@ -35,8 +32,9 @@ class BasisEmbedding(Operation):
gradients with respect to the argument cannot be computed by PennyLane.
Args:
features (tensor_like): binary input of shape ``(len(wires), )``
wires (Any or Iterable[Any]): wires that the template acts on
features (tensor_like or int): binary input of shape ``(len(wires), )`` or integer
that represents the binary input.
wires (Any or Iterable[Any]): wires that the template acts on.
Example:
Expand Down Expand Up @@ -69,92 +67,5 @@ def circuit(feature_vector):
"""

num_wires = AnyWires
grad_method = None

def _flatten(self):
basis_state = self.hyperparameters["basis_state"]
basis_state = tuple(basis_state) if isinstance(basis_state, list) else basis_state
return tuple(), (self.wires, basis_state)

@classmethod
def _unflatten(cls, _, metadata) -> "BasisEmbedding":
return cls(features=metadata[1], wires=metadata[0])

def __init__(self, features, wires, id=None):
if isinstance(features, list):
features = qml.math.stack(features)

tracing = qml.math.is_abstract(features)

if qml.math.shape(features) == ():
if not tracing and features >= 2 ** len(wires):
raise ValueError(
f"Features must be of length {len(wires)}, got features={features} which is >= {2 ** len(wires)}"
)
bin = 2 ** np.arange(len(wires))[::-1]
features = qml.math.where((features & bin) > 0, 1, 0)

wires = Wires(wires)
shape = qml.math.shape(features)

if len(shape) != 1:
raise ValueError(f"Features must be one-dimensional; got shape {shape}.")

n_features = shape[0]
if n_features != len(wires):
raise ValueError(
f"Features must be of length {len(wires)}; got length {n_features} (features={features})."
)

if not tracing:
features = list(qml.math.toarray(features))
if not set(features).issubset({0, 1}):
raise ValueError(f"Basis state must only consist of 0s and 1s; got {features}")

self._hyperparameters = {"basis_state": features}

super().__init__(wires=wires, id=id)

@property
def num_params(self):
return 0

@staticmethod
def compute_decomposition(wires, basis_state): # pylint: disable=arguments-differ
r"""Representation of the operator as a product of other operators.
.. math:: O = O_1 O_2 \dots O_n.
.. seealso:: :meth:`~.BasisEmbedding.decomposition`.
Args:
features (tensor-like): binary input of shape ``(len(wires), )``
wires (Any or Iterable[Any]): wires that the operator acts on
Returns:
list[.Operator]: decomposition of the operator
**Example**
>>> features = torch.tensor([1, 0, 1])
>>> qml.BasisEmbedding.compute_decomposition(features, wires=["a", "b", "c"])
[X('a'),
X('c')]
"""
if not qml.math.is_abstract(basis_state):
ops_list = []
for wire, bit in zip(wires, basis_state):
if bit == 1:
ops_list.append(qml.X(wire))
return ops_list

ops_list = []
for wire, state in zip(wires, basis_state):
ops_list.append(qml.PhaseShift(state * np.pi / 2, wire))
ops_list.append(qml.RX(state * np.pi, wire))
ops_list.append(qml.PhaseShift(state * np.pi / 2, wire))

return ops_list
super().__init__(features, wires=wires, id=id)
5 changes: 3 additions & 2 deletions tests/devices/test_default_qubit_legacy.py
Original file line number Diff line number Diff line change
Expand Up @@ -652,13 +652,14 @@ def test_apply_errors_qubit_state_vector(self, qubit_device_2_wires):
)

def test_apply_errors_basis_state(self, qubit_device_2_wires):

with pytest.raises(
ValueError, match="BasisState parameter must consist of 0 or 1 integers."
ValueError, match=r"Basis state must only consist of 0s and 1s; got \[-0\.2, 4\.2\]"
):
qubit_device_2_wires.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])])

with pytest.raises(
ValueError, match="BasisState parameter and wires must be of equal length."
ValueError, match=r"State must be of length 1; got length 2 \(state=\[0 1\]\)\."
):
qubit_device_2_wires.apply([qml.BasisState(np.array([0, 1]), wires=[0])])

Expand Down
5 changes: 3 additions & 2 deletions tests/devices/test_default_qubit_tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,8 @@ def test_invalid_basis_state_length(self):
state = np.array([0, 0, 1, 0])

with pytest.raises(
ValueError, match=r"BasisState parameter and wires must be of equal length"
ValueError,
match=r"State must be of length 3; got length 4 \(state=\[0 0 1 0\]\)",
):
dev.apply([qml.BasisState(state, wires=[0, 1, 2])])

Expand All @@ -305,7 +306,7 @@ def test_invalid_basis_state(self):
state = np.array([0, 0, 1, 2])

with pytest.raises(
ValueError, match=r"BasisState parameter must consist of 0 or 1 integers"
ValueError, match=r"Basis state must only consist of 0s and 1s; got \[0, 0, 1, 2\]"
):
dev.apply([qml.BasisState(state, wires=[0, 1, 2, 3])])

Expand Down
5 changes: 3 additions & 2 deletions tests/devices/test_default_qubit_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def test_invalid_basis_state_length(self, device, torch_device):
state = torch.tensor([0, 0, 1, 0])

with pytest.raises(
ValueError, match=r"BasisState parameter and wires must be of equal length"
ValueError,
match=r"State must be of length 3; got length 4 \(state=tensor\(\[0, 0, 1, 0\]\)\)",
):
dev.apply([qml.BasisState(state, wires=[0, 1, 2])])

Expand All @@ -267,7 +268,7 @@ def test_invalid_basis_state(self, device, torch_device):
state = torch.tensor([0, 0, 1, 2])

with pytest.raises(
ValueError, match=r"BasisState parameter must consist of 0 or 1 integers"
ValueError, match=r"Basis state must only consist of 0s and 1s; got \[0, 0, 1, 2\]"
):
dev.apply([qml.BasisState(state, wires=[0, 1, 2, 3])])

Expand Down
Loading

0 comments on commit f00c924

Please sign in to comment.