diff --git a/.github/CHANGELOG.md b/.github/CHANGELOG.md index 27c0442c8..fc02db7d1 100644 --- a/.github/CHANGELOG.md +++ b/.github/CHANGELOG.md @@ -27,6 +27,9 @@ ### Improvements +* The `setBasisState` and `setStateVector` methods of `StateVectorLQubit` and `StateVectorKokkos` are overloaded to support PennyLane-like parameters. + [(#843)](https://github.com/PennyLaneAI/pennylane-lightning/pull/843) + * `ENABLE_LAPACK` is off by default for all Lightning backends. [(#825)](https://github.com/PennyLaneAI/pennylane-lightning/pull/825) diff --git a/.github/workflows/tests_without_binary.yml b/.github/workflows/tests_without_binary.yml index 64c059bc2..7fb1e75bb 100644 --- a/.github/workflows/tests_without_binary.yml +++ b/.github/workflows/tests_without_binary.yml @@ -42,6 +42,13 @@ jobs: name: Python Tests without Binary (${{ matrix.pl_backend }}) steps: + - name: Make disk space + run: | + for DIR in /usr/share/dotnet /usr/local/share/powershell /usr/share/swift; do + sudo du -sh $DIR || echo $DIR not found + sudo rm -rf $DIR + done + - name: Checkout PennyLane-Lightning uses: actions/checkout@v4 with: diff --git a/.gitignore b/.gitignore index 8e422d276..dedcf8aa5 100644 --- a/.gitignore +++ b/.gitignore @@ -32,6 +32,7 @@ kokkos/ PennyLane_Lightning_Kokkos.egg-info/ PennyLane_Lightning.egg-info/ prototypes/ +pyproject.toml tests/__pycache__/ venv/ wheelhouse/ \ No newline at end of file diff --git a/Makefile b/Makefile index 8b673ed71..0d4a5ed73 100644 --- a/Makefile +++ b/Makefile @@ -63,6 +63,11 @@ clean: rm -rf pennylane_lightning/*_ops* rm -rf *.egg-info +.PHONY: python +python: + PL_BACKEND=$(PL_BACKEND) python scripts/configure_pyproject_toml.py + pip install -e . -vv + .PHONY: wheel wheel: PL_BACKEND=$(PL_BACKEND) python scripts/configure_pyproject_toml.py diff --git a/pennylane_lightning/core/_version.py b/pennylane_lightning/core/_version.py index 955e96232..1e786a92c 100644 --- a/pennylane_lightning/core/_version.py +++ b/pennylane_lightning/core/_version.py @@ -16,4 +16,4 @@ Version number (major.minor.patch[-label]) """ -__version__ = "0.38.0-dev31" +__version__ = "0.38.0-dev32" diff --git a/pennylane_lightning/core/src/simulators/lightning_kokkos/StateVectorKokkos.hpp b/pennylane_lightning/core/src/simulators/lightning_kokkos/StateVectorKokkos.hpp index 56fd2a42c..13f4889bb 100644 --- a/pennylane_lightning/core/src/simulators/lightning_kokkos/StateVectorKokkos.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_kokkos/StateVectorKokkos.hpp @@ -144,6 +144,43 @@ class StateVectorKokkos final }); } + /** + * @brief Prepares a single computational basis state. + * + * @param state Binary number representing the index + * @param wires Wires. + */ + void setBasisState(const std::vector &state, + const std::vector &wires) { + PL_ABORT_IF_NOT(state.size() == wires.size(), + "state and wires must have equal dimensions."); + const auto num_qubits = this->getNumQubits(); + PL_ABORT_IF_NOT( + std::find_if(wires.begin(), wires.end(), + [&num_qubits](const auto i) { + return i >= num_qubits; + }) == wires.end(), + "wires must take values lower than the number of qubits."); + const auto n_wires = wires.size(); + std::size_t index{0U}; + for (std::size_t k = 0; k < n_wires; k++) { + const auto bit = static_cast(state[k]); + index |= bit << (num_qubits - 1 - wires[k]); + } + setBasisState(index); + } + + /** + * @brief Reset the data back to the \f$\ket{0}\f$ state. + * + * @param num_qubits Number of qubits + */ + void resetStateVector() { + if (this->getLength() > 0) { + setBasisState(0U); + } + } + /** * @brief Set values for a batch of elements of the state-vector. * @@ -164,14 +201,49 @@ class StateVectorKokkos final } /** - * @brief Reset the data back to the \f$\ket{0}\f$ state. + * @brief Set values for a batch of elements of the state-vector. * - * @param num_qubits Number of qubits + * @param state State. + * @param wires Wires. */ - void resetStateVector() { - if (this->getLength() > 0) { - setBasisState(0U); - } + void setStateVector(const std::vector &state, + const std::vector &wires) { + PL_ABORT_IF_NOT(state.size() == exp2(wires.size()), + "Inconsistent state and wires dimensions."); + setStateVector(state.data(), wires); + } + + /** + * @brief Set values for a batch of elements of the state-vector. + * + * @param state State. + * @param wires Wires. + */ + void setStateVector(const ComplexT *state, + const std::vector &wires) { + constexpr std::size_t one{1U}; + const auto num_qubits = this->getNumQubits(); + PL_ABORT_IF_NOT( + std::find_if(wires.begin(), wires.end(), + [&num_qubits](const auto i) { + return i >= num_qubits; + }) == wires.end(), + "wires must take values lower than the number of qubits."); + const auto num_state = exp2(wires.size()); + auto d_sv = getView(); + auto d_state = pointer2view(state, num_state); + auto d_wires = vector2view(wires); + initZeros(); + Kokkos::parallel_for( + num_state, KOKKOS_LAMBDA(const std::size_t i) { + std::size_t index{0U}; + for (std::size_t w = 0; w < d_wires.size(); w++) { + const std::size_t bit = (i & (one << w)) >> w; + index |= bit << (num_qubits - 1 - + d_wires(d_wires.size() - 1 - w)); + } + d_sv(index) = d_state(i); + }); } /** diff --git a/pennylane_lightning/core/src/simulators/lightning_kokkos/bindings/LKokkosBindings.hpp b/pennylane_lightning/core/src/simulators/lightning_kokkos/bindings/LKokkosBindings.hpp index 3e39388dd..4d8a70308 100644 --- a/pennylane_lightning/core/src/simulators/lightning_kokkos/bindings/LKokkosBindings.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_kokkos/bindings/LKokkosBindings.hpp @@ -73,27 +73,20 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) { .def("resetStateVector", &StateVectorT::resetStateVector) .def( "setBasisState", - [](StateVectorT &sv, const std::size_t index) { - sv.setBasisState(index); + [](StateVectorT &sv, const std::vector &state, + const std::vector &wires) { + sv.setBasisState(state, wires); }, - "Create Basis State on Device.") + "Set the state vector to a basis state.") .def( "setStateVector", - [](StateVectorT &sv, const std::vector &indices, - const np_arr_c &state) { + [](StateVectorT &sv, const np_arr_c &state, + const std::vector &wires) { const auto buffer = state.request(); - std::vector> state_kok; - if (buffer.size) { - const auto ptr = - static_cast *>( - buffer.ptr); - state_kok = std::vector>{ - ptr, ptr + buffer.size}; - } - sv.setStateVector(indices, state_kok); + sv.setStateVector(static_cast(buffer.ptr), + wires); }, - "Set State Vector on device with values and their corresponding " - "indices for the state vector on device") + "Set the state vector to the data contained in `state`.") .def( "DeviceToHost", [](StateVectorT &device_sv, np_arr_c &host_sv) { diff --git a/pennylane_lightning/core/src/simulators/lightning_kokkos/gates/tests/Test_StateVectorKokkos_NonParam.cpp b/pennylane_lightning/core/src/simulators/lightning_kokkos/gates/tests/Test_StateVectorKokkos_NonParam.cpp index 30c182288..93c9c4423 100644 --- a/pennylane_lightning/core/src/simulators/lightning_kokkos/gates/tests/Test_StateVectorKokkos_NonParam.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_kokkos/gates/tests/Test_StateVectorKokkos_NonParam.cpp @@ -837,15 +837,19 @@ TEMPLATE_TEST_CASE("StateVectorKokkos::SetStateVector", std::vector> values = { init_state[1], init_state[3], init_state[5], init_state[7], init_state[0], init_state[2], init_state[4], init_state[6]}; - kokkos_sv.setStateVector(indices, values); - kokkos_sv.DeviceToHost(result_sv.data(), result_sv.size()); - for (std::size_t j = 0; j < exp2(num_qubits); j++) { CHECK(imag(expected_state[j]) == Approx(imag(result_sv[j]))); CHECK(real(expected_state[j]) == Approx(real(result_sv[j]))); } + + kokkos_sv.setStateVector(init_state, {0, 1, 2}); + kokkos_sv.DeviceToHost(result_sv.data(), result_sv.size()); + for (std::size_t j = 0; j < exp2(num_qubits); j++) { + CHECK(imag(init_state[j]) == Approx(imag(result_sv[j]))); + CHECK(real(init_state[j]) == Approx(real(result_sv[j]))); + } } } diff --git a/pennylane_lightning/core/src/simulators/lightning_kokkos/tests/Test_StateVectorLKokkos.cpp b/pennylane_lightning/core/src/simulators/lightning_kokkos/tests/Test_StateVectorLKokkos.cpp index 71f49896e..ea82dfcb3 100644 --- a/pennylane_lightning/core/src/simulators/lightning_kokkos/tests/Test_StateVectorLKokkos.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_kokkos/tests/Test_StateVectorLKokkos.cpp @@ -109,6 +109,40 @@ TEMPLATE_PRODUCT_TEST_CASE("StateVectorKokkos::applyMatrix with a std::vector", } } +TEMPLATE_PRODUCT_TEST_CASE("StateVectorKokkos::setState", "[errors]", + (StateVectorKokkos), (float, double)) { + using StateVectorT = TestType; + using ComplexT = typename StateVectorT::ComplexT; + const std::size_t num_qubits = 2; + StateVectorT sv(num_qubits); + + SECTION("setBasisState incompatible dimensions") { + REQUIRE_THROWS_WITH( + sv.setBasisState({0}, {0, 1}), + Catch::Contains("state and wires must have equal dimensions.")); + } + + SECTION("setBasisState high wire index") { + REQUIRE_THROWS_WITH( + sv.setBasisState({0, 0, 0}, {0, 1, 2}), + Catch::Contains( + "wires must take values lower than the number of qubits.")); + } + + SECTION("setStateVector incompatible dimensions") { + REQUIRE_THROWS_WITH( + sv.setStateVector(std::vector(2, 0.0), {0, 1}), + Catch::Contains("Inconsistent state and wires dimensions.")); + } + + SECTION("setStateVector high wire index") { + REQUIRE_THROWS_WITH( + sv.setStateVector(std::vector(8, 0.0), {0, 1, 2}), + Catch::Contains( + "wires must take values lower than the number of qubits.")); + } +} + TEMPLATE_PRODUCT_TEST_CASE("StateVectorKokkos::applyMatrix with a pointer", "[applyMatrix]", (StateVectorKokkos), (float, double)) { diff --git a/pennylane_lightning/core/src/simulators/lightning_kokkos/utils/UtilKokkos.hpp b/pennylane_lightning/core/src/simulators/lightning_kokkos/utils/UtilKokkos.hpp index 4b1d32e71..dac6411fa 100644 --- a/pennylane_lightning/core/src/simulators/lightning_kokkos/utils/UtilKokkos.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_kokkos/utils/UtilKokkos.hpp @@ -49,6 +49,23 @@ inline auto view2vector(const Kokkos::View view) -> std::vector { return vec; } +/** + * @brief Copy the content of a pointer to a Kokkos view. + * + * @tparam T Pointer data type. + * @param vec Pointer. + * @return Kokkos view pointing to a copy of the pointer. + */ +template +inline auto pointer2view(const T *vec, const std::size_t num) + -> Kokkos::View { + using UnmanagedView = Kokkos::View>; + Kokkos::View view("vec", num); + Kokkos::deep_copy(view, UnmanagedView(vec, num)); + return view; +} + /** * @brief Copy the content of an `std::vector` to a Kokkos view. * @@ -58,11 +75,7 @@ inline auto view2vector(const Kokkos::View view) -> std::vector { */ template inline auto vector2view(const std::vector &vec) -> Kokkos::View { - using UnmanagedView = Kokkos::View>; - Kokkos::View view("vec", vec.size()); - Kokkos::deep_copy(view, UnmanagedView(vec.data(), vec.size())); - return view; + return pointer2view(vec.data(), vec.size()); } /** diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp index f742dccec..509e343c8 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/StateVectorLQubit.hpp @@ -30,6 +30,7 @@ #include "KernelType.hpp" #include "StateVectorBase.hpp" #include "Threading.hpp" +#include "cpu_kernels/GateImplementationsLM.hpp" /// @cond DEV namespace { @@ -702,6 +703,34 @@ class StateVectorLQubit : public StateVectorBase { arr[index] = {1.0, 0.0}; } + /** + * @brief Prepares a single computational basis state. + * + * @param state Binary number representing the index + * @param wires Wires. + */ + void setBasisState(const std::vector &state, + const std::vector &wires) { + const auto n_wires = wires.size(); + const auto num_qubits = this->getNumQubits(); + std::size_t index{0U}; + for (std::size_t k = 0; k < n_wires; k++) { + const auto bit = static_cast(state[k]); + index |= bit << (num_qubits - 1 - wires[k]); + } + setBasisState(index); + } + + /** + * @brief Reset the data back to the \f$\ket{0}\f$ state. + * + */ + void resetStateVector() { + if (this->getLength() > 0) { + setBasisState(0U); + } + } + /** * @brief Set values for a batch of elements of the state-vector. * @@ -710,12 +739,12 @@ class StateVectorLQubit : public StateVectorBase { */ void setStateVector(const std::vector &indices, const std::vector &values) { - auto num_indices = indices.size(); + const auto num_indices = indices.size(); PL_ABORT_IF(num_indices != values.size(), "Indices and values length must match"); auto *arr = this->getData(); - auto length = this->getLength(); + const auto length = this->getLength(); std::fill(arr, arr + length, 0.0); for (std::size_t i = 0; i < num_indices; i++) { PL_ABORT_IF(i >= length, "Invalid index"); @@ -724,13 +753,54 @@ class StateVectorLQubit : public StateVectorBase { } /** - * @brief Reset the data back to the \f$\ket{0}\f$ state. + * @brief Set values for a batch of elements of the state-vector. * + * @param state State. + * @param wires Wires. */ - void resetStateVector() { - if (this->getLength()) { - setBasisState(0U); + void setStateVector(const std::vector &state, + const std::vector &wires) { + PL_ABORT_IF_NOT(state.size() == exp2(wires.size()), + "Inconsistent state and wires dimensions.") + setStateVector(state.data(), wires); + } + + /** + * @brief Set values for a batch of elements of the state-vector. + * + * @param state State. + * @param wires Wires. + */ + void setStateVector(const ComplexT *state, + const std::vector &wires) { + const std::size_t num_state = exp2(wires.size()); + const auto total_wire_count = this->getNumQubits(); + + std::vector reversed_sorted_wires(wires); + std::sort(reversed_sorted_wires.begin(), reversed_sorted_wires.end()); + std::reverse(reversed_sorted_wires.begin(), + reversed_sorted_wires.end()); + std::vector controlled_wires(total_wire_count); + std::iota(std::begin(controlled_wires), std::end(controlled_wires), 0); + for (auto wire : reversed_sorted_wires) { + // Reverse guarantees that we start erasing at the end of the array. + // Maybe this can be optimized. + controlled_wires.erase(controlled_wires.begin() + wire); } + + const std::vector controlled_values(controlled_wires.size(), + false); + auto core_function = + [num_state, &state](std::complex *arr, + const std::vector &indices, + const std::vector> &) { + for (std::size_t i = 0; i < num_state; i++) { + arr[indices[i]] = state[i]; + } + }; + GateImplementationsLM::applyNCN(this->getData(), total_wire_count, + controlled_wires, controlled_values, + wires, core_function); } }; } // namespace Pennylane::LightningQubit diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp b/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp index f6bf6c0db..188315900 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/bindings/LQubitBindings.hpp @@ -179,23 +179,20 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) { .def("resetStateVector", &StateVectorT::resetStateVector) .def( "setBasisState", - [](StateVectorT &sv, const std::size_t index) { - sv.setBasisState(index); + [](StateVectorT &sv, const std::vector &state, + const std::vector &wires) { + sv.setBasisState(state, wires); }, - "Create Basis State.") + "Set the state vector to a basis state.") .def( "setStateVector", - [](StateVectorT &sv, const std::vector &indices, - const np_arr_c &state) { + [](StateVectorT &sv, const np_arr_c &state, + const std::vector &wires) { const auto buffer = state.request(); - std::vector state_in; - if (buffer.size) { - const auto ptr = static_cast(buffer.ptr); - state_in = std::vector{ptr, ptr + buffer.size}; - } - sv.setStateVector(indices, state_in); + sv.setStateVector(static_cast(buffer.ptr), + wires); }, - "Set State Vector with values and their corresponding indices") + "Set the state vector to the data contained in `state`.") .def( "getState", [](const StateVectorT &sv, np_arr_c &state) { diff --git a/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp b/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp index 8dc2967d0..775fdfd8c 100644 --- a/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp +++ b/pennylane_lightning/core/src/simulators/lightning_qubit/tests/Test_StateVectorLQubit.cpp @@ -49,6 +49,130 @@ TEMPLATE_TEST_CASE("StateVectorLQubit::Constructibility", } } +TEMPLATE_PRODUCT_TEST_CASE("StateVectorLQubit::setBasisState", + "[setBasisState]", + (StateVectorLQubitManaged, StateVectorLQubitRaw), + (float, double)) { + using StateVectorT = TestType; + using ComplexT = typename StateVectorT::ComplexT; + + SECTION("setBasisState") { + const ComplexT one{1.0}; + const ComplexT zero{0.0}; + + std::vector init_state_zeros = {zero, zero, zero, zero, + zero, zero, zero, zero}; + std::vector expected_state_000 = {one, zero, zero, zero, + zero, zero, zero, zero}; + + StateVectorT sv(init_state_zeros.data(), init_state_zeros.size()); + sv.setBasisState({0}, {0}); + REQUIRE(sv.getDataVector() == approx(expected_state_000)); + + std::vector expected_state_001 = {zero, one, zero, zero, + zero, zero, zero, zero}; + sv.setBasisState({1}, {2}); + REQUIRE(sv.getDataVector() == approx(expected_state_001)); + + std::vector expected_state_010 = {zero, zero, one, zero, + zero, zero, zero, zero}; + sv.setBasisState({1}, {1}); + REQUIRE(sv.getDataVector() == approx(expected_state_010)); + + std::vector expected_state_011 = {zero, zero, zero, one, + zero, zero, zero, zero}; + sv.setBasisState({1, 1}, {1, 2}); + REQUIRE(sv.getDataVector() == approx(expected_state_011)); + + std::vector expected_state_100 = {zero, zero, zero, zero, + one, zero, zero, zero}; + sv.setBasisState({1}, {0}); + REQUIRE(sv.getDataVector() == approx(expected_state_100)); + + std::vector expected_state_101 = {zero, zero, zero, zero, + zero, one, zero, zero}; + + sv.setBasisState({1, 0, 1}, {0, 1, 2}); + REQUIRE(sv.getDataVector() == approx(expected_state_101)); + + std::vector expected_state_110 = {zero, zero, zero, zero, + zero, zero, one, zero}; + sv.setBasisState({1, 1}, {0, 1}); + REQUIRE(sv.getDataVector() == approx(expected_state_110)); + + std::vector expected_state_111 = {zero, zero, zero, zero, + zero, zero, zero, one}; + sv.setBasisState({1, 1, 1}, {0, 1, 2}); + REQUIRE(sv.getDataVector() == approx(expected_state_111)); + } +} + +TEMPLATE_PRODUCT_TEST_CASE("StateVectorLQubit::setStateVector", + "[setStateVector]", + (StateVectorLQubitManaged, StateVectorLQubitRaw), + (float, double)) { + using StateVectorT = TestType; + using ComplexT = typename StateVectorT::ComplexT; + + SECTION("setStateVector") { + const ComplexT one{1.0}; + const ComplexT zero{0.0}; + + std::vector init_state_zeros = {zero, zero, zero, zero, + zero, zero, zero, zero}; + std::vector expected_state_000 = {one, zero, zero, zero, + zero, zero, zero, zero}; + + StateVectorT sv(init_state_zeros.data(), init_state_zeros.size()); + sv.setStateVector({one, zero}, {0}); + REQUIRE(sv.getDataVector() == approx(expected_state_000)); + + std::vector expected_state_001 = {zero, one, zero, zero, + zero, zero, zero, zero}; + sv.resetStateVector(); + sv.setStateVector({zero, one}, {2}); + REQUIRE(sv.getDataVector() == approx(expected_state_001)); + + std::vector expected_state_010 = {zero, zero, one, zero, + zero, zero, zero, zero}; + + sv.resetStateVector(); + sv.setStateVector({zero, one}, {1}); + REQUIRE(sv.getDataVector() == approx(expected_state_010)); + + std::vector expected_state_011 = {zero, zero, zero, one, + zero, zero, zero, zero}; + sv.resetStateVector(); + sv.setStateVector({zero, zero, zero, one}, {1, 2}); + REQUIRE(sv.getDataVector() == approx(expected_state_011)); + + std::vector expected_state_100 = {zero, zero, zero, zero, + one, zero, zero, zero}; + sv.resetStateVector(); + sv.setStateVector({zero, one}, {0}); + REQUIRE(sv.getDataVector() == approx(expected_state_100)); + + std::vector expected_state_101 = {zero, zero, zero, zero, + zero, one, zero, zero}; + sv.resetStateVector(); + sv.setStateVector({zero, zero, zero, one}, {0, 2}); + REQUIRE(sv.getDataVector() == approx(expected_state_101)); + + std::vector expected_state_110 = {zero, zero, zero, zero, + zero, zero, one, zero}; + sv.resetStateVector(); + sv.setStateVector({zero, zero, zero, one}, {0, 1}); + REQUIRE(sv.getDataVector() == approx(expected_state_110)); + + std::vector expected_state_111 = {zero, zero, zero, zero, + zero, zero, zero, one}; + sv.resetStateVector(); + sv.setStateVector({zero, zero, zero, zero, zero, zero, zero, one}, + {0, 1, 2}); + REQUIRE(sv.getDataVector() == approx(expected_state_111)); + } +} + TEMPLATE_PRODUCT_TEST_CASE("StateVectorLQubit::Constructibility", "[General Constructibility]", (StateVectorLQubitManaged, StateVectorLQubitRaw), diff --git a/pennylane_lightning/lightning_kokkos/lightning_kokkos.py b/pennylane_lightning/lightning_kokkos/lightning_kokkos.py index 652825326..6fea7c162 100644 --- a/pennylane_lightning/lightning_kokkos/lightning_kokkos.py +++ b/pennylane_lightning/lightning_kokkos/lightning_kokkos.py @@ -249,17 +249,6 @@ def _asarray(arr, dtype=None): arr = new_arr return arr - def _create_basis_state(self, index): - """Return a computational basis state over all wires. - Args: - index (int): integer representing the computational basis state - Returns: - array[complex]: complex array of shape ``[2]*self.num_wires`` - representing the statevector of the basis state - Note: This function does not support broadcasted inputs yet. - """ - self._kokkos_state.setBasisState(index) - def reset(self): """Reset the device""" super().reset() @@ -357,18 +346,15 @@ def _apply_state_vector(self, state, device_wires): state.DeviceToHost(state_data) state = state_data - ravelled_indices, state = self._preprocess_state_vector(state, device_wires) - # translate to wire labels used by device device_wires = self.map_wires(device_wires) - output_shape = [2] * self.num_wires if len(device_wires) == self.num_wires and Wires(sorted(device_wires)) == device_wires: + output_shape = (2,) * self.num_wires # Initialize the entire device state with the input state self.sync_h2d(self._reshape(state, output_shape)) return - - self._kokkos_state.setStateVector(ravelled_indices, state) # this operation on device + self._kokkos_state.setStateVector(state, list(device_wires)) # this operation on device def _apply_basis_state(self, state, wires): """Initialize the state vector in a specified computational basis state. @@ -380,8 +366,13 @@ def _apply_basis_state(self, state, wires): Note: This function does not support broadcasted inputs yet. """ - num = self._get_basis_state_index(state, wires) - self._create_basis_state(num) + if not set(state.tolist()).issubset({0, 1}): + raise ValueError("BasisState parameter must consist of 0 or 1 integers.") + + if len(state) != len(wires): + raise ValueError("BasisState parameter and wires must be of equal length.") + + self._kokkos_state.setBasisState(state, list(wires)) def _apply_lightning_midmeasure( self, operation: MidMeasureMP, mid_measurements: dict, postselect_mode: str diff --git a/pennylane_lightning/lightning_qubit/_state_vector.py b/pennylane_lightning/lightning_qubit/_state_vector.py index 7307b35d0..f9d305d52 100644 --- a/pennylane_lightning/lightning_qubit/_state_vector.py +++ b/pennylane_lightning/lightning_qubit/_state_vector.py @@ -24,8 +24,6 @@ except ImportError: pass -from itertools import product - import numpy as np import pennylane as qml from pennylane import BasisState, DeviceError, StatePrep @@ -111,75 +109,11 @@ def _state_dtype(self): """ return StateVectorC128 if self.dtype == np.complex128 else StateVectorC64 - def _create_basis_state(self, index): - """Return a computational basis state over all wires. - - Args: - index (int): integer representing the computational basis state. - """ - self._qubit_state.setBasisState(index) - def reset_state(self): """Reset the device's state""" # init the state vector to |00..0> self._qubit_state.resetStateVector() - def _preprocess_state_vector(self, state, device_wires): - """Initialize the internal state vector in a specified state. - - Args: - state (array[complex]): normalized input state of length ``2**len(wires)`` - or broadcasted state of shape ``(batch_size, 2**len(wires))`` - device_wires (Wires): wires that get initialized in the state - - Returns: - array[int]: indices for which the state is changed to input state vector elements - array[complex]: normalized input state of length ``2**len(wires)`` - or broadcasted state of shape ``(batch_size, 2**len(wires))`` - """ - # special case for integral types - if state.dtype.kind == "i": - state = np.array(state, dtype=self.dtype) - - if len(device_wires) == self._num_wires and Wires(sorted(device_wires)) == device_wires: - return None, state - - # generate basis states on subset of qubits via the cartesian product - basis_states = np.array(list(product([0, 1], repeat=len(device_wires)))) - - # get basis states to alter on full set of qubits - unravelled_indices = np.zeros((2 ** len(device_wires), self._num_wires), dtype=int) - unravelled_indices[:, device_wires] = basis_states - - # get indices for which the state is changed to input state vector elements - ravelled_indices = np.ravel_multi_index(unravelled_indices.T, [2] * self._num_wires) - return ravelled_indices, state - - def _get_basis_state_index(self, state, wires): - """Returns the basis state index of a specified computational basis state. - - Args: - state (array[int]): computational basis state of shape ``(wires,)`` - consisting of 0s and 1s - wires (Wires): wires that the provided computational state should be initialized on - - Returns: - int: basis state index - """ - # length of basis state parameter - n_basis_state = len(state) - - if not set(state.tolist()).issubset({0, 1}): - raise ValueError("BasisState parameter must consist of 0 or 1 integers.") - - if n_basis_state != len(wires): - raise ValueError("BasisState parameter and wires must be of equal length.") - - # get computational basis state number - basis_states = 2 ** (self._num_wires - 1 - np.array(wires)) - basis_states = qml.math.convert_like(basis_states, state) - return int(qml.math.dot(state, basis_states)) - def _apply_state_vector(self, state, device_wires: Wires): """Initialize the internal state vector in a specified state. Args: @@ -193,18 +127,14 @@ def _apply_state_vector(self, state, device_wires: Wires): state.getState(state_data) state = state_data - ravelled_indices, state = self._preprocess_state_vector(state, device_wires) - - # translate to wire labels used by device - output_shape = [2] * self._num_wires - if len(device_wires) == self._num_wires and Wires(sorted(device_wires)) == device_wires: # Initialize the entire device state with the input state + output_shape = (2,) * self._num_wires state = np.reshape(state, output_shape).ravel(order="C") self._qubit_state.UpdateData(state) return - self._qubit_state.setStateVector(ravelled_indices, state) + self._qubit_state.setStateVector(state, list(device_wires)) def _apply_basis_state(self, state, wires): """Initialize the state vector in a specified computational basis state. @@ -217,8 +147,13 @@ def _apply_basis_state(self, state, wires): Note: This function does not support broadcasted inputs yet. """ - num = self._get_basis_state_index(state, wires) - self._create_basis_state(num) + if not set(state.tolist()).issubset({0, 1}): + raise ValueError("BasisState parameter must consist of 0 or 1 integers.") + + if len(state) != len(wires): + raise ValueError("BasisState parameter and wires must be of equal length.") + + self._qubit_state.setBasisState(list(state), list(wires)) def _apply_lightning_controlled(self, operation): """Apply an arbitrary controlled operation to the state tensor. diff --git a/tests/test_gates.py b/tests/test_gates.py index 855619efc..98a84e3ec 100644 --- a/tests/test_gates.py +++ b/tests/test_gates.py @@ -358,6 +358,32 @@ def circuit(): assert np.allclose(circ(), circ_def(), tol) +@pytest.mark.skipif( + device_name not in ("lightning.qubit", "lightning.kokkos"), + reason="PennyLane-like StatePrep only implemented in lightning.qubit and lightning.kokkos.", +) +@pytest.mark.parametrize("n_targets", list(range(2, 8))) +def test_state_prep(n_targets, tol): + """Test that StatePrep is correctly applied to a state.""" + n_wires = 7 + dq = qml.device("default.qubit", wires=n_wires) + dev = qml.device(device_name, wires=n_wires) + init_state = np.random.rand(2**n_targets) + 1.0j * np.random.rand(2**n_targets) + init_state /= np.linalg.norm(init_state) + for i in range(10): + if i == 0: + wires = np.arange(n_targets, dtype=int) + else: + wires = np.random.permutation(n_wires)[0:n_targets] + tape = qml.tape.QuantumTape( + [qml.StatePrep(init_state, wires=wires)] + [qml.X(i) for i in range(n_wires)], + [qml.state()], + ) + ref = dq.execute([tape])[0] + res = dev.execute([tape])[0] if ld._new_API else dev.execute(tape) + assert np.allclose(res.ravel(), ref.ravel(), tol) + + @pytest.mark.skipif( device_name != "lightning.qubit", reason="N-controlled operations only implemented in lightning.qubit.",