Skip to content

Commit

Permalink
Add Matrix::subPipeline(Const)
Browse files Browse the repository at this point in the history
  • Loading branch information
msimberg committed Jun 6, 2023
1 parent 744997e commit 3b818d8
Show file tree
Hide file tree
Showing 4 changed files with 1,131 additions and 151 deletions.
17 changes: 17 additions & 0 deletions include/dlaf/matrix/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,15 @@ class Matrix : public Matrix<const T, D> {
return readwrite(this->distribution().localTileIndex(index));
}

private:
using typename Matrix<const T, D>::SubPipelineTag;
Matrix(Matrix& mat, const SubPipelineTag);

public:
Matrix subPipeline() {
return Matrix(*this, SubPipelineTag{});
}

protected:
using Matrix<const T, D>::tileLinearIndex;

Expand Down Expand Up @@ -186,10 +195,18 @@ class Matrix<const T, D> : public internal::MatrixBase {
/// involving any of the locally available tiles are completed.
void waitLocalTiles() noexcept;

Matrix subPipelineConst() {
return Matrix(*this, SubPipelineTag{});
}

protected:
Matrix(Distribution distribution) : internal::MatrixBase{std::move(distribution)} {}

struct SubPipelineTag {};
Matrix(Matrix& mat, const SubPipelineTag);

void setUpTiles(const memory::MemoryView<ElementType, D>& mem, const LayoutInfo& layout) noexcept;
void setUpSubPipelines(Matrix<const T, D>&) noexcept;

std::vector<internal::TilePipeline<T, D>> tile_managers_;
};
Expand Down
3 changes: 3 additions & 0 deletions include/dlaf/matrix/matrix.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,8 @@ Matrix<T, D>::Matrix(Distribution distribution, const LayoutInfo& layout, Elemen
template <class T, Device D>
Matrix<T, D>::Matrix(const LayoutInfo& layout, ElementType* ptr) : Matrix<const T, D>(layout, ptr) {}

template <class T, Device D>
Matrix<T, D>::Matrix(Matrix<T, D>& mat, const SubPipelineTag tag) : Matrix<const T, D>(mat, tag) {}

}
}
24 changes: 24 additions & 0 deletions include/dlaf/matrix/matrix_const.tpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,5 +72,29 @@ void Matrix<const T, D>::setUpTiles(const memory::MemoryView<ElementType, D>& me
}
}

template <class T, Device D>
Matrix<const T, D>::Matrix(Matrix<const T, D>& mat, const SubPipelineTag)
: MatrixBase(mat.distribution()) {
setUpSubPipelines(mat);
}

template <class T, Device D>
void Matrix<const T, D>::setUpSubPipelines(Matrix<const T, D>& mat) noexcept {
namespace ex = pika::execution::experimental;

// TODO: Optimize read-after-read. This is currently forced to access the base
// matrix in readwrite mode so that we can move the tile into the
// sub-pipeline. This is semantically not required and should eventually be
// optimized.
tile_managers_.reserve(mat.tile_managers_.size());
for (auto& tm : mat.tile_managers_) {
tile_managers_.emplace_back(Tile<T, D>());
auto s = ex::when_all(tile_managers_.back().readwrite_with_wrapper(), tm.readwrite()) |
ex::then([](internal::TileAsyncRwMutexReadWriteWrapper<T, D> empty_tile_wrapper,
Tile<T, D> tile) { empty_tile_wrapper.get() = std::move(tile); });
ex::start_detached(std::move(s));
}
}

}
}
Loading

0 comments on commit 3b818d8

Please sign in to comment.