Skip to content

Commit

Permalink
small extension for the test
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Aug 25, 2023
1 parent edeadf4 commit 1012688
Showing 1 changed file with 42 additions and 12 deletions.
54 changes: 42 additions & 12 deletions test/unit/eigensolver/test_tridiag_solver_merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ TYPED_TEST(TridiagEigensolverMergeTest, SortIndex) {
}

TEST(StablePartitionIndexOnDeflated, FullRange) {
const SizeType n = 10;
constexpr SizeType n = 10;
const SizeType nb = 3;

const LocalElementSize sz(n, 1);
Expand All @@ -93,32 +93,62 @@ TEST(StablePartitionIndexOnDeflated, FullRange) {
Matrix<SizeType, Device::CPU> in(sz, bk);
Matrix<SizeType, Device::CPU> out(sz, bk);

std::vector<ColType> c_arr{ColType::LowerHalf, ColType::Dense, ColType::Deflated,
ColType::Deflated, ColType::UpperHalf, ColType::UpperHalf,
ColType::LowerHalf, ColType::Dense, ColType::Deflated,
ColType::LowerHalf};
// Note:
// UpperHalf -> u
// Dense -> f
// LowerHalf -> l
// Deflated -> d

// | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial
// | l| f|d1|d2| u| u| l| f|d3| l| c_arr
std::array<ColType, n> c_arr{ColType::LowerHalf, ColType::Dense,
ColType::Deflated, ColType::Deflated,
ColType::UpperHalf, ColType::UpperHal ColType::LowerHalf,
ColType::Dense, ColType::Deflate ColType::LowerHalf};
DLAF_ASSERT(c_arr.size() == to_sizet(n), n);
dlaf::matrix::util::set(c, [&c_arr](GlobalElementIndex i) { return c_arr[to_sizet(i.row())]; });

// f, u, d, d, l, u, l, f, d, l
std::vector<double> vals_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9};
std::vector<SizeType> in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9};
// | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr
// | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr
std::array<SizeType, n> in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9};
dlaf::matrix::util::set(in, [&in_arr](GlobalElementIndex i) { return in_arr[to_sizet(i.row())]; });

// | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial
// |30|10| 2| 3|20|40|50|60| 1|70| vals_arr
//
// | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr
// | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr
// |10|20| 2| 3|30|40|50|60| 1|70| vals_arr permuted by in_arr
std::array<double, n> vals_arr{30, 10, 2, 3, 20, 40, 50, 60, 1, 70};
dlaf::matrix::util::set(vals,
[&vals_arr](GlobalElementIndex i) { return vals_arr[to_sizet(i.row())]; });
dlaf::matrix::util::set(in, [&in_arr](GlobalElementIndex i) { return in_arr[to_sizet(i.row())]; });

const SizeType i_begin = 0;
const SizeType i_end = 4;
auto k = stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out);

ASSERT_TRUE(tt::sync_wait(std::move(k)) == 7);
// | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial
// | l| f|d1|d2| u| u| l| f|d3| l| c_arr
// |30|10| 2| 3|20|40|50|60| 1|70| vals_arr
//
// | 1| 4| 0| 5| 6| 7| 9| 8| 2| 3| out_arr
// | f| u| l| u| l| f| l|d3|d1|d2| c_arr permuted by out_arr
// |10|20|30|40|50|60|70| 1| 2| 3| vals_arr permuted by out_arr

const SizeType k_value = tt::sync_wait(std::move(k));
ASSERT_TRUE(k_value == 7);

// f u l u l f l d d d
std::vector<SizeType> expected_out_arr{1, 4, 0, 5, 6, 7, 9, 2, 3, 8};
std::array<SizeType, n> expected_out_arr{1, 4, 0, 5, 6, 7, 9, 8, 2, 3};
auto expected_out = [&expected_out_arr](GlobalElementIndex i) {
return expected_out_arr[to_sizet(i.row())];
};
CHECK_MATRIX_EQ(expected_out, out);

const SizeType* out_ptr = tt::sync_wait(out.read(LocalTileIndex(0, 0))).get().ptr();
EXPECT_TRUE(std::is_sorted(out_ptr + k_value, out_ptr + n,
[&vals_arr](const SizeType i, const SizeType j) {
return vals_arr[to_sizet(i)] < vals_arr[tosizet(j)];
}));
}

TYPED_TEST(TridiagEigensolverMergeTest, Deflation) {
Expand Down

0 comments on commit 1012688

Please sign in to comment.