Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move functions to base class #841

Merged
merged 13 commits into from
Aug 9, 2024
3 changes: 3 additions & 0 deletions .github/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
* Add `initial_state_prep` option to Catalyst TOML file.
[(#826)](https://github.com/PennyLaneAI/pennylane-lightning/pull/826)

* Move `setBasisState`, `setStateVector` and `resetStateVector` from `StateVectorLQubitManaged` to `StateVectorLQubit`.
[(#841)](https://github.com/PennyLaneAI/pennylane-lightning/pull/841)

### Documentation

* Updated the README and added citation format for Lightning arxiv preprint.
Expand Down
2 changes: 1 addition & 1 deletion pennylane_lightning/core/_version.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@
Version number (major.minor.patch[-label])
"""

__version__ = "0.38.0-dev26"
__version__ = "0.38.0-dev27"
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#pragma once
#include <complex>
#include <unordered_map>
#include <unordered_set>
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved

#include "CPUMemoryModel.hpp"
#include "GateOperation.hpp"
Expand Down Expand Up @@ -687,5 +688,50 @@ class StateVectorLQubit : public StateVectorBase<PrecisionT, Derived> {
arr[k] *= inv_norm;
}
}

/**
* @brief Prepares a single computational basis state.
*
* @param index Index of the target element.
*/
void setBasisState(const std::size_t index) {
auto length = this->getLength();
PL_ABORT_IF(index > length - 1, "Invalid index");

auto *arr = this->getData();
std::fill(arr, arr + length, 0.0);
arr[index] = {1.0, 0.0};
}

/**
* @brief Set values for a batch of elements of the state-vector.
*
* @param values Values to be set for the target elements.
* @param indices Indices of the target elements.
*/
void setStateVector(const std::vector<std::size_t> &indices,
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
const std::vector<ComplexT> &values) {
auto num_indices = indices.size();
PL_ABORT_IF(num_indices != values.size(),
"Indices and values length must match");

auto *arr = this->getData();
auto length = this->getLength();
std::fill(arr, arr + length, 0.0);
for (std::size_t i = 0; i < num_indices; i++) {
PL_ABORT_IF(i >= length, "Invalid index");
arr[indices[i]] = values[i];
}
maliasadi marked this conversation as resolved.
Show resolved Hide resolved
}

/**
* @brief Reset the data back to the \f$\ket{0}\f$ state.
*
*/
void resetStateVector() {
vincentmr marked this conversation as resolved.
Show resolved Hide resolved
if (this->getLength() > 0) {
erick-xanadu marked this conversation as resolved.
Show resolved Hide resolved
setBasisState(0U);
}
}
};
} // namespace Pennylane::LightningQubit
Original file line number Diff line number Diff line change
Expand Up @@ -140,39 +140,6 @@ class StateVectorLQubitManaged final

~StateVectorLQubitManaged() = default;

/**
* @brief Prepares a single computational basis state.
*
* @param index Index of the target element.
*/
void setBasisState(const std::size_t index) {
std::fill(data_.begin(), data_.end(), 0.0);
data_[index] = {1.0, 0.0};
}

/**
* @brief Set values for a batch of elements of the state-vector.
*
* @param values Values to be set for the target elements.
* @param indices Indices of the target elements.
*/
void setStateVector(const std::vector<std::size_t> &indices,
const std::vector<ComplexT> &values) {
for (std::size_t n = 0; n < indices.size(); n++) {
data_[indices[n]] = values[n];
}
}

/**
* @brief Reset the data back to the \f$\ket{0}\f$ state.
*
*/
void resetStateVector() {
if (this->getLength() > 0) {
setBasisState(0U);
}
}

[[nodiscard]] auto getData() -> ComplexT * { return data_.data(); }

[[nodiscard]] auto getData() const -> const ComplexT * {
Expand Down Expand Up @@ -218,4 +185,4 @@ class StateVectorLQubitManaged final
return data_.get_allocator();
}
};
} // namespace Pennylane::LightningQubit
} // namespace Pennylane::LightningQubit
Loading