diff --git a/src/ArborX_DistributedSearchTree.hpp b/src/ArborX_DistributedSearchTree.hpp index de2744be6..130ee6d73 100644 --- a/src/ArborX_DistributedSearchTree.hpp +++ b/src/ArborX_DistributedSearchTree.hpp @@ -19,6 +19,8 @@ #include +#include + #include namespace ArborX @@ -42,8 +44,6 @@ class DistributedSearchTree DistributedSearchTree(MPI_Comm comm, ExecutionSpace const &space, Primitives const &primitives); - ~DistributedSearchTree() { MPI_Comm_free(&_comm); } - /** Returns the smallest axis-aligned box able to contain all the objects * stored in the tree or an invalid box if the tree is empty. */ @@ -101,7 +101,8 @@ class DistributedSearchTree private: template friend struct Details::DistributedSearchTreeImpl; - MPI_Comm _comm; + MPI_Comm getComm() const { return *_comm_ptr; } + std::shared_ptr _comm_ptr; BVH _top_tree; // replicated BVH _bottom_tree; // local size_type _top_tree_size; @@ -118,12 +119,24 @@ DistributedSearchTree::DistributedSearchTree( // Create new context for the library to isolate library's communication from // user's - MPI_Comm_dup(comm, &_comm); + _comm_ptr.reset( + // duplicate the communicator and store it in a std::shared_ptr so that + // all copies of the distributed tree point to the same object + [comm]() { + auto p = std::make_unique(); + MPI_Comm_dup(comm, p.get()); + return p.release(); + }(), + // custom deleter to mark the communicator object for deallocation + [](MPI_Comm *p) { + MPI_Comm_free(p); + delete p; + }); int comm_rank; - MPI_Comm_rank(_comm, &comm_rank); + MPI_Comm_rank(getComm(), &comm_rank); int comm_size; - MPI_Comm_size(_comm, &comm_size); + MPI_Comm_size(getComm(), &comm_size); Kokkos::View boxes( Kokkos::ViewAllocateWithoutInitializing("rank_bounding_boxes"), @@ -134,7 +147,7 @@ DistributedSearchTree::DistributedSearchTree( boxes_host(comm_rank) = _bottom_tree.bounds(); MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, static_cast(boxes_host.data()), sizeof(Box), MPI_BYTE, - _comm); + getComm()); Kokkos::deep_copy(space, boxes, boxes_host); _top_tree = BVH{space, boxes}; @@ -146,7 +159,7 @@ DistributedSearchTree::DistributedSearchTree( bottom_tree_sizes_host(comm_rank) = _bottom_tree.size(); MPI_Allgather(MPI_IN_PLACE, 0, MPI_DATATYPE_NULL, static_cast(bottom_tree_sizes_host.data()), - sizeof(size_type), MPI_BYTE, _comm); + sizeof(size_type), MPI_BYTE, getComm()); Kokkos::deep_copy(space, _bottom_tree_sizes, bottom_tree_sizes_host); _top_tree_size = accumulate(space, _bottom_tree_sizes, 0); diff --git a/src/details/ArborX_DetailsDistributedSearchTreeImpl.hpp b/src/details/ArborX_DetailsDistributedSearchTreeImpl.hpp index 8147c1d07..13d5932ac 100644 --- a/src/details/ArborX_DetailsDistributedSearchTreeImpl.hpp +++ b/src/details/ArborX_DetailsDistributedSearchTreeImpl.hpp @@ -54,7 +54,7 @@ struct DistributedSearchTreeImpl IndicesAndRanks &values, Offset &offset) { int comm_rank; - MPI_Comm_rank(tree._comm, &comm_rank); + MPI_Comm_rank(tree.getComm(), &comm_rank); queryDispatch(SpatialPredicateTag{}, tree, space, queries, CallbackDefaultSpatialPredicateWithRank{comm_rank}, values, offset); @@ -379,7 +379,7 @@ DistributedSearchTreeImpl::queryDispatchImpl( Offset &offset, Ranks &ranks, Distances *distances_ptr) { auto const &bottom_tree = tree._bottom_tree; - auto comm = tree._comm; + auto comm = tree.getComm(); Distances distances("distances", 0); if (distances_ptr) @@ -452,7 +452,7 @@ DistributedSearchTreeImpl::queryDispatch( { auto const &top_tree = tree._top_tree; auto const &bottom_tree = tree._bottom_tree; - auto comm = tree._comm; + auto comm = tree.getComm(); Kokkos::View indices("indices", 0); Kokkos::View ranks("ranks", 0);