Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap communicators in RoundRobin and Pipeline in CommunicatorGrid #993

Merged
merged 148 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
148 commits
Select commit Hold shift + click to select a range
23edf4f
TEMP: First hacky version of round-robin pipelined communicator grid
msimberg Sep 20, 2023
ccff64c
Add configuration option for number of communicators per pipeline in …
msimberg Sep 21, 2023
9908c47
Add POC support for share pipeline access from CommunicatorGrid
msimberg Sep 22, 2023
83209fe
Add workaround for TuneParameters getting pika thread pools
msimberg Sep 22, 2023
d32419b
Remove TODO about removing wait in reduction to band
msimberg Sep 25, 2023
e615b2a
Rename communicator grid num pipelines/communicators tune parameter
msimberg Sep 25, 2023
e152b39
Add with_result_of utility
msimberg Sep 25, 2023
f9374df
Add TODO for naming Pipeline::read
msimberg Sep 26, 2023
58736d2
Remove default constructor for Pipeline again
msimberg Sep 26, 2023
80eb161
Clean up sub pipeline implementation
msimberg Sep 26, 2023
ebe70ca
Rename CommunicatorGrid::numCommunicators to numPipelines
msimberg Sep 26, 2023
eb436cf
Rename *_pipeline_ members to *_pipelines_ in CommunicatorGrid
msimberg Sep 26, 2023
1618b50
Add TODO for communicator use in tridiagonal eigensolver merge
msimberg Sep 26, 2023
b61a150
Remove comm from wrapper variable names in Pipeline
msimberg Sep 26, 2023
b20981d
Use snake_case and CamelCase consistently in new functionality
msimberg Sep 26, 2023
0e37877
Assert that bt_band_to_tridiag has at least two communicator grid pip…
msimberg Sep 26, 2023
e94825c
Rename mpi_chain_col_shared to mpi_chain_col_p2p in bt_band_to_tridiag
msimberg Sep 26, 2023
d9db5f8
Pass CommunicatorGrid by reference
msimberg Sep 26, 2023
91d03cd
Add documentation for Pipeline::sub_pipeline
msimberg Sep 26, 2023
6cc0b2d
Add TODO about CommunicatorGrid copy test
msimberg Sep 26, 2023
e218323
Use communicator grid pipelines in tridiagonal eigensolver rot
msimberg Sep 26, 2023
73dff79
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Sep 26, 2023
7d2f767
Undo postprocess.py changes
msimberg Sep 27, 2023
4d65f62
Mark grid potentially unused in tridiag_solver
msimberg Sep 27, 2023
4f2f5d7
Format pipeline.h and remove TODO
msimberg Sep 27, 2023
2a3631c
Fix pipeline use in tridiag_solver rot
msimberg Sep 28, 2023
eff4377
Fix transform to actually move callables passed to it
msimberg Sep 28, 2023
9c4f0ef
Fix pipeline move constructors/assignment to reset optionals
msimberg Sep 28, 2023
ea29c92
Use pika@main temporarily for CUDA bugfix
msimberg Sep 28, 2023
10e078c
Manually format long ETI lines
msimberg Sep 28, 2023
156e7c2
Format postprocess.py
msimberg Sep 28, 2023
a7e648e
Add missing pragma once to with_result_of.h
msimberg Sep 28, 2023
8869afa
Use pika@main for ensure_started bugfix
msimberg Oct 4, 2023
e32a1f3
Update intel apt gpg key
msimberg Oct 4, 2023
b8f94da
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Oct 4, 2023
56a8caa
Take grid by reference in gen_eigensolver.h for C API
msimberg Oct 5, 2023
0289d7c
Pass grid by non-const reference in gen_eigensolver C API test
msimberg Oct 5, 2023
33edb2b
Explicitly initialize boolean in test_pipeline
msimberg Oct 6, 2023
f6f3d30
Add basic test for sub_pipeline with parent pipeline unused
msimberg Oct 6, 2023
6c2822a
Add random access test to test_pipeline
msimberg Oct 13, 2023
db0f70c
Add pipeline::read test to test_pipeline
msimberg Oct 13, 2023
bb26699
Refactor pipeline tests
msimberg Oct 13, 2023
096fe1c
Add another simple test for sub pipelines with access from a differen…
msimberg Oct 13, 2023
004773a
Add another sub pipeline test
msimberg Oct 13, 2023
09d67d8
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Oct 13, 2023
b27e2ec
Require pika 0.19.0 or newer for ensure_started fixes
msimberg Oct 13, 2023
094fa20
Fix formatting in test_pipeline
msimberg Oct 13, 2023
1484d76
Clean up any_sender workarounds in p2p_allsum
msimberg Oct 13, 2023
a44f1c6
Make CommunicatorGrid move-only
msimberg Oct 13, 2023
62ab922
Add TODO for checking if pika runtime is initialized before construct…
msimberg Oct 13, 2023
84f25a8
Assert that RoundRobin contains at least one element when accessing r…
msimberg Oct 13, 2023
2e2e70b
Add a test to check that communicator grid gives the expected communi…
msimberg Oct 13, 2023
259947c
Add comment about explicitly specifying ncommunicator_pipelines for C…
msimberg Oct 13, 2023
33bb88c
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Oct 13, 2023
bb4b596
Relax assumptions in test_pipeline
msimberg Oct 13, 2023
931ba16
Enable disabled test cases and fix indexing bug in CommunicatorGridTe…
msimberg Nov 1, 2023
076d7a1
Fix comment in communicator grid test
msimberg Nov 1, 2023
6d1e741
Fix documentation in tune.h
msimberg Nov 1, 2023
554b04b
Remove unused pika include from communicator_grid.cpp
msimberg Nov 1, 2023
05945b6
Add more sanity checks to test_pipeline.cpp
msimberg Nov 1, 2023
5497f13
Use DLAF_ASSERT_HEAVY for assertions in round_robin.h
msimberg Nov 1, 2023
e7e4072
Remove pika@main constraint from CI
msimberg Nov 1, 2023
4605a67
Rename Pipeline::operator() to readwrite
msimberg Nov 1, 2023
13c1d2a
Add deprecated operator() back to Pipeline
msimberg Nov 1, 2023
e2e0326
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 1, 2023
ba5b0e3
Update assertions in round_robin.h based on discussion
msimberg Nov 6, 2023
8c84ef1
Introduce new type for communicator pipeline
msimberg Nov 6, 2023
7548eac
Add size and rank members to CommunicatorPipeline
msimberg Nov 7, 2023
8085aac
Only pass CommunicatorPipelines, no CommunicatorGrid, to internal alg…
msimberg Nov 7, 2023
18cd1e4
Add CommunicatorGrid::communicator_pipeline(Coord)
msimberg Nov 7, 2023
a082ffe
Template CommunicatorPipeline on type of communicator
msimberg Nov 7, 2023
6300e4e
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 13, 2023
e165a5e
Rename TODOCoord to CommunicatorType
msimberg Nov 13, 2023
cb7d275
Update naming of rank/size members of CommunicatorPipeline
msimberg Nov 13, 2023
9e55234
Clean up broadcast_panel.h
msimberg Nov 13, 2023
de364b9
Move MPI typedefs for indices to separate header
msimberg Nov 13, 2023
b5de0e7
Move CommunicatorType to separate header
msimberg Nov 13, 2023
128f27b
Remove TODO
msimberg Nov 13, 2023
449fd66
Format files
msimberg Nov 13, 2023
7e4d014
Update auto return type to explicit type
msimberg Nov 13, 2023
3333faa
Add docstring to CommunicatorGrid::communicator_pipeline
msimberg Nov 13, 2023
1728b47
Remove unnecessary typedef
msimberg Nov 13, 2023
c1678c8
Minor cleanup
msimberg Nov 13, 2023
6748611
Update documentation for CommunicatorPipeline
msimberg Nov 13, 2023
665402e
Add missing index.h includes
msimberg Nov 13, 2023
a5e75e9
Format communicator_pipeline.h
msimberg Nov 13, 2023
a21b567
Formatting and miscellaneous test fixes
msimberg Nov 14, 2023
60fa809
Fix formatting in communicator_grid.h
msimberg Nov 14, 2023
078e97a
Add top-level docstring to CommunicatorPipeline
msimberg Nov 14, 2023
67c6e96
Remove stray comment line
msimberg Nov 14, 2023
51aab43
Add missing communicator_type.h header
msimberg Nov 14, 2023
0b90659
Fix formatting in communicator_type.h
msimberg Nov 14, 2023
9acff10
Minor fixes
msimberg Nov 14, 2023
8953122
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 14, 2023
8deca51
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 15, 2023
d26ae3f
Make Pipeline move constructor and assignment operator noexcept
msimberg Nov 20, 2023
c32ca21
Use Size2D::linear_size in CommunicatorPipeline
msimberg Nov 20, 2023
e411415
Explicitly delete CommunicatorGrid copy constructor and assignment op…
msimberg Nov 20, 2023
553a21e
Use default "invalid" indices in communicator pipeline
msimberg Nov 20, 2023
0081fc1
Add comment about pipelines in band_to_tridiag implementation
msimberg Nov 20, 2023
c68e274
Resolve TODO in t_factor implementation
msimberg Nov 20, 2023
b78e638
Rename CommunicatorPipeline::read/readwrite to shared/exclusive
msimberg Nov 22, 2023
48f2016
Add another note about using two pipelines in band_to_tridiag
msimberg Nov 22, 2023
b11bbc3
Remove manual use of shared communicator in permutations implementation
msimberg Nov 22, 2023
d977df9
Store communicator 1d rank and size in CommunicatorPipeline
msimberg Nov 22, 2023
0c17bc7
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 22, 2023
d8df279
Fix formatting
msimberg Nov 22, 2023
4927179
Rename CommunicatorPipeline::rankFullCommunicator to rank_full_commun…
msimberg Nov 22, 2023
d856040
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 23, 2023
286628b
Remove unnecessary note
msimberg Nov 23, 2023
bb4c341
Remove unused grid in test
msimberg Nov 23, 2023
a4cde45
Add default constructor for Pipeline
msimberg Nov 23, 2023
379617e
Remove deprecated operator() from Pipeline
msimberg Nov 23, 2023
e16a63b
Update note in band_to_tridiag about communicator pipelines
msimberg Nov 23, 2023
bafb3fe
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 27, 2023
f6203ca
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 27, 2023
19db4ed
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Nov 29, 2023
c5d00e8
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Dec 4, 2023
7f365a3
Update gemm API signatures to use new communicator pipeline
msimberg Dec 4, 2023
7787993
Fix assertions in gemm
msimberg Dec 4, 2023
4726236
Update assertions in gemm
msimberg Dec 4, 2023
1660eb8
Qualify calls to equal_process_grid
msimberg Dec 4, 2023
7dee6b1
Template equal_process_grid for comparison with communicator pipeline…
msimberg Dec 4, 2023
a242a83
Update assertions in gemm to first check for equality of communicator…
msimberg Dec 4, 2023
90cfa95
merge-squashed: propedeuthic changes towards gemm cost reduction
albestro Sep 15, 2023
0d17094
document i6/i2 for rank, embed i2 invert, remove redundant n paramete…
albestro Dec 11, 2023
4d6af6d
remove _futs suffix from parameters
albestro Dec 11, 2023
5ba97b5
update to snake case
albestro Dec 11, 2023
70d37c2
fix doc
albestro Dec 11, 2023
4ce60c0
add check for pre-weights and opt for not reducing over k_lc
albestro Dec 11, 2023
199cf80
fix TODO about reducing just needed sum_squares ("norms")
albestro Dec 11, 2023
84efe67
drop eg (extended-global) from variables (following doc decision)
albestro Dec 11, 2023
04a2587
merged squashed changes to make solveRank1 multi-threaded
albestro Dec 4, 2023
a538c2e
adapt with changes on step2
albestro Dec 12, 2023
887cc16
merge-squashed trisolver dist change to reduce gemm cost
albestro Sep 25, 2023
cc36539
merged squashed changes to make solveRank1 multi-threaded
albestro Dec 4, 2023
a2d6aa1
workload in terms of tilesize instead of blocksize
albestro Dec 12, 2023
ec9ebf6
snake_case
albestro Dec 12, 2023
8817224
merge-squashed trisolver dist change to reduce gemm cost
albestro Sep 25, 2023
941ffb4
Merge branch 'alby/trisolver-dist-opt-step3' into comm-grid-round-robin
msimberg Dec 12, 2023
e65a26d
Merge remote-tracking branch 'origin/alby/trisolver-dist-opt-step3' i…
msimberg Dec 12, 2023
e0b8b9e
merge-squashed trisolver dist change to reduce gemm cost
albestro Sep 25, 2023
4f31507
start factoring out gemm
albestro Dec 12, 2023
a824d40
doc + minor changes
albestro Dec 12, 2023
c422602
make gemm scheduler hp + fix doc + minor changes
albestro Dec 12, 2023
4eb08c6
snake case
albestro Dec 12, 2023
3353914
Merge remote-tracking branch 'origin/alby/trisolver-dist-opt-step3' i…
msimberg Dec 12, 2023
b72fde5
Merge remote-tracking branch 'origin/master' into comm-grid-round-robin
msimberg Dec 12, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/dlaf/auxiliary/norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace dlaf::auxiliary {
/// @pre @p A has tilesize (NB x NB)
/// @return the max norm of the Matrix @p A or 0 if `A.size().isEmpty()`
template <Backend backend, Device device, class T>
dlaf::BaseType<T> max_norm(comm::CommunicatorGrid grid, comm::Index2D rank, blas::Uplo uplo,
dlaf::BaseType<T> max_norm(comm::CommunicatorGrid& grid, comm::Index2D rank, blas::Uplo uplo,
Matrix<const T, device>& A) {
using dlaf::matrix::equal_process_grid;
using dlaf::matrix::single_tile_per_block;
Expand Down
4 changes: 2 additions & 2 deletions include/dlaf/auxiliary/norm/api.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@ struct Norm {};

template <class T>
struct Norm<Backend::MC, Device::CPU, T> {
static dlaf::BaseType<T> max_L(comm::CommunicatorGrid comm_grid, comm::Index2D rank,
static dlaf::BaseType<T> max_L(comm::CommunicatorGrid& comm_grid, comm::Index2D rank,
Matrix<const T, Device::CPU>& matrix);

static dlaf::BaseType<T> max_G(comm::CommunicatorGrid comm_grid, comm::Index2D rank,
static dlaf::BaseType<T> max_G(comm::CommunicatorGrid& comm_grid, comm::Index2D rank,
Matrix<const T, Device::CPU>& matrix);
};

Expand Down
4 changes: 2 additions & 2 deletions include/dlaf/auxiliary/norm/mc.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ namespace dlaf::auxiliary::internal {
// - sy/he lower
// - tr lower non-unit
template <class T>
dlaf::BaseType<T> Norm<Backend::MC, Device::CPU, T>::max_L(comm::CommunicatorGrid comm_grid,
dlaf::BaseType<T> Norm<Backend::MC, Device::CPU, T>::max_L(comm::CommunicatorGrid& comm_grid,
comm::Index2D rank,
Matrix<const T, Device::CPU>& matrix) {
using namespace dlaf::matrix;
Expand Down Expand Up @@ -92,7 +92,7 @@ dlaf::BaseType<T> Norm<Backend::MC, Device::CPU, T>::max_L(comm::CommunicatorGri
}

template <class T>
dlaf::BaseType<T> Norm<Backend::MC, Device::CPU, T>::max_G(comm::CommunicatorGrid comm_grid,
dlaf::BaseType<T> Norm<Backend::MC, Device::CPU, T>::max_G(comm::CommunicatorGrid& comm_grid,
comm::Index2D rank,
Matrix<const T, Device::CPU>& matrix) {
using namespace dlaf::matrix;
Expand Down
84 changes: 78 additions & 6 deletions include/dlaf/common/pipeline.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

/// @file

#include <utility>

#include <pika/async_rw_mutex.hpp>
#include <pika/execution.hpp>

Expand All @@ -32,25 +34,78 @@ class Pipeline {
using AsyncRwMutex = pika::execution::experimental::async_rw_mutex<T>;

public:
using Wrapper = typename AsyncRwMutex::readwrite_access_type;
using Sender = pika::execution::experimental::unique_any_sender<Wrapper>;
using ReadOnlyWrapper = typename AsyncRwMutex::read_access_type;
using ReadWriteWrapper = typename AsyncRwMutex::readwrite_access_type;
using ReadOnlySender = pika::execution::experimental::any_sender<ReadOnlyWrapper>;
using ReadWriteSender = pika::execution::experimental::unique_any_sender<ReadWriteWrapper>;

/// Create an invalid Pipeline.
Pipeline() = default;

/// Create a Pipeline by moving in the resource (it takes the ownership).
explicit Pipeline(T object) : pipeline(std::move(object)) {}
Pipeline(Pipeline&&) = default;
Pipeline& operator=(Pipeline&&) = default;

Pipeline(Pipeline&& other) noexcept
: pipeline(std::exchange(other.pipeline, std::nullopt)),
nested_sender(std::exchange(other.nested_sender, std::nullopt)) {}

Pipeline& operator=(Pipeline&& other) noexcept {
if (this != &other) {
pipeline = std::exchange(other.pipeline, std::nullopt);
nested_sender = std::exchange(other.nested_sender, std::nullopt);
}

return *this;
};

Pipeline(const Pipeline&) = delete;
Pipeline& operator=(const Pipeline&) = delete;

/// Enqueue for the resource.
~Pipeline() {
release_parent_pipeline();
}

/// Enqueue for exclusive read-write access to the resource.
///
/// @return a sender that will become ready as soon as the previous user releases the resource.
/// @pre valid()
Sender operator()() {
ReadWriteSender readwrite() {
DLAF_ASSERT(valid(), "");
return pipeline->readwrite();
}

/// Enqueue for shared read-only access to the resource.
///
/// @return a sender that will become ready as soon as the previous user releases the resource.
/// @pre valid()
ReadOnlySender read() {
DLAF_ASSERT(valid(), "");
return pipeline->read();
}

/// Create a sub pipeline to the value contained in the current Pipeline
///
/// All accesses to the sub pipeline are sequenced after previous accesses and before later accesses to
/// the original pipeline, independently of when values are accessed in the sub pipeline.
Pipeline sub_pipeline() {
namespace ex = pika::execution::experimental;

// Move value from pipeline into sub pipeline, then store a sender of the wrapper of the pipeline in
// a sender which we will release when the sub pipeline is released. This ensures that all accesses
// to the sub pipeline happen after previous accesses and before later accesses to the pipeline.
Pipeline sub_pipeline(T{});
sub_pipeline.nested_sender =
ex::when_all(sub_pipeline.pipeline->readwrite(), this->pipeline->readwrite()) |
ex::then([](auto sub_wrapper, auto wrapper) {
sub_wrapper.get() = std::move(wrapper.get());

return wrapper;
}) |
ex::ensure_started();

return sub_pipeline;
}

/// Check if the pipeline is valid.
///
/// @return true if the pipeline hasn't been reset, otherwise false.
Expand All @@ -62,10 +117,27 @@ class Pipeline {
///
/// @post !valid()
void reset() noexcept {
release_parent_pipeline();
pipeline.reset();
}

private:
void release_parent_pipeline() {
namespace ex = pika::execution::experimental;

if (nested_sender) {
DLAF_ASSERT(valid(), "");

auto s =
ex::when_all(pipeline->readwrite(), std::move(nested_sender.value())) |
ex::then([](auto sub_wrapper, auto wrapper) { wrapper.get() = std::move(sub_wrapper.get()); });
ex::start_detached(std::move(s));
nested_sender.reset();
}
}

std::optional<AsyncRwMutex> pipeline = std::nullopt;
std::optional<pika::execution::experimental::unique_any_sender<ReadWriteWrapper>> nested_sender =
std::nullopt;
};
}
24 changes: 21 additions & 3 deletions include/dlaf/common/round_robin.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,28 +13,46 @@

#include <vector>

#include <dlaf/common/assert.h>

namespace dlaf {
namespace common {

template <class T>
struct RoundRobin {
class RoundRobin {
public:
RoundRobin() = default;

template <class... Args>
RoundRobin(std::size_t n, Args... args) : curr_index_(0) {
RoundRobin(std::size_t n, Args... args) {
pool_.reserve(n);
for (std::size_t i = 0; i < n; ++i)
pool_.emplace_back(args...);
}

RoundRobin(RoundRobin&&) = default;
RoundRobin(const RoundRobin&) = delete;
RoundRobin& operator=(RoundRobin&&) = default;
RoundRobin& operator=(const RoundRobin&) = delete;

T& currentResource() {
DLAF_ASSERT(curr_index_ < pool_.size(), curr_index_, pool_.size());
msimberg marked this conversation as resolved.
Show resolved Hide resolved
return pool_[curr_index_];
}

T& nextResource() {
DLAF_ASSERT(!pool_.empty(), "");
curr_index_ = (curr_index_ + 1) % pool_.size();
DLAF_ASSERT_HEAVY(curr_index_ < pool_.size(), curr_index_, pool_.size());
return pool_[curr_index_];
}

std::size_t curr_index_;
std::size_t size() const noexcept {
return pool_.size();
}

private:
std::size_t curr_index_ = 0;
std::vector<T> pool_;
};

Expand Down
41 changes: 41 additions & 0 deletions include/dlaf/common/with_result_of.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
//
// Distributed Linear Algebra with Future (DLAF)
//
// Copyright (c) 2018-2023, ETH Zurich
// All rights reserved.
//
// Please, refer to the LICENSE file in the root directory.
// SPDX-License-Identifier: BSD-3-Clause
//

#pragma once

/// @file

#include <type_traits>
#include <utility>

namespace dlaf::internal {
// Based on https://quuxplusone.github.io/blog/2018/05/17/super-elider-round-2/ and
// https://akrzemi1.wordpress.com/2018/05/16/rvalues-redefined/.
//
// Because of the conversion operator and guaranteed copy-elision, useful for
// emplacing immovable types into e.g. variants and optionals. Can also be used
// to construct new instances for each element in a vector, where the element
// type has reference semantics and regular copy construction is not what is
// wanted.
template <typename F>
class WithResultOf {
F&& f;

public:
using ResultType = std::invoke_result_t<F&&>;
explicit WithResultOf(F&& f) : f(std::forward<F>(f)) {}
operator ResultType() {
return std::forward<F>(f)();
}
};

template <typename F>
WithResultOf(F&&) -> WithResultOf<F&&>;
}
38 changes: 25 additions & 13 deletions include/dlaf/communication/broadcast_panel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include <pika/execution.hpp>

#include <dlaf/common/index2d.h>
#include <dlaf/communication/communicator_pipeline.h>
#include <dlaf/communication/index.h>
#include <dlaf/communication/kernels/broadcast.h>
#include <dlaf/communication/message.h>
#include <dlaf/matrix/copy_tile.h>
Expand Down Expand Up @@ -56,7 +58,7 @@ std::pair<SizeType, comm::IndexT_MPI> transposedOwner(const matrix::Distribution
template <class T, Device D, Coord axis, matrix::StoreTransposed storage,
class = std::enable_if_t<!std::is_const_v<T>>>
void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& panel,
common::Pipeline<comm::Communicator>& serial_comm) {
comm::CommunicatorPipeline<coord_to_communicator_type(orthogonal(axis))>& serial_comm) {
constexpr auto comm_coord = axis;

// do not schedule communication tasks if there is no reason to do so...
Expand All @@ -68,12 +70,25 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& p
namespace ex = pika::execution::experimental;
for (const auto& index : panel.iteratorLocal()) {
if (rank == rank_root)
ex::start_detached(scheduleSendBcast(serial_comm(), panel.read(index)));
ex::start_detached(scheduleSendBcast(serial_comm.exclusive(), panel.read(index)));
else
ex::start_detached(scheduleRecvBcast(serial_comm(), rank_root, panel.readwrite(index)));
ex::start_detached(scheduleRecvBcast(serial_comm.exclusive(), rank_root, panel.readwrite(index)));
}
}

namespace internal {
template <Coord C>
auto& get_taskchain(comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row_task_chain,
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain) {
if constexpr (C == Coord::Row) {
return row_task_chain;
}
else {
return col_task_chain;
}
}
} // namespace internal

/// Broadcast
///
/// Given a source panel on a rank, this communication pattern makes every rank access tiles of both:
Expand Down Expand Up @@ -107,17 +122,13 @@ template <class T, Device D, Coord axis, matrix::StoreTransposed storage,
matrix::StoreTransposed storageT, class = std::enable_if_t<!std::is_const_v<T>>>
void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& panel,
matrix::Panel<orthogonal(axis), T, D, storageT>& panelT,
common::Pipeline<comm::Communicator>& row_task_chain,
common::Pipeline<comm::Communicator>& col_task_chain) {
comm::CommunicatorPipeline<comm::CommunicatorType::Row>& row_task_chain,
comm::CommunicatorPipeline<comm::CommunicatorType::Col>& col_task_chain) {
constexpr Coord axisT = orthogonal(axis);

constexpr Coord coord = std::decay_t<decltype(panel)>::coord;
constexpr Coord coordT = std::decay_t<decltype(panelT)>::coord;

auto get_taskchain = [&](Coord comm_dir) -> auto& {
return comm_dir == Coord::Row ? row_task_chain : col_task_chain;
};

// Note:
// Given a source panel, this communication pattern makes every rank access tiles of both the
// source panel and it's tranposed variant (just tile coordinates, data is not transposed).
Expand Down Expand Up @@ -160,15 +171,15 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& p

// STEP 1
constexpr auto comm_dir_step1 = orthogonal(axis);
auto& chain_step1 = get_taskchain(comm_dir_step1);
auto& chain_step1 = internal::get_taskchain<comm_dir_step1>(row_task_chain, col_task_chain);

broadcast(rank_root, panel, chain_step1);

// STEP 2
constexpr auto comm_dir_step2 = orthogonal(axisT);
constexpr auto comm_coord_step2 = axisT;

auto& chain_step2 = get_taskchain(comm_dir_step2);
auto& chain_step2 = internal::get_taskchain<comm_dir_step2>(row_task_chain, col_task_chain);

const SizeType last_tile = std::max(panelT.rangeStart(), panelT.rangeEnd() - 1);
const auto owner = dist.template rankGlobalTile<coordT>(last_tile);
Expand All @@ -186,11 +197,12 @@ void broadcast(comm::IndexT_MPI rank_root, matrix::Panel<axis, T, D, storage>& p
panelT.setTile(indexT, panel.read({coord, index_diag_local}));

if (dist.commGridSize().get(comm_coord_step2) > 1)
ex::start_detached(scheduleSendBcast(chain_step2(), panelT.read(indexT)));
ex::start_detached(scheduleSendBcast(chain_step2.exclusive(), panelT.read(indexT)));
}
else {
if (dist.commGridSize().get(comm_coord_step2) > 1)
ex::start_detached(scheduleRecvBcast(chain_step2(), owner_diag, panelT.readwrite(indexT)));
ex::start_detached(scheduleRecvBcast(chain_step2.exclusive(), owner_diag,
panelT.readwrite(indexT)));
}
}
}
Expand Down
4 changes: 1 addition & 3 deletions include/dlaf/communication/communicator.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
#include <memory>

#include <dlaf/communication/error.h>
#include <dlaf/communication/index.h>

namespace dlaf {
namespace comm {

/// Type used for indexes in MPI API.
using IndexT_MPI = int;

class CommunicatorImpl;

/// MPI-compatible wrapper for the MPI_Comm.
Expand Down
Loading
Loading