diff --git a/test/unit/eigensolver/test_tridiag_solver_merge.cpp b/test/unit/eigensolver/test_tridiag_solver_merge.cpp index 468f184ad1..543bad1285 100644 --- a/test/unit/eigensolver/test_tridiag_solver_merge.cpp +++ b/test/unit/eigensolver/test_tridiag_solver_merge.cpp @@ -82,7 +82,7 @@ TYPED_TEST(TridiagEigensolverMergeTest, SortIndex) { } TEST(StablePartitionIndexOnDeflated, FullRange) { - const SizeType n = 10; + constexpr SizeType n = 10; const SizeType nb = 3; const LocalElementSize sz(n, 1); @@ -93,32 +93,62 @@ TEST(StablePartitionIndexOnDeflated, FullRange) { Matrix in(sz, bk); Matrix out(sz, bk); - std::vector c_arr{ColType::LowerHalf, ColType::Dense, ColType::Deflated, - ColType::Deflated, ColType::UpperHalf, ColType::UpperHalf, - ColType::LowerHalf, ColType::Dense, ColType::Deflated, - ColType::LowerHalf}; + // Note: + // UpperHalf -> u + // Dense -> f + // LowerHalf -> l + // Deflated -> d + + // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial + // | l| f|d1|d2| u| u| l| f|d3| l| c_arr + std::array c_arr{ColType::LowerHalf, ColType::Dense, + ColType::Deflated, ColType::Deflated, + ColType::UpperHalf, ColType::UpperHal ColType::LowerHalf, + ColType::Dense, ColType::Deflate ColType::LowerHalf}; DLAF_ASSERT(c_arr.size() == to_sizet(n), n); dlaf::matrix::util::set(c, [&c_arr](GlobalElementIndex i) { return c_arr[to_sizet(i.row())]; }); - // f, u, d, d, l, u, l, f, d, l - std::vector vals_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9}; - std::vector in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9}; + // | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr + // | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr + std::array in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9}; + dlaf::matrix::util::set(in, [&in_arr](GlobalElementIndex i) { return in_arr[to_sizet(i.row())]; }); + + // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial + // |30|10| 2| 3|20|40|50|60| 1|70| vals_arr + // + // | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr + // | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr + // |10|20| 2| 3|30|40|50|60| 1|70| vals_arr permuted by in_arr + std::array vals_arr{30, 10, 2, 3, 20, 40, 50, 60, 1, 70}; dlaf::matrix::util::set(vals, [&vals_arr](GlobalElementIndex i) { return vals_arr[to_sizet(i.row())]; }); - dlaf::matrix::util::set(in, [&in_arr](GlobalElementIndex i) { return in_arr[to_sizet(i.row())]; }); const SizeType i_begin = 0; const SizeType i_end = 4; auto k = stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out); - ASSERT_TRUE(tt::sync_wait(std::move(k)) == 7); + // | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial + // | l| f|d1|d2| u| u| l| f|d3| l| c_arr + // |30|10| 2| 3|20|40|50|60| 1|70| vals_arr + // + // | 1| 4| 0| 5| 6| 7| 9| 8| 2| 3| out_arr + // | f| u| l| u| l| f| l|d3|d1|d2| c_arr permuted by out_arr + // |10|20|30|40|50|60|70| 1| 2| 3| vals_arr permuted by out_arr + + const SizeType k_value = tt::sync_wait(std::move(k)); + ASSERT_TRUE(k_value == 7); - // f u l u l f l d d d - std::vector expected_out_arr{1, 4, 0, 5, 6, 7, 9, 2, 3, 8}; + std::array expected_out_arr{1, 4, 0, 5, 6, 7, 9, 8, 2, 3}; auto expected_out = [&expected_out_arr](GlobalElementIndex i) { return expected_out_arr[to_sizet(i.row())]; }; CHECK_MATRIX_EQ(expected_out, out); + + const SizeType* out_ptr = tt::sync_wait(out.read(LocalTileIndex(0, 0))).get().ptr(); + EXPECT_TRUE(std::is_sorted(out_ptr + k_value, out_ptr + n, + [&vals_arr](const SizeType i, const SizeType j) { + return vals_arr[to_sizet(i)] < vals_arr[tosizet(j)]; + })); } TYPED_TEST(TridiagEigensolverMergeTest, Deflation) {