Skip to content

Commit

Permalink
Enable custom all-reduce for ROCm.
Browse files Browse the repository at this point in the history
* add csrc custom_all_reduce into compilation
* add ops and pybindings
* do not disable custom all-reduce on ROCm in config
* port the custom all reduce source and test to HIP
* remove volatile signatures on ROCm
* optimized locking for ROCm
* increase default #threads and decrease #blocks after testing
  • Loading branch information
iotamudelta committed Jun 5, 2024
1 parent 69ce080 commit 21b9048
Show file tree
Hide file tree
Showing 8 changed files with 133 additions and 14 deletions.
5 changes: 5 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA")
"csrc/custom_all_reduce.cu")
endif()

if(VLLM_GPU_LANG STREQUAL "HIP")
list(APPEND VLLM_EXT_SRC
"csrc/custom_all_reduce.cu")
endif()

define_gpu_extension_target(
_C
DESTINATION vllm
Expand Down
9 changes: 8 additions & 1 deletion csrc/custom_all_reduce.cu
Original file line number Diff line number Diff line change
Expand Up @@ -84,13 +84,20 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
out.numel());
break;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__)) && !defined USE_ROCM
case at::ScalarType::BFloat16: {
fa->allreduce<nv_bfloat16>(
stream, reinterpret_cast<nv_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<nv_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#elif defined USE_ROCM
case at::ScalarType::BFloat16: {
fa->allreduce<__hip_bfloat16>(
stream, reinterpret_cast<__hip_bfloat16 *>(inp.data_ptr()),
reinterpret_cast<__hip_bfloat16 *>(out.data_ptr()), out.numel());
break;
}
#endif
default:
throw std::runtime_error(
Expand Down
84 changes: 82 additions & 2 deletions csrc/custom_all_reduce.cuh
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
#pragma once

#include <cuda.h>
#ifndef USE_ROCM
#include <cuda_bf16.h>
#else
#include <hip/amd_detail/amd_hip_bf16.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>

Expand Down Expand Up @@ -31,9 +35,17 @@ struct Signal {
alignas(128) uint32_t end[kMaxBlocks][8];
};

#ifndef USE_ROCM
struct __align__(16) RankData { const void *__restrict__ ptrs[8]; };
#else
struct __align__(16) RankData { const void * ptrs[8]; };
#endif

#ifndef USE_ROCM
struct __align__(16) RankSignals { volatile Signal *signals[8]; };
#else
struct __align__(16) RankSignals { Signal *signals[8]; };
#endif

// like std::array, but aligned
template <typename T, int sz>
Expand Down Expand Up @@ -74,6 +86,7 @@ DINLINE half &assign_add(half &a, half b) {
}
DINLINE float &assign_add(float &a, float b) { return a += b; }

#ifndef USE_ROCM
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE float upcast_s(nv_bfloat16 val) { return __bfloat162float(val); }
template <>
Expand All @@ -85,6 +98,17 @@ DINLINE nv_bfloat16 &assign_add(nv_bfloat16 &a, nv_bfloat16 b) {
return a;
}
#endif
#else
DINLINE float upcast_s(__hip_bfloat16 val) { return __bfloat162float(val); }
template <>
DINLINE __hip_bfloat16 downcast_s(float val) {
return __float2bfloat16(val);
}
DINLINE __hip_bfloat16 &assign_add(__hip_bfloat16 &a, __hip_bfloat16 b) {
a = __hadd(a, b);
return a;
}
#endif

template <typename T, int N>
DINLINE array_t<T, N> &packed_assign_add(array_t<T, N> &a, array_t<T, N> b) {
Expand Down Expand Up @@ -128,16 +152,30 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template <int ngpus>
#ifndef USE_ROCM
DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
int rank) {
#else
DINLINE void start_sync(const RankSignals &sg, Signal *self_sg, int rank) {
#endif
if (threadIdx.x < ngpus) {
// reset flag for next time
#ifndef USE_ROCM
self_sg->end[blockIdx.x][threadIdx.x] = 0;
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
sg.signals[threadIdx.x]->start[blockIdx.x][rank] = 1;
// wait until we got true from all ranks
while (!self_sg->start[blockIdx.x][threadIdx.x])
#else
__atomic_store_n(&self_sg->end[blockIdx.x][threadIdx.x], 0, __ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], 1, __ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED))
#endif
;
}
__syncthreads();
Expand All @@ -147,13 +185,18 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template <int ngpus, bool final_sync = false>
#ifndef USE_ROCM
DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
int rank) {
#else
DINLINE void end_sync(const RankSignals &sg, Signal *self_sg, int rank) {
#endif
__syncthreads();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
#ifndef USE_ROCM
if constexpr (!final_sync) __threadfence_system();
if (threadIdx.x < ngpus) {
// reset flag for next time
Expand All @@ -164,6 +207,18 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// wait until we got true from all ranks
while (!self_sg->end[blockIdx.x][threadIdx.x])
;
#else
if (threadIdx.x < ngpus) {
// reset flag for next time
__atomic_store_n(&self_sg->start[blockIdx.x][threadIdx.x], 0, __ATOMIC_RELAXED);
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], 1, __ATOMIC_RELAXED);
__atomic_thread_fence(__ATOMIC_ACQ_REL);
// wait until we got true from all ranks
while (!__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED))
;
#endif
}
if constexpr (!final_sync) __syncthreads();
}
Expand All @@ -179,10 +234,16 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
}

