Skip to content

Commit

Permalink
Merge pull request #268 from MichaelBroughton/bulk_set
Browse files Browse the repository at this point in the history
Bulk set
  • Loading branch information
95-martin-orion authored Jan 25, 2021
2 parents 3661250 + 94b1864 commit dbe31f4
Show file tree
Hide file tree
Showing 7 changed files with 156 additions and 0 deletions.
31 changes: 31 additions & 0 deletions lib/statespace_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,37 @@ class StateSpaceAVX : public StateSpace<StateSpaceAVX<For>, For, float> {
state.get()[k + 8] = im;
}

// Sets state[i] = val where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
__m256 re_reg = _mm256_set1_ps(re);
__m256 im_reg = _mm256_set1_ps(im);

auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, __m256 re_n, __m256 im_n, fp_type* p) {
__m256 ml =
_mm256_castsi256_ps(detail::GetZeroMaskAVX(8 * i, maskv, bitsv));

__m256 re = _mm256_load_ps(p + 16 * i);
__m256 im = _mm256_load_ps(p + 16 * i + 8);

re = _mm256_blendv_ps(re, re_n, ml);
im = _mm256_blendv_ps(im, im_n, ml);

_mm256_store_ps(p + 16 * i, re);
_mm256_store_ps(p + 16 * i + 8, im);
};

Base::for_.Run(MinSize(state.num_qubits()) / 16, f, mask, bits, re_reg,
im_reg, state.get());
}

// Does the equivalent of dest += src elementwise.
bool Add(const State& src, State& dest) const {
if (src.num_qubits() != dest.num_qubits()) {
Expand Down
22 changes: 22 additions & 0 deletions lib/statespace_basic.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,28 @@ class StateSpaceBasic : public StateSpace<StateSpaceBasic<For, FP>, For, FP> {
state.get()[p + 1] = im;
}

// Sets state[i] = val where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, fp_type re_n, fp_type im_n, fp_type* p) {
auto s = p + 2 * i;
bool in_mask = (i & maskv) == bitsv;

s[0] = in_mask ? re_n : s[0];
s[1] = in_mask ? im_n : s[1];
};

Base::for_.Run(MinSize(state.num_qubits()) / 2, f, mask, bits, re, im,
state.get());
}

// Does the equivalent of dest += src elementwise.
bool Add(const State& src, State& dest) const {
if (src.num_qubits() != dest.num_qubits()) {
Expand Down
30 changes: 30 additions & 0 deletions lib/statespace_sse.h
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,36 @@ class StateSpaceSSE : public StateSpace<StateSpaceSSE<For>, For, float> {
state.get()[p + 4] = im;
}

// Sets state[i] = val where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits,
const std::complex<fp_type>& val) const {
BulkSetAmpl(state, mask, bits, std::real(val), std::imag(val));
}

// Sets state[i] = complex(re, im) where (i & mask) == bits
void BulkSetAmpl(State& state, uint64_t mask, uint64_t bits, fp_type re,
fp_type im) const {
__m128 re_reg = _mm_set1_ps(re);
__m128 im_reg = _mm_set1_ps(im);

auto f = [](unsigned n, unsigned m, uint64_t i, uint64_t maskv,
uint64_t bitsv, __m128 re_n, __m128 im_n, fp_type* p) {
__m128 ml = _mm_castsi128_ps(detail::GetZeroMaskSSE(4 * i, maskv, bitsv));

__m128 re = _mm_load_ps(p + 8 * i);
__m128 im = _mm_load_ps(p + 8 * i + 4);

re = _mm_blendv_ps(re, re_n, ml);
im = _mm_blendv_ps(im, im_n, ml);

_mm_store_ps(p + 8 * i, re);
_mm_store_ps(p + 8 * i + 4, im);
};

Base::for_.Run(MinSize(state.num_qubits()) / 8, f, mask, bits, re_reg,
im_reg, state.get());
}

// Does the equivalent of dest += src elementwise.
bool Add(const State& src, State& dest) const {
if (src.num_qubits() != dest.num_qubits()) {
Expand Down
4 changes: 4 additions & 0 deletions tests/statespace_avx_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ TEST(StateSpaceAVXTest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceAVX<For>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceAVX<For>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
4 changes: 4 additions & 0 deletions tests/statespace_basic_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ TEST(StateSpaceBasicTest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceBasic<For, float>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceBasic<For, float>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
4 changes: 4 additions & 0 deletions tests/statespace_sse_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ TEST(StateSpaceSSETest, InvalidStateSize) {
TestInvalidStateSize<StateSpaceSSE<For>>();
}

TEST(StateSpaceBasicTest, BulkSetAmpl) {
TestBulkSetAmplitude<StateSpaceSSE<For>>();
}

} // namespace qsim

int main(int argc, char** argv) {
Expand Down
61 changes: 61 additions & 0 deletions tests/statespace_testfixture.h
Original file line number Diff line number Diff line change
Expand Up @@ -809,6 +809,67 @@ void TestInvalidStateSize() {
EXPECT_FALSE(!std::isnan(state_space.RealInnerProduct(state1, state2)));
}

template <typename StateSpace>
void TestBulkSetAmplitude() {
using State = typename StateSpace::State;
unsigned num_qubits = 3;

StateSpace state_space(1);

State state = state_space.Create(num_qubits);
for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 1, 0, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 2, 0, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4, 0, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));

for(int i = 0; i < 8; i++) {
state_space.SetAmpl(state, i, 1, 1);
}
state_space.BulkSetAmpl(state, 4 | 1, 4, 0, 0);
EXPECT_EQ(state_space.GetAmpl(state, 0), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 1), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 2), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 3), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 4), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 5), std::complex<float>(1, 1));
EXPECT_EQ(state_space.GetAmpl(state, 6), std::complex<float>(0, 0));
EXPECT_EQ(state_space.GetAmpl(state, 7), std::complex<float>(1, 1));
}

} // namespace qsim

#endif // STATESPACE_TESTFIXTURE_H_

0 comments on commit dbe31f4

Please sign in to comment.