Skip to content

Commit

Permalink
CarpetX: Correct interpolation on multiple MPI processes
Browse files Browse the repository at this point in the history
  • Loading branch information
eschnett committed Jul 28, 2023
1 parent 92bbcf2 commit 2ecc871
Showing 1 changed file with 20 additions and 11 deletions.
31 changes: 20 additions & 11 deletions CarpetX/src/interpolate.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -798,33 +798,42 @@ extern "C" void CarpetX_Interpolate(const CCTK_POINTER_TO_CONST cctkGH_,
const MPI_Comm comm = amrex::ParallelDescriptor::Communicator();
const MPI_Datatype datatype = mpi_datatype<CCTK_REAL>::value;

// int total_npoints;
// MPI_Allreduce(&npoints, &total_npoints, 1, MPI_INT, MPI_SUM, comm);

std::vector<int> sendcounts(nprocs);
std::vector<int> senddispls(nprocs);
int sendcount = 0;
int total_sendcount = 0;
for (int p = 0; p < nprocs; ++p) {
const auto &result = results.at(p);
sendcounts.at(p) = result.size();
senddispls.at(p) = sendcount;
sendcount += sendcounts.at(p);
senddispls.at(p) = total_sendcount;
total_sendcount += sendcounts.at(p);
}
assert(sendcount == (nvars + 1) * npoints);
std::vector<int> recvcounts(nprocs);
MPI_Alltoall(sendcounts.data(), 1, MPI_INT, recvcounts.data(), 1, MPI_INT,
comm);
std::vector<int> recvdispls(nprocs);
int recvcount = 0;
int total_recvcount = 0;
for (int p = 0; p < nprocs; ++p) {
recvdispls.at(p) = recvcount;
recvcount += recvcounts.at(p);
recvdispls.at(p) = total_recvcount;
total_recvcount += recvcounts.at(p);
}
assert(recvcount == (nvars + 1) * npoints);

std::vector<CCTK_REAL> sendbuf(sendcount);
std::vector<CCTK_REAL> sendbuf(total_sendcount);
for (int p = 0; p < nprocs; ++p) {
// TODO: Don't copy, store data here right away
assert(p >= 0);
assert(p < int(results.size()));
const auto &result = results.at(p);
copy(result.begin(), result.end(), &sendbuf.at(senddispls.at(p)));
assert(sendcounts.at(p) == result.size());
assert(p >= 0);
assert(p < int(senddispls.size()));
assert(senddispls.at(p) >= 0);
assert(senddispls.at(p) + sendcounts.at(p) <= int(sendbuf.size()));
std::copy(result.begin(), result.end(), sendbuf.data() + senddispls.at(p));
}
std::vector<CCTK_REAL> recvbuf(recvcount);
std::vector<CCTK_REAL> recvbuf(total_recvcount);
MPI_Alltoallv(sendbuf.data(), sendcounts.data(), senddispls.data(), datatype,
recvbuf.data(), recvcounts.data(), recvdispls.data(), datatype,
comm);
Expand Down

0 comments on commit 2ecc871

Please sign in to comment.