Skip to content

Commit

Permalink
WIP: workaround problem of last tile not communicated in panel bcast …
Browse files Browse the repository at this point in the history
…transpose
  • Loading branch information
albestro committed Sep 25, 2024
1 parent 9f3aa79 commit 350de05
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
41 changes: 38 additions & 3 deletions include/dlaf/communication/broadcast_panel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
#include <pika/execution.hpp>

#include <dlaf/common/index2d.h>
#include <dlaf/common/range2d.h>
#include <dlaf/communication/communicator_pipeline.h>
#include <dlaf/communication/index.h>
#include <dlaf/communication/kernels/broadcast.h>
#include <dlaf/communication/message.h>
#include <dlaf/matrix/copy_tile.h>
#include <dlaf/matrix/index.h>
#include <dlaf/matrix/panel.h>
#include <dlaf/matrix/tile.h>
#include <dlaf/types.h>
Expand Down Expand Up @@ -89,7 +91,6 @@ auto& get_taskchain(comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row
return col_task_chain;
}
}
} // namespace internal

/// Broadcast
///
Expand Down Expand Up @@ -122,7 +123,8 @@ template <class T, Device D, Coord axis, matrix::StoreTransposed storage,
void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& panel,
matrix::Panel<orthogonal(axis), T, D, storageT>& panelT,
comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row_task_chain,
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain) {
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain,
common::IterableRange2D<SizeType, matrix::LocalTile_TAG> range) {
constexpr Coord axisT = orthogonal(axis);

constexpr Coord coord = std::decay_t<decltype(panel)>::coord;
Expand Down Expand Up @@ -180,7 +182,7 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& p

auto& chain_step2 = internal::get_taskchain<comm_dir_step2>(row_task_chain, col_task_chain);

for (const auto& indexT : panelT.iteratorLocal()) {
for (const auto& indexT : range) {
auto [index_diag, owner_diag] = internal::transposedOwner<coordT>(dist, indexT);

namespace ex = pika::execution::experimental;
Expand All @@ -198,6 +200,39 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& p
}
}
}
} // namespace internal

template <class T, Device D, Coord axis, matrix::StoreTransposed storage,
matrix::StoreTransposed storageT, class = std::enable_if_t<!std::is_const_v<T>>>
void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& panel,
matrix::Panel<orthogonal(axis), T, D, storageT>& panelT,
comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row_task_chain,
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain) {
constexpr Coord coordT = std::decay_t<decltype(panelT)>::coord;

const auto& dist = panel.parentDistribution();

const SizeType last_tile = std::max(panelT.rangeStart(), panelT.rangeEnd() - 1);

if (panel.rangeStart() == panel.rangeEnd())
return;

const auto owner = dist.template rankGlobalTile<coordT>(last_tile);
const auto range = dist.rankIndex().get(coordT) == owner
? common::iterate_range2d(*panelT.iteratorLocal().begin(),
LocalTileIndex(coordT, panelT.rangeEndLocal() - 1, 1))
: panelT.iteratorLocal();

internal::broadcast(rank_root, panel, panelT, row_task_chain, col_task_chain, range);
}

template <class T, Device D, Coord axis, matrix::StoreTransposed storage,
matrix::StoreTransposed storageT, class = std::enable_if_t<!std::is_const_v<T>>>
void broadcast_all(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& panel,
matrix::Panel<orthogonal(axis), T, D, storageT>& panelT,
comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row_task_chain,
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain) {
internal::broadcast(rank_root, panel, panelT, row_task_chain, col_task_chain, panelT.iteratorLocal());
}
}
}
4 changes: 2 additions & 2 deletions include/dlaf/eigensolver/reduction_to_band/ca-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ CARed2BandResult<T, D> CAReductionToBand<B, D, T>::call(comm::CommunicatorGrid&
ws_VT.setRange(at_offset, at_end_R);
ws_VT.setHeight(nrefls_step);

comm::broadcast(rank_panel, ws_V, ws_VT, mpi_row_chain, mpi_col_chain);
comm::broadcast_all(rank_panel, ws_V, ws_VT, mpi_row_chain, mpi_col_chain);

// Note:
// Differently from 1st pass, where transformations are independent one from the other,
Expand All @@ -981,7 +981,7 @@ CARed2BandResult<T, D> CAReductionToBand<B, D, T>::call(comm::CommunicatorGrid&
ws_W0T.setRange(at_offset, at_end_R);
ws_W0T.setHeight(nrefls_step);

comm::broadcast(rank_panel, ws_W0, ws_W0T, mpi_row_chain, mpi_col_chain);
comm::broadcast_all(rank_panel, ws_W0, ws_W0T, mpi_row_chain, mpi_col_chain);

// W1 = A W0
auto& ws_W1 = panels_w1.nextResource();
Expand Down
2 changes: 1 addition & 1 deletion include/dlaf/eigensolver/reduction_to_band/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -896,7 +896,7 @@ Matrix<T, Device::CPU> ReductionToBand<B, D, T>::call(comm::CommunicatorGrid& gr
xt.setRangeStart(at_offset);
xt.setHeight(nrefls_tile);

comm::broadcast(rank_v0.col(), x, xt, mpi_row_chain, mpi_col_chain);
comm::broadcast_all(rank_v0.col(), x, xt, mpi_row_chain, mpi_col_chain);

// TRAILING MATRIX UPDATE

Expand Down

0 comments on commit 350de05

Please sign in to comment.