diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8829bb2f323..7848e547ba7 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -379,7 +379,10 @@ [(#5978)](https://github.com/PennyLaneAI/pennylane/pull/5978) * `qml.AmplitudeEmbedding` has better support for features using low precision integer data types. -[(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969) + [(#5969)](https://github.com/PennyLaneAI/pennylane/pull/5969) + +* `qml.BasisState` and `qml.BasisEmbedding` now works with jax.jit, lightning.qubit and give the correct decomposition. + [(#6021)](https://github.com/PennyLaneAI/pennylane/pull/6021) * Jacobian shape is fixed for measurements with dimension in `qml.gradients.vjp.compute_vjp_single`. [(5986)](https://github.com/PennyLaneAI/pennylane/pull/5986) diff --git a/pennylane/ops/qubit/state_preparation.py b/pennylane/ops/qubit/state_preparation.py index 0871f3f6d5a..e7fe6c77ece 100644 --- a/pennylane/ops/qubit/state_preparation.py +++ b/pennylane/ops/qubit/state_preparation.py @@ -20,9 +20,10 @@ import numpy as np +import pennylane as qml from pennylane import math from pennylane.operation import AnyWires, Operation, Operator, StatePrepBase -from pennylane.templates.state_preparations import BasisStatePreparation, MottonenStatePreparation +from pennylane.templates.state_preparations import MottonenStatePreparation from pennylane.typing import TensorLike from pennylane.wires import WireError, Wires, WiresLike @@ -33,14 +34,14 @@ class BasisState(StatePrepBase): - r"""BasisState(n, wires) + r"""BasisState(state, wires) Prepares a single computational basis state. **Details:** * Number of wires: Any (the operation can act on any number of wires) * Number of parameters: 1 - * Gradient recipe: None (integer parameters not supported) + * Gradient recipe: None .. note:: @@ -54,9 +55,8 @@ class BasisState(StatePrepBase): as :math:`U|0\rangle = |\psi\rangle` Args: - n (array): prepares the basis state :math:`\ket{n}`, where ``n`` is an - array of integers from the set :math:`\{0, 1\}`, i.e., - if ``n = np.array([0, 1, 0])``, prepares the state :math:`|010\rangle`. + state (tensor_like): binary input of shape ``(len(wires), )``, e.g., for ``state=np.array([0, 1, 0])`` or ``state=2`` (binary 010), the quantum system will be prepared in state :math:`|010 \rangle`. + wires (Sequence[int] or int): the wire(s) the operation acts on id (str): custom label given to an operator instance, can be useful for some applications where the instance has to be identified. @@ -72,15 +72,51 @@ class BasisState(StatePrepBase): [0.+0.j 0.+0.j 0.+0.j 1.+0.j] """ - num_wires = AnyWires - num_params = 1 - """int: Number of trainable parameters that the operator depends on.""" + def __init__(self, state, wires, id=None): - ndim_params = (1,) - """int: Number of dimensions per trainable parameter of the operator.""" + if isinstance(state, list): + state = qml.math.stack(state) + + tracing = qml.math.is_abstract(state) + + if not qml.math.shape(state): + if not tracing and state >= 2 ** len(wires): + raise ValueError( + f"Integer state must be < {2 ** len(wires)} to have a feasible binary representation, got {state}" + ) + bin = 2 ** math.arange(len(wires))[::-1] + state = qml.math.where((state & bin) > 0, 1, 0) + + wires = Wires(wires) + shape = qml.math.shape(state) + + if len(shape) != 1: + raise ValueError(f"State must be one-dimensional; got shape {shape}.") + + n_states = shape[0] + if n_states != len(wires): + raise ValueError( + f"State must be of length {len(wires)}; got length {n_states} (state={state})." + ) + + if not tracing: + state_list = list(qml.math.toarray(state)) + if not set(state_list).issubset({0, 1}): + raise ValueError(f"Basis state must only consist of 0s and 1s; got {state_list}") + + super().__init__(state, wires=wires, id=id) + + def _flatten(self): + state = self.parameters[0] + state = tuple(state) if isinstance(state, list) else state + return (state,), (self.wires,) + + @classmethod + def _unflatten(cls, data, metadata) -> "BasisState": + return cls(data[0], wires=metadata[0]) @staticmethod - def compute_decomposition(n: TensorLike, wires: WiresLike) -> list[Operator]: + def compute_decomposition(state: TensorLike, wires: WiresLike) -> list[Operator]: r"""Representation of the operator as a product of other operators (static method). : .. math:: O = O_1 O_2 \dots O_n. @@ -89,8 +125,7 @@ def compute_decomposition(n: TensorLike, wires: WiresLike) -> list[Operator]: .. seealso:: :meth:`~.BasisState.decomposition`. Args: - n (array): prepares the basis state :math:`\ket{n}`, where ``n`` is an - array of integers from the set :math:`\{0, 1\}` + state (array): the basis state to be prepared wires (Iterable, Wires): the wire(s) the operation acts on Returns: @@ -99,33 +134,45 @@ def compute_decomposition(n: TensorLike, wires: WiresLike) -> list[Operator]: **Example:** >>> qml.BasisState.compute_decomposition([1,0], wires=(0,1)) - [BasisStatePreparation([1, 0], wires=[0, 1])] + [X(0)] """ - return [BasisStatePreparation(n, wires)] + + if not qml.math.is_abstract(state): + return [qml.X(wire) for wire, basis in zip(wires, state) if basis == 1] + + op_list = [] + for wire, basis in zip(wires, state): + op_list.append(qml.PhaseShift(basis * np.pi / 2, wire)) + op_list.append(qml.RX(basis * np.pi, wire)) + op_list.append(qml.PhaseShift(basis * np.pi / 2, wire)) + + return op_list def state_vector(self, wire_order: Optional[WiresLike] = None) -> TensorLike: """Returns a statevector of shape ``(2,) * num_wires``.""" prep_vals = self.parameters[0] - if any(i not in [0, 1] for i in prep_vals): - raise ValueError("BasisState parameter must consist of 0 or 1 integers.") - - if (num_wires := len(self.wires)) != len(prep_vals): - raise ValueError("BasisState parameter and wires must be of equal length.") + prep_vals_int = math.cast(self.parameters[0], int) - prep_vals = math.cast(prep_vals, int) if wire_order is None: - indices = prep_vals + indices = prep_vals_int + num_wires = len(indices) else: if not Wires(wire_order).contains_wires(self.wires): raise WireError("Custom wire_order must contain all BasisState wires") num_wires = len(wire_order) indices = [0] * num_wires - for base_wire_label, value in zip(self.wires, prep_vals): + for base_wire_label, value in zip(self.wires, prep_vals_int): indices[wire_order.index(base_wire_label)] = value - ket = np.zeros((2,) * num_wires) - ket[tuple(indices)] = 1 + if qml.math.get_interface(prep_vals_int) == "jax": + ket = math.array(math.zeros((2,) * num_wires), like="jax") + ket = ket.at[tuple(indices)].set(1) + + else: + ket = math.zeros((2,) * num_wires) + ket[tuple(indices)] = 1 + return math.convert_like(ket, prep_vals) diff --git a/pennylane/templates/embeddings/basis.py b/pennylane/templates/embeddings/basis.py index 18175a8fd6a..9c95f7f253a 100644 --- a/pennylane/templates/embeddings/basis.py +++ b/pennylane/templates/embeddings/basis.py @@ -15,17 +15,14 @@ Contains the BasisEmbedding template. """ # pylint: disable-msg=too-many-branches,too-many-arguments,protected-access -import numpy as np -import pennylane as qml -from pennylane.operation import AnyWires, Operation -from pennylane.wires import Wires +from pennylane.ops.qubit.state_preparation import BasisState -class BasisEmbedding(Operation): +class BasisEmbedding(BasisState): r"""Encodes :math:`n` binary features into a basis state of :math:`n` qubits. - For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 10), the + For example, for ``features=np.array([0, 1, 0])`` or ``features=2`` (binary 010), the quantum system will be prepared in state :math:`|010 \rangle`. .. warning:: @@ -35,8 +32,9 @@ class BasisEmbedding(Operation): gradients with respect to the argument cannot be computed by PennyLane. Args: - features (tensor_like): binary input of shape ``(len(wires), )`` - wires (Any or Iterable[Any]): wires that the template acts on + features (tensor_like or int): binary input of shape ``(len(wires), )`` or integer + that represents the binary input. + wires (Any or Iterable[Any]): wires that the template acts on. Example: @@ -69,92 +67,5 @@ def circuit(feature_vector): """ - num_wires = AnyWires - grad_method = None - - def _flatten(self): - basis_state = self.hyperparameters["basis_state"] - basis_state = tuple(basis_state) if isinstance(basis_state, list) else basis_state - return tuple(), (self.wires, basis_state) - - @classmethod - def _unflatten(cls, _, metadata) -> "BasisEmbedding": - return cls(features=metadata[1], wires=metadata[0]) - def __init__(self, features, wires, id=None): - if isinstance(features, list): - features = qml.math.stack(features) - - tracing = qml.math.is_abstract(features) - - if qml.math.shape(features) == (): - if not tracing and features >= 2 ** len(wires): - raise ValueError( - f"Features must be of length {len(wires)}, got features={features} which is >= {2 ** len(wires)}" - ) - bin = 2 ** np.arange(len(wires))[::-1] - features = qml.math.where((features & bin) > 0, 1, 0) - - wires = Wires(wires) - shape = qml.math.shape(features) - - if len(shape) != 1: - raise ValueError(f"Features must be one-dimensional; got shape {shape}.") - - n_features = shape[0] - if n_features != len(wires): - raise ValueError( - f"Features must be of length {len(wires)}; got length {n_features} (features={features})." - ) - - if not tracing: - features = list(qml.math.toarray(features)) - if not set(features).issubset({0, 1}): - raise ValueError(f"Basis state must only consist of 0s and 1s; got {features}") - - self._hyperparameters = {"basis_state": features} - - super().__init__(wires=wires, id=id) - - @property - def num_params(self): - return 0 - - @staticmethod - def compute_decomposition(wires, basis_state): # pylint: disable=arguments-differ - r"""Representation of the operator as a product of other operators. - - .. math:: O = O_1 O_2 \dots O_n. - - - - .. seealso:: :meth:`~.BasisEmbedding.decomposition`. - - Args: - features (tensor-like): binary input of shape ``(len(wires), )`` - wires (Any or Iterable[Any]): wires that the operator acts on - - Returns: - list[.Operator]: decomposition of the operator - - **Example** - - >>> features = torch.tensor([1, 0, 1]) - >>> qml.BasisEmbedding.compute_decomposition(features, wires=["a", "b", "c"]) - [X('a'), - X('c')] - """ - if not qml.math.is_abstract(basis_state): - ops_list = [] - for wire, bit in zip(wires, basis_state): - if bit == 1: - ops_list.append(qml.X(wire)) - return ops_list - - ops_list = [] - for wire, state in zip(wires, basis_state): - ops_list.append(qml.PhaseShift(state * np.pi / 2, wire)) - ops_list.append(qml.RX(state * np.pi, wire)) - ops_list.append(qml.PhaseShift(state * np.pi / 2, wire)) - - return ops_list + super().__init__(features, wires=wires, id=id) diff --git a/tests/devices/test_default_qubit_legacy.py b/tests/devices/test_default_qubit_legacy.py index ac2f753ab01..ad60caa0ad7 100644 --- a/tests/devices/test_default_qubit_legacy.py +++ b/tests/devices/test_default_qubit_legacy.py @@ -652,13 +652,14 @@ def test_apply_errors_qubit_state_vector(self, qubit_device_2_wires): ) def test_apply_errors_basis_state(self, qubit_device_2_wires): + with pytest.raises( - ValueError, match="BasisState parameter must consist of 0 or 1 integers." + ValueError, match=r"Basis state must only consist of 0s and 1s; got \[-0\.2, 4\.2\]" ): qubit_device_2_wires.apply([qml.BasisState(np.array([-0.2, 4.2]), wires=[0, 1])]) with pytest.raises( - ValueError, match="BasisState parameter and wires must be of equal length." + ValueError, match=r"State must be of length 1; got length 2 \(state=\[0 1\]\)\." ): qubit_device_2_wires.apply([qml.BasisState(np.array([0, 1]), wires=[0])]) diff --git a/tests/devices/test_default_qubit_tf.py b/tests/devices/test_default_qubit_tf.py index 93b5472996f..09158a19941 100644 --- a/tests/devices/test_default_qubit_tf.py +++ b/tests/devices/test_default_qubit_tf.py @@ -295,7 +295,8 @@ def test_invalid_basis_state_length(self): state = np.array([0, 0, 1, 0]) with pytest.raises( - ValueError, match=r"BasisState parameter and wires must be of equal length" + ValueError, + match=r"State must be of length 3; got length 4 \(state=\[0 0 1 0\]\)", ): dev.apply([qml.BasisState(state, wires=[0, 1, 2])]) @@ -305,7 +306,7 @@ def test_invalid_basis_state(self): state = np.array([0, 0, 1, 2]) with pytest.raises( - ValueError, match=r"BasisState parameter must consist of 0 or 1 integers" + ValueError, match=r"Basis state must only consist of 0s and 1s; got \[0, 0, 1, 2\]" ): dev.apply([qml.BasisState(state, wires=[0, 1, 2, 3])]) diff --git a/tests/devices/test_default_qubit_torch.py b/tests/devices/test_default_qubit_torch.py index 29f0caee8f5..aab77ece1e7 100644 --- a/tests/devices/test_default_qubit_torch.py +++ b/tests/devices/test_default_qubit_torch.py @@ -257,7 +257,8 @@ def test_invalid_basis_state_length(self, device, torch_device): state = torch.tensor([0, 0, 1, 0]) with pytest.raises( - ValueError, match=r"BasisState parameter and wires must be of equal length" + ValueError, + match=r"State must be of length 3; got length 4 \(state=tensor\(\[0, 0, 1, 0\]\)\)", ): dev.apply([qml.BasisState(state, wires=[0, 1, 2])]) @@ -267,7 +268,7 @@ def test_invalid_basis_state(self, device, torch_device): state = torch.tensor([0, 0, 1, 2]) with pytest.raises( - ValueError, match=r"BasisState parameter must consist of 0 or 1 integers" + ValueError, match=r"Basis state must only consist of 0s and 1s; got \[0, 0, 1, 2\]" ): dev.apply([qml.BasisState(state, wires=[0, 1, 2, 3])]) diff --git a/tests/ops/qubit/test_state_prep.py b/tests/ops/qubit/test_state_prep.py index b89b2ef5197..342aaff5df0 100644 --- a/tests/ops/qubit/test_state_prep.py +++ b/tests/ops/qubit/test_state_prep.py @@ -26,7 +26,7 @@ @pytest.mark.parametrize( "op", [ - qml.BasisState(np.array([0, 1]), wires=0), + qml.BasisState(np.array([0, 1]), wires=[0, 1]), qml.StatePrep(np.array([1.0, 0.0]), wires=0), qml.QubitDensityMatrix(densitymat0, wires=0), ], @@ -66,8 +66,8 @@ def test_BasisState_decomposition(self): ops2 = qml.BasisState(n, wires=wires).decomposition() assert len(ops1) == len(ops2) == 1 - assert isinstance(ops1[0], qml.BasisStatePreparation) - assert isinstance(ops2[0], qml.BasisStatePreparation) + assert isinstance(ops1[0], qml.X) + assert isinstance(ops2[0], qml.X) def test_StatePrep_decomposition(self): """Test the decomposition for StatePrep.""" @@ -392,18 +392,9 @@ def test_BasisState_state_vector_bad_wire_order(self): with pytest.raises(WireError, match="wire_order must contain all BasisState wires"): basis_op.state_vector(wire_order=[1, 2]) - def test_BasisState_explicitly_checks_0_1(self): - """Tests that BasisState gives a clear error if a value other than 0 or 1 is given.""" - op = qml.BasisState([2, 1], wires=[0, 1]) - with pytest.raises( - ValueError, match="BasisState parameter must consist of 0 or 1 integers." - ): - _ = op.state_vector() - def test_BasisState_wrong_param_size(self): """Tests that the parameter must be of length num_wires.""" - op = qml.BasisState([0], wires=[0, 1]) with pytest.raises( - ValueError, match="BasisState parameter and wires must be of equal length." + ValueError, match=r"State must be of length 2; got length 1 \(state=\[0\]\)." ): - _ = op.state_vector() + _ = qml.BasisState([0], wires=[0, 1]) diff --git a/tests/tape/test_qscript.py b/tests/tape/test_qscript.py index f72b9f98651..bcafc1da65d 100644 --- a/tests/tape/test_qscript.py +++ b/tests/tape/test_qscript.py @@ -641,7 +641,7 @@ def test_deep_copy(self): def test_adjoint(): """Tests taking the adjoint of a quantum script.""" ops = [ - qml.BasisState([1, 1], wires=0), + qml.BasisState([1, 1], wires=[0, 1]), qml.RX(1.2, wires=0), qml.S(0), qml.CNOT((0, 1)), diff --git a/tests/tape/test_tape.py b/tests/tape/test_tape.py index 8492817ffdc..227d4025a3b 100644 --- a/tests/tape/test_tape.py +++ b/tests/tape/test_tape.py @@ -896,8 +896,7 @@ def test_decomposition_removing_parameters(self): with QuantumTape() as tape: qml.BasisState(np.array([1]), wires=0) - # since expansion calls `BasisStatePreparation` we have to expand twice - new_tape = tape.expand(depth=2) + new_tape = tape.expand(depth=1) assert len(new_tape.operations) == 1 assert new_tape.operations[0].name == "PauliX" @@ -958,7 +957,8 @@ def test_nesting_and_decomposition(self): qml.probs(wires="a") new_tape = tape.expand() - assert len(new_tape.operations) == 4 + + assert len(new_tape.operations) == 5 assert new_tape.shots is tape.shots def test_stopping_criterion(self): @@ -991,7 +991,7 @@ def test_depth_expansion(self): qml.probs(wires=0) qml.probs(wires="a") - new_tape = tape.expand(depth=3) + new_tape = tape.expand(depth=2) assert len(new_tape.operations) == 11 @pytest.mark.parametrize("skip_first", (True, False)) @@ -1005,11 +1005,9 @@ def test_depth_expansion(self): qml.PauliZ(0), ], [ - qml.BasisStatePreparation([1, 0], wires=[0, 1]), + qml.PauliX(0), + qml.MottonenStatePreparation([0, 1, 0, 0], wires=[0, 1]), qml.MottonenStatePreparation([0, 1, 0, 0], wires=[0, 1]), - qml.MottonenStatePreparation( - [0, 1, 0, 0], wires=[0, 1] - ), # still a StatePrepBase :/ qml.PauliZ(0), ], ), @@ -1036,7 +1034,6 @@ def test_expansion_state_prep(self, skip_first, op, decomp): true_decomposition += [ qml.PauliZ(wires=0), qml.Rot(0.1, 0.2, 0.3, wires=0), - qml.BasisStatePreparation([0], wires=[1]), qml.MottonenStatePreparation([0, 1], wires=[0]), ] @@ -1081,7 +1078,7 @@ def test_measurement_expansion(self): new_tape = tape.expand(expand_measurements=True) - assert len(new_tape.operations) == 5 + assert len(new_tape.operations) == 6 expected = [ qml.measurements.Probability, diff --git a/tests/templates/test_embeddings/test_basis.py b/tests/templates/test_embeddings/test_basis.py index 31bc736cff6..3b8b93b79e6 100644 --- a/tests/templates/test_embeddings/test_basis.py +++ b/tests/templates/test_embeddings/test_basis.py @@ -34,9 +34,8 @@ def test_flatten_unflatten(): wires = qml.wires.Wires((0, 1, 2)) op = qml.BasisEmbedding(features=[1, 1, 1], wires=wires) data, metadata = op._flatten() - assert data == tuple() + assert np.allclose(data[0], [1, 1, 1]) assert metadata[0] == wires - assert metadata[1] == (1, 1, 1) # make sure metadata hashable assert hash(metadata) @@ -111,13 +110,17 @@ def test_features_as_int_conversion(self, feat, wires, expected): """checks conversion from features as int to a list of binary digits with length = len(wires)""" - assert ( - qml.BasisEmbedding(features=feat, wires=wires).hyperparameters["basis_state"] - == expected - ) + assert np.allclose(qml.BasisEmbedding(features=feat, wires=wires).parameters[0], expected) - @pytest.mark.parametrize("x", [[0], [0, 1, 1], 4]) - def test_wrong_input_bits_exception(self, x): + @pytest.mark.parametrize( + "x, msg", + [ + ([0], "State must be of length"), + ([0, 1, 1], "State must be of length"), + (4, "Integer state must be"), + ], + ) + def test_wrong_input_bits_exception(self, x, msg): """Checks exception if number of features is not same as number of qubits.""" dev = qml.device("default.qubit", wires=2) @@ -127,7 +130,7 @@ def circuit(): qml.BasisEmbedding(features=x, wires=range(2)) return qml.expval(qml.PauliZ(0)) - with pytest.raises(ValueError, match="Features must be of length"): + with pytest.raises(ValueError, match=msg): circuit() def test_input_not_binary_exception(self): @@ -153,7 +156,7 @@ def circuit(x=None): qml.BasisEmbedding(features=x, wires=2) return qml.expval(qml.PauliZ(0)) - with pytest.raises(ValueError, match="Features must be one-dimensional"): + with pytest.raises(ValueError, match="State must be one-dimensional"): circuit(x=[[1], [0]]) def test_id(self): @@ -236,8 +239,12 @@ def test_jax(self, tol): res = circuit(jnp.array(2)) assert qml.math.allclose(res, res2, atol=tol, rtol=0) + @pytest.mark.parametrize( + "device_name", + ["default.qubit", "lightning.qubit"], + ) @pytest.mark.jax - def test_jax_jit(self, tol): + def test_jax_jit(self, tol, device_name): """Tests the jax-jit interface.""" import jax @@ -245,7 +252,7 @@ def test_jax_jit(self, tol): features = jnp.array([0, 1, 0]) - dev = qml.device("default.qubit", wires=3) + dev = qml.device(device_name, wires=3) circuit = qml.QNode(circuit_template, dev) circuit2 = qml.QNode(circuit_decomposed, dev) diff --git a/tests/templates/test_subroutines/test_all_singles_doubles.py b/tests/templates/test_subroutines/test_all_singles_doubles.py index 7d20863135d..f7844dcb807 100644 --- a/tests/templates/test_subroutines/test_all_singles_doubles.py +++ b/tests/templates/test_subroutines/test_all_singles_doubles.py @@ -206,7 +206,7 @@ class TestInputs: [[0, 2]], [[0, 1, 2, 3]], np.array([1, 1, 0, 0, 0]), - "Basis states must be of length 4", + "State must be of length 4", ), ( np.array([-2.8, 1.6]), diff --git a/tests/templates/test_subroutines/test_uccsd.py b/tests/templates/test_subroutines/test_uccsd.py index f335851bdd6..62476ad6ad6 100644 --- a/tests/templates/test_subroutines/test_uccsd.py +++ b/tests/templates/test_subroutines/test_uccsd.py @@ -291,7 +291,7 @@ class TestInputs: [], np.array([1, 1, 0, 0, 0]), 1, - "Basis states must be of length 4", + "State must be of length 4", ), ( np.array([-2.8, 1.6]),