Skip to content

Commit

Permalink
renamed global_nr_tiles_ to nr_tiles_ to match with the getter, impro…
Browse files Browse the repository at this point in the history
…ved the doc, and fixed calculation of local sizes in the constructor when offset != 0 and add extra tests
  • Loading branch information
rasolca committed Sep 29, 2023
1 parent 1548669 commit ce0b617
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 127 deletions.
114 changes: 78 additions & 36 deletions include/dlaf/matrix/distribution.h
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ class Distribution {

bool operator==(const Distribution& rhs) const noexcept {
return size_ == rhs.size_ && local_size_ == rhs.local_size_ && tile_size_ == rhs.tile_size_ &&
block_size_ == rhs.block_size_ && global_nr_tiles_ == rhs.global_nr_tiles_ &&
block_size_ == rhs.block_size_ && nr_tiles_ == rhs.nr_tiles_ &&
local_nr_tiles_ == rhs.local_nr_tiles_ && rank_index_ == rhs.rank_index_ &&
grid_size_ == rhs.grid_size_ && source_rank_index_ == rhs.source_rank_index_ &&
offset_ == rhs.offset_;
Expand All @@ -232,7 +232,7 @@ class Distribution {

/// Returns the number of tiles of the global matrix (2D size).
const GlobalTileSize& nr_tiles() const noexcept {
return global_nr_tiles_;
return nr_tiles_;
}

/// Returns the number of tiles stored locally (2D size).
Expand Down Expand Up @@ -271,7 +271,7 @@ class Distribution {
/// @pre tile_element.isIn(block_size()).
GlobalElementIndex global_element_index(const GlobalTileIndex& global_tile,
const TileElementIndex& tile_element) const noexcept {
DLAF_ASSERT_HEAVY(global_tile.isIn(global_nr_tiles_), global_tile, global_nr_tiles_);
DLAF_ASSERT_HEAVY(global_tile.isIn(nr_tiles_), global_tile, nr_tiles_);
DLAF_ASSERT_HEAVY(tile_element.isIn(tile_size_of(global_tile)), tile_element,
tile_size_of(global_tile));

Expand Down Expand Up @@ -306,7 +306,7 @@ class Distribution {
///
/// @pre global_tile.isIn(nr_tiles()).
comm::Index2D rank_global_tile(const GlobalTileIndex& global_tile) const noexcept {
DLAF_ASSERT_HEAVY(global_tile.isIn(global_nr_tiles_), global_tile, global_nr_tiles_);
DLAF_ASSERT_HEAVY(global_tile.isIn(nr_tiles_), global_tile, nr_tiles_);

return {rank_global_tile<Coord::Row>(global_tile.row()),
rank_global_tile<Coord::Col>(global_tile.col())};
Expand All @@ -317,7 +317,7 @@ class Distribution {
/// @pre global_tile.isIn(nr_tiles()),
/// @pre rank_index == rank_global_tile(global_tile).
LocalTileIndex local_tile_index(const GlobalTileIndex& global_tile) const {
DLAF_ASSERT_HEAVY(global_tile.isIn(global_nr_tiles_), global_tile, global_nr_tiles_);
DLAF_ASSERT_HEAVY(global_tile.isIn(nr_tiles_), global_tile, nr_tiles_);

DLAF_ASSERT_HEAVY(rank_index_ == rank_global_tile(global_tile), rank_index_,
rank_global_tile(global_tile));
Expand Down Expand Up @@ -357,8 +357,8 @@ class Distribution {
template <Coord rc>
SizeType global_element_from_global_tile_and_tile_element(SizeType global_tile,
SizeType tile_element) const noexcept {
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < global_nr_tiles_.get<rc>(), global_tile,
global_nr_tiles_.get<rc>());
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < nr_tiles_.get<rc>(), global_tile,
nr_tiles_.get<rc>());
DLAF_ASSERT_HEAVY(0 <= tile_element && tile_element < tile_size_.get<rc>(), tile_element,
tile_size_.get<rc>());
return util::matrix::element_from_tile_and_tile_element(global_tile, tile_element,
Expand Down Expand Up @@ -463,8 +463,8 @@ class Distribution {
/// @pre 0 <= global_tile < nr_tiles().get<rc>().
template <Coord rc>
SizeType local_tile_from_global_tile(SizeType global_tile) const noexcept {
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < global_nr_tiles_.get<rc>(), global_tile,
global_nr_tiles_.get<rc>());
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < nr_tiles_.get<rc>(), global_tile,
nr_tiles_.get<rc>());
return util::matrix::local_tile_from_global_tile(global_tile, tiles_per_block<rc>(),
grid_size_.get<rc>(), rank_index_.get<rc>(),
source_rank_index_.get<rc>(),
Expand Down Expand Up @@ -499,10 +499,18 @@ class Distribution {
/// index, the local tile grid size along @rc is returned.
///
/// @pre 0 <= global_tile <= nr_tiles().get<rc>().
// Note: safe to use in constructors if:
// - nr_tiles_, is already set correctly.
// - block_size_, is already set correctly.
// - tile_size_, is already set correctly.
// - grid_size_, is already set correctly.
// - rank_index_, is already set correctly.
// - offset_, is already set and normalized.
// - src_rank_index_, is already set and normalized.
template <Coord rc>
SizeType next_local_tile_from_global_tile(SizeType global_tile) const noexcept {
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile <= global_nr_tiles_.get<rc>(), global_tile,
global_nr_tiles_.get<rc>());
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile <= nr_tiles_.get<rc>(), global_tile,
nr_tiles_.get<rc>());
return util::matrix::next_local_tile_from_global_tile(global_tile, tiles_per_block<rc>(),
grid_size_.get<rc>(), rank_index_.get<rc>(),
source_rank_index_.get<rc>(),
Expand Down Expand Up @@ -550,10 +558,17 @@ class Distribution {
/// Returns the rank index of the process that stores the tile with global index @p global_tile.
///
/// @pre 0 <= global_tile < nr_tiles().get<rc>().
// Note: safe to use in constructors if:
// - nr_tiles_, is already set correctly.
// - block_size_, is already set correctly.
// - tile_size_, is already set correctly.
// - grid_size_, is already set correctly.
// - offset_, is already set and normalized.
// - src_rank_index_, is already set and normalized.
template <Coord rc>
int rank_global_tile(SizeType global_tile) const noexcept {
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < global_nr_tiles_.get<rc>(), global_tile,
global_nr_tiles_.get<rc>());
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < nr_tiles_.get<rc>(), global_tile,
nr_tiles_.get<rc>());
return util::matrix::rank_global_tile(global_tile, tiles_per_block<rc>(), grid_size_.get<rc>(),
source_rank_index_.get<rc>(), global_tile_offset<rc>());
}
Expand All @@ -567,8 +582,8 @@ class Distribution {
/// @pre 0 <= global_tile < nr_tiles().get<rc>().
template <Coord rc>
SizeType tile_size_of(SizeType global_tile) const noexcept {
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < global_nr_tiles_.get<rc>(), global_tile,
global_nr_tiles_.get<rc>());
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < nr_tiles_.get<rc>(), global_tile,
nr_tiles_.get<rc>());
SizeType n = size_.get<rc>();
SizeType nb = tile_size_.get<rc>();
if (global_tile == 0) {
Expand All @@ -578,31 +593,47 @@ class Distribution {
}

private:
/// @pre block_size_, and tile_size_ are already set correctly.
// Note: safe to use in constructors if:
// - tile_size_, is already set correctly.
// - block_size_, is already set correctly.
template <Coord rc>
SizeType tiles_per_block() const noexcept {
return block_size_.get<rc>() / tile_size_.get<rc>();
}

/// Returns true if the current rank is the source rank along the @p rc coordinate, otherwise false.
// Note: safe to use in constructors if:
// - rank_index_, is already set correctly.
// - src_rank_index_, is already set and normalized.
template <Coord rc>
bool is_source_rank() const noexcept {
return rank_index_.get<rc>() == source_rank_index_.get<rc>();
}

/// Computes the offset inside the first global block in terms of tiles along the @p rc coordinate.
// Note: safe to use in constructors if:
// - tile_size_, is already set correctly.
// - offset_, is already set and normalized.
template <Coord rc>
SizeType global_tile_offset() const noexcept {
return offset_.get<rc>() / tile_size_.get<rc>();
}

/// Computes the offset inside the first global tile in terms of elements along the @p rc coordinate.
// Note: safe to use in constructors if:
// - tile_size_, is already set correctly.
// - offset_, is already set and normalized.
template <Coord rc>
SizeType global_tile_element_offset() const noexcept {
return offset_.get<rc>() % tile_size_.get<rc>();
}

/// Computes the offset inside the first local block in terms of tiles along the @p rc coordinate.
// Note: safe to use in constructors if:
// - tile_size_, is already set correctly.
// - rank_index_, is already set correctly.
// - offset_, is already set and normalized.
// - src_rank_index_, is already set and normalized.
template <Coord rc>
SizeType local_tile_offset() const noexcept {
if (is_source_rank<rc>()) {
Expand All @@ -614,6 +645,11 @@ class Distribution {
}

/// Computes the offset inside the first local tile in terms of elements along the @p rc coordinate.
// Note: safe to use in constructors if:
// - tile_size_, is already set correctly.
// - rank_index_, is already set correctly.
// - offset_, is already set and normalized.
// - src_rank_index_, is already set and normalized.
template <Coord rc>
SizeType local_tile_element_offset() const noexcept {
if (is_source_rank<rc>()) {
Expand All @@ -624,51 +660,57 @@ class Distribution {
}
}

/// Computes and sets @p size_.
/// Computes and sets @p global_tiles_.
///
/// @pre local_size_, is already set correctly.
/// @pre grid_size_ == {1,1}.
void compute_global_size_for_non_distr() noexcept;

/// computes and sets global_tiles_.
///
/// @pre local_size_, and tile_size_ are already set correctly.
/// @post nr_tiles_ is set.
// Note: safe to use in constructors if:
// - size_, is already set correctly.
// - tile_size_, is already set correctly.
// - offset_, is already set and normalized.
void compute_global_nr_tiles() noexcept;

/// Computes and sets @p global_tiles_, @p local_tiles_ and @p local_size_.
///
/// @pre size_, block_size_, tile_size_, grid_size_, rank_index and source_rank_index are already set correctly.
void compute_global_and_local_nr_tiles_and_local_size() noexcept;

/// computes and sets @p local_tiles_.
///
/// @pre local_size_, and tile_size_ are already set correctly.
void compute_local_nr_tiles() noexcept;
/// Computes and sets @p local_tiles_ and @p local_size_.
///
/// @post local_nr_tiles_ and local_size_ are set.
/// @pre offset_ and src_rank_index_ are already normalized.
// Note: safe to use in constructors if:
// - size_, is already set correctly.
// - nr_tiles_, is already set correctly.
// - block_size_, is already set correctly.
// - tile_size_, is already set correctly.
// - grid_size_, is already set correctly.
// - rank_index_, is already set correctly.
// - offset_, is already set and normalized.
// - src_rank_index_, is already set and normalized.
void compute_local_nr_tiles_and_local_size() noexcept;

/// Normalizes @p offset_ and @p source_rank_index_ into a canonical form.
///
/// @pre offset_ and source_rank_index_ are already set correctly.
/// @post offset_.row() < block_size_.rows() && offset_.col() < block_size_.cols()
// Note: safe to use in constructors if:
// - offset_, is already set correctly.
// - src_rank_index_, is already set correctly.
void normalize_source_rank_and_offset() noexcept;

/// Sets default values.
///
/// offset_ = {0, 0}
/// size_ = {0, 0}
/// local_size_ = {0, 0}
/// global_nr_tiles_ = {0, 0}
/// nr_tiles_ = {0, 0}
/// local_nr_tiles_ = {0, 0}
/// block_size_ = {1, 1}
/// tile_size_ = {1, 1}
/// rank_index_ = {0, 0}
/// grid_size_ = {1, 1}
/// source_rank_index_ = {0, 0}
// Note: safe to use in constructors.
void set_default_sizes() noexcept;

GlobalElementIndex offset_;
GlobalElementSize size_;
LocalElementSize local_size_;
GlobalTileSize global_nr_tiles_;
GlobalTileSize nr_tiles_;
LocalTileSize local_nr_tiles_;
TileElementSize block_size_;
TileElementSize tile_size_;
Expand Down
37 changes: 20 additions & 17 deletions include/dlaf/matrix/distribution_extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,42 +85,45 @@ SizeType local_tile_element_offset_on_rank(const Distribution& dist, comm::Index

/// Returns the local index of the tile which contains the element with local index @p local_element.
///
/// @pre 0 <= local_element < dist.local_size().get<rc>().
/// @pre 0 <= local_element < local_size.get<rc>() on given rank.
template <Coord rc>
SizeType local_tile_from_local_element_on_rank(const Distribution& dist, comm::IndexT_MPI rank,
SizeType local_element) noexcept {
DLAF_ASSERT_HEAVY(0 <= local_element && local_element < dist.local_size().get<rc>(), local_element,
dist.local_size().get<rc>());
return util::matrix::tile_from_element(local_element, dist.tile_size().get<rc>(),
local_tile_element_offset_on_rank<rc>(dist, rank));
using util::matrix::tile_from_element;
// Assertion 0 <= local_element is performed by tile_from_element
return tile_from_element(local_element, dist.tile_size().get<rc>(),
local_tile_element_offset_on_rank<rc>(dist, rank));
}

/// Returns the index within the tile of the local element with index @p local_element.
///
/// @pre 0 <= local_element < dist.local_size().get<rc>().
/// @pre 0 <= local_element < local_size.get<rc>() on given rank.
template <Coord rc>
SizeType tile_element_from_local_element_on_rank(const Distribution& dist, comm::IndexT_MPI rank,
SizeType local_element) noexcept {
DLAF_ASSERT_HEAVY(0 <= local_element && local_element < dist.local_size().get<rc>(), local_element,
dist.local_size().get<rc>());
return util::matrix::tile_element_from_element(local_element, dist.tile_size().get<rc>(),
local_tile_element_offset_on_rank<rc>(dist, rank));
using util::matrix::tile_element_from_element;
// Assertion 0 <= local_element is performed by tile_element_from_element
return tile_element_from_element(local_element, dist.tile_size().get<rc>(),
local_tile_element_offset_on_rank<rc>(dist, rank));
}

/// Returns the global index of the tile that has index @p local_tile
/// in the current rank.
///
/// @pre 0 <= local_tile < dist.local_nr_tiles().get<rc>().
/// @pre 0 <= local_tile < local_nr_tiles.get<rc>() on given rank.
template <Coord rc>
SizeType global_tile_from_local_tile_on_rank(const Distribution& dist, comm::IndexT_MPI rank,
SizeType local_tile) noexcept {
DLAF_ASSERT_HEAVY(0 <= local_tile && local_tile < dist.local_nr_tiles().get<rc>(), local_tile,
dist.local_nr_tiles().get<rc>());
using util::matrix::global_tile_from_local_tile;
// Assertion 0 <= local_tile is performed by global_tile_from_local_tile
const SizeType tiles_per_block = dist.block_size().get<rc>() / dist.tile_size().get<rc>();
return util::matrix::global_tile_from_local_tile(local_tile, tiles_per_block,
dist.grid_size().get<rc>(), rank,
dist.source_rank_index().get<rc>(),
global_tile_offset<rc>(dist));
const SizeType global_tile =
global_tile_from_local_tile(local_tile, tiles_per_block, dist.grid_size().get<rc>(), rank,
dist.source_rank_index().get<rc>(), global_tile_offset<rc>(dist));
// Assert on the result to avoid to compute the local number of tiles on the given rank
DLAF_ASSERT_HEAVY(0 <= global_tile && global_tile < dist.nr_tiles().get<rc>(), global_tile,
dist.nr_tiles().get<rc>());
return global_tile;
}

/// Returns the global index of the element
Expand Down
Loading

0 comments on commit ce0b617

Please sign in to comment.