diff --git a/include/dlaf/communication/broadcast_panel.h b/include/dlaf/communication/broadcast_panel.h index 2fa15e335b..c3cf7e0725 100644 --- a/include/dlaf/communication/broadcast_panel.h +++ b/include/dlaf/communication/broadcast_panel.h @@ -18,11 +18,13 @@ #include #include +#include #include #include #include #include #include +#include #include #include #include @@ -89,7 +91,6 @@ auto& get_taskchain(comm::CommunicatorPipeline& row return col_task_chain; } } -} // namespace internal /// Broadcast /// @@ -122,7 +123,8 @@ template & panel, matrix::Panel& panelT, comm::CommunicatorPipeline& row_task_chain, - comm::CommunicatorPipeline& col_task_chain) { + comm::CommunicatorPipeline& col_task_chain, + common::IterableRange2D range) { constexpr Coord axisT = orthogonal(axis); constexpr Coord coord = std::decay_t::coord; @@ -180,7 +182,7 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel& p auto& chain_step2 = internal::get_taskchain(row_task_chain, col_task_chain); - for (const auto& indexT : panelT.iteratorLocal()) { + for (const auto& indexT : range) { auto [index_diag, owner_diag] = internal::transposedOwner(dist, indexT); namespace ex = pika::execution::experimental; @@ -198,6 +200,39 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel& p } } } +} // namespace internal + +template >> +void broadcast(comm::IndexT_MPI rank_root, matrix::Panel& panel, + matrix::Panel& panelT, + comm::CommunicatorPipeline& row_task_chain, + comm::CommunicatorPipeline& col_task_chain) { + constexpr Coord coordT = std::decay_t::coord; + const auto& dist = panel.parentDistribution(); + + const SizeType last_tile = std::max(panelT.rangeStart(), panelT.rangeEnd() - 1); + + if (panel.rangeStart() == panel.rangeEnd()) + return; + + const auto owner = dist.template rankGlobalTile(last_tile); + const auto range = dist.rankIndex().get(coordT) == owner + ? common::iterate_range2d(*panelT.iteratorLocal().begin(), + LocalTileIndex(coordT, panelT.rangeEndLocal() - 1, 1)) + : panelT.iteratorLocal(); + + internal::broadcast(rank_root, panel, panelT, row_task_chain, col_task_chain, range); +} + +template >> +void broadcast_all(comm::IndexT_MPI rank_root, matrix::Panel& panel, + matrix::Panel& panelT, + comm::CommunicatorPipeline& row_task_chain, + comm::CommunicatorPipeline& col_task_chain) { + internal::broadcast(rank_root, panel, panelT, row_task_chain, col_task_chain, panelT.iteratorLocal()); +} } } diff --git a/include/dlaf/eigensolver/reduction_to_band/ca-impl.h b/include/dlaf/eigensolver/reduction_to_band/ca-impl.h index 3b32b254ad..4a2b525234 100644 --- a/include/dlaf/eigensolver/reduction_to_band/ca-impl.h +++ b/include/dlaf/eigensolver/reduction_to_band/ca-impl.h @@ -962,7 +962,7 @@ CARed2BandResult CAReductionToBand::call(comm::CommunicatorGrid& ws_VT.setRange(at_offset, at_end_R); ws_VT.setHeight(nrefls_step); - comm::broadcast(rank_panel, ws_V, ws_VT, mpi_row_chain, mpi_col_chain); + comm::broadcast_all(rank_panel, ws_V, ws_VT, mpi_row_chain, mpi_col_chain); // Note: // Differently from 1st pass, where transformations are independent one from the other, @@ -981,7 +981,7 @@ CARed2BandResult CAReductionToBand::call(comm::CommunicatorGrid& ws_W0T.setRange(at_offset, at_end_R); ws_W0T.setHeight(nrefls_step); - comm::broadcast(rank_panel, ws_W0, ws_W0T, mpi_row_chain, mpi_col_chain); + comm::broadcast_all(rank_panel, ws_W0, ws_W0T, mpi_row_chain, mpi_col_chain); // W1 = A W0 auto& ws_W1 = panels_w1.nextResource(); diff --git a/include/dlaf/eigensolver/reduction_to_band/impl.h b/include/dlaf/eigensolver/reduction_to_band/impl.h index 75acd0a06f..165b87c546 100644 --- a/include/dlaf/eigensolver/reduction_to_band/impl.h +++ b/include/dlaf/eigensolver/reduction_to_band/impl.h @@ -896,7 +896,7 @@ Matrix ReductionToBand::call(comm::CommunicatorGrid& gr xt.setRangeStart(at_offset); xt.setHeight(nrefls_tile); - comm::broadcast(rank_v0.col(), x, xt, mpi_row_chain, mpi_col_chain); + comm::broadcast_all(rank_v0.col(), x, xt, mpi_row_chain, mpi_col_chain); // TRAILING MATRIX UPDATE