Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add p2p for cpu communicator #289

Open
wants to merge 4 commits into
base: branch-24.03
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/core.mk
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,13 @@ GEN_CPU_SRC += core/comm/alltoall_thread_mpi.cc \
core/comm/alltoallv_thread_mpi.cc \
core/comm/gather_thread_mpi.cc \
core/comm/allgather_thread_mpi.cc \
core/comm/bcast_thread_mpi.cc
core/comm/bcast_thread_mpi.cc \
core/comm/p2p_thread_mpi.cc
else
GEN_CPU_SRC += core/comm/alltoall_thread_local.cc \
core/comm/alltoallv_thread_local.cc \
core/comm/allgather_thread_local.cc
core/comm/allgather_thread_local.cc \
core/comm/p2p_thread_local.cc
endif

# Source files for GPUs
Expand Down
10 changes: 6 additions & 4 deletions src/core/comm/alltoallv_thread_mpi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,20 @@ int alltoallvMPI(const void* sendbuf,
int recv_tag = generateAlltoallvTag(global_rank, recvfrom_global_rank, global_comm);
#ifdef DEBUG_LEGATE
log_coll.debug(
"AlltoallvMPI i: %d === global_rank %d, mpi rank %d, send to %d (%d), send_tag %d, "
"recv from %d (%d), "
"recv_tag %d",
"AlltoallvMPI i: %d === global_rank %d, mpi rank %d, "
"send to %d (%d), send_tag %d, count %d, "
"recv from %d (%d), recv_tag %d, count %d",
i,
global_rank,
global_comm->mpi_rank,
sendto_global_rank,
sendto_mpi_rank,
send_tag,
scount,
recvfrom_global_rank,
recvfrom_mpi_rank,
recv_tag);
recv_tag,
rcount);
#endif
CHECK_MPI(MPI_Sendrecv(src,
scount,
Expand Down
67 changes: 67 additions & 0 deletions src/core/comm/coll.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ enum CollTag : int {
GATHER_TAG = 1,
ALLTOALL_TAG = 2,
ALLTOALLV_TAG = 3,
P2P_TAG = 4,
MAX_TAG = 10,
};

Expand Down Expand Up @@ -109,10 +110,15 @@ int collCommCreate(CollComm global_comm,
(const void**)malloc(sizeof(void*) * global_comm_size);
thread_comms[global_comm->unique_id]->displs =
(const int**)malloc(sizeof(int*) * global_comm_size);
thread_comms[global_comm->unique_id]->buffer_ready =
(int*)malloc(sizeof(int*) * global_comm_size * global_comm_size);
for (int i = 0; i < global_comm_size; i++) {
thread_comms[global_comm->unique_id]->buffers[i] = nullptr;
thread_comms[global_comm->unique_id]->displs[i] = nullptr;
}
for (int i = 0; i < global_comm_size * global_comm_size; i++) {
thread_comms[global_comm->unique_id]->buffer_ready[i] = 0;
}
__sync_synchronize();
thread_comms[global_comm->unique_id]->ready_flag = true;
}
Expand All @@ -124,6 +130,7 @@ int collCommCreate(CollComm global_comm,
assert(global_comm->comm->ready_flag == true);
assert(global_comm->comm->buffers != nullptr);
assert(global_comm->comm->displs != nullptr);
assert(global_comm->comm->buffer_ready != nullptr);
global_comm->nb_threads = global_comm->global_comm_size;
#endif
return CollSuccess;
Expand All @@ -148,6 +155,8 @@ int collCommDestroy(CollComm global_comm)
thread_comms[global_comm->unique_id]->buffers = nullptr;
free(thread_comms[global_comm->unique_id]->displs);
thread_comms[global_comm->unique_id]->displs = nullptr;
free(thread_comms[global_comm->unique_id]->buffer_ready);
thread_comms[global_comm->unique_id]->buffer_ready = nullptr;
__sync_synchronize();
thread_comms[global_comm->unique_id]->ready_flag = false;
}
Expand Down Expand Up @@ -237,6 +246,57 @@ int collAllgather(
#endif
}

int collSend(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm)
{
if (dest == global_comm->global_rank) {
log_coll.error("Do not support sending to self");
LEGATE_ABORT;
}
log_coll.debug(
"Send: global_rank %d, mpi_rank %d, unique_id %d, comm_size %d, "
"mpi_comm_size %d %d, nb_threads %d, dst %d, tag %d",
global_comm->global_rank,
global_comm->mpi_rank,
global_comm->unique_id,
global_comm->global_comm_size,
global_comm->mpi_comm_size,
global_comm->mpi_comm_size_actual,
global_comm->nb_threads,
dest,
tag);
#ifdef LEGATE_USE_GASNET
return sendMPI(sendbuf, count, type, dest, tag, global_comm);
#else
return sendLocal(sendbuf, count, type, dest, tag, global_comm);
#endif
}

int collRecv(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm)
{
if (source == global_comm->global_rank) {
log_coll.error("Do not support receiving to self");
LEGATE_ABORT;
}
log_coll.debug(
"Recv: global_rank %d, mpi_rank %d, unique_id %d, comm_size %d, "
"mpi_comm_size %d %d, nb_threads %d, src %d, tag %d",
global_comm->global_rank,
global_comm->mpi_rank,
global_comm->unique_id,
global_comm->global_comm_size,
global_comm->mpi_comm_size,
global_comm->mpi_comm_size_actual,
global_comm->nb_threads,
source,
tag);
#ifdef LEGATE_USE_GASNET
return recvMPI(recvbuf, count, type, source, tag, global_comm);
#else
return recvLocal(recvbuf, count, type, source, tag, global_comm);
#endif
}

// called from main thread
int collInit(int argc, char* argv[])
{
Expand Down Expand Up @@ -451,6 +511,13 @@ int generateGatherTag(int rank, CollComm global_comm)
return tag;
}

int generateP2PTag(int user_tag)
{
int tag = user_tag * CollTag::MAX_TAG + CollTag::P2P_TAG;
assert(tag <= mpi_tag_ub && tag > 0);
return tag;
}

#else // undef LEGATE_USE_GASNET
size_t getDtypeSize(CollDataType dtype)
{
Expand Down
20 changes: 20 additions & 0 deletions src/core/comm/coll.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ struct ThreadComm {
bool ready_flag;
const void** buffers;
const int** displs;
int* buffer_ready; // use for p2p with size = comm_size*comm_size
};
#endif

Expand Down Expand Up @@ -120,6 +121,12 @@ int collAlltoall(
int collAllgather(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);

int collSend(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm);

int collRecv(
void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm);

int collInit(int argc, char* argv[]);

int collFinalize();
Expand Down Expand Up @@ -150,6 +157,11 @@ int allgatherMPI(

int bcastMPI(void* buf, int count, CollDataType type, int root, CollComm global_comm);

int sendMPI(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm);

int recvMPI(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm);

MPI_Datatype dtypeToMPIDtype(CollDataType dtype);

int generateAlltoallTag(int rank1, int rank2, CollComm global_comm);
Expand All @@ -159,6 +171,8 @@ int generateAlltoallvTag(int rank1, int rank2, CollComm global_comm);
int generateBcastTag(int rank, CollComm global_comm);

int generateGatherTag(int rank, CollComm global_comm);

int generateP2PTag(int user_tag);
#else
size_t getDtypeSize(CollDataType dtype);

Expand All @@ -177,6 +191,12 @@ int alltoallLocal(
int allgatherLocal(
const void* sendbuf, void* recvbuf, int count, CollDataType type, CollComm global_comm);

int sendLocal(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm);

int recvLocal(
void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm);

void resetLocalBuffer(CollComm global_comm);

void barrierLocal(CollComm global_comm);
Expand Down
102 changes: 102 additions & 0 deletions src/core/comm/p2p_thread_local.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
/* Copyright 2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "coll.h"
#include "legion.h"

namespace legate {
namespace comm {
namespace coll {

using namespace Legion;
extern Logger log_coll;

enum P2PTag : int {
INIT = 0,
SEND_BUFFER_READY = 1,
SEND_CP_DONE = 2,
SEND_BUFFER_RELEASE = 3,
};

int sendLocal(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm)
{
int total_size = global_comm->global_comm_size;
int global_rank = global_comm->global_rank;

int type_extent = getDtypeSize(type);

int key = global_rank * total_size + dest;
global_comm->comm->buffers[global_rank] = sendbuf;
global_comm->comm->buffer_ready[key] = P2PTag::SEND_BUFFER_READY;
__sync_synchronize();

// wait for dest to copy
while (global_comm->comm->buffer_ready[key] != P2PTag::SEND_CP_DONE)
;
__sync_synchronize();

// remote thread done with the copy, let reset buffer
resetLocalBuffer(global_comm);
global_comm->comm->buffer_ready[key] = P2PTag::SEND_BUFFER_RELEASE;
__sync_synchronize();

// wait for dest to reset flag to init
while (global_comm->comm->buffer_ready[key] != P2PTag::INIT)
;

return CollSuccess;
}

int recvLocal(
void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm)
{
int total_size = global_comm->global_comm_size;
int global_rank = global_comm->global_rank;

int type_extent = getDtypeSize(type);

// wait for source to put the buffer
int key = source * total_size + global_rank;
while (global_comm->comm->buffer_ready[key] == P2PTag::INIT ||
global_comm->comm->buffers[source] == nullptr)
;
__sync_synchronize();

// start memcpy
memcpy(recvbuf, global_comm->comm->buffers[source], count * type_extent);
__sync_synchronize();
global_comm->comm->buffer_ready[key] = P2PTag::SEND_CP_DONE;

// wait for source to reset the buffer
while (global_comm->comm->buffer_ready[key] != P2PTag::SEND_BUFFER_RELEASE ||
global_comm->comm->buffers[source] != nullptr)
;
__sync_synchronize();

global_comm->comm->buffer_ready[key] = P2PTag::INIT;

return CollSuccess;
}

} // namespace coll
} // namespace comm
} // namespace legate
77 changes: 77 additions & 0 deletions src/core/comm/p2p_thread_mpi.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
/* Copyright 2022 NVIDIA Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
*/

#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <algorithm>

#include "coll.h"
#include "legion.h"

namespace legate {
namespace comm {
namespace coll {

using namespace Legion;
extern Logger log_coll;

int sendMPI(
const void* sendbuf, int count, CollDataType type, int dest, int tag, CollComm global_comm)
{
MPI_Datatype mpi_type = dtypeToMPIDtype(type);

int dest_mpi_rank = global_comm->mapping_table.mpi_rank[dest];
int send_tag = generateP2PTag(tag);
#ifdef DEBUG_LEGATE
log_coll.debug("sendMPI global_rank %d, mpi rank %d, send to %d (%d), send_tag %d",
global_comm->global_rank,
global_comm->mpi_rank,
dest,
dest_mpi_rank,
send_tag);
#endif
CHECK_MPI(MPI_Send(sendbuf, count, mpi_type, dest_mpi_rank, send_tag, global_comm->comm));

return CollSuccess;
}

int recvMPI(void* recvbuf, int count, CollDataType type, int source, int tag, CollComm global_comm)
{
MPI_Status status;

MPI_Datatype mpi_type = dtypeToMPIDtype(type);

int source_mpi_rank = global_comm->mapping_table.mpi_rank[source];
int recv_tag = generateP2PTag(tag);
#ifdef DEBUG_LEGATE
log_coll.debug("recvMPI global_rank %d, mpi rank %d, recv from %d (%d), recv_tag %d",
global_comm->global_rank,
global_comm->mpi_rank,
source,
source_mpi_rank,
recv_tag);
#endif
CHECK_MPI(
MPI_Recv(recvbuf, count, mpi_type, source_mpi_rank, recv_tag, global_comm->comm, &status));

return CollSuccess;
}

} // namespace coll
} // namespace comm
} // namespace legate