diff --git a/include/dlaf/matrix/distribution.h b/include/dlaf/matrix/distribution.h index 10cb35b914..007201df8e 100644 --- a/include/dlaf/matrix/distribution.h +++ b/include/dlaf/matrix/distribution.h @@ -20,10 +20,15 @@ namespace dlaf { namespace matrix { +/// Contains information to create a sub-distribution. +struct SubDistributionSpec { + GlobalElementIndex origin; + GlobalElementSize size; +}; + /// Distribution contains the information about the size and distribution of a matrix. /// /// More details available in misc/matrix_distribution.md. - class Distribution { public: /// Constructs a distribution for a non distributed matrix of size {0, 0} and block size {1, 1}. @@ -119,6 +124,15 @@ class Distribution { Distribution& operator=(Distribution&& rhs) noexcept; + /// Constructs a sub-distribution based on the given distribution @p dist specified by @p spec. + /// + /// @param[in] dist is the input distribution, + /// @param[in] spec contains the origin and size of the new distribution relative to the input distribution, + /// @pre spec.origin.isValid() + /// @pre spec.size.isValid() + /// @pre spec.origin + spec.size <= dist.size() + Distribution(Distribution dist, const SubDistributionSpec& spec); + bool operator==(const Distribution& rhs) const noexcept { return size_ == rhs.size_ && local_size_ == rhs.local_size_ && tile_size_ == rhs.tile_size_ && block_size_ == rhs.block_size_ && global_nr_tiles_ == rhs.global_nr_tiles_ && @@ -490,6 +504,30 @@ class Distribution { localElementDistanceFromLocalTile(begin.col(), end.col())}; } + /// Returns the tile index in the current distribution corresponding to a tile index @p sub_index in a + /// sub-distribution (defined by @p sub_offset and @p sub_distribution) + GlobalTileIndex globalTileIndexFromSubDistribution(const GlobalElementIndex& sub_offset, + const Distribution& sub_distribution, + const GlobalTileIndex& sub_index) const noexcept { + DLAF_ASSERT(sub_index.isIn(sub_distribution.nrTiles()), sub_index, sub_distribution.nrTiles()); + DLAF_ASSERT(isCompatibleSubDistribution(sub_offset, sub_distribution), ""); + const GlobalTileIndex tile_offset = globalTileIndex(sub_offset); + return tile_offset + common::sizeFromOrigin(sub_index); + } + + /// Returns the element offset within the tile in the current distribution corresponding to a tile + /// index @p sub_index in a sub-distribution (defined by @p sub_offset and @p sub_distribution) + TileElementIndex tileElementOffsetFromSubDistribution( + const GlobalElementIndex& sub_offset, const Distribution& sub_distribution, + const GlobalTileIndex& sub_index) const noexcept { + DLAF_ASSERT(sub_index.isIn(sub_distribution.nrTiles()), sub_index, sub_distribution.nrTiles()); + DLAF_ASSERT(isCompatibleSubDistribution(sub_offset, sub_distribution), ""); + return { + sub_index.row() == 0 ? tileElementFromGlobalElement(sub_offset.row()) : 0, + sub_index.col() == 0 ? tileElementFromGlobalElement(sub_offset.col()) : 0, + }; + } + private: /// @pre block_size_, and tile_size_ are already set correctly. template @@ -564,6 +602,25 @@ class Distribution { /// @post offset_.row() < block_size_.rows() && offset_.col() < block_size_.cols() void normalizeSourceRankAndOffset() noexcept; + /// Checks if another distribution is a compatible sub-distribution of the current distribution. + /// + /// Compatible means that the block size, tile size, rank index, and grid size are equal. + /// Sub-distribution means that the source rank index of the sub-distribution is the rank index + /// of the tile at sub_offset in the current distribution. Additionally, the size and offset of + /// the sub-distribution must be within the size of the current distribution. + bool isCompatibleSubDistribution(const GlobalElementIndex& sub_offset, + const Distribution& sub_distribution) const noexcept { + const bool compatibleGrid = blockSize() == sub_distribution.blockSize() && + baseTileSize() == sub_distribution.baseTileSize() && + rankIndex() == sub_distribution.rankIndex() && + commGridSize() == sub_distribution.commGridSize(); + const bool compatibleSourceRankIndex = + rankGlobalTile(globalTileIndex(sub_offset)) == sub_distribution.sourceRankIndex(); + const bool compatibleSize = sub_offset.row() + sub_distribution.size().rows() <= size().rows() && + sub_offset.col() + sub_distribution.size().cols() <= size().cols(); + return compatibleGrid && compatibleSourceRankIndex && compatibleSize; + } + /// Sets default values. /// /// offset_ = {0, 0} diff --git a/include/dlaf/matrix/matrix_ref.h b/include/dlaf/matrix/matrix_ref.h new file mode 100644 index 0000000000..d8fd64258a --- /dev/null +++ b/include/dlaf/matrix/matrix_ref.h @@ -0,0 +1,194 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2023, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#pragma once + +/// @file + +#include +#include +#include +#include +#include + +namespace dlaf::matrix::internal { +/// Contains information to create a sub-matrix. +using SubMatrixSpec = SubDistributionSpec; + +/// A @c MatrixRef represents a sub-matrix of a @c Matrix. +/// +/// The class has reference semantics, meaning accesses to a @c MatrixRef and +/// it's corresponding @c Matrix are interleaved if calls to read/readwrite are +/// interleaved. Access to a @c MatrixRef and its corresponding @c Matrix is not +/// thread-safe. A @c MatrixRef must outlive its corresponding @c Matrix. +template +class MatrixRef; + +template +class MatrixRef : public internal::MatrixBase { +public: + static constexpr Device device = D; + + using ElementType = T; + using TileType = Tile; + using ConstTileType = Tile; + using TileDataType = internal::TileData; + using ReadOnlySenderType = ReadOnlyTileSender; + + /// Create a sub-matrix of @p mat specified by @p spec. + /// + /// @param[in] mat is the input matrix, + /// @param[in] spec contains the origin and size of the new matrix relative to the input matrix, + /// @pre spec.origin.isValid(), + /// @pre spec.size.isValid(), + /// @pre spec.origin + spec.size <= mat.size(). + MatrixRef(Matrix& mat, const SubMatrixSpec& spec) + : internal::MatrixBase(Distribution(mat.distribution(), spec)), mat_const_(mat), + origin_(spec.origin) {} + + MatrixRef() = delete; + MatrixRef(MatrixRef&&) = delete; + MatrixRef(const MatrixRef&) = delete; + MatrixRef& operator=(MatrixRef&&) = delete; + MatrixRef& operator=(const MatrixRef&) = delete; + + /// Returns a read-only sender of the Tile with local index @p index. + /// + /// @pre index.isIn(distribution().localNrTiles()). + ReadOnlySenderType read(const LocalTileIndex& index) noexcept { + // Note: this forwards to the overload with GlobalTileIndex which will + // handle taking a subtile if needed + return read(distribution().globalTileIndex(index)); + } + + /// Returns a read-only sender of the Tile with global index @p index. + /// + /// @pre the global tile is stored in the current process, + /// @pre index.isIn(globalNrTiles()). + ReadOnlySenderType read(const GlobalTileIndex& index) { + DLAF_ASSERT(index.isIn(distribution().nrTiles()), index, distribution().nrTiles()); + + const auto parent_index( + mat_const_.distribution().globalTileIndexFromSubDistribution(origin_, distribution(), index)); + auto tile_sender = mat_const_.read(parent_index); + + const auto parent_dist = mat_const_.distribution(); + const auto parent_tile_size = parent_dist.tileSize(parent_index); + const auto tile_size = tileSize(index); + + // If the corresponding tile in the parent distribution is exactly the same + // size as the tile in the sub-distribution, we don't need to take a subtile + // and can return the tile sender directly. This avoids unnecessary wrapping. + if (parent_tile_size == tile_size) { + return tile_sender; + } + + // Otherwise we have to extract a subtile from the tile in the parent + // distribution. + const auto ij_tile = + parent_dist.tileElementOffsetFromSubDistribution(origin_, distribution(), index); + return splitTile(std::move(tile_sender), SubTileSpec{ij_tile, tile_size}); + } + +private: + Matrix& mat_const_; + +protected: + GlobalElementIndex origin_; +}; + +template +class MatrixRef : public MatrixRef { +public: + static constexpr Device device = D; + + using ElementType = T; + using TileType = Tile; + using ConstTileType = Tile; + using TileDataType = internal::TileData; + using ReadWriteSenderType = ReadWriteTileSender; + + /// Create a sub-matrix of @p mat specified by @p spec. + /// + /// @param[in] mat is the input matrix, + /// @param[in] spec contains the origin and size of the new matrix relative to the input matrix, + /// @pre spec.origin.isValid(), + /// @pre spec.size.isValid(), + /// @pre spec.origin + spec.size <= mat.size(). + MatrixRef(Matrix& mat, const SubMatrixSpec& spec) + : MatrixRef(mat, spec), mat_(mat) {} + + MatrixRef() = delete; + MatrixRef(MatrixRef&&) = delete; + MatrixRef(const MatrixRef&) = delete; + MatrixRef& operator=(MatrixRef&&) = delete; + MatrixRef& operator=(const MatrixRef&) = delete; + + /// Returns a sender of the Tile with local index @p index. + /// + /// @pre index.isIn(distribution().localNrTiles()). + ReadWriteSenderType readwrite(const LocalTileIndex& index) noexcept { + // Note: this forwards to the overload with GlobalTileIndex which will + // handle taking a subtile if needed + return readwrite(this->distribution().globalTileIndex(index)); + } + + /// Returns a sender of the Tile with global index @p index. + /// + /// @pre the global tile is stored in the current process, + /// @pre index.isIn(globalNrTiles()). + ReadWriteSenderType readwrite(const GlobalTileIndex& index) { + DLAF_ASSERT(index.isIn(this->distribution().nrTiles()), index, this->distribution().nrTiles()); + + const auto parent_index( + mat_.distribution().globalTileIndexFromSubDistribution(origin_, this->distribution(), index)); + auto tile_sender = mat_.readwrite(parent_index); + + const auto parent_dist = mat_.distribution(); + const auto parent_tile_size = parent_dist.tileSize(parent_index); + const auto tile_size = this->tileSize(index); + + // If the corresponding tile in the parent distribution is exactly the same + // size as the tile in the sub-distribution, we don't need to take a subtile + // and can return the tile sender directly. This avoids unnecessary wrapping. + if (parent_tile_size == tile_size) { + return tile_sender; + } + + // Otherwise we have to extract a subtile from the tile in the parent + // distribution. + const auto ij_tile = + parent_dist.tileElementOffsetFromSubDistribution(origin_, this->distribution(), index); + return splitTile(std::move(tile_sender), SubTileSpec{ij_tile, tile_size}); + } + +private: + Matrix& mat_; + using MatrixRef::origin_; +}; + +// ETI + +#define DLAF_MATRIX_REF_ETI(KWORD, DATATYPE, DEVICE) \ + KWORD template class MatrixRef; \ + KWORD template class MatrixRef; + +DLAF_MATRIX_REF_ETI(extern, float, Device::CPU) +DLAF_MATRIX_REF_ETI(extern, double, Device::CPU) +DLAF_MATRIX_REF_ETI(extern, std::complex, Device::CPU) +DLAF_MATRIX_REF_ETI(extern, std::complex, Device::CPU) + +#if defined(DLAF_WITH_GPU) +DLAF_MATRIX_REF_ETI(extern, float, Device::GPU) +DLAF_MATRIX_REF_ETI(extern, double, Device::GPU) +DLAF_MATRIX_REF_ETI(extern, std::complex, Device::GPU) +DLAF_MATRIX_REF_ETI(extern, std::complex, Device::GPU) +#endif +} diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 9a03606bb3..eb38366dfe 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -195,6 +195,7 @@ DLAF_addSublibrary( init.cpp matrix/distribution.cpp matrix/layout_info.cpp + matrix/matrix_ref.cpp matrix/tile.cpp matrix.cpp matrix_mirror.cpp diff --git a/src/matrix/distribution.cpp b/src/matrix/distribution.cpp index b25af8909d..d4d96e0f72 100644 --- a/src/matrix/distribution.cpp +++ b/src/matrix/distribution.cpp @@ -88,6 +88,19 @@ Distribution& Distribution::operator=(Distribution&& rhs) noexcept { return *this; } +Distribution::Distribution(Distribution rhs, const SubDistributionSpec& spec) + : Distribution(std::move(rhs)) { + DLAF_ASSERT(spec.origin.isValid(), spec.origin); + DLAF_ASSERT(spec.size.isValid(), spec.size); + DLAF_ASSERT(spec.origin.row() + spec.size.rows() <= size_.rows(), spec.origin, spec.size, size_); + DLAF_ASSERT(spec.origin.col() + spec.size.cols() <= size_.cols(), spec.origin, spec.size, size_); + + offset_ = offset_ + sizeFromOrigin(spec.origin); + size_ = spec.size; + + computeGlobalAndLocalNrTilesAndLocalSize(); +} + void Distribution::computeGlobalSizeForNonDistr() noexcept { size_ = GlobalElementSize(local_size_.rows(), local_size_.cols()); } diff --git a/src/matrix/matrix_ref.cpp b/src/matrix/matrix_ref.cpp new file mode 100644 index 0000000000..6c9e759d18 --- /dev/null +++ b/src/matrix/matrix_ref.cpp @@ -0,0 +1,26 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2023, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#include + +namespace dlaf::matrix::internal { + +DLAF_MATRIX_REF_ETI(, float, Device::CPU) +DLAF_MATRIX_REF_ETI(, double, Device::CPU) +DLAF_MATRIX_REF_ETI(, std::complex, Device::CPU) +DLAF_MATRIX_REF_ETI(, std::complex, Device::CPU) + +#if defined(DLAF_WITH_GPU) +DLAF_MATRIX_REF_ETI(, float, Device::GPU) +DLAF_MATRIX_REF_ETI(, double, Device::GPU) +DLAF_MATRIX_REF_ETI(, std::complex, Device::GPU) +DLAF_MATRIX_REF_ETI(, std::complex, Device::GPU) +#endif +} diff --git a/test/unit/matrix/CMakeLists.txt b/test/unit/matrix/CMakeLists.txt index 978bddf6cd..773f0fd81f 100644 --- a/test/unit/matrix/CMakeLists.txt +++ b/test/unit/matrix/CMakeLists.txt @@ -76,6 +76,14 @@ DLAF_addTest( MPIRANKS 6 ) +DLAF_addTest( + test_matrix_ref + SOURCES test_matrix_ref.cpp + LIBRARIES dlaf.core + USE_MAIN MPIPIKA + MPIRANKS 6 +) + DLAF_addTest( test_panel SOURCES test_panel.cpp diff --git a/test/unit/matrix/test_distribution.cpp b/test/unit/matrix/test_distribution.cpp index c99043605d..6dd48be898 100644 --- a/test/unit/matrix/test_distribution.cpp +++ b/test/unit/matrix/test_distribution.cpp @@ -633,3 +633,79 @@ TEST(DistributionTest, LocalElementDistanceFromGlobalTile) { obj.localElementDistanceFromGlobalTile(test.global_tile_begin, test.global_tile_end)); } } + +struct ParametersSubDistribution { + // Distribution settings + GlobalElementSize size; + TileElementSize block_size; + comm::Index2D rank; + comm::Size2D grid_size; + comm::Index2D src_rank; + GlobalElementIndex offset; + // Sub-distribution settings + GlobalElementIndex sub_origin; + GlobalElementSize sub_size; + // Valid indices + GlobalElementIndex global_element; + GlobalTileIndex global_tile; + comm::Index2D rank_tile; + std::array local_tile; // can be an invalid LocalTileIndex +}; + +const std::vector tests_sub_distribution = { + // {size, block_size, rank, grid_size, src_rank, offset, sub_origin, sub_size, + // global_element, global_tile, rank_tile, local_tile} + // Empty distribution + {{0, 0}, {2, 5}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}, + {{0, 0}, {2, 5}, {0, 0}, {1, 1}, {0, 0}, {4, 8}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}, + // Empty sub-distribution + {{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}, + {{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {2, 3}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}, + {{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {4, 5}, {0, 0}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}, + // Sub-distribution == distribution + {{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {0, 0}, {3, 4}, {1, 3}, {0, 1}, {0, 0}, {0, 1}}, + {{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {0, 0}, {5, 9}, {1, 3}, {0, 2}, {0, 0}, {-1, -1}}, + // clang-format off + {{123, 59}, {32, 16}, {3, 3}, {5, 7}, {3, 1}, {1, 1}, {0, 0}, {123, 59}, {30, 30}, {0, 1}, {3, 2}, {0, -1}}, + // clang-format on + // Other sub-distributions + {{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {1, 2}, {2, 1}, {0, 0}, {0, 0}, {0, 0}, {0, 0}}, + {{3, 4}, {2, 2}, {0, 0}, {1, 1}, {0, 0}, {0, 0}, {1, 2}, {2, 1}, {1, 0}, {1, 0}, {0, 0}, {1, 0}}, + {{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {3, 4}, {2, 3}, {0, 0}, {0, 0}, {1, 0}, {0, -1}}, + {{5, 9}, {3, 2}, {1, 1}, {2, 4}, {0, 2}, {1, 1}, {3, 4}, {2, 3}, {1, 2}, {0, 1}, {1, 1}, {0, 0}}, + // clang-format off + {{123, 59}, {32, 16}, {3, 3}, {5, 7}, {3, 1}, {1, 1}, {50, 17}, {40, 20}, {20, 10}, {1, 0}, {0, 2}, {-1, -1}}, + // clang-format on +}; + +TEST(DistributionTest, SubDistribution) { + for (const auto& test : tests_sub_distribution) { + Distribution dist(test.size, test.block_size, test.grid_size, test.rank, test.src_rank, test.offset); + const SubDistributionSpec spec{test.sub_origin, test.sub_size}; + Distribution sub_dist(dist, spec); + + EXPECT_EQ(sub_dist.size(), test.sub_size); + + EXPECT_EQ(sub_dist.blockSize(), dist.blockSize()); + EXPECT_EQ(sub_dist.baseTileSize(), dist.baseTileSize()); + EXPECT_EQ(sub_dist.rankIndex(), dist.rankIndex()); + EXPECT_EQ(sub_dist.commGridSize(), dist.commGridSize()); + + EXPECT_LE(sub_dist.localSize().rows(), dist.localSize().rows()); + EXPECT_LE(sub_dist.localSize().cols(), dist.localSize().cols()); + EXPECT_LE(sub_dist.localNrTiles().rows(), dist.localNrTiles().rows()); + EXPECT_LE(sub_dist.localNrTiles().cols(), dist.localNrTiles().cols()); + EXPECT_LE(sub_dist.nrTiles().rows(), dist.nrTiles().rows()); + EXPECT_LE(sub_dist.nrTiles().cols(), dist.nrTiles().cols()); + + if (!test.sub_size.isEmpty()) { + EXPECT_EQ(sub_dist.globalTileIndex(test.global_element), test.global_tile); + EXPECT_EQ(sub_dist.rankGlobalTile(sub_dist.globalTileIndex(test.global_element)), test.rank_tile); + + EXPECT_EQ(sub_dist.localTileFromGlobalElement(test.global_element.get()), + test.local_tile[0]); + EXPECT_EQ(sub_dist.localTileFromGlobalElement(test.global_element.get()), + test.local_tile[1]); + } + } +} diff --git a/test/unit/matrix/test_matrix_ref.cpp b/test/unit/matrix/test_matrix_ref.cpp new file mode 100644 index 0000000000..0d29798685 --- /dev/null +++ b/test/unit/matrix/test_matrix_ref.cpp @@ -0,0 +1,203 @@ +// +// Distributed Linear Algebra with Future (DLAF) +// +// Copyright (c) 2018-2023, ETH Zurich +// All rights reserved. +// +// Please, refer to the LICENSE file in the root directory. +// SPDX-License-Identifier: BSD-3-Clause +// + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include +#include +#include + +using namespace dlaf; +using namespace dlaf::matrix; +using namespace dlaf::matrix::internal; +using namespace dlaf::matrix::test; +using namespace dlaf::test; +using namespace testing; + +namespace ex = pika::execution::experimental; +namespace tt = pika::this_thread::experimental; + +::testing::Environment* const comm_grids_env = + ::testing::AddGlobalTestEnvironment(new CommunicatorGrid6RanksEnvironment); + +template +struct MatrixRefTest : public TestWithCommGrids {}; + +TYPED_TEST_SUITE(MatrixRefTest, MatrixElementTypes); + +struct TestSubMatrix { + GlobalElementSize size; + TileElementSize block_size; + GlobalElementIndex sub_origin; + GlobalElementSize sub_size; +}; + +const std::vector tests_sub_matrix({ + // Empty matrix + {{0, 0}, {1, 1}, {0, 0}, {0, 0}}, + // Empty sub-matrices + {{3, 4}, {3, 4}, {0, 0}, {0, 0}}, + {{3, 4}, {3, 4}, {2, 3}, {0, 0}}, + // Single-block matrix + {{3, 4}, {3, 4}, {0, 0}, {3, 4}}, + {{3, 4}, {3, 4}, {0, 0}, {2, 1}}, + {{3, 4}, {3, 4}, {1, 2}, {2, 1}}, + {{3, 4}, {8, 6}, {0, 0}, {3, 4}}, + {{3, 4}, {8, 6}, {0, 0}, {2, 1}}, + {{3, 4}, {8, 6}, {1, 2}, {2, 1}}, + // Larger matrices + {{10, 15}, {5, 5}, {6, 7}, {0, 0}}, + {{10, 15}, {5, 5}, {6, 7}, {0, 0}}, + {{10, 15}, {5, 5}, {1, 2}, {0, 0}}, + {{10, 15}, {5, 5}, {0, 0}, {10, 15}}, + {{10, 15}, {5, 5}, {0, 0}, {10, 15}}, + {{10, 15}, {5, 5}, {0, 0}, {10, 15}}, + {{10, 15}, {5, 5}, {6, 7}, {2, 2}}, + {{10, 15}, {5, 5}, {6, 7}, {4, 7}}, + {{10, 15}, {5, 5}, {1, 2}, {8, 7}}, + {{256, 512}, {32, 16}, {45, 71}, {87, 55}}, +}); + +inline bool indexInSubMatrix(const GlobalElementIndex& index, const SubMatrixSpec& spec) { + bool r = spec.origin.row() <= index.row() && index.row() < spec.origin.row() + spec.size.rows() && + spec.origin.col() <= index.col() && index.col() < spec.origin.col() + spec.size.cols(); + return r; +} + +TYPED_TEST(MatrixRefTest, Basic) { + using Type = TypeParam; + constexpr Device device = Device::CPU; + + for (const auto& comm_grid : this->commGrids()) { + for (const auto& test : tests_sub_matrix) { + Matrix mat(test.size, test.block_size, comm_grid); + Matrix& mat_const = mat; + + const SubMatrixSpec spec{test.sub_origin, test.sub_size}; + MatrixRef mat_ref(mat, spec); + MatrixRef mat_const_ref1(mat, spec); + MatrixRef mat_const_ref2(mat_const, spec); + + EXPECT_EQ(mat_ref.distribution(), mat_const_ref1.distribution()); + EXPECT_EQ(mat_ref.distribution(), mat_const_ref2.distribution()); + EXPECT_EQ(mat_ref.size(), test.sub_size); + EXPECT_EQ(mat_ref.blockSize(), mat.blockSize()); + EXPECT_EQ(mat_ref.baseTileSize(), mat.baseTileSize()); + EXPECT_EQ(mat_ref.rankIndex(), mat.rankIndex()); + EXPECT_EQ(mat_ref.commGridSize(), mat.commGridSize()); + if (test.sub_origin.isIn(GlobalElementSize(test.block_size.rows(), test.block_size.cols()))) { + EXPECT_EQ(mat_ref.sourceRankIndex(), mat.sourceRankIndex()); + } + } + } +} + +TYPED_TEST(MatrixRefTest, NonConstRefFromNonConstMatrix) { + using Type = TypeParam; + constexpr Device device = Device::CPU; + constexpr Type el_submatrix(1); + constexpr Type el_border(-1); + + const auto f_el_submatrix = [=](const GlobalElementIndex&) { return el_submatrix; }; + const auto f_el_border = [=](const GlobalElementIndex&) { return el_border; }; + + for (const auto& comm_grid : this->commGrids()) { + for (const auto& test : tests_sub_matrix) { + const SubMatrixSpec spec{test.sub_origin, test.sub_size}; + const auto f_el_full = [=](const GlobalElementIndex& index) { + return indexInSubMatrix(index, spec) ? el_submatrix : el_border; + }; + + Matrix mat_expected(test.size, test.block_size, comm_grid); + Matrix mat(test.size, test.block_size, comm_grid); + MatrixRef mat_ref(mat, spec); + + set(mat_expected, f_el_full); + set(mat, f_el_border); + for (const auto& ij_local : iterate_range2d(mat_ref.distribution().localNrTiles())) { + ex::start_detached(mat_ref.readwrite(ij_local) | + dlaf::internal::transform(dlaf::internal::Policy(), + [=](const auto& tile) { + set(tile, el_submatrix); + })); + } + + CHECK_MATRIX_EQ(f_el_full, mat); + CHECK_MATRIX_EQ(f_el_submatrix, mat_ref); + } + } +} + +TYPED_TEST(MatrixRefTest, ConstRefFromNonConstMatrix) { + using Type = TypeParam; + constexpr Device device = Device::CPU; + constexpr Type el_submatrix(1); + constexpr Type el_border(-1); + + const auto f_el_submatrix = [=](const GlobalElementIndex&) { return el_submatrix; }; + + for (const auto& comm_grid : this->commGrids()) { + for (const auto& test : tests_sub_matrix) { + const SubMatrixSpec spec{test.sub_origin, test.sub_size}; + const auto f_el_full = [=](const GlobalElementIndex& index) { + return indexInSubMatrix(index, spec) ? el_submatrix : el_border; + }; + + Matrix mat_expected(test.size, test.block_size, comm_grid); + Matrix mat(test.size, test.block_size, comm_grid); + MatrixRef mat_const_ref(mat, spec); + + set(mat_expected, f_el_full); + set(mat, f_el_full); + + CHECK_MATRIX_EQ(f_el_full, mat); + CHECK_MATRIX_EQ(f_el_submatrix, mat_const_ref); + } + } +} + +TYPED_TEST(MatrixRefTest, ConstRefFromConstMatrix) { + using Type = TypeParam; + constexpr Device device = Device::CPU; + constexpr Type el_submatrix(1); + constexpr Type el_border(-1); + + const auto f_el_submatrix = [=](const GlobalElementIndex&) { return el_submatrix; }; + + for (const auto& comm_grid : this->commGrids()) { + for (const auto& test : tests_sub_matrix) { + const SubMatrixSpec spec{test.sub_origin, test.sub_size}; + const auto f_el_full = [=](const GlobalElementIndex& index) { + return indexInSubMatrix(index, spec) ? el_submatrix : el_border; + }; + + Matrix mat_expected(test.size, test.block_size, comm_grid); + Matrix mat(test.size, test.block_size, comm_grid); + Matrix& mat_const = mat; + MatrixRef mat_const_ref(mat_const, spec); + + set(mat_expected, f_el_full); + set(mat, f_el_full); + + CHECK_MATRIX_EQ(f_el_full, mat); + CHECK_MATRIX_EQ(f_el_submatrix, mat_const_ref); + } + } +}