Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Annotate cholesky for use with apex/other #826

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 23 additions & 8 deletions include/dlaf/factorization/cholesky/impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,22 @@

namespace dlaf::factorization::internal {

#ifdef PIKA_HAVE_APEX
#define ANNOTATE(NAME) (priority == pika::execution::thread_priority::high ? "HP_" #NAME : #NAME)
#else
#define ANNOTATE(name) nullptr
#endif

namespace cholesky_l {

template <Backend backend, class MatrixTileSender>
void potrfDiagTile(pika::execution::thread_priority priority, MatrixTileSender&& matrix_tile) {
using pika::execution::thread_stacksize;

pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Uplo::Lower, std::forward<MatrixTileSender>(matrix_tile)) |
tile::potrf(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::potrf(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(potrfDiagTile))));
}

template <Backend backend, class KKTileSender, class MatrixTileSender>
Expand All @@ -63,7 +71,8 @@ void trsmPanelTile(pika::execution::thread_priority priority, KKTileSender&& kk_
blas::Diag::NonUnit, ElementType(1.0),
std::forward<KKTileSender>(kk_tile),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::trsm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::trsm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(trsmPanelTile))));
}

template <Backend backend, class PanelTileSender, class MatrixTileSender>
Expand All @@ -76,7 +85,8 @@ void herkTrailingDiagTile(pika::execution::thread_priority priority, PanelTileSe
dlaf::internal::whenAllLift(blas::Uplo::Lower, blas::Op::NoTrans, BaseElementType(-1.0),
std::forward<PanelTileSender>(panel_tile), BaseElementType(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::herk(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::herk(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(herkTrailingDiagTile))));
}

template <Backend backend, class PanelTileSender, class ColPanelSender, class MatrixTileSender>
Expand All @@ -90,7 +100,8 @@ void gemmTrailingMatrixTile(pika::execution::thread_priority priority, PanelTile
std::forward<PanelTileSender>(panel_tile),
std::forward<ColPanelSender>(col_panel), ElementType(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::gemm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::gemm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(gemmTrailingMatrixTile))));
}
}

Expand All @@ -101,7 +112,8 @@ void potrfDiagTile(pika::execution::thread_priority priority, MatrixTileSender&&

pika::execution::experimental::start_detached(
dlaf::internal::whenAllLift(blas::Uplo::Upper, std::forward<MatrixTileSender>(matrix_tile)) |
tile::potrf(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::potrf(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(potrfDiagTile))));
}

template <Backend backend, class KKTileSender, class MatrixTileSender>
Expand All @@ -115,7 +127,8 @@ void trsmPanelTile(pika::execution::thread_priority priority, KKTileSender&& kk_
blas::Diag::NonUnit, ElementType(1.0),
std::forward<KKTileSender>(kk_tile),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::trsm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::trsm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(trsmPanelTile))));
}

template <Backend backend, class PanelTileSender, class MatrixTileSender>
Expand All @@ -128,7 +141,8 @@ void herkTrailingDiagTile(pika::execution::thread_priority priority, PanelTileSe
dlaf::internal::whenAllLift(blas::Uplo::Upper, blas::Op::ConjTrans, base_element_type(-1.0),
std::forward<PanelTileSender>(panel_tile), base_element_type(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::herk(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::herk(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(herkTrailingDiagTile))));
}

template <Backend backend, class PanelTileSender, class ColPanelSender, class MatrixTileSender>
Expand All @@ -142,7 +156,8 @@ void gemmTrailingMatrixTile(pika::execution::thread_priority priority, PanelTile
std::forward<PanelTileSender>(panel_tile),
std::forward<ColPanelSender>(col_panel), ElementType(1.0),
std::forward<MatrixTileSender>(matrix_tile)) |
tile::gemm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack)));
tile::gemm(dlaf::internal::Policy<backend>(priority, thread_stacksize::nostack,
ANNOTATE(gemmTrailingMatrixTile))));
}
}

Expand Down
7 changes: 0 additions & 7 deletions include/dlaf/init.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ struct configuration {
std::size_t num_gpu_lapack_handles = 16;
std::size_t umpire_host_memory_pool_initial_bytes = 1 << 30;
std::size_t umpire_device_memory_pool_initial_bytes = 1 << 30;
std::string mpi_pool = "mpi";
};

