diff --git a/include/dlaf/matrix/distribution.h b/include/dlaf/matrix/distribution.h index 0ae5fe0f38..f8ba15c752 100644 --- a/include/dlaf/matrix/distribution.h +++ b/include/dlaf/matrix/distribution.h @@ -672,7 +672,6 @@ class Distribution { /// Computes and sets @p local_tiles_ and @p local_size_. /// /// @post local_nr_tiles_ and local_size_ are set. - /// @pre offset_ and src_rank_index_ are already normalized. // Note: safe to use in constructors if: // - size_, is already set correctly. // - nr_tiles_, is already set correctly. @@ -684,6 +683,21 @@ class Distribution { // - src_rank_index_, is already set and normalized. void compute_local_nr_tiles_and_local_size() noexcept; + /// Computes and returns the rc coord of local_size. + /// + // Note: safe to use in constructors if: + // - size_, is already set correctly. + // - nr_tiles_, is already set correctly. + // - local_nr_tiles_, is already set correctly. + // - block_size_, is already set correctly. + // - tile_size_, is already set correctly. + // - grid_size_, is already set correctly. + // - rank_index_, is already set correctly. + // - offset_, is already set and normalized. + // - src_rank_index_, is already set and normalized. + template + SizeType compute_local_size() noexcept; + /// Normalizes @p offset_ and @p source_rank_index_ into a canonical form. /// /// @post offset_.row() < block_size_.rows() && offset_.col() < block_size_.cols() diff --git a/include/dlaf/matrix/util_distribution.h b/include/dlaf/matrix/util_distribution.h index 9d8464affb..7da44a6ee1 100644 --- a/include/dlaf/matrix/util_distribution.h +++ b/include/dlaf/matrix/util_distribution.h @@ -111,11 +111,13 @@ inline SizeType local_tile_from_global_tile(SizeType global_tile, SizeType tiles DLAF_ASSERT_HEAVY(0 <= tile_offset && tile_offset < tiles_per_block, tile_offset, tiles_per_block); if (rank == rank_global_tile(global_tile, tiles_per_block, grid_size, src_rank, tile_offset)) { - // tile_offset only affects the source rank - bool may_have_partial_first_block = rank == src_rank; global_tile += tile_offset; SizeType local_block = global_tile / tiles_per_block / grid_size; + + // tile_offset only affects the source rank + bool may_have_partial_first_block = rank == src_rank; + return local_block * tiles_per_block + global_tile % tiles_per_block - (may_have_partial_first_block ? tile_offset : 0); } diff --git a/src/matrix/distribution.cpp b/src/matrix/distribution.cpp index 385dc725ac..ecb8afee64 100644 --- a/src/matrix/distribution.cpp +++ b/src/matrix/distribution.cpp @@ -121,37 +121,31 @@ void Distribution::compute_local_nr_tiles_and_local_size() noexcept { // Set local_nr_tiles_. local_nr_tiles_ = {tile_row, tile_col}; - SizeType row = 0; - if (size_.rows() > 0) { - // Start from full tiles - row = tile_row * tile_size_.rows(); - - // Fix first tile size removing local offset - row -= local_tile_element_offset(); - - // Fix last tile size - if (rank_index_.row() == rank_global_tile(nr_tiles_.rows() - 1)) - // remove the elements missing in the last tile - row -= nr_tiles_.rows() * tile_size_.rows() - - (size_.rows() + global_tile_element_offset()); - } - SizeType col = 0; - if (size_.cols() > 0) { - // Start from full tiles - col = tile_col * tile_size_.cols(); - - // Fix first tile size removing local offset - col -= local_tile_element_offset(); - - // Fix last tile size - if (rank_index_.col() == rank_global_tile(nr_tiles_.cols() - 1)) - // remove the elements missing in the last tile - col -= nr_tiles_.cols() * tile_size_.cols() - - (size_.cols() + global_tile_element_offset()); - } + SizeType row = compute_local_size(); + SizeType col = compute_local_size(); // Set local_size_. - local_size_ = LocalElementSize(row, col); + local_size_ = {row, col}; +} + +template +SizeType Distribution::compute_local_size() noexcept { + if (local_nr_tiles_.get() == 0) + return 0; + + // Start from full tiles + SizeType ret = local_nr_tiles_.get() * tile_size_.get(); + + // Fix first tile size removing local offset + ret -= local_tile_element_offset(); + + // Fix last tile size + if (rank_index_.get() == rank_global_tile(nr_tiles_.get() - 1)) + // remove the elements missing in the last tile + ret -= nr_tiles_.get() * tile_size_.get() - + (size_.get() + global_tile_element_offset()); + + return ret; } void Distribution::normalize_source_rank_and_offset() noexcept {