From 9f3aa79cf32b6e22a842fbab8436814ab6b6df99 Mon Sep 17 00:00:00 2001 From: Alberto Invernizzi Date: Wed, 25 Sep 2024 11:20:14 +0200 Subject: [PATCH] WIP: workaround for t-factor local --- include/dlaf/factorization/qr/t_factor_impl.h | 9 ++++++--- include/dlaf/matrix/panel.h | 4 ++++ 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/include/dlaf/factorization/qr/t_factor_impl.h b/include/dlaf/factorization/qr/t_factor_impl.h index 0bf1a2616c..49a9ac73b4 100644 --- a/include/dlaf/factorization/qr/t_factor_impl.h +++ b/include/dlaf/factorization/qr/t_factor_impl.h @@ -237,6 +237,9 @@ void QR_Tfactor::call(matrix::Panel& if (hh_panel.getWidth() == 0) return; + const SizeType bs = hh_panel.parentDistribution().blockSize().rows(); + const SizeType offset_lc = (bs - hh_panel.tile_size_of_local_head().rows()); + matrix::ReadWriteTileSender t_local = Helpers::set0(std::move(t)); // Note: @@ -257,15 +260,15 @@ void QR_Tfactor::call(matrix::Panel& // 1st step: compute the column partial result `t` // First we compute the matrix vector multiplication for each column // -tau(j) . V(j:, 0:j)* . V(j:, j) - for (const auto& i_lc : hh_panel.iteratorLocal()) { + for (const auto& v_i : hh_panel.iteratorLocal()) { const SizeType first_row_tile = - (i_lc.row() - hh_panel.rangeStartLocal()) * hh_panel.parentDistribution().tile_size().rows(); + std::max(0, (v_i.row() - hh_panel.rangeStartLocal()) * bs - offset_lc); // Note: // Since we are writing always on the same t, the gemv are serialized // A possible solution to this would be to have multiple places where to store partial // results, and then locally reduce them just before the reduce over ranks - t_local = Helpers::gemvColumnT(first_row_tile, hh_panel.read(i_lc), taus, std::move(t_local)); + t_local = Helpers::gemvColumnT(first_row_tile, hh_panel.read(v_i), taus, std::move(t_local)); } // 2nd step: compute the T factor, by performing the last step on each column diff --git a/include/dlaf/matrix/panel.h b/include/dlaf/matrix/panel.h index 498ed60697..16b79476d0 100644 --- a/include/dlaf/matrix/panel.h +++ b/include/dlaf/matrix/panel.h @@ -304,6 +304,10 @@ struct Panel { has_been_used_ = false; } + TileElementSize tile_size_of_local_head() const { + return tileSize(LocalTileIndex(coord, rangeStartLocal())); + } + protected: using ReadWriteSenderType = typename BaseT::ReadWriteSenderType;