std::ostream& operator<<(std::ostream& os, const configuration& cfg);
Expand Down Expand Up @@ -98,10 +97,4 @@ struct [[nodiscard]] ScopedInitializer {
ScopedInitializer& operator=(ScopedInitializer&&) = delete;
ScopedInitializer& operator=(const ScopedInitializer&) = delete;
};

/// Initialize the MPI pool.
///
///
void initResourcePartitionerHandler(pika::resource::partitioner& rp,
const pika::program_options::variables_map& vm);
}
3 changes: 2 additions & 1 deletion include/dlaf/schedulers.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
/// @file

#include <pika/execution.hpp>
#include <pika/mpi.hpp>
#include <pika/runtime.hpp>
#include <pika/thread.hpp>

Expand Down Expand Up @@ -50,6 +51,6 @@ auto getBackendScheduler(

inline auto getMPIScheduler() {
return pika::execution::experimental::thread_pool_scheduler{
&pika::resource::get_thread_pool(getConfiguration().mpi_pool)};
&pika::resource::get_thread_pool(pika::mpi::experimental::get_pool_name())};
}
} // namespace dlaf::internal
13 changes: 10 additions & 3 deletions include/dlaf/sender/policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ class Policy {
private:
const pika::execution::thread_priority priority_ = pika::execution::thread_priority::normal;
const pika::execution::thread_stacksize stacksize_ = pika::execution::thread_stacksize::default_;
const char* annotation_ = nullptr;

public:
Policy() = default;
explicit Policy(
pika::execution::thread_priority priority,
pika::execution::thread_stacksize stacksize = pika::execution::thread_stacksize::default_)
: priority_(priority), stacksize_(stacksize) {}
explicit Policy(pika::execution::thread_stacksize stacksize) : stacksize_(stacksize) {}
pika::execution::thread_stacksize stacksize = pika::execution::thread_stacksize::default_,
const char* annotation = nullptr)
: priority_(priority), stacksize_(stacksize), annotation_(annotation) {}
explicit Policy(pika::execution::thread_stacksize stacksize, const char* annotation = nullptr)
: stacksize_(stacksize), annotation_(annotation) {}
Policy(Policy&&) = default;
Policy(const Policy&) = default;
Policy& operator=(Policy&&) = default;
Expand All @@ -45,6 +48,10 @@ class Policy {
pika::execution::thread_stacksize stacksize() const noexcept {
return stacksize_;
}

const char* annotation() const noexcept {
return annotation_;
}
};
}
}
11 changes: 10 additions & 1 deletion include/dlaf/sender/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,19 @@ template <TransformDispatchType Tag = TransformDispatchType::Plain, Backend B =
using pika::execution::experimental::drop_operation_state;
using pika::execution::experimental::then;
using pika::execution::experimental::transfer;
using pika::execution::experimental::with_annotation;

auto scheduler = getBackendScheduler<B>(policy.priority(), policy.stacksize());
auto transfer_sender = transfer(std::forward<Sender>(sender), std::move(scheduler));

using dlaf::common::internal::ConsumeRvalues;
using dlaf::common::internal::Unwrapping;

if constexpr (B == Backend::MC) {
if (policy.annotation()) {
scheduler = with_annotation(scheduler, policy.annotation());
}
auto transfer_sender = transfer(std::forward<Sender>(sender), std::move(scheduler));

return then(std::move(transfer_sender), ConsumeRvalues{Unwrapping{std::forward<F>(f)}}) |
drop_operation_state();
}
Expand All @@ -73,6 +78,10 @@ template <TransformDispatchType Tag = TransformDispatchType::Plain, Backend B =
using pika::cuda::experimental::then_with_cusolver;
using pika::cuda::experimental::then_with_stream;

