Skip to content

Commit

Permalink
Fix the FP6 kernels compilation problem on non-Ampere GPUs. (microsof…
Browse files Browse the repository at this point in the history
…t#5333)

Refine the guards of FP6 kernel compilation. Fix the `undefined symbol`
problem of FP6 kernels on non-Ampere architectures.

Related issue: microsoft/DeepSpeed-MII#443.

---------

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Michael Wyatt <[email protected]>
  • Loading branch information
3 people authored and rraminen committed May 9, 2024
1 parent ea7e250 commit c5e3a5e
Show file tree
Hide file tree
Showing 16 changed files with 96 additions and 34 deletions.
4 changes: 2 additions & 2 deletions deepspeed/inference/v2/kernels/core_ops/core_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

#include "bias_activation.h"
#include "blas.h"
#include "cuda_linear_kernels.h"
#include "gated_activation_kernels.h"
#include "layer_norm.h"
#include "linear_kernels.h"
#include "rms_norm.h"

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
Expand All @@ -35,7 +35,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
m.def("rms_norm", &rms_norm, "DeepSpeed rms norm in CUDA");
m.def("rms_pre_norm", &rms_pre_norm, "DeepSpeed rms pre norm in CUDA");

// cuda_linear_kernels.h
// linear_kernels.h
m.def("cuda_wf6af16_linear", &cuda_wf6af16_linear, "DeepSpeed Wf6Af16 linear in CUDA");
m.def(
"preprocess_weight", &preprocess_weight, "preprocess the FP16 weight to be 2bit and 4 bit");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef DEEPSPEED_CUDA_LINEAR_KERNEL_MATMUL_CUH
#define DEEPSPEED_CUDA_LINEAR_KERNEL_MATMUL_CUH

#include "configs.h"
#include "utils_core.cuh"
#include "utils_gmem.cuh"
Expand All @@ -26,6 +29,8 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1,
const size_t K_Global,
int Split_K)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 && __CUDA_ARCH__ < 900

#ifdef DEBUG_MODE
assert(K_Global % TilingConfig::TILE_K == 0);
assert(M_Global % TilingConfig::TILE_M == 0);
Expand Down Expand Up @@ -258,4 +263,10 @@ __global__ void QUANT_GEMM_Kernel(const uint4* Weight1,
else
BlockGlobalPTR[j + i * M_Global] = smem_CFrag[i][j];
}

#else
#warning "The FP6 functions are only available on Ampere GPUs."
#endif
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef DEEPSPEED_CUDA_LINEAR_KERNEL_REDUCTION_CUH
#define DEEPSPEED_CUDA_LINEAR_KERNEL_REDUCTION_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -36,3 +39,5 @@ __global__ void SplitK_Reduction(half* C,
#pragma unroll
for (int i = 0; i < HALF_PER_128BIT; i++) THREAD_GPTR_C[i] = __float2half_rn(Results[i]);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef PTX_CP_ASYNC_CUH
#define PTX_CP_ASYNC_CUH
#ifndef DEEPSPEED_CUDA_LINEAR_PTX_CP_ASYNC_CUH
#define DEEPSPEED_CUDA_LINEAR_PTX_CP_ASYNC_CUH

#include <cuda.h>
#include <cuda_fp16.h>
Expand All @@ -17,6 +17,7 @@ __device__ __forceinline__ void cp_async(half* smem_ptr,
const half* global_ptr,
bool pred_guard = true)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
static_assert(SizeInBytes == 16, "Size is not supported");
unsigned smem_int_ptr = __cvta_generic_to_shared(smem_ptr);
asm volatile(
Expand All @@ -28,25 +29,43 @@ __device__ __forceinline__ void cp_async(half* smem_ptr,
"r"(smem_int_ptr),
"l"(global_ptr),
"n"(SizeInBytes));
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
#endif
}

/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block.
__device__ __forceinline__ void cp_async_group_commit()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.commit_group;\n" ::);
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
#endif
}

/// Blocks until all but <N> previous cp.async.commit_group operations have committed.
template <int N>
__device__ __forceinline__ void cp_async_wait_group()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_group %0;\n" ::"n"(N));
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
#endif
}

/// Blocks until all previous cp.async.commit_group operations have committed.
// cp.async.wait_all is equivalent to :
// cp.async.commit_group;
// cp.async.wait_group 0;
__device__ __forceinline__ void cp_async_wait_all() { asm volatile("cp.async.wait_all;\n" ::); }
__device__ __forceinline__ void cp_async_wait_all()
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile("cp.async.wait_all;\n" ::);
#else
#warning "The async copy functions are only supported on Ampere and newer architectures"
#endif
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef PTX_MMA_CUH
#define PTX_MMA_CUH
#ifndef DEEPSPEED_CUDA_LINEAR_PTX_MMA_CUH
#define DEEPSPEED_CUDA_LINEAR_PTX_MMA_CUH

