diff --git a/include/dlaf/factorization/qr/t_factor_impl.h b/include/dlaf/factorization/qr/t_factor_impl.h index 5e5b00da36..7fbc0c2b6a 100644 --- a/include/dlaf/factorization/qr/t_factor_impl.h +++ b/include/dlaf/factorization/qr/t_factor_impl.h @@ -33,6 +33,7 @@ #include "dlaf/common/range2d.h" #include "dlaf/common/vector.h" #include "dlaf/communication/kernels/all_reduce.h" +#include "dlaf/communication/sync/all_reduce.h" #include "dlaf/lapack/tile.h" #include "dlaf/matrix/matrix.h" #include "dlaf/matrix/views.h" @@ -284,8 +285,8 @@ void QR_Tfactor::call(matrix::Panel>(nthreads)), ex::when_all_vector(std::move(panel_tiles)), std::move(taus), std::move(t)) | - ex::let_value([=](std::shared_ptr>& barrier_ptr, auto&& panel, - const common::internal::vector& taus, matrix::Tile& t) { + ex::let_value([=](auto& barrier_ptr, auto&& panel, const common::internal::vector& taus, + matrix::Tile& t) { matrix::Matrix t_all({t.size().rows() * to_SizeType(nthreads - 1), t.size().cols()}, t.size()); return ex::when_all_vector( @@ -399,16 +400,13 @@ void QR_TfactorDistributed::call(matrix::Panel& hh_pa common::Pipeline& mpi_col_task_chain) { namespace ex = pika::execution::experimental; - using Helpers = tfactor_l::Helpers; - // Fast return in case of no reflectors if (hh_panel.getWidth() == 0) return; - const auto v_start = hh_panel.offsetElement(); - auto dist = hh_panel.parentDistribution(); - - ex::unique_any_sender> t_local = Helpers::set0(std::move(t)); + std::vector()))> panel_tiles; + for (const auto idx : hh_panel.iteratorLocal()) + panel_tiles.push_back(hh_panel.read_sender(idx)); // Note: // T factor is an upper triangular square matrix, built column by column @@ -425,29 +423,76 @@ void QR_TfactorDistributed::call(matrix::Panel& hh_pa // 1) t = -tau(j) . V(j:, 0:j)* . V(j:, j) // 2) T(0:j, j) = T(0:j, 0:j) . t - // 1st step: compute the column partial result `t` - // First we compute the matrix vector multiplication for each column - // -tau(j) . V(j:, 0:j)* . V(j:, j) - for (const auto& v_i_loc : hh_panel.iteratorLocal()) { - const SizeType v_i = dist.template globalTileFromLocalTile(v_i_loc.row()); - const SizeType first_row_tile = std::max(0, v_i * dist.blockSize().rows() - v_start); + const auto dist = hh_panel.parentDistribution(); - // TODO - // Note: - // Since we are writing always on the same t, the gemv are serialized - // A possible solution to this would be to have multiple places where to store partial - // results, and then locally reduce them just before the reduce over ranks - t_local = Helpers::gemvColumnT(first_row_tile, hh_panel.read(v_i_loc), taus, std::move(t_local)); - } + const SizeType v_start = hh_panel.offsetElement(); + const SizeType bsRows = hh_panel.parentDistribution().blockSize().rows(); + const SizeType panelRowBegin = hh_panel.iteratorLocal().begin()->row(); + + const size_t nthreads = std::max(1, (pika::get_num_worker_threads() / 2)); + ex::start_detached( + ex::when_all(ex::just(std::make_shared>(nthreads)), + ex::when_all_vector(std::move(panel_tiles)), std::move(taus), std::move(t), + mpi_col_task_chain()) | + ex::let_value([=](auto& barrier_ptr, auto&& panel, const common::internal::vector& taus, + matrix::Tile& t, auto&& pcomm) { + matrix::Matrix t_all({t.size().rows() * to_SizeType(nthreads - 1), t.size().cols()}, + t.size()); + return ex::when_all_vector( + select(t_all, common::iterate_range2d(t_all.distribution().localNrTiles()))) | + ex::transfer( + dlaf::internal::getBackendScheduler(pika::execution::thread_priority::high)) | + ex::bulk(nthreads, [=, &barrier_ptr, &t, &taus, &panel, + &pcomm](const size_t index, std::vector>& t_all) { + using Helpers = tfactor_l::Helpers; - // at this point each rank has its partial result for each column - // so, let's reduce the results (on all ranks, so that everyone can independently compute T factor) - if (true) // TODO if the column communicator has more than 1 tile...but I just have the pipeline - t_local = dlaf::comm::scheduleAllReduceInPlace(mpi_col_task_chain(), MPI_SUM, std::move(t_local)); + tile::internal::set0(index == 0 ? t : t_all[index - 1]); - // 2nd step: compute the T factor, by performing the last step on each column - // each column depends on the previous part (all reflectors that comes before) - // so it is performed sequentially - ex::start_detached(Helpers::trmvUpdateColumn(std::move(t_local))); + // 1st step + // compute the column partial result `t` (multi-threaded) + // First we compute the matrix vector multiplication for each column + // -tau(j) . V(j:, 0:j)* . V(j:, j) + const size_t chunk_size = util::ceilDiv(panel.size(), nthreads); + const size_t begin = index * chunk_size; + const size_t end = std::min(index * chunk_size + chunk_size, panel.size()); + + for (size_t i = begin; i < end; ++i) { + const matrix::Tile& tile_v = panel[i].get(); + + const SizeType first_row_tile = + std::max(0, dist.template globalTileFromLocalTile( + panelRowBegin + to_SizeType(i)) * + bsRows - + v_start); + + if (index == 0) + t = Helpers::gemv_func(first_row_tile, tile_v, taus, std::move(t)); + else + t_all[index - 1] = + Helpers::gemv_func(first_row_tile, tile_v, taus, std::move(t_all[index - 1])); + } + + barrier_ptr->arrive_and_wait(); + + // (single-threaded) + if (index == 0) { + // reduce + for (auto& partial_t : t_all) + tile::internal::add(T(1), partial_t, t); + + // at this point each rank has its partial result for each column + // so, let's reduce the results (on all ranks, so that everyone can + // independently compute T factor) + if (pcomm.ref().size() > 1) + comm::sync::allReduceInPlace(pcomm.ref(), MPI_SUM, common::make_data(t)); + + // 2nd step + // compute the T factor, by performing the last step on each column + // (single-threaded) each column depends on the previous part (all reflectors + // that comes before) so it is performed sequentially + t = Helpers::trmv_func(std::move(t)); + } + }); + })); } }