From fa7840346b28edde9ea87bd90501807b3981dabd Mon Sep 17 00:00:00 2001 From: Wenkai Du <43822138+wenkaidu@users.noreply.github.com> Date: Mon, 24 Jun 2024 08:46:30 -0700 Subject: [PATCH] [Kernel] Enable custom AR on ROCm (#27) * [Kernel] Enable custome AR on ROCm * Install amdsmi in Docker in preparation for custom all reduce (cherry picked from commit f6cfb9bf31e9feeefbdedecf2165f80dd0564b75) * Fix for yapf * Linting and small fixes to vLLM syntax (cherry picked from commit 2cf8103bfb0afce59b28a06c5bbe905983c42728) --------- Co-authored-by: Matthew Wong --- CMakeLists.txt | 5 + Dockerfile.rocm | 13 ++- benchmarks/benchmark_latency.py | 2 + csrc/custom_all_reduce.cu | 39 +++++++ csrc/custom_all_reduce.cuh | 54 ++++++++- csrc/custom_all_reduce_test.cu | 27 ++++- csrc/ops.h | 5 +- csrc/pybind.cpp | 7 +- vllm/config.py | 17 +-- .../device_communicators/custom_all_reduce.py | 105 +++++++++++++----- 10 files changed, 230 insertions(+), 44 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 15b9cfe677a57..2b3679a0548c9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -222,6 +222,11 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() +if(VLLM_GPU_LANG STREQUAL "HIP") + list(APPEND VLLM_EXT_SRC + "csrc/custom_all_reduce.cu") +endif() + define_gpu_extension_target( _C DESTINATION vllm diff --git a/Dockerfile.rocm b/Dockerfile.rocm index ac0d2d8a6aa5e..46f1edd405c4f 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -117,6 +117,13 @@ COPY --from=build_triton ${COMMON_WORKDIR}/triton/python/dist/*.whl / FROM scratch AS export_triton_0 FROM export_triton_${BUILD_TRITON} AS export_triton +# AMD-SMI build stages +FROM base AS build_amdsmi +RUN cd /opt/rocm/share/amd_smi \ + && pip wheel . --wheel-dir=dist +FROM scratch AS export_amdsmi +COPY --from=build_amdsmi /opt/rocm/share/amd_smi/dist/*.whl / + # ----------------------- # vLLM (and gradlib) fetch stages FROM base AS fetch_vllm_0 @@ -201,7 +208,10 @@ RUN --mount=type=bind,from=export_triton,src=/,target=/install \ pip install /install/*.whl; \ fi -RUN python3 -m pip install --upgrade numba +RUN --mount=type=bind,from=export_amdsmi,src=/,target=/install \ + pip install /install/*.whl; + +RUN python3 -m pip install --upgrade numba scipy huggingface-hub[cli] # Install vLLM (and gradlib) # Make sure punica kernels are built (for LoRA) @@ -221,6 +231,7 @@ RUN --mount=type=bind,from=export_vllm,src=/,target=/install \ COPY --from=export_vllm /benchmarks ${COMMON_WORKDIR}/vllm/benchmarks ENV RAY_EXPERIMENTAL_NOSET_ROCR_VISIBLE_DEVICES=1 +ENV TOKENIZERS_PARALLELISM=false # Performance environment variable. ENV HIP_FORCE_DEV_KERNARG=1 diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 7eeb90516bdfa..400bb9936e02d 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -38,6 +38,7 @@ def main(args: argparse.Namespace): enable_chunked_prefill=args.enable_chunked_prefill, download_dir=args.download_dir, block_size=args.block_size, + disable_custom_all_reduce=args.disable_custom_all_reduce, gpu_memory_utilization=args.gpu_memory_utilization) sampling_params = SamplingParams( @@ -229,6 +230,7 @@ def run_to_completion(profile_dir: Optional[str] = None): type=str, default=None, help='Path to save the latency results in JSON format.') + parser.add_argument('--disable_custom_all_reduce', action='store_true') parser.add_argument('--gpu-memory-utilization', type=float, default=0.9, diff --git a/csrc/custom_all_reduce.cu b/csrc/custom_all_reduce.cu index 0b1d95848525a..9069a98b51ccf 100644 --- a/csrc/custom_all_reduce.cu +++ b/csrc/custom_all_reduce.cu @@ -145,3 +145,42 @@ void register_graph_buffers(fptr_t _fa, const std::vector& handles, auto fa = reinterpret_cast(_fa); fa->register_graph_buffers(handles, offsets); } + +#ifdef USE_ROCM + +void free_meta_buffer(void* buffer) { hipFree(buffer); } + +std::vector get_meta_buffer_ipc_handle(torch::Tensor inp) { + std::vector data_handle(sizeof(cudaIpcMemHandle_t), 0); + CUDACHECK(cudaIpcGetMemHandle((cudaIpcMemHandle_t*)data_handle.data(), + inp.data_ptr())); + return data_handle; +} + +torch::Tensor allocate_meta_buffer(int size) { + auto device_index = c10::cuda::current_device(); + at::DeviceGuard device_guard(at::Device(at::DeviceType::CUDA, device_index)); + void* buffer; + cudaStreamCaptureMode mode = cudaStreamCaptureModeRelaxed; + auto stream = c10::cuda::getCurrentCUDAStream().stream(); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + AT_CUDA_CHECK( + hipExtMallocWithFlags((void**)&buffer, size, hipDeviceMallocUncached)); + AT_CUDA_CHECK(cudaMemsetAsync(buffer, 0, size, stream)); + AT_CUDA_CHECK(cudaStreamSynchronize(stream)); + AT_CUDA_CHECK(cudaThreadExchangeStreamCaptureMode(&mode)); + auto options = torch::TensorOptions() + .dtype(torch::kI8) + .device(torch::kCUDA, device_index); + return torch::from_blob(buffer, {size}, free_meta_buffer, options); +} + +std::vector get_device_bdf(int dev) { + char busIdStr[] = "0000:00:00.0"; + std::vector bdf(sizeof(busIdStr), 0); + CUDACHECK(cudaDeviceGetPCIBusId((char*)bdf.data(), sizeof(busIdStr), dev)); + bdf.resize(bdf.size() - 1); // remove trailing NULL + return bdf; +} + +#endif diff --git a/csrc/custom_all_reduce.cuh b/csrc/custom_all_reduce.cuh index 1ed49b8aa9cae..c640b15a2346a 100644 --- a/csrc/custom_all_reduce.cuh +++ b/csrc/custom_all_reduce.cuh @@ -1,7 +1,12 @@ #pragma once #include -#include +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 nv_bfloat16; +#else + #include +#endif #include #include @@ -29,9 +34,14 @@ constexpr int kMaxBlocks = 64; struct Signal { alignas(128) uint32_t start[kMaxBlocks][8]; alignas(128) uint32_t end[kMaxBlocks][8]; + alignas(128) uint32_t _flag[kMaxBlocks]; // incremental flags for each rank }; +#ifdef USE_ROCM +struct __align__(16) RankData { const void* ptrs[8]; }; +#else struct __align__(16) RankData { const void* __restrict__ ptrs[8]; }; +#endif struct __align__(16) RankSignals { volatile Signal* signals[8]; }; @@ -130,6 +140,21 @@ DINLINE O downcast(array_t val) { template DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { +#ifdef USE_ROCM + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __atomic_store_n(&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, + __ATOMIC_RELAXED); + // wait until we got true from all ranks + while (__atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], + __ATOMIC_RELAXED) < flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else if (threadIdx.x < ngpus) { // reset flag for next time self_sg->end[blockIdx.x][threadIdx.x] = 0; @@ -140,6 +165,7 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, while (!self_sg->start[blockIdx.x][threadIdx.x]); } __syncthreads(); +#endif } // This function is meant to be used as the second or the final synchronization @@ -148,6 +174,27 @@ DINLINE void start_sync(const RankSignals& sg, volatile Signal* self_sg, template DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, int rank) { +#ifdef USE_ROCM + __syncthreads(); + // eliminate the case that prior writes are not visible after signals become + // visible. Note that I did not managed to make this happen through a lot of + // testing. Might be the case that hardware provides stronger guarantee than + // the memory model. + uint32_t flag = self_sg->_flag[blockIdx.x] + 1; + if (threadIdx.x < ngpus) { + // simultaneously write to the corresponding flag of all ranks. + // Latency = 1 p2p write + __atomic_store_n(&sg.signals[threadIdx.x]->end[blockIdx.x][rank], flag, + final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE); + // wait until we got true from all ranks + while (__atomic_load_n(&self_sg->end[blockIdx.x][threadIdx.x], + final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE) < + flag); + } + __syncthreads(); + // use one thread to update flag + if (threadIdx.x == 0) self_sg->_flag[blockIdx.x] = flag; +#else __syncthreads(); // eliminate the case that prior writes are not visible after signals become // visible. Note that I did not managed to make this happen through a lot of @@ -164,6 +211,7 @@ DINLINE void end_sync(const RankSignals& sg, volatile Signal* self_sg, while (!self_sg->end[blockIdx.x][threadIdx.x]); } if constexpr (!final_sync) __syncthreads(); +#endif } template @@ -324,7 +372,11 @@ class CustomAllreduce { // note: must share the base address of each allocation, or we get wrong // address if (cuPointerGetAttribute(&base_ptr, +#ifdef USE_ROCM + HIP_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#else CU_POINTER_ATTRIBUTE_RANGE_START_ADDR, +#endif (CUdeviceptr)ptr) != CUDA_SUCCESS) throw std::runtime_error("failed to get pointer attr"); CUDACHECK(cudaIpcGetMemHandle( diff --git a/csrc/custom_all_reduce_test.cu b/csrc/custom_all_reduce_test.cu index f7868233076cd..9b809caa6e045 100644 --- a/csrc/custom_all_reduce_test.cu +++ b/csrc/custom_all_reduce_test.cu @@ -20,9 +20,16 @@ #include #include "cuda_profiler_api.h" -#include "custom_all_reduce.cuh" #include "mpi.h" -#include "nccl.h" +#ifdef USE_ROCM + #include +typedef __hip_bfloat16 nv_bfloat16; + #include "rccl/rccl.h" + #include "custom_all_reduce_hip.cuh" +#else + #include "nccl.h" + #include "custom_all_reduce.cuh" +#endif #define MPICHECK(cmd) \ do { \ @@ -44,7 +51,17 @@ } while (0) __global__ void dummy_kernel() { +#ifdef USE_ROCM + for (int i = 0; i < 100; i++) { + uint64_t start = wall_clock64(); + uint64_t cycles_elapsed; + do { + cycles_elapsed = wall_clock64() - start; + } while (cycles_elapsed < 100); + } +#else for (int i = 0; i < 100; i++) __nanosleep(1000000); // 100ms +#endif } template @@ -114,8 +131,14 @@ void run(int myRank, int nRanks, ncclComm_t& comm, int threads, int block_limit, * registration, they are allocated and registered together in the test for * convenience. */ +#ifdef USE_ROCM + CUDACHECK(hipExtMallocWithFlags( + (void**)&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal), + hipDeviceMallocUncached)); +#else CUDACHECK( cudaMalloc(&buffer, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); +#endif CUDACHECK( cudaMemset(buffer, 0, 2 * data_size * sizeof(T) + sizeof(vllm::Signal))); CUDACHECK(cudaMalloc(&self_data_copy, data_size * sizeof(T))); diff --git a/csrc/ops.h b/csrc/ops.h index aa015c3d5dc39..836119755cad9 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -131,7 +131,6 @@ void moe_align_block_size(torch::Tensor topk_ids, int num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); -#ifndef USE_ROCM using fptr_t = uint64_t; fptr_t init_custom_ar(torch::Tensor& meta, torch::Tensor& rank_data, const std::vector& handles, @@ -151,4 +150,8 @@ std::pair, std::vector> get_graph_buffer_ipc_meta( fptr_t _fa); void register_graph_buffers(fptr_t _fa, const std::vector& handles, const std::vector>& offsets); +#ifdef USE_ROCM +torch::Tensor allocate_meta_buffer(int size); +std::vector get_meta_buffer_ipc_handle(torch::Tensor inp); +std::vector get_device_bdf(int dev); #endif diff --git a/csrc/pybind.cpp b/csrc/pybind.cpp index a507af396bcf9..3bbfedd568267 100644 --- a/csrc/pybind.cpp +++ b/csrc/pybind.cpp @@ -98,7 +98,6 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &get_max_shared_memory_per_block_device_attribute, "Gets the maximum shared memory per block device attribute."); -#ifndef USE_ROCM // Custom all-reduce kernels pybind11::module custom_ar = m.def_submodule("custom_ar", "custom allreduce"); custom_ar.def("init_custom_ar", &init_custom_ar, "init_custom_ar"); @@ -112,5 +111,11 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "get_graph_buffer_ipc_meta"); custom_ar.def("register_graph_buffers", ®ister_graph_buffers, "register_graph_buffers"); +#ifdef USE_ROCM + custom_ar.def("allocate_meta_buffer", &allocate_meta_buffer, + "allocate_meta_buffer"); + custom_ar.def("get_meta_buffer_ipc_handle", &get_meta_buffer_ipc_handle, + "get_meta_buffer_ipc_handle"); + custom_ar.def("get_device_bdf", &get_device_bdf, "get_device_bdf"); #endif } diff --git a/vllm/config.py b/vllm/config.py index a810b67777224..160ec63d4d501 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -623,17 +623,12 @@ def _verify_args(self) -> None: raise ValueError( "Unrecognized distributed executor backend. Supported values " "are 'ray' or 'mp' or 'torchrun'.") - if not self.disable_custom_all_reduce and self.world_size > 1: - if is_hip(): - self.disable_custom_all_reduce = True - logger.info( - "Disabled the custom all-reduce kernel because it is not " - "supported on AMD GPUs.") - elif self.pipeline_parallel_size > 1: - self.disable_custom_all_reduce = True - logger.info( - "Disabled the custom all-reduce kernel because it is not " - "supported with pipeline parallelism.") + if (not self.disable_custom_all_reduce and self.world_size > 1 + and self.pipeline_parallel_size > 1): + self.disable_custom_all_reduce = True + logger.info( + "Disabled the custom all-reduce kernel because it is not " + "supported with pipeline parallelism.") if self.ray_workers_use_nsight and ( not self.distributed_executor_backend == "ray"): raise ValueError("Unable to use nsight profiling unless workers " diff --git a/vllm/distributed/device_communicators/custom_all_reduce.py b/vllm/distributed/device_communicators/custom_all_reduce.py index a3902aecb3793..28edeb69eb37b 100644 --- a/vllm/distributed/device_communicators/custom_all_reduce.py +++ b/vllm/distributed/device_communicators/custom_all_reduce.py @@ -11,19 +11,32 @@ from vllm.distributed.parallel_state import ( get_local_rank, get_tensor_model_parallel_cpu_group) from vllm.logger import init_logger +from vllm.utils import is_hip try: - import pynvml + if is_hip(): + from amdsmi import (AmdSmiException, + amdsmi_get_processor_handle_from_bdf, amdsmi_init, + amdsmi_shut_down, amdsmi_topo_get_link_type) + else: + import pynvml from vllm._C import custom_ar @contextmanager def _nvml(): - try: - pynvml.nvmlInit() - yield - finally: - pynvml.nvmlShutdown() + if torch.version.hip: + try: + amdsmi_init() + yield + finally: + amdsmi_shut_down() + else: + try: + pynvml.nvmlInit() + yield + finally: + pynvml.nvmlShutdown() except ImportError: # For AMD GPUs @@ -42,27 +55,49 @@ def _nvml(): @_nvml() -def _is_full_nvlink(device_ids: List[int]) -> bool: +def _is_full_nvlink(device_ids: List[int], world_size) -> bool: """ query if the set of gpus are fully connected by nvlink (1 hop) Note that `pynvml` is not affected by `CUDA_VISIBLE_DEVICES`, so it works on real physical device ids. """ - handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] - for i, handle in enumerate(handles): - for j, peer_handle in enumerate(handles): - if i < j: - try: - p2p_status = pynvml.nvmlDeviceGetP2PStatus( - handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK) - if p2p_status != pynvml.NVML_P2P_STATUS_OK: + if is_hip(): + # get devices' BDF in order to get XGMI link info from amdsmi + bdf = custom_ar.get_device_bdf(torch.cuda.current_device()) + all_bdf = [0] * world_size + dist.all_gather_object(all_bdf, bdf) + hsmi = [None] * world_size + try: + for i in range(world_size): + bdf_str = str(bytes(all_bdf[i]).decode("utf-8")) + hsmi[i] = amdsmi_get_processor_handle_from_bdf(bdf_str) + for i in range(world_size): + if i != 0: + link_type = amdsmi_topo_get_link_type(hsmi[0], hsmi[i]) + # type is 2 for XGMI + if link_type['hops'] != 1 or link_type['type'] != 2: + return False + except AmdSmiException as e: + logger.warning(e) + return False + return True + else: + handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in device_ids] + for i, handle in enumerate(handles): + for j, peer_handle in enumerate(handles): + if i < j: + try: + p2p_status = pynvml.nvmlDeviceGetP2PStatus( + handle, peer_handle, + pynvml.NVML_P2P_CAPS_INDEX_NVLINK) + if p2p_status != pynvml.NVML_P2P_STATUS_OK: + return False + except pynvml.NVMLError as error: + logger.error( + "NVLink detection failed. This is normal if your" + " machine has no NVLink equipped.", + exc_info=error) return False - except pynvml.NVMLError as error: - logger.error( - "NVLink detection failed. This is normal if your" - " machine has no NVLink equipped.", - exc_info=error) - return False return True @@ -153,7 +188,7 @@ def __init__(self, # test nvlink first, this will filter out most of the cases # where custom allreduce is not supported # this checks hardware and driver support for NVLink - full_nvlink = _is_full_nvlink(physical_device_ids) + full_nvlink = _is_full_nvlink(physical_device_ids, world_size) if world_size > 2 and not full_nvlink: logger.warning( "Custom allreduce is disabled because it's not supported on" @@ -163,7 +198,8 @@ def __init__(self, # test P2P capability, this checks software/cudaruntime support # this is expensive to compute at the first time # then we cache the result - if not _can_p2p(rank, world_size): + # On AMD GPU, p2p is always enabled between XGMI connected GPUs + if not is_hip() and not _can_p2p(rank, world_size): logger.warning( "Custom allreduce is disabled because your platform lacks " "GPU P2P capability or P2P test failed. To silence this " @@ -175,9 +211,14 @@ def __init__(self, # meta data composes of two parts: meta data for synchronization # (256 bytes) and a temporary buffer for storing intermediate # allreduce results. - self.meta = torch.zeros(custom_ar.meta_size() + max_size, - dtype=torch.uint8, - device=self.device) + if is_hip(): + # meta data buffers need to be "uncached" for signal on MI200 + self.meta = custom_ar.allocate_meta_buffer(custom_ar.meta_size() + + max_size) + else: + self.meta = torch.zeros(custom_ar.meta_size() + max_size, + dtype=torch.uint8, + device=self.device) # This is a pre-registered IPC buffer. In eager mode, input tensors # are first copied into this buffer before allreduce is performed self.buffer = torch.empty(max_size, @@ -194,7 +235,17 @@ def __init__(self, self.max_size = max_size self.rank = rank self.world_size = world_size - handles, offsets = self._get_ipc_meta(self.meta) + if is_hip(): + # _share_cuda_() doesn't accept meta buffer not allocated from + # PyTorch cache allocator, use direct HIP call to get IPC handle + handle = custom_ar.get_meta_buffer_ipc_handle(self.meta) + shard_data = ( + bytes(handle), # ipc handle to base ptr + 0, # offset of base ptr + ) + handles, offsets = self._gather_ipc_meta(shard_data) + else: + handles, offsets = self._get_ipc_meta(self.meta) self.full_nvlink = full_nvlink self._ptr = custom_ar.init_custom_ar(self.meta, self.rank_data, handles, offsets, rank,