#include <cuda.h>
#include <cuda_fp16.h>
Expand All @@ -22,6 +22,7 @@ __device__ __forceinline__ void B_FromSharedToReg(
half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
int slice_id)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#ifdef DEBUG_MODE
static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
(TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0));
Expand Down Expand Up @@ -54,6 +55,9 @@ __device__ __forceinline__ void B_FromSharedToReg(
smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
}
}
#else
#warning "The matrix load functions are only supported on Ampere and newer architectures"
#endif
}
#else
// Debug: Whether ldmatrix.trans is required???
Expand All @@ -64,6 +68,7 @@ __device__ __forceinline__ void B_FromSharedToReg(
half __restrict__ (*read_SPTR)[WARP_K + PADDING_SHARED_MEM_FOR_B_8],
int k_offset)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#ifdef DEBUG_MODE
static_assert((TilingConfig::WARP_COL_MMA_TENSORS == 1) ||
(TilingConfig::WARP_COL_MMA_TENSORS % 2 == 0));
Expand Down Expand Up @@ -96,13 +101,17 @@ __device__ __forceinline__ void B_FromSharedToReg(
smem_local_ptr += 16 * (WARP_K + PADDING_SHARED_MEM_FOR_B_8) * sizeof(half);
}
}
#else
#warning "The matrix load functions are only supported on Ampere and newer architectures"
#endif
}
#endif

__device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[],
uint32_t __restrict__* a,
uint32_t __restrict__* b)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm volatile(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{ %0, %1, %2, %3},"
Expand All @@ -120,6 +129,9 @@ __device__ __forceinline__ void MMA_FP16_M16N8K16(uint32_t __restrict__ c[],
"r"(c[1]),
"r"(c[2]),
"r"(c[3]));
#else
#warning "The mma functions are only implemented for Ampere and newer architectures"
#endif
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef UTILS_CORE_CUH
#define UTILS_CORE_CUH
#ifndef DEEPSPEED_CUDA_LINEAR_UTILS_CORE_CUH
#define DEEPSPEED_CUDA_LINEAR_UTILS_CORE_CUH

#include <assert.h>

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef UTILS_GMEM_CUH
#define UTILS_GMEM_CUH
#ifndef DEEPSPEED_CUDA_LINEAR_UTILS_GMEM_CUH
#define DEEPSPEED_CUDA_LINEAR_UTILS_GMEM_CUH

#include <assert.h>
#include "configs.h"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef UTILS_PARALLELDEQUANT_CUH
#define UTILS_PARALLELDEQUANT_CUH
#ifndef DEEPSPEED_CUDA_LINEAR_UTILS_PARALLELDEQUANT_CUH
#define DEEPSPEED_CUDA_LINEAR_UTILS_PARALLELDEQUANT_CUH

#include <cuda.h>
#include <cuda_fp16.h>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef DEEPSPEED_CUDA_LINEAR_WEIGHT_PREPACKING_H
#define DEEPSPEED_CUDA_LINEAR_WEIGHT_PREPACKING_H

#include <assert.h>
#include <stdio.h>
#include <vector>
Expand Down Expand Up @@ -202,3 +205,5 @@ void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K)
for (size_t i = 0; i < BytesPerThread_4bit * 32 / 4; i++)
BitInterleaving_4bit(Weight_4bit + 4 * i);
}

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <ATen/cuda/CUDAContext.h>

#include "cuda_linear_kernels.h"
#include "linear_kernels.h"

