From 2ecc8716934f4082ec2275fe2f3dba1d16648cc5 Mon Sep 17 00:00:00 2001 From: Erik Schnetter Date: Fri, 28 Jul 2023 15:00:29 -0400 Subject: [PATCH] CarpetX: Correct interpolation on multiple MPI processes --- CarpetX/src/interpolate.cxx | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/CarpetX/src/interpolate.cxx b/CarpetX/src/interpolate.cxx index f75446382..ba04f1d18 100644 --- a/CarpetX/src/interpolate.cxx +++ b/CarpetX/src/interpolate.cxx @@ -798,33 +798,42 @@ extern "C" void CarpetX_Interpolate(const CCTK_POINTER_TO_CONST cctkGH_, const MPI_Comm comm = amrex::ParallelDescriptor::Communicator(); const MPI_Datatype datatype = mpi_datatype::value; + // int total_npoints; + // MPI_Allreduce(&npoints, &total_npoints, 1, MPI_INT, MPI_SUM, comm); + std::vector sendcounts(nprocs); std::vector senddispls(nprocs); - int sendcount = 0; + int total_sendcount = 0; for (int p = 0; p < nprocs; ++p) { const auto &result = results.at(p); sendcounts.at(p) = result.size(); - senddispls.at(p) = sendcount; - sendcount += sendcounts.at(p); + senddispls.at(p) = total_sendcount; + total_sendcount += sendcounts.at(p); } - assert(sendcount == (nvars + 1) * npoints); std::vector recvcounts(nprocs); MPI_Alltoall(sendcounts.data(), 1, MPI_INT, recvcounts.data(), 1, MPI_INT, comm); std::vector recvdispls(nprocs); - int recvcount = 0; + int total_recvcount = 0; for (int p = 0; p < nprocs; ++p) { - recvdispls.at(p) = recvcount; - recvcount += recvcounts.at(p); + recvdispls.at(p) = total_recvcount; + total_recvcount += recvcounts.at(p); } - assert(recvcount == (nvars + 1) * npoints); - std::vector sendbuf(sendcount); + std::vector sendbuf(total_sendcount); for (int p = 0; p < nprocs; ++p) { + // TODO: Don't copy, store data here right away + assert(p >= 0); + assert(p < int(results.size())); const auto &result = results.at(p); - copy(result.begin(), result.end(), &sendbuf.at(senddispls.at(p))); + assert(sendcounts.at(p) == result.size()); + assert(p >= 0); + assert(p < int(senddispls.size())); + assert(senddispls.at(p) >= 0); + assert(senddispls.at(p) + sendcounts.at(p) <= int(sendbuf.size())); + std::copy(result.begin(), result.end(), sendbuf.data() + senddispls.at(p)); } - std::vector recvbuf(recvcount); + std::vector recvbuf(total_recvcount); MPI_Alltoallv(sendbuf.data(), sendcounts.data(), senddispls.data(), datatype, recvbuf.data(), recvcounts.data(), recvdispls.data(), datatype, comm);