template <typename T, int ngpus>
#ifndef USE_ROCM
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
#else
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_1stage(RankData *_dp, RankSignals sg,
Signal *self_sg, T *__restrict__ result, int rank, int size) {
#endif
using P = typename packed_t<T>::P;
using A = typename packed_t<T>::A;
// note: we don't reorder the address so the accumulation order is the same
Expand All @@ -199,15 +260,26 @@ __global__ void __launch_bounds__(512, 1)
}

template <typename P>
#ifndef USE_ROCM
DINLINE P *get_tmp_buf(volatile Signal *sg) {
#else
DINLINE P *get_tmp_buf(Signal *sg) {
#endif
return (P *)(((Signal *)sg) + 1);
}

template <typename T, int ngpus>
#ifndef USE_ROCM
__global__ void __launch_bounds__(512, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
volatile Signal *self_sg, T *__restrict__ result,
int rank, int size) {
#else
__global__ void __launch_bounds__(1024, 1)
cross_device_reduce_2stage(RankData *_dp, RankSignals sg,
Signal *self_sg, T *__restrict__ result,
int rank, int size) {
#endif
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = gridDim.x * blockDim.x;
using P = typename packed_t<T>::P;
Expand Down Expand Up @@ -327,8 +399,12 @@ class CustomAllreduce {
// note: must share the base address of each allocation, or we get wrong
// address
if (cuPointerGetAttribute(&base_ptr,
#ifndef USE_ROCM
CU_POINTER_ATTRIBUTE_RANGE_START_ADDR,
(CUdeviceptr)ptr) != CUDA_SUCCESS)
#else
HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR,
#endif
(CUdeviceptr)ptr) != CUDA_SUCCESS)
throw std::runtime_error("failed to get pointer attr");
CUDACHECK(cudaIpcGetMemHandle(
(cudaIpcMemHandle_t *)&handles[i * handle_sz], base_ptr));
Expand Down Expand Up @@ -406,7 +482,11 @@ class CustomAllreduce {
*/
template <typename T>
void allreduce(cudaStream_t stream, T *input, T *output, int size,
#ifndef USE_ROCM
int threads = 512, int block_limit = 36) {
#else
int threads = 1024, int block_limit = 36) {
#endif
auto d = packed_t<T>::P::size;
if (size % d != 0)
throw std::runtime_error(
Expand Down
30 changes: 30 additions & 0 deletions csrc/custom_all_reduce_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
* export MPI_HOME=XXX
* nvcc -O2 -arch=native -std=c++17 custom_all_reduce_test.cu -o
* custom_all_reduce_test -lnccl -I${MPI_HOME}/include -lmpi
* to hipify and compile
* export MPI_HOME=XXX
* hipify-perl custom_all_reduce_test.cu > custom_all_reduce_test.hip
* hipcc -O2 -std=c++17 custom_all_reduce_test.hip -o custom_all_reduce_test -lrccl -I${MPI_HOME}/include -L${MPI_HOME}/lib -lmpi -DUSE_ROCM=1
*
* Warning: this C++ test is not designed to be very readable and was used
* during the rapid prototyping process.
Expand All @@ -12,17 +16,29 @@
* mpirun -np 8 ./custom_all_reduce_test
*/
#include <cuda.h>
#ifndef USE_ROCM
#include <curand_kernel.h>
#else
#include <hiprand/hiprand_kernel.h>
#endif
#include <stdio.h>
#include <stdlib.h>

#include <limits>
#include <vector>

#include "cuda_profiler_api.h"
#ifndef USE_ROCM
#include "custom_all_reduce.cuh"
#else
#include "custom_all_reduce_hip.cuh"
#endif
#include "mpi.h"
#ifndef USE_ROCM
#include "nccl.h"
#else
#include <rccl/rccl.h>
#endif

#define MPICHECK(cmd) \
do { \
Expand All @@ -44,7 +60,12 @@
} while (0)

__global__ void dummy_kernel() {
#ifndef USE_ROCM
for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms
#else
#pragma unroll
for (int i = 0; i < 100; i++) __builtin_amdgcn_s_sleep(127);
#endif
}

template <typename T>
Expand Down Expand Up @@ -164,7 +185,11 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
ncclDataType_t ncclDtype;
if (std::is_same<T, half>::value) {
ncclDtype = ncclFloat16;
#ifndef USE_ROCM
} else if (std::is_same<T, nv_bfloat16>::value) {
#else
} else if (std::is_same<T, __hip_bfloat16>::value) {
#endif
ncclDtype = ncclBfloat16;
} else {
ncclDtype = ncclFloat;
Expand Down Expand Up @@ -308,9 +333,14 @@ int main(int argc, char **argv) {
// }
// }
for (int sz = 512; sz <= (8 << 20); sz *= 2) {
#ifndef USE_ROCM
run<half>(myRank, nRanks, comm, 512, 36, sz + 8 * 47, performance_test);
#else
run<half>(myRank, nRanks, comm, 1024, 16, sz + 8 * 47, performance_test);
#endif
}

cudaProfilerStop();
MPICHECK(MPI_Finalize());
return EXIT_SUCCESS;
}
4 changes: 1 addition & 3 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ void moe_align_block_size(
torch::Tensor experts_ids,
torch::Tensor num_tokens_post_pad);

#ifndef USE_ROCM
using fptr_t = uint64_t;
fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
const std::vector<std::string> &handles,
Expand All @@ -158,7 +157,6 @@ void register_buffer(fptr_t _fa, torch::Tensor &t,
std::pair<std::vector<uint8_t>, std::vector<int64_t>> get_graph_buffer_ipc_meta(fptr_t _fa);
void register_graph_buffers(fptr_t _fa, const std::vector<std::string> &handles,
const std::vector<std::vector<int64_t>> &offsets);
#endif

void convert_fp8(
torch::Tensor& src_cache,
Expand All @@ -180,4 +178,4 @@ torch::Tensor fp8_gemm_16(
torch::Tensor& scaleA,
torch::Tensor& scaleB,
int algo_idx
);
);
2 changes: 0 additions & 2 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&get_max_shared_memory_per_block_device_attribute,
"Gets the maximum shared memory per block device attribute.");

#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce");
custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar");
Expand All @@ -126,6 +125,5 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_graph_buffer_ipc_meta");
custom_ar.def("register_graph_buffers", &register_graph_buffers,
"register_graph_buffers");
#endif

}
5 changes: 2 additions & 3 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,10 +526,9 @@ def _verify_args(self) -> None:
"Pipeline parallelism is not supported yet.")
if not self.disable_custom_all_reduce and self.world_size > 1:
if is_hip():
self.disable_custom_all_reduce = True
self.disable_custom_all_reduce = False
logger.info(
"Disabled the custom all-reduce kernel because it is not "
"supported on AMD GPUs.")
"Enable the custom all-reduce kernel on AMD GPUs.")
elif self.pipeline_parallel_size > 1:
self.disable_custom_all_reduce = True
logger.info(
Expand Down
8 changes: 5 additions & 3 deletions vllm/model_executor/parallel_utils/custom_all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size)

try:
import pynvml

from vllm._C import custom_ar
except ImportError:
# For AMD GPUs
custom_ar = None

try:
import pynvml
except ImportError:
# For AMD GPUs
pynvml = None

logger = init_logger(__name__)
Expand Down

0 comments on commit 21b9048

Please sign in to comment.