Skip to content

Commit

Permalink
WIP: workaround for t-factor local
Browse files Browse the repository at this point in the history
  • Loading branch information
albestro committed Sep 25, 2024
1 parent 045c5bf commit 9f3aa79
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 3 deletions.
9 changes: 6 additions & 3 deletions include/dlaf/factorization/qr/t_factor_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ void QR_Tfactor<backend, device, T>::call(matrix::Panel<Coord::Col, T, device>&
if (hh_panel.getWidth() == 0)
return;

const SizeType bs = hh_panel.parentDistribution().blockSize().rows();
const SizeType offset_lc = (bs - hh_panel.tile_size_of_local_head().rows());

matrix::ReadWriteTileSender<T, device> t_local = Helpers::set0(std::move(t));

// Note:
Expand All @@ -257,15 +260,15 @@ void QR_Tfactor<backend, device, T>::call(matrix::Panel<Coord::Col, T, device>&
// 1st step: compute the column partial result `t`
// First we compute the matrix vector multiplication for each column
// -tau(j) . V(j:, 0:j)* . V(j:, j)
for (const auto& i_lc : hh_panel.iteratorLocal()) {
for (const auto& v_i : hh_panel.iteratorLocal()) {
const SizeType first_row_tile =
(i_lc.row() - hh_panel.rangeStartLocal()) * hh_panel.parentDistribution().tile_size().rows();
std::max<SizeType>(0, (v_i.row() - hh_panel.rangeStartLocal()) * bs - offset_lc);

// Note:
// Since we are writing always on the same t, the gemv are serialized
// A possible solution to this would be to have multiple places where to store partial
// results, and then locally reduce them just before the reduce over ranks
t_local = Helpers::gemvColumnT(first_row_tile, hh_panel.read(i_lc), taus, std::move(t_local));
t_local = Helpers::gemvColumnT(first_row_tile, hh_panel.read(v_i), taus, std::move(t_local));
}

// 2nd step: compute the T factor, by performing the last step on each column
Expand Down
4 changes: 4 additions & 0 deletions include/dlaf/matrix/panel.h
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,10 @@ struct Panel<axis, const T, D, StoreTransposed::No> {
has_been_used_ = false;
}

TileElementSize tile_size_of_local_head() const {
return tileSize(LocalTileIndex(coord, rangeStartLocal()));
}

protected:
using ReadWriteSenderType = typename BaseT::ReadWriteSenderType;

Expand Down

0 comments on commit 9f3aa79

Please sign in to comment.