Skip to content

Commit

Permalink
apply suggestion
Browse files Browse the repository at this point in the history
  • Loading branch information
rasolca committed Sep 29, 2023
1 parent ce0b617 commit 4ad45a0
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 32 deletions.
16 changes: 15 additions & 1 deletion include/dlaf/matrix/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -672,7 +672,6 @@ class Distribution {
/// Computes and sets @p local_tiles_ and @p local_size_.
///
/// @post local_nr_tiles_ and local_size_ are set.
/// @pre offset_ and src_rank_index_ are already normalized.
// Note: safe to use in constructors if:
// - size_, is already set correctly.
// - nr_tiles_, is already set correctly.
Expand All @@ -684,6 +683,21 @@ class Distribution {
// - src_rank_index_, is already set and normalized.
void compute_local_nr_tiles_and_local_size() noexcept;

/// Computes and returns the rc coord of local_size.
///
// Note: safe to use in constructors if:
// - size_, is already set correctly.
// - nr_tiles_, is already set correctly.
// - local_nr_tiles_, is already set correctly.
// - block_size_, is already set correctly.
// - tile_size_, is already set correctly.
// - grid_size_, is already set correctly.
// - rank_index_, is already set correctly.
// - offset_, is already set and normalized.
// - src_rank_index_, is already set and normalized.
template <Coord rc>
SizeType compute_local_size() noexcept;

/// Normalizes @p offset_ and @p source_rank_index_ into a canonical form.
///
/// @post offset_.row() < block_size_.rows() && offset_.col() < block_size_.cols()
Expand Down
6 changes: 4 additions & 2 deletions include/dlaf/matrix/util_distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,13 @@ inline SizeType local_tile_from_global_tile(SizeType global_tile, SizeType tiles
DLAF_ASSERT_HEAVY(0 <= tile_offset && tile_offset < tiles_per_block, tile_offset, tiles_per_block);

if (rank == rank_global_tile(global_tile, tiles_per_block, grid_size, src_rank, tile_offset)) {
// tile_offset only affects the source rank
bool may_have_partial_first_block = rank == src_rank;
global_tile += tile_offset;

SizeType local_block = global_tile / tiles_per_block / grid_size;

// tile_offset only affects the source rank
bool may_have_partial_first_block = rank == src_rank;

return local_block * tiles_per_block + global_tile % tiles_per_block -
(may_have_partial_first_block ? tile_offset : 0);
}
Expand Down
52 changes: 23 additions & 29 deletions src/matrix/distribution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,37 +121,31 @@ void Distribution::compute_local_nr_tiles_and_local_size() noexcept {
// Set local_nr_tiles_.
local_nr_tiles_ = {tile_row, tile_col};

SizeType row = 0;
if (size_.rows() > 0) {
// Start from full tiles
row = tile_row * tile_size_.rows();

// Fix first tile size removing local offset
row -= local_tile_element_offset<Coord::Row>();

// Fix last tile size
if (rank_index_.row() == rank_global_tile<Coord::Row>(nr_tiles_.rows() - 1))
// remove the elements missing in the last tile
row -= nr_tiles_.rows() * tile_size_.rows() -
(size_.rows() + global_tile_element_offset<Coord::Row>());
}
SizeType col = 0;
if (size_.cols() > 0) {
// Start from full tiles
col = tile_col * tile_size_.cols();

// Fix first tile size removing local offset
col -= local_tile_element_offset<Coord::Col>();

// Fix last tile size
if (rank_index_.col() == rank_global_tile<Coord::Col>(nr_tiles_.cols() - 1))
// remove the elements missing in the last tile
col -= nr_tiles_.cols() * tile_size_.cols() -
(size_.cols() + global_tile_element_offset<Coord::Col>());
}
SizeType row = compute_local_size<Coord::Row>();
SizeType col = compute_local_size<Coord::Col>();

// Set local_size_.
local_size_ = LocalElementSize(row, col);
local_size_ = {row, col};
}

template <Coord rc>
SizeType Distribution::compute_local_size() noexcept {
if (local_nr_tiles_.get<rc>() == 0)
return 0;

// Start from full tiles
SizeType ret = local_nr_tiles_.get<rc>() * tile_size_.get<rc>();

// Fix first tile size removing local offset
ret -= local_tile_element_offset<rc>();

// Fix last tile size
if (rank_index_.get<rc>() == rank_global_tile<rc>(nr_tiles_.get<rc>() - 1))
// remove the elements missing in the last tile
ret -= nr_tiles_.get<rc>() * tile_size_.get<rc>() -
(size_.get<rc>() + global_tile_element_offset<rc>());

return ret;
}

void Distribution::normalize_source_rank_and_offset() noexcept {
Expand Down

0 comments on commit 4ad45a0

Please sign in to comment.