diff --git a/src/core.mk b/src/core.mk index 7025280a2..38639dcb1 100644 --- a/src/core.mk +++ b/src/core.mk @@ -44,11 +44,13 @@ GEN_CPU_SRC += core/comm/alltoall_thread_mpi.cc \ core/comm/alltoallv_thread_mpi.cc \ core/comm/gather_thread_mpi.cc \ core/comm/allgather_thread_mpi.cc \ - core/comm/bcast_thread_mpi.cc + core/comm/bcast_thread_mpi.cc \ + core/comm/p2p_thread_mpi.cc else GEN_CPU_SRC += core/comm/alltoall_thread_local.cc \ core/comm/alltoallv_thread_local.cc \ - core/comm/allgather_thread_local.cc + core/comm/allgather_thread_local.cc \ + core/comm/p2p_thread_local.cc endif # Source files for GPUs diff --git a/src/core/comm/alltoallv_thread_mpi.cc b/src/core/comm/alltoallv_thread_mpi.cc index 1bcc8806e..18ba483af 100644 --- a/src/core/comm/alltoallv_thread_mpi.cc +++ b/src/core/comm/alltoallv_thread_mpi.cc @@ -68,18 +68,20 @@ int alltoallvMPI(const void* sendbuf, int recv_tag = generateAlltoallvTag(global_rank, recvfrom_global_rank, global_comm); #ifdef DEBUG_LEGATE log_coll.debug( - "AlltoallvMPI i: %d === global_rank %d, mpi rank %d, send to %d (%d), send_tag %d, " - "recv from %d (%d), " - "recv_tag %d", + "AlltoallvMPI i: %d === global_rank %d, mpi rank %d, " + "send to %d (%d), send_tag %d, count %d, " + "recv from %d (%d), recv_tag %d, count %d", i, global_rank, global_comm->mpi_rank, sendto_global_rank, sendto_mpi_rank, send_tag, + scount, recvfrom_global_rank, recvfrom_mpi_rank, - recv_tag); + recv_tag, + rcount); #endif CHECK_MPI(MPI_Sendrecv(src, scount, diff --git a/src/core/comm/coll.cc b/src/core/comm/coll.cc index 2f8472657..49f8c3fa6 100644 --- a/src/core/comm/coll.cc +++ b/src/core/comm/coll.cc @@ -45,6 +45,7 @@ enum CollTag : int { GATHER_TAG = 1, ALLTOALL_TAG = 2, ALLTOALLV_TAG = 3, + P2P_TAG = 4, MAX_TAG = 10, }; @@ -109,10 +110,15 @@ int collCommCreate(CollComm global_comm, (const void**)malloc(sizeof(void*) * global_comm_size); thread_comms[global_comm->unique_id]->displs = (const int**)malloc(sizeof(int*) * global_comm_size); + thread_comms[global_comm->unique_id]->buffer_ready = + (int*)malloc(sizeof(int*) * global_comm_size * global_comm_size); for (int i = 0; i < global_comm_size; i++) { thread_comms[global_comm->unique_id]->buffers[i] = nullptr; thread_comms[global_comm->unique_id]->displs[i] = nullptr; } + for (int i = 0; i < global_comm_size * global_comm_size; i++) { + thread_comms[global_comm->unique_id]->buffer_ready[i] = 0; + } __sync_synchronize(); thread_comms[global_comm->unique_id]->ready_flag = true; } @@ -124,6 +130,7 @@ int collCommCreate(CollComm global_comm, assert(global_comm->comm->ready_flag == true); assert(global_comm->comm->buffers != nullptr); assert(global_comm->comm->displs != nullptr); + assert(global_comm->comm->buffer_ready != nullptr); global_comm->nb_threads = global_comm->global_comm_size; #endif return CollSuccess; @@ -148,6 +155,8 @@ int collCommDestroy(CollComm global_comm) thread_comms[global_comm->unique_id]->buffers = nullptr; free(thread_comms[global_comm->unique_id]->displs); thread_comms[global_comm->unique_id]->displs = nullptr; + free(thread_comms[global_comm->unique_id]->buffer_ready); + thread_comms[global_comm->unique_id]->buffer_ready = nullptr; __sync_synchronize(); thread_comms[global_comm->unique_id]->ready_flag = false; } @@ -237,6 +246,57 @@ int collAllgather( #endif } +int collSend( + const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm) +{ + if (dest == global_comm->global_rank) { + log_coll.error("Do not support sending to self"); + LEGATE_ABORT; + } + log_coll.debug( + "Send: global_rank %d, mpi_rank %d, unique_id %d, comm_size %d, " + "mpi_comm_size %d %d, nb_threads %d, dst %d, tag %d", + global_comm->global_rank, + global_comm->mpi_rank, + global_comm->unique_id, + global_comm->global_comm_size, + global_comm->mpi_comm_size, + global_comm->mpi_comm_size_actual, + global_comm->nb_threads, + dest, + tag); +#ifdef LEGATE_USE_GASNET + return sendMPI(sendbuf, count, type, dest, tag, global_comm); +#else + return sendLocal(sendbuf, count, type, dest, tag, global_comm); +#endif +} + +int collRecv(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm) +{ + if (source == global_comm->global_rank) { + log_coll.error("Do not support receiving to self"); + LEGATE_ABORT; + } + log_coll.debug( + "Recv: global_rank %d, mpi_rank %d, unique_id %d, comm_size %d, " + "mpi_comm_size %d %d, nb_threads %d, src %d, tag %d", + global_comm->global_rank, + global_comm->mpi_rank, + global_comm->unique_id, + global_comm->global_comm_size, + global_comm->mpi_comm_size, + global_comm->mpi_comm_size_actual, + global_comm->nb_threads, + source, + tag); +#ifdef LEGATE_USE_GASNET + return recvMPI(recvbuf, count, type, source, tag, global_comm); +#else + return recvLocal(recvbuf, count, type, source, tag, global_comm); +#endif +} + // called from main thread int collInit(int argc, char* argv[]) { @@ -451,6 +511,13 @@ int generateGatherTag(int rank, CollComm global_comm) return tag; } +int generateP2PTag(int user_tag) +{ + int tag = user_tag * CollTag::MAX_TAG + CollTag::P2P_TAG; + assert(tag <= mpi_tag_ub && tag > 0); + return tag; +} + #else // undef LEGATE_USE_GASNET size_t getDtypeSize(CollDataType dtype) { diff --git a/src/core/comm/coll.h b/src/core/comm/coll.h index 406887752..c4adee3e6 100644 --- a/src/core/comm/coll.h +++ b/src/core/comm/coll.h @@ -58,6 +58,7 @@ struct ThreadComm { bool ready_flag; const void** buffers; const int** displs; + int* buffer_ready; // use for p2p with size = comm_size*comm_size }; #endif @@ -120,6 +121,12 @@ int collAlltoall( int collAllgather( const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm); +int collSend( + const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm); + +int collRecv( + void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm); + int collInit(int argc, char* argv[]); int collFinalize(); @@ -150,6 +157,11 @@ int allgatherMPI( int bcastMPI(void* buf, int count, CollDataType type, int root, CollComm global_comm); +int sendMPI( + const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm); + +int recvMPI(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm); + MPI_Datatype dtypeToMPIDtype(CollDataType dtype); int generateAlltoallTag(int rank1, int rank2, CollComm global_comm); @@ -159,6 +171,8 @@ int generateAlltoallvTag(int rank1, int rank2, CollComm global_comm); int generateBcastTag(int rank, CollComm global_comm); int generateGatherTag(int rank, CollComm global_comm); + +int generateP2PTag(int user_tag); #else size_t getDtypeSize(CollDataType dtype); @@ -177,6 +191,12 @@ int alltoallLocal( int allgatherLocal( const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm); +int sendLocal( + const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm); + +int recvLocal( + void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm); + void resetLocalBuffer(CollComm global_comm); void barrierLocal(CollComm global_comm); diff --git a/src/core/comm/p2p_thread_local.cc b/src/core/comm/p2p_thread_local.cc new file mode 100644 index 000000000..dff43ad64 --- /dev/null +++ b/src/core/comm/p2p_thread_local.cc @@ -0,0 +1,102 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include + +#include "coll.h" +#include "legion.h" + +namespace legate { +namespace comm { +namespace coll { + +using namespace Legion; +extern Logger log_coll; + +enum P2PTag : int { + INIT = 0, + SEND_BUFFER_READY = 1, + SEND_CP_DONE = 2, + SEND_BUFFER_RELEASE = 3, +}; + +int sendLocal( + const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm) +{ + int total_size = global_comm->global_comm_size; + int global_rank = global_comm->global_rank; + + int type_extent = getDtypeSize(type); + + int key = global_rank * total_size + dest; + global_comm->comm->buffers[global_rank] = sendbuf; + global_comm->comm->buffer_ready[key] = P2PTag::SEND_BUFFER_READY; + __sync_synchronize(); + + // wait for dest to copy + while (global_comm->comm->buffer_ready[key] != P2PTag::SEND_CP_DONE) + ; + __sync_synchronize(); + + // remote thread done with the copy, let reset buffer + resetLocalBuffer(global_comm); + global_comm->comm->buffer_ready[key] = P2PTag::SEND_BUFFER_RELEASE; + __sync_synchronize(); + + // wait for dest to reset flag to init + while (global_comm->comm->buffer_ready[key] != P2PTag::INIT) + ; + + return CollSuccess; +} + +int recvLocal( + void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm) +{ + int total_size = global_comm->global_comm_size; + int global_rank = global_comm->global_rank; + + int type_extent = getDtypeSize(type); + + // wait for source to put the buffer + int key = source * total_size + global_rank; + while (global_comm->comm->buffer_ready[key] == P2PTag::INIT || + global_comm->comm->buffers[source] == nullptr) + ; + __sync_synchronize(); + + // start memcpy + memcpy(recvbuf, global_comm->comm->buffers[source], count * type_extent); + __sync_synchronize(); + global_comm->comm->buffer_ready[key] = P2PTag::SEND_CP_DONE; + + // wait for source to reset the buffer + while (global_comm->comm->buffer_ready[key] != P2PTag::SEND_BUFFER_RELEASE || + global_comm->comm->buffers[source] != nullptr) + ; + __sync_synchronize(); + + global_comm->comm->buffer_ready[key] = P2PTag::INIT; + + return CollSuccess; +} + +} // namespace coll +} // namespace comm +} // namespace legate \ No newline at end of file diff --git a/src/core/comm/p2p_thread_mpi.cc b/src/core/comm/p2p_thread_mpi.cc new file mode 100644 index 000000000..5af175927 --- /dev/null +++ b/src/core/comm/p2p_thread_mpi.cc @@ -0,0 +1,77 @@ +/* Copyright 2022 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include +#include +#include +#include +#include + +#include "coll.h" +#include "legion.h" + +namespace legate { +namespace comm { +namespace coll { + +using namespace Legion; +extern Logger log_coll; + +int sendMPI( + const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm) +{ + MPI_Datatype mpi_type = dtypeToMPIDtype(type); + + int dest_mpi_rank = global_comm->mapping_table.mpi_rank[dest]; + int send_tag = generateP2PTag(tag); +#ifdef DEBUG_LEGATE + log_coll.debug("sendMPI global_rank %d, mpi rank %d, send to %d (%d), send_tag %d", + global_comm->global_rank, + global_comm->mpi_rank, + dest, + dest_mpi_rank, + send_tag); +#endif + CHECK_MPI(MPI_Send(sendbuf, count, mpi_type, dest_mpi_rank, send_tag, global_comm->comm)); + + return CollSuccess; +} + +int recvMPI(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm) +{ + MPI_Status status; + + MPI_Datatype mpi_type = dtypeToMPIDtype(type); + + int source_mpi_rank = global_comm->mapping_table.mpi_rank[source]; + int recv_tag = generateP2PTag(tag); +#ifdef DEBUG_LEGATE + log_coll.debug("recvMPI global_rank %d, mpi rank %d, recv from %d (%d), recv_tag %d", + global_comm->global_rank, + global_comm->mpi_rank, + source, + source_mpi_rank, + recv_tag); +#endif + CHECK_MPI( + MPI_Recv(recvbuf, count, mpi_type, source_mpi_rank, recv_tag, global_comm->comm, &status)); + + return CollSuccess; +} + +} // namespace coll +} // namespace comm +} // namespace legate \ No newline at end of file