diff --git a/xla/service/gpu/nccl_collective_permute_thunk.cc b/xla/service/gpu/nccl_collective_permute_thunk.cc index 1f47069059056..9b27d6611200d 100644 --- a/xla/service/gpu/nccl_collective_permute_thunk.cc +++ b/xla/service/gpu/nccl_collective_permute_thunk.cc @@ -39,10 +39,6 @@ limitations under the License. #include "xla/xla_data.pb.h" #include "tsl/platform/errors.h" -#if XLA_ENABLE_XCCL -#include "xla/stream_executor/gpu/gpu_stream.h" -#endif - namespace xla { namespace gpu { @@ -236,7 +232,6 @@ absl::Status RunCollectivePermute( NcclP2PConfig::SourceTargetMapEntry source_target, DeviceBufferPair& buffer, se::Stream& stream, ncclComm_t comm, absl::string_view device_string, int64_t current_id) { -#if XLA_ENABLE_XCCL // Determine the source and target IDs for this instance. The source ID is the // ID which will copy its data to this instance. The destination ID is the ID // to which this instance will copy its data. Either are optional. @@ -281,35 +276,20 @@ absl::Status RunCollectivePermute( TF_RETURN_IF_ERROR(NcclApi::GroupStart()); } - TF_ASSIGN_OR_RETURN(auto dtype_and_multiplier, - ToNcclDataTypeAndCountMultiplier( - buffer.element_type, Thunk::kNcclCollectivePermute)); - ncclDataType_t dtype = dtype_and_multiplier.first; - int64_t element_count = buffer.element_count * dtype_and_multiplier.second; - - se::gpu::GpuStreamHandle gpu_stream = se::gpu::AsGpuStreamValue(&stream); - // Send source buffer to target peer if needed. if (target_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclSend(sendbuff=%p, count=%d, peer=%d " - "comm=%p, stream=%p)", - device_string, src_addr.opaque(), element_count, *target_id, - static_cast(comm), gpu_stream); - XLA_NCCL_RETURN_IF_ERROR(ncclSend(src_addr.opaque(), element_count, dtype, - *target_id, comm, gpu_stream)); + TF_RETURN_IF_ERROR(NcclApi::Send( + src_addr, buffer.element_type, buffer.element_count, *target_id, + reinterpret_cast(comm), &stream)); } // Receive data from the source peer to the destination buffer. if (source_id) { - VLOG(3) << absl::StreamFormat( - "%s : Calling ncclRecv(recvbuff=%p, count=%d, peer=%d comm=%p, " - "stream=%p)", - device_string, dest_addr.opaque(), element_count, *source_id, - static_cast(comm), gpu_stream); - XLA_NCCL_RETURN_IF_ERROR(ncclRecv(dest_addr.opaque(), element_count, dtype, - *source_id, comm, gpu_stream)); + TF_RETURN_IF_ERROR(NcclApi::Recv( + dest_addr, buffer.element_type, buffer.element_count, *source_id, + reinterpret_cast(comm), &stream)); } + if (is_nccl_group_needed) { TF_RETURN_IF_ERROR(NcclApi::GroupEnd()); } @@ -322,11 +302,6 @@ absl::Status RunCollectivePermute( stream.ThenMemZero(&dest_addr, dest_addr.size()); } return absl::OkStatus(); -#else // XLA_ENABLE_XCCL - return Unimplemented( - "NCCL support is not available: this binary was not built with a CUDA " - "compiler, which is necessary to build the NCCL source library."); -#endif // XLA_ENABLE_XCCL } } // namespace gpu