Skip to content

Commit

Permalink
Combine the two definition of getMatrixLayout
Browse files Browse the repository at this point in the history
  • Loading branch information
aurianer committed Sep 19, 2024
1 parent f1a43df commit 1ba8c14
Showing 1 changed file with 6 additions and 19 deletions.
25 changes: 6 additions & 19 deletions include/dlaf/permutations/general/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,29 +133,16 @@ void applyPermutationOnCPU(

#if defined(DLAF_WITH_GPU)

template <class T>
MatrixLayout getMatrixLayout(const matrix::Distribution& distr,
const std::vector<matrix::Tile<T, Device::GPU>>& tiles) {
template <class TileType>
MatrixLayout getMatrixLayout(const matrix::Distribution& distr, const std::vector<TileType>& tiles) {
LocalTileSize tile_sz = distr.localNrTiles();
MatrixLayout layout;
layout.nb = distr.blockSize().rows();
layout.ld = tiles[0].ld();
layout.row_offset = (tile_sz.rows() > 1) ? tiles[1].ptr() - tiles[0].ptr() : 0;
layout.col_offset = (tile_sz.cols() > 1) ? tiles[to_sizet(tile_sz.rows())].ptr() - tiles[0].ptr() : 0;
return layout;
}

template <class T>
MatrixLayout getMatrixLayout(
const matrix::Distribution& distr,
const std::vector<matrix::internal::TileAsyncRwMutexReadOnlyWrapper<T, Device::GPU>>& tiles) {
const LocalTileSize tile_sz = distr.localNrTiles();
MatrixLayout layout;
layout.nb = distr.blockSize().rows();
layout.ld = tiles[0].get().ld();
layout.row_offset = (tile_sz.rows() > 1) ? tiles[1].get().ptr() - tiles[0].get().ptr() : 0;
using dlaf::common::internal::unwrap;
layout.ld = unwrap(tiles[0]).ld();
layout.row_offset = (tile_sz.rows() > 1) ? unwrap(tiles[1]).ptr() - unwrap(tiles[0]).ptr() : 0;
layout.col_offset =
(tile_sz.cols() > 1) ? tiles[to_sizet(tile_sz.rows())].get().ptr() - tiles[0].get().ptr() : 0;
(tile_sz.cols() > 1) ? unwrap(tiles[to_sizet(tile_sz.rows())]).ptr() - unwrap(tiles[0]).ptr() : 0;
return layout;
}

Expand Down

0 comments on commit 1ba8c14

Please sign in to comment.