if (policy.annotation()) {
scheduler = with_annotation(scheduler, policy.annotation());
}
auto transfer_sender = transfer(std::forward<Sender>(sender), std::move(scheduler));
if constexpr (Tag == TransformDispatchType::Plain) {
return then_with_stream(std::move(transfer_sender),
ConsumeRvalues{Unwrapping{std::forward<F>(f)}}) |
Expand Down
106 changes: 60 additions & 46 deletions include/dlaf/sender/transform_mpi.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <type_traits>
#include <utility>

#include <pika/debugging/print.hpp>
#include <pika/execution.hpp>

#include <dlaf/common/consume_rvalues.h>
Expand All @@ -20,9 +21,18 @@
#include <dlaf/communication/communicator_pipeline.h>
#include <dlaf/sender/transform.h>
#include <dlaf/sender/when_all_lift.h>
//
#include <pika/mpi.hpp>
//
#ifdef EXTRA_MPI_TYPES_DEBUGGING
#include <pika/debugging/demangle_helper.hpp>
#endif

namespace dlaf::comm::internal {

template <int Level>
static pika::debug::detail::print_threshold<Level, 0> dla_debug("DLA_MPI");

/// This helper "consumes" a CommunicatorPipelineExclusiveWrapper ensuring that after this call
/// the one passed as argument gets destroyed. All other types left as they are
/// by the second overload.
Expand All @@ -45,17 +55,12 @@ void consumeCommunicatorWrapper(T&) {}
/// least until version 12 fails with an internal compiler error with a trailing
/// decltype for SFINAE. GCC has no problems with a lambda.
template <typename F>
struct MPICallHelper {
struct MPIYieldWhileCallHelper {
std::decay_t<F> f;
template <typename... Ts>
auto operator()(Ts&&... ts) -> decltype(std::move(f)(dlaf::common::internal::unwrap(ts)...,
std::declval<MPI_Request*>())) {
auto operator()(Ts&&... ts) {
namespace mpid = pika::mpi::experimental::detail;
MPI_Request req;
auto is_request_completed = [&req] {
int flag;
MPI_Test(&req, &flag, MPI_STATUS_IGNORE);
return flag == 0;
};

// Note:
// Callables passed to transformMPI have their arguments passed by reference, but doing so
Expand All @@ -71,17 +76,41 @@ struct MPICallHelper {
if constexpr (std::is_void_v<result_type>) {
std::move(f)(dlaf::common::internal::unwrap(ts)..., &req);
(internal::consumeCommunicatorWrapper(ts), ...);
pika::util::yield_while(is_request_completed);
pika::util::yield_while([req]() { return !mpid::poll_request(req); });
}
else {
/*auto r = */ std::move(f)(dlaf::common::internal::unwrap(ts)..., &req);
(internal::consumeCommunicatorWrapper(ts), ...);
pika::util::yield_while([req]() { return !mpid::poll_request(req); });
}
}
};

/// Helper type for wrapping MPI calls.
template <typename F>
struct MPICallHelper {
std::decay_t<F> f;

template <typename... Ts>
auto operator()(Ts&&... ts) -> decltype(std::move(f)(dlaf::common::internal::unwrap(ts)...)) {
using namespace pika::debug::detail;
PIKA_DETAIL_DP(dla_debug<5>, debug(str<>("MPICallHelper"), pika::debug::print_type<Ts...>(", ")));
using result_type = decltype(std::move(f)(dlaf::common::internal::unwrap(ts)...));
if constexpr (std::is_void_v<result_type>) {
std::move(f)(dlaf::common::internal::unwrap(ts)...);
(internal::consumeCommunicatorWrapper(ts), ...);
}
else {
auto r = std::move(f)(dlaf::common::internal::unwrap(ts)..., &req);
auto r = std::move(f)(dlaf::common::internal::unwrap(ts)...);
(internal::consumeCommunicatorWrapper(ts), ...);
pika::util::yield_while(is_request_completed);
return r;
}
}
};

template <typename F>
MPIYieldWhileCallHelper(F&&) -> MPIYieldWhileCallHelper<std::decay_t<F>>;

template <typename F>
MPICallHelper(F&&) -> MPICallHelper<std::decay_t<F>>;

Expand All @@ -90,10 +119,26 @@ template <typename F, typename Sender,
typename = std::enable_if_t<pika::execution::experimental::is_sender_v<Sender>>>
[[nodiscard]] decltype(auto) transformMPI(F&& f, Sender&& sender) {
namespace ex = pika::execution::experimental;

return ex::transfer(std::forward<Sender>(sender), dlaf::internal::getMPIScheduler()) |
ex::then(dlaf::common::internal::ConsumeRvalues{MPICallHelper{std::forward<F>(f)}}) |
ex::drop_operation_state();
namespace mpi = pika::mpi::experimental;
namespace mpid = pika::mpi::experimental::detail;

#ifdef EXTRA_MPI_TYPES_DEBUGGING
auto snd1 =
std::forward<Sender>(sender) |
ex::let_value([=, f = std::move(f)]<typename... LArgs>(LArgs&&... largs) {
PIKA_DETAIL_DP(dla_debug<2>, debug(str<>("Args to MPI fn\n"),
pika::debug::print_type<LArgs...>(", "), "\nValues\n"));
return ex::just(std::move(largs)...) |
mpi::transform_mpi(dlaf::common::internal::ConsumeRvalues{MPICallHelper{std::move(f)}});
});
return ex::make_unique_any_sender(std::move(snd1));
#else
PIKA_DETAIL_DP(dla_debug<5>, debug(str<>("MPI fn\n")));
auto snd1 =
std::forward<Sender>(sender) |
mpi::transform_mpi(dlaf::common::internal::ConsumeRvalues{MPICallHelper{std::forward<F>(f)}});
return ex::make_unique_any_sender(std::move(snd1));
#endif
}

/// Fire-and-forget transformMPI. This submits the work and returns void.
Expand Down Expand Up @@ -148,29 +193,6 @@ class PartialTransformMPI : private PartialTransformMPIBase<F> {
template <typename F>
PartialTransformMPI(F&& f) -> PartialTransformMPI<std::decay_t<F>>;

/// A partially applied transformMPIDetach, with the callable object given, but
/// the predecessor sender missing. The predecessor sender is applied when
/// calling the operator| overload.
template <typename F>
class PartialTransformMPIDetach : private PartialTransformMPIBase<F> {
public:
template <typename F_>
PartialTransformMPIDetach(F_&& f) : PartialTransformMPIBase<F>{std::forward<F_>(f)} {}
PartialTransformMPIDetach(PartialTransformMPIDetach&&) = default;
PartialTransformMPIDetach(const PartialTransformMPIDetach&) = default;
PartialTransformMPIDetach& operator=(PartialTransformMPIDetach&&) = default;
PartialTransformMPIDetach& operator=(const PartialTransformMPIDetach&) = default;

template <typename Sender>
friend auto operator|(Sender&& sender, PartialTransformMPIDetach pa) {
return pika::execution::experimental::start_detached(transformMPI(std::move(pa.f_),
std::forward<Sender>(sender)));
}
};

template <typename F>
PartialTransformMPIDetach(F&& f) -> PartialTransformMPIDetach<std::decay_t<F>>;

/// \overload transformMPI
///
/// This overload partially applies the MPI transform for later use with
Expand All @@ -180,12 +202,4 @@ template <typename F>
return PartialTransformMPI{std::forward<F>(f)};
}

/// \overload transformMPIDetach
///
/// This overload partially applies transformMPIDetach for later use with
/// operator| with a sender on the left-hand side.
template <typename F>
[[nodiscard]] decltype(auto) transformMPIDetach(F&& f) {
return PartialTransformMPIDetach{std::forward<F>(f)};
}
}
1 change: 0 additions & 1 deletion miniapp/miniapp_band_to_tridiag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,5 @@ int main(int argc, char** argv) {

pika::init_params p;
p.desc_cmdline = desc_commandline;
p.rp_callback = dlaf::initResourcePartitionerHandler;
return pika::init(pika_main, argc, argv, p);
}
1 change: 0 additions & 1 deletion miniapp/miniapp_bt_band_to_tridiag.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,5 @@ int main(int argc, char** argv) {

pika::init_params p;
p.desc_cmdline = desc_commandline;
p.rp_callback = dlaf::initResourcePartitionerHandler;
return pika::init(pika_main, argc, argv, p);
}
1 change: 0 additions & 1 deletion miniapp/miniapp_bt_reduction_to_band.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,5 @@ int main(int argc, char** argv) {

pika::init_params p;
p.desc_cmdline = desc_commandline;
p.rp_callback = dlaf::initResourcePartitionerHandler;
return pika::init(pika_main, argc, argv, p);
}
1 change: 0 additions & 1 deletion miniapp/miniapp_cholesky.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@ int main(int argc, char** argv) {

pika::init_params p;
p.desc_cmdline = desc_commandline;
p.rp_callback = dlaf::initResourcePartitionerHandler;
return pika::init(pika_main, argc, argv, p);
}

Expand Down
1 change: 0 additions & 1 deletion miniapp/miniapp_communication.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -606,6 +606,5 @@ int main(int argc, char** argv) {

pika::init_params p;
p.desc_cmdline = desc_commandline;
p.rp_callback = dlaf::initResourcePartitionerHandler;
return pika::init(pika_main, argc, argv, p);
}
Loading