namespace {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@

// DeepSpeed Team

#pragma once
#ifndef DEEPSPEED_CUDA_LINEAR_KERNELS_H
#define DEEPSPEED_CUDA_LINEAR_KERNELS_H

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include "ds_kernel_utils.h"

#include "fp6_linear.cuh"
#include "linear_kernels_cuda.h"

void cuda_wf6af16_linear(torch::Tensor& output,
torch::Tensor& hidden_states,
Expand All @@ -23,3 +24,5 @@ void cuda_wf6af16_linear(torch::Tensor& output,
int split_k);

std::vector<torch::Tensor> preprocess_weight(torch::Tensor& Weight);

#endif
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
#include <assert.h>
#include <stdio.h>

#include "linear_kernels_cuda.h"

template <typename TilingConfig, typename OutputDataType>
static void Kernel_Ex(cudaStream_t stream,
const uint4* Weight1,
Expand Down Expand Up @@ -50,7 +52,7 @@ static void Kernel_Ex(cudaStream_t stream,
size_t dimM = M_Global * Split_K / TilingConfig::TILE_M;
dim3 GridDim(dimN, dimM, 1);
dim3 BlockDim(WARP_SIZE * TilingConfig::BLOCK_WARPS, 1, 1);
//

#ifdef DEBUG_MODE
printf(
"GridDim.x: %d, GridDim.y: %d, GridDim.z: %d, BlockDim.x: %d, BlockDim.y: %d, BlockDim.z: "
Expand All @@ -64,6 +66,7 @@ static void Kernel_Ex(cudaStream_t stream,
SHMEM_SZ);
printf("\n");
#endif

QUANT_GEMM_Kernel<TilingConfig, OutputDataType><<<GridDim, BlockDim, SHMEM_SZ, stream>>>(
Weight1, Weight2, Scales, B, C, M_Global, N_Global, K_Global, Split_K);
}
Expand Down Expand Up @@ -121,7 +124,7 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream,
break;
default:
if (N_PowerOf2 % 128 != 0) {
printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
printf("QuantLLM_API Error: Unsupported N dimension %lu!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, half>(
Expand Down Expand Up @@ -192,7 +195,7 @@ cudaError_t fp6_linear_kernel(cudaStream_t stream,
break;
default:
if (N_PowerOf2 % 128 != 0) {
printf("QuantLLM_API Error: Unsupported N dimension %d!\n", N_PowerOf2);
printf("QuantLLM_API Error: Unsupported N dimension %lu!\n", N_PowerOf2);
return cudaErrorUnknown;
}
Kernel_Ex<TilingConfig<4, 1, 8>, float>(stream,
Expand Down Expand Up @@ -236,7 +239,7 @@ calling our CUDA kernel.
torch::Tensor fp6_linear_forward_cuda(torch::Tensor _in_feats,
torch::Tensor _weights,
torch::Tensor _scales,
int splitK = 1)
int splitK)
{
int num_in_feats = _in_feats.size(0);
int num_in_channels = _in_feats.size(1);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@

// This is a copy of FP6-LLM kernel code: https://arxiv.org/abs/2401.14112

#ifndef DEEPSPEED_CUDA_LINEAR_FP6_LINEAR_CUH
#define DEEPSPEED_CUDA_LINEAR_FP6_LINEAR_CUH

#include <cuda.h>
#include <cuda_fp16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -44,3 +47,5 @@ void weight_matrix_prepacking(int* FP6Weights, size_t M, size_t K);
* Weight prepacking (Pytorch interface).
*/
torch::Tensor weight_matrix_prepacking_cpu(torch::Tensor fp6_tensor, size_t M, size_t K);

#endif
11 changes: 10 additions & 1 deletion deepspeed/inference/v2/modules/heuristics.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,17 @@ def instantiate_linear(linear_config: DSLinearConfig, engine_config: RaggedInfer
if quantization_mode is None:
config = ConfigBundle(name="blas_fp_linear", config=linear_config)
else:
# Currently, we only support ``quantized_wf6af16_linear``.
# Currently, we only support ``quantized_wf6af16_linear`` on NVIDIA Ampere GPUs.
if quantization_mode == "wf6af16":
import torch
if not torch.cuda.is_available(): #ignore-cuda
raise ValueError("WF6AF16 quantization is only supported on CUDA")
else:
is_rocm_pytorch = hasattr(torch.version, 'hip') and torch.version.hip is not None
if is_rocm_pytorch:
raise ValueError("WF6AF16 quantization is only supported on NVIDIA GPUs")
elif torch.cuda.get_device_properties(0).major != 8: #ignore-cuda
raise ValueError("WF6AF16 quantization is only supported on Ampere architectures")
config = ConfigBundle(name="quantized_wf6af16_linear", config=linear_config)
else:
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
Expand Down
13 changes: 2 additions & 11 deletions op_builder/inference_core_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def get_prefix(self):
return "deepspeed" if os.path.isdir(ds_path) else ".."

def sources(self):
import torch

sources = [
"inference/v2/kernels/core_ops/core_ops.cpp",
"inference/v2/kernels/core_ops/bias_activations/bias_activation.cpp",
Expand All @@ -69,17 +67,10 @@ def sources(self):
"inference/v2/kernels/core_ops/cuda_rms_norm/rms_norm_cuda.cu",
"inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels.cpp",
"inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu",
"inference/v2/kernels/core_ops/cuda_linear/linear_kernels.cpp",
"inference/v2/kernels/core_ops/cuda_linear/linear_kernels_cuda.cu",
]

# The source files with specific GPU architecture requirements.
if not self.is_rocm_pytorch() and torch.cuda.is_available(): #ignore-cuda
cuda_capability = torch.cuda.get_device_properties(0).major #ignore-cuda
if cuda_capability != 8:
self.warning("FP6 quantization kernel is only supported on Ampere architectures")
else:
sources.append("inference/v2/kernels/core_ops/cuda_linear/fp6_linear.cu")
sources.append("inference/v2/kernels/core_ops/cuda_linear/cuda_linear_kernels.cpp")

prefix = self.get_prefix()
sources = [os.path.join(prefix, src) for src in sources]
return sources
Expand Down
Loading

0 comments on commit c5e3a5e

Please sign in to comment.