Skip to content

Commit

Permalink
[xla:gpu] Do not use ncclSend and ncclRecv directly and use NcclApi p…
Browse files Browse the repository at this point in the history
…art #3

PiperOrigin-RevId: 599039077
  • Loading branch information
ezhulenev authored and copybara-github committed Jan 17, 2024
1 parent 588171c commit 48a80dd
Showing 1 changed file with 7 additions and 32 deletions.
39 changes: 7 additions & 32 deletions xla/service/gpu/nccl_collective_permute_thunk.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,6 @@ limitations under the License.
#include "xla/xla_data.pb.h"
#include "tsl/platform/errors.h"

#if XLA_ENABLE_XCCL
#include "xla/stream_executor/gpu/gpu_stream.h"
#endif

namespace xla {
namespace gpu {

Expand Down Expand Up @@ -236,7 +232,6 @@ absl::Status RunCollectivePermute(
NcclP2PConfig::SourceTargetMapEntry source_target, DeviceBufferPair& buffer,
se::Stream& stream, ncclComm_t comm, absl::string_view device_string,
int64_t current_id) {
#if XLA_ENABLE_XCCL
// Determine the source and target IDs for this instance. The source ID is the
// ID which will copy its data to this instance. The destination ID is the ID
// to which this instance will copy its data. Either are optional.
Expand Down Expand Up @@ -281,35 +276,20 @@ absl::Status RunCollectivePermute(
TF_RETURN_IF_ERROR(NcclApi::GroupStart());
}

TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier,
ToNcclDataTypeAndCountMultiplier(
buffer.element_type, Thunk::kNcclCollectivePermute));
ncclDataType_t dtype = dtype_and_multiplier.first;
int64_t element_count = buffer.element_count * dtype_and_multiplier.second;

se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream);

// Send source buffer to target peer if needed.
if (target_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d "
"comm=%p, stream=%p)",
device_string, src_addr.opaque(), element_count, *target_id,
static_cast<const void*>(comm), gpu_stream);
XLA_NCCL_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype,
*target_id, comm, gpu_stream));
TF_RETURN_IF_ERROR(NcclApi::Send(
src_addr, buffer.element_type, buffer.element_count, *target_id,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
}

// Receive data from the source peer to the destination buffer.
if (source_id) {
VLOG(3) << absl::StreamFormat(
"%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, "
"stream=%p)",
device_string, dest_addr.opaque(), element_count, *source_id,
static_cast<const void*>(comm), gpu_stream);
XLA_NCCL_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype,
*source_id, comm, gpu_stream));
TF_RETURN_IF_ERROR(NcclApi::Recv(
dest_addr, buffer.element_type, buffer.element_count, *source_id,
reinterpret_cast<NcclApi::NcclCommHandle>(comm), &stream));
}

if (is_nccl_group_needed) {
TF_RETURN_IF_ERROR(NcclApi::GroupEnd());
}
Expand All @@ -322,11 +302,6 @@ absl::Status RunCollectivePermute(
stream.ThenMemZero(&dest_addr, dest_addr.size());
}
return absl::OkStatus();
#else // XLA_ENABLE_XCCL
return Unimplemented(
"NCCL support is not available: this binary was not built with a CUDA "
"compiler, which is necessary to build the NCCL source library.");
#endif // XLA_ENABLE_XCCL
}

} // namespace gpu
Expand Down

0 comments on commit 48a80dd

Please sign in to comment.