Skip to content

Commit

Permalink
Add setBasisState and setStateVector with wire operands. (#843)
Browse files Browse the repository at this point in the history
**Context:** Runtime implementation of `qml.BasisState` and
`qml.StatePrep`.

**Description of the Change:** Add StatePrep methods to Lightning Qubit
and Lightning Kokkos.

[sc-71129]

---------

Co-authored-by: ringo-but-quantum <[email protected]>
Co-authored-by: Vincent Michaud-Rioux <[email protected]>
Co-authored-by: Vincent Michaud-Rioux <[email protected]>
  • Loading branch information
4 people authored Aug 13, 2024
1 parent 09d626b commit c6b86a5
Show file tree
Hide file tree
Showing 16 changed files with 416 additions and 141 deletions.
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

### Improvements

* The `setBasisState` and `setStateVector` methods of `StateVectorLQubit` and `StateVectorKokkos` are overloaded to support PennyLane-like parameters.
[(#843)](https://github.com/PennyLaneAI/pennylane-lightning/pull/843)

* `ENABLE_LAPACK` is off by default for all Lightning backends.
[(#825)](https://github.com/PennyLaneAI/pennylane-lightning/pull/825)

Expand Down
7 changes: 7 additions & 0 deletions .github/workflows/tests_without_binary.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ jobs:
name: Python Tests without Binary (${{ matrix.pl_backend }})

steps:
- name: Make disk space
run: |
for DIR in /usr/share/dotnet /usr/local/share/powershell /usr/share/swift; do
sudo du -sh $DIR || echo $DIR not found
sudo rm -rf $DIR
done
- name: Checkout PennyLane-Lightning
uses: actions/checkout@v4
with:
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ kokkos/
PennyLane_Lightning_Kokkos.egg-info/
PennyLane_Lightning.egg-info/
prototypes/
pyproject.toml
tests/__pycache__/
venv/
wheelhouse/
5 changes: 5 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,11 @@ clean:
rm -rf pennylane_lightning/*_ops*
rm -rf *.egg-info

.PHONY: python
python:
PL_BACKEND=$(PL_BACKEND) python scripts/configure_pyproject_toml.py
pip install -e . -vv

.PHONY: wheel
wheel:
PL_BACKEND=$(PL_BACKEND) python scripts/configure_pyproject_toml.py
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.38.0-dev31"
__version__ = "0.38.0-dev32"
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,43 @@ class StateVectorKokkos final
});
}

/**
* @brief Prepares a single computational basis state.
*
* @param state Binary number representing the index
* @param wires Wires.
*/
void setBasisState(const std::vector<std::size_t> &state,
const std::vector<std::size_t> &wires) {
PL_ABORT_IF_NOT(state.size() == wires.size(),
"state and wires must have equal dimensions.");
const auto num_qubits = this->getNumQubits();
PL_ABORT_IF_NOT(
std::find_if(wires.begin(), wires.end(),
[&num_qubits](const auto i) {
return i >= num_qubits;
}) == wires.end(),
"wires must take values lower than the number of qubits.");
const auto n_wires = wires.size();
std::size_t index{0U};
for (std::size_t k = 0; k < n_wires; k++) {
const auto bit = static_cast<std::size_t>(state[k]);
index |= bit << (num_qubits - 1 - wires[k]);
}
setBasisState(index);
}

/**
* @brief Reset the data back to the \f$\ket{0}\f$ state.
*
* @param num_qubits Number of qubits
*/
void resetStateVector() {
if (this->getLength() > 0) {
setBasisState(0U);
}
}

/**
* @brief Set values for a batch of elements of the state-vector.
*
Expand All @@ -164,14 +201,49 @@ class StateVectorKokkos final
}

/**
* @brief Reset the data back to the \f$\ket{0}\f$ state.
* @brief Set values for a batch of elements of the state-vector.
*
* @param num_qubits Number of qubits
* @param state State.
* @param wires Wires.
*/
void resetStateVector() {
if (this->getLength() > 0) {
setBasisState(0U);
}
void setStateVector(const std::vector<ComplexT> &state,
const std::vector<std::size_t> &wires) {
PL_ABORT_IF_NOT(state.size() == exp2(wires.size()),
"Inconsistent state and wires dimensions.");
setStateVector(state.data(), wires);
}

/**
* @brief Set values for a batch of elements of the state-vector.
*
* @param state State.
* @param wires Wires.
*/
void setStateVector(const ComplexT *state,
const std::vector<std::size_t> &wires) {
constexpr std::size_t one{1U};
const auto num_qubits = this->getNumQubits();
PL_ABORT_IF_NOT(
std::find_if(wires.begin(), wires.end(),
[&num_qubits](const auto i) {
return i >= num_qubits;
}) == wires.end(),
"wires must take values lower than the number of qubits.");
const auto num_state = exp2(wires.size());
auto d_sv = getView();
auto d_state = pointer2view(state, num_state);
auto d_wires = vector2view(wires);
initZeros();
Kokkos::parallel_for(
num_state, KOKKOS_LAMBDA(const std::size_t i) {
std::size_t index{0U};
for (std::size_t w = 0; w < d_wires.size(); w++) {
const std::size_t bit = (i & (one << w)) >> w;
index |= bit << (num_qubits - 1 -
d_wires(d_wires.size() - 1 - w));
}
d_sv(index) = d_state(i);
});
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,27 +73,20 @@ void registerBackendClassSpecificBindings(PyClass &pyclass) {
.def("resetStateVector", &StateVectorT::resetStateVector)
.def(
"setBasisState",
[](StateVectorT &sv, const std::size_t index) {
sv.setBasisState(index);
[](StateVectorT &sv, const std::vector<std::size_t> &state,
const std::vector<std::size_t> &wires) {
sv.setBasisState(state, wires);
},
"Create Basis State on Device.")
"Set the state vector to a basis state.")
.def(
"setStateVector",
[](StateVectorT &sv, const std::vector<std::size_t> &indices,
const np_arr_c &state) {
[](StateVectorT &sv, const np_arr_c &state,
const std::vector<std::size_t> &wires) {
const auto buffer = state.request();
std::vector<Kokkos::complex<ParamT>> state_kok;
if (buffer.size) {
const auto ptr =
static_cast<const Kokkos::complex<ParamT> *>(
buffer.ptr);
state_kok = std::vector<Kokkos::complex<ParamT>>{
ptr, ptr + buffer.size};
}
sv.setStateVector(indices, state_kok);
sv.setStateVector(static_cast<const ComplexT *>(buffer.ptr),
wires);
},
"Set State Vector on device with values and their corresponding "
"indices for the state vector on device")
"Set the state vector to the data contained in `state`.")
.def(
"DeviceToHost",
[](StateVectorT &device_sv, np_arr_c &host_sv) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -837,15 +837,19 @@ TEMPLATE_TEST_CASE("StateVectorKokkos::SetStateVector",
std::vector<Kokkos::complex<PrecisionT>> values = {
init_state[1], init_state[3], init_state[5], init_state[7],
init_state[0], init_state[2], init_state[4], init_state[6]};

kokkos_sv.setStateVector(indices, values);

kokkos_sv.DeviceToHost(result_sv.data(), result_sv.size());

for (std::size_t j = 0; j < exp2(num_qubits); j++) {
CHECK(imag(expected_state[j]) == Approx(imag(result_sv[j])));
CHECK(real(expected_state[j]) == Approx(real(result_sv[j])));
}

kokkos_sv.setStateVector(init_state, {0, 1, 2});
kokkos_sv.DeviceToHost(result_sv.data(), result_sv.size());
for (std::size_t j = 0; j < exp2(num_qubits); j++) {
CHECK(imag(init_state[j]) == Approx(imag(result_sv[j])));
CHECK(real(init_state[j]) == Approx(real(result_sv[j])));
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,40 @@ TEMPLATE_PRODUCT_TEST_CASE("StateVectorKokkos::applyMatrix with a std::vector",
}
}

TEMPLATE_PRODUCT_TEST_CASE("StateVectorKokkos::setState", "[errors]",
(StateVectorKokkos), (float, double)) {
using StateVectorT = TestType;
using ComplexT = typename StateVectorT::ComplexT;
const std::size_t num_qubits = 2;
StateVectorT sv(num_qubits);

SECTION("setBasisState incompatible dimensions") {
REQUIRE_THROWS_WITH(
sv.setBasisState({0}, {0, 1}),
Catch::Contains("state and wires must have equal dimensions."));
}

SECTION("setBasisState high wire index") {
REQUIRE_THROWS_WITH(
sv.setBasisState({0, 0, 0}, {0, 1, 2}),
Catch::Contains(
"wires must take values lower than the number of qubits."));
}

SECTION("setStateVector incompatible dimensions") {
REQUIRE_THROWS_WITH(
sv.setStateVector(std::vector<ComplexT>(2, 0.0), {0, 1}),
Catch::Contains("Inconsistent state and wires dimensions."));
}

SECTION("setStateVector high wire index") {
REQUIRE_THROWS_WITH(
sv.setStateVector(std::vector<ComplexT>(8, 0.0), {0, 1, 2}),
Catch::Contains(
"wires must take values lower than the number of qubits."));
}
}

TEMPLATE_PRODUCT_TEST_CASE("StateVectorKokkos::applyMatrix with a pointer",
"[applyMatrix]", (StateVectorKokkos),
(float, double)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,23 @@ inline auto view2vector(const Kokkos::View<T *> view) -> std::vector<T> {
return vec;
}

/**
* @brief Copy the content of a pointer to a Kokkos view.
*
* @tparam T Pointer data type.
* @param vec Pointer.
* @return Kokkos view pointing to a copy of the pointer.
*/
template <typename T>
inline auto pointer2view(const T *vec, const std::size_t num)
-> Kokkos::View<T *> {
using UnmanagedView = Kokkos::View<const T *, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
Kokkos::View<T *> view("vec", num);
Kokkos::deep_copy(view, UnmanagedView(vec, num));
return view;
}

/**
* @brief Copy the content of an `std::vector` to a Kokkos view.
*
Expand All @@ -58,11 +75,7 @@ inline auto view2vector(const Kokkos::View<T *> view) -> std::vector<T> {
*/
template <typename T>
inline auto vector2view(const std::vector<T> &vec) -> Kokkos::View<T *> {
using UnmanagedView = Kokkos::View<const T *, Kokkos::HostSpace,
Kokkos::MemoryTraits<Kokkos::Unmanaged>>;
Kokkos::View<T *> view("vec", vec.size());
Kokkos::deep_copy(view, UnmanagedView(vec.data(), vec.size()));
return view;
return pointer2view(vec.data(), vec.size());
}

/**
Expand Down
Loading

0 comments on commit c6b86a5

Please sign in to comment.