Skip to content

Commit

Permalink
TridiagSolver: fix missing sort in the deflation (#960)
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro authored Aug 30, 2023
1 parent dd22d71 commit 7f96b89
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 100 deletions.
5 changes: 0 additions & 5 deletions include/dlaf/eigensolver/tridiag_solver/kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,6 @@ void initIndexTileAsync(SizeType tile_row, TileSender&& tile) {

#ifdef DLAF_WITH_GPU

// Returns the number of non-deflated entries
void stablePartitionIndexOnDevice(SizeType n, const ColType* c_ptr, const SizeType* in_ptr,
SizeType* out_ptr, SizeType* host_k_ptr, SizeType* device_k_ptr,
whip::stream_t stream);

template <class T>
void mergeIndicesOnDevice(const SizeType* begin_ptr, const SizeType* split_ptr, const SizeType* end_ptr,
SizeType* out_ptr, const T* v_ptr, whip::stream_t stream);
Expand Down
122 changes: 69 additions & 53 deletions include/dlaf/eigensolver/tridiag_solver/merge.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,77 +243,91 @@ auto calcTolerance(const SizeType i_begin, const SizeType i_end, Matrix<const T,
ex::ensure_started();
}

// The index array `out_ptr` holds the indices of elements of `c_ptr` that order it such that
// ColType::Deflated entries are moved to the end. The `c_ptr` array is implicitly ordered according to
// `in_ptr` on entry.
// This function returns number of non-deflated eigenvectors, together with a permutation @p out_ptr
// that represent mapping (sorted non-deflated | sorted deflated) -> initial.
//
inline SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* c_ptr,
const SizeType* in_ptr, SizeType* out_ptr) {
// The permutation will allow to keep the mapping between sorted eigenvalues and unsorted eigenvectors,
// which is useful since eigenvectors are more expensive to permuted, so we can keep them in their initial order.
//
// @param n number of eigenvalues
// @param c_ptr array[n] containing the column type of each eigenvector after deflation (initial order)
// @param evals_ptr array[n] of eigenvalues sorted as in_ptr
// @param in_ptr array[n] representing permutation current -> initial (i.e. evals[i] -> c_ptr[in_ptr[i]])
// @param out_ptr array[n] permutation (sorted non-deflated | sorted deflated) -> initial
//
// @return k number of non-deflated eigenvectors
template <class T>
SizeType stablePartitionIndexForDeflationArrays(const SizeType n, const ColType* c_ptr,
const T* evals_ptr, const SizeType* in_ptr,
SizeType* out_ptr) {
// Get the number of non-deflated entries
SizeType k = 0;
for (SizeType i = 0; i < n; ++i) {
if (c_ptr[i] != ColType::Deflated)
++k;
}

// Create the permutation (sorted non-deflated | sorted deflated) -> initial
// Note:
// Since during deflation, eigenvalues related to deflated eigenvectors, might not be sorted anymore,
// this step also take care of sorting eigenvalues (actually just their related index) by their ascending value.
SizeType i1 = 0; // index of non-deflated values in out
SizeType i2 = k; // index of deflated values
for (SizeType i = 0; i < n; ++i) {
const SizeType ii = in_ptr[i];
SizeType& io = (c_ptr[ii] != ColType::Deflated) ? i1 : i2;
out_ptr[io] = ii;
++io;

// non-deflated are untouched, just squeeze them at the beginning as they appear
if (c_ptr[ii] != ColType::Deflated) {
out_ptr[i1] = ii;
++i1;
}
// deflated are the ones that can have been moved "out-of-order" by deflation...
// ... so each time insert it in the right place based on eigenvalue value
else {
const T a = evals_ptr[ii];

SizeType j = i2;
// shift to right all greater values (shift just indices)
for (; j > k; --j) {
const T b = evals_ptr[out_ptr[j - 1]];
if (a > b) {
break;
}
out_ptr[j] = out_ptr[j - 1];
}
// and insert the current index in the empty place, such that eigenvalues are sorted.
out_ptr[j] = ii;
++i2;
}
}
return k;
}

template <Device D>
template <class T>
auto stablePartitionIndexForDeflation(const SizeType i_begin, const SizeType i_end,
Matrix<const ColType, D>& c, Matrix<const SizeType, D>& in,
Matrix<SizeType, D>& out) {
Matrix<const ColType, Device::CPU>& c,
Matrix<const T, Device::CPU>& evals,
Matrix<const SizeType, Device::CPU>& in,
Matrix<SizeType, Device::CPU>& out) {
namespace ex = pika::execution::experimental;
namespace di = dlaf::internal;

constexpr auto backend = dlaf::DefaultBackend_v<D>;

const SizeType n = problemSize(i_begin, i_end, in.distribution());
if constexpr (D == Device::CPU) {
auto part_fn = [n](const auto& c_tiles_futs, const auto& in_tiles_futs, const auto& out_tiles) {
const TileElementIndex zero_idx(0, 0);
const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx);
const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx);
SizeType* out_ptr = out_tiles[0].ptr(zero_idx);

return stablePartitionIndexForDeflationArrays(n, c_ptr, in_ptr, out_ptr);
};

TileCollector tc{i_begin, i_end};
return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(in)),
ex::when_all_vector(tc.readwrite(out))) |
di::transform(di::Policy<backend>(), std::move(part_fn));
}
else {
#ifdef DLAF_WITH_GPU
auto part_fn = [n](const auto& c_tiles_futs, const auto& in_tiles_futs, const auto& out_tiles,
auto& host_k, auto& device_k) {
const TileElementIndex zero_idx(0, 0);
const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx);
const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx);
SizeType* out_ptr = out_tiles[0].ptr(zero_idx);

return ex::just(n, c_ptr, in_ptr, out_ptr, host_k(), device_k()) |
di::transform(di::Policy<backend>(), stablePartitionIndexOnDevice) |
ex::then([&host_k]() { return *host_k(); });
};

TileCollector tc{i_begin, i_end};
return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(in)),
ex::when_all_vector(tc.readwrite(out)),
ex::just(memory::MemoryChunk<SizeType, Device::CPU>{1},
memory::MemoryChunk<SizeType, Device::GPU>{1})) |
ex::let_value(std::move(part_fn));
#endif
}
auto part_fn = [n](const auto& c_tiles_futs, const auto& evals_tiles_fut, const auto& in_tiles_futs,
const auto& out_tiles) {
const TileElementIndex zero_idx(0, 0);
const ColType* c_ptr = c_tiles_futs[0].get().ptr(zero_idx);
const T* evals_ptr = evals_tiles_fut[0].get().ptr(zero_idx);
const SizeType* in_ptr = in_tiles_futs[0].get().ptr(zero_idx);
SizeType* out_ptr = out_tiles[0].ptr(zero_idx);

return stablePartitionIndexForDeflationArrays(n, c_ptr, evals_ptr, in_ptr, out_ptr);
};

TileCollector tc{i_begin, i_end};
return ex::when_all(ex::when_all_vector(tc.read(c)), ex::when_all_vector(tc.read(evals)),
ex::when_all_vector(tc.read(in)), ex::when_all_vector(tc.readwrite(out))) |
di::transform(di::Policy<Backend::MC>(), std::move(part_fn));
}

template <Device D>
Expand Down Expand Up @@ -370,7 +384,7 @@ std::vector<GivensRotation<T>> applyDeflationToArrays(T rho, T tol, const SizeTy
// `s` is not negated.
//
// [1] LAPACK 3.10.0, file dlaed2.f, line 393
T r = std::sqrt(z1 * z1 + z2 * z2);
T r = std::hypot(z1, z2);
T c = z1 / r;
T s = z2 / r;

Expand Down Expand Up @@ -696,7 +710,8 @@ void mergeSubproblems(const SizeType i_begin, const SizeType i_split, const Size
// - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
// - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
//
auto k = stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_hm.i2, ws_h.i3) | ex::split();
auto k =
stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, ws_h.i3) | ex::split();

applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
Expand Down Expand Up @@ -1294,7 +1309,8 @@ void mergeDistSubproblems(comm::CommunicatorGrid grid,
// - solve the rank-1 problem and save eigenvalues in `d0` and `d1` (copy) and eigenvectors in `e2`.
// - set deflated diagonal entries of `U` to 1 (temporary solution until optimized GEMM is implemented)
//
auto k = stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_hm.i2, ws_h.i3) | ex::split();
auto k =
stablePartitionIndexForDeflation(i_begin, i_end, ws_h.c, ws_h.d0, ws_hm.i2, ws_h.i3) | ex::split();
applyIndex(i_begin, i_end, ws_h.i3, ws_h.d0, ws_hm.d1);
applyIndex(i_begin, i_end, ws_h.i3, ws_hm.z0, ws_hm.z1);
copy(idx_begin_tiles_vec, sz_tiles_vec, ws_hm.d1, ws_h.d0);
Expand Down
35 changes: 0 additions & 35 deletions src/eigensolver/tridiag_solver/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -185,41 +185,6 @@ void initIndexTile(SizeType offset, const matrix::Tile<SizeType, Device::GPU>& t
initIndexTile<<<nr_blocks, nr_threads, 0, stream>>>(offset, len, index_arr);
}

// -----------------------------------------
// This is a separate struct with a call operator instead of a lambda, because
// nvcc does not compile the file with a lambda.
struct PartitionIndicesPredicate {
const ColType* c_ptr;
__device__ bool operator()(const SizeType i) {
return c_ptr[i] != ColType::Deflated;
}
};

__global__ void stablePartitionIndexOnDevice(SizeType n, const ColType* c_ptr, const SizeType* in_ptr,
SizeType* out_ptr, SizeType* device_k_ptr) {
#ifdef DLAF_WITH_CUDA
constexpr auto par = thrust::cuda::par;
#elif defined(DLAF_WITH_HIP)
constexpr auto par = thrust::hip::par;
#endif

SizeType& k = *device_k_ptr;

// The number of non-deflated values
k = n - thrust::count(par, c_ptr, c_ptr + n, ColType::Deflated);

// Partition while preserving relative order such that deflated entries are at the end
thrust::stable_partition_copy(par, in_ptr, in_ptr + n, out_ptr, out_ptr + k,
PartitionIndicesPredicate{c_ptr});
}

void stablePartitionIndexOnDevice(SizeType n, const ColType* c_ptr, const SizeType* in_ptr,
SizeType* out_ptr, SizeType* host_k_ptr, SizeType* device_k_ptr,
whip::stream_t stream) {
stablePartitionIndexOnDevice<<<1, 1, 0, stream>>>(n, c_ptr, in_ptr, out_ptr, device_k_ptr);
whip::memcpy_async(host_k_ptr, device_k_ptr, sizeof(SizeType), whip::memcpy_device_to_host, stream);
}

template <class T>
__global__ void mergeIndicesOnDevice(const SizeType* begin_ptr, const SizeType* split_ptr,
const SizeType* end_ptr, SizeType* out_ptr, const T* v_ptr) {
Expand Down
48 changes: 41 additions & 7 deletions test/unit/eigensolver/test_tridiag_solver_merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,39 +82,73 @@ TYPED_TEST(TridiagEigensolverMergeTest, SortIndex) {
}

TEST(StablePartitionIndexOnDeflated, FullRange) {
const SizeType n = 10;
constexpr SizeType n = 10;
const SizeType nb = 3;

const LocalElementSize sz(n, 1);
const TileElementSize bk(nb, 1);

Matrix<ColType, Device::CPU> c(sz, bk);
Matrix<double, Device::CPU> vals(sz, bk);
Matrix<SizeType, Device::CPU> in(sz, bk);
Matrix<SizeType, Device::CPU> out(sz, bk);

// Note:
// UpperHalf -> u
// Dense -> f
// LowerHalf -> l
// Deflated -> d

// | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial
// | l| f|d1|d2| u| u| l| f|d3| l| c_arr
std::vector<ColType> c_arr{ColType::LowerHalf, ColType::Dense, ColType::Deflated,
ColType::Deflated, ColType::UpperHalf, ColType::UpperHalf,
ColType::LowerHalf, ColType::Dense, ColType::Deflated,
ColType::LowerHalf};
DLAF_ASSERT(c_arr.size() == to_sizet(n), n);
dlaf::matrix::util::set(c, [&c_arr](GlobalElementIndex i) { return c_arr[to_sizet(i.row())]; });

// f, u, d, d, l, u, l, f, d, l
std::vector<SizeType> in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9};
// | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr
// | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr
std::array<SizeType, n> in_arr{1, 4, 2, 3, 0, 5, 6, 7, 8, 9};
dlaf::matrix::util::set(in, [&in_arr](GlobalElementIndex i) { return in_arr[to_sizet(i.row())]; });

// | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial
// |30|10| 2| 3|20|40|50|60| 1|70| vals_arr
//
// | 1| 4| 2| 3| 0| 5| 6| 7| 8| 9| in_arr
// | f| u|d1|d2| l| u| l| f|d3| l| c_arr permuted by in_arr
// |10|20| 2| 3|30|40|50|60| 1|70| vals_arr permuted by in_arr
std::array<double, n> vals_arr{30, 10, 2, 3, 20, 40, 50, 60, 1, 70};
dlaf::matrix::util::set(vals,
[&vals_arr](GlobalElementIndex i) { return vals_arr[to_sizet(i.row())]; });

const SizeType i_begin = 0;
const SizeType i_end = 4;
auto k = stablePartitionIndexForDeflation(i_begin, i_end, c, in, out);
auto k = stablePartitionIndexForDeflation(i_begin, i_end, c, vals, in, out);

ASSERT_TRUE(tt::sync_wait(std::move(k)) == 7);
// | 0| 1| 2| 3| 4| 5| 6| 7| 8| 9| initial
// | l| f|d1|d2| u| u| l| f|d3| l| c_arr
// |30|10| 2| 3|20|40|50|60| 1|70| vals_arr
//
// | 1| 4| 0| 5| 6| 7| 9| 8| 2| 3| out_arr
// | f| u| l| u| l| f| l|d3|d1|d2| c_arr permuted by out_arr
// |10|20|30|40|50|60|70| 1| 2| 3| vals_arr permuted by out_arr

// f u l u l f l d d d
std::vector<SizeType> expected_out_arr{1, 4, 0, 5, 6, 7, 9, 2, 3, 8};
const SizeType k_value = tt::sync_wait(std::move(k));
ASSERT_TRUE(k_value == 7);

std::array<SizeType, n> expected_out_arr{1, 4, 0, 5, 6, 7, 9, 8, 2, 3};
auto expected_out = [&expected_out_arr](GlobalElementIndex i) {
return expected_out_arr[to_sizet(i.row())];
};
CHECK_MATRIX_EQ(expected_out, out);

const SizeType* out_ptr = tt::sync_wait(out.read(LocalTileIndex(0, 0))).get().ptr();
EXPECT_TRUE(std::is_sorted(out_ptr + k_value, out_ptr + n,
[&vals_arr](const SizeType i, const SizeType j) {
return vals_arr[to_sizet(i)] < vals_arr[to_sizet(j)];
}));
}

TYPED_TEST(TridiagEigensolverMergeTest, Deflation) {
Expand Down

0 comments on commit 7f96b89

Please sign in to comment.