Skip to content

Commit

Permalink
FP [6,8,12] quantizer op (microsoft#5336)
Browse files Browse the repository at this point in the history
Flexible-bit quantizer-dequantizer library with fp6/fp12/fp8 support

Requires Ampere+ architecture, this is due to the initial focus of this
op only on `bfloat16` input types.

Co-authored-by: Reza Yazdani <[email protected]>
  • Loading branch information
2 people authored and umchand committed May 20, 2024
1 parent f6e9adb commit 6dcb50c
Show file tree
Hide file tree
Showing 12 changed files with 1,014 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/nv-pre-compile-ops.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
#python -c "import torch; print('CUDA available:', torch.cuda.is_available())"
- name: Compile DeepSpeed Ops
run: |
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
DS_ACCELERATOR=cuda DS_ENABLE_NINJA=1 TORCH_CUDA_ARCH_LIST="7.0;7.5;8.0" DS_BUILD_OPS=1 DS_BUILD_SPARSE_ATTN=0 DS_BUILD_FP_QUANTIZER=0 DS_BUILD_CUTLASS_OPS=0 DS_BUILD_RAGGED_DEVICE_OPS=0 DS_BUILD_EVOFORMER_ATTN=0 pip3 install .
- name: DS Report
run: |
ds_report
66 changes: 66 additions & 0 deletions csrc/fp_quantizer/includes/context.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#pragma once

#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime_api.h>
#include <cassert>
#include <iostream>
#include <vector>
#include "cublas_v2.h"
#include "cuda.h"
#include "curand.h"

#include <cuda.h>
#include <cuda_runtime_api.h>
#include <stdlib.h>
#include <sys/time.h>
#include <map>
#include <memory>
#include <stack>
#include <string>
#define WARP_SIZE 32

class FPContext {
public:
FPContext() : _seed(42)
{
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT);
curandSetPseudoRandomGeneratorSeed(_gen, 123);
}

virtual ~FPContext() {}

static FPContext& Instance()
{
static FPContext _ctx;
return _ctx;
}

curandGenerator_t& GetRandGenerator() { return _gen; }

cudaStream_t GetCurrentStream()
{
// get current pytorch stream.
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
return stream;
}

std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc)
{
uint64_t offset = _curr_offset;
_curr_offset += offset_inc;
return std::pair<uint64_t, uint64_t>(_seed, offset);
}

void SetSeed(uint64_t new_seed) { _seed = new_seed; }

private:
curandGenerator_t _gen;
cublasHandle_t _cublasHandle;
uint64_t _seed;
uint64_t _curr_offset;
};
115 changes: 115 additions & 0 deletions csrc/fp_quantizer/includes/quantize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#pragma once

#include <cuda.h>
#include <stdint.h>

#include <cuda_fp16.h>

#include <cuda_bf16.h>
#include <cuda_runtime_api.h>
#include <stdio.h>

#define QUANT_SWITCH(Q_BITS, ...) \
[&] { \
if (12 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 3; \
__VA_ARGS__(); \
} else if (13 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 3; \
__VA_ARGS__(); \
} else if (10 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (11 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 8; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (28 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 12; \
constexpr int CONST_Q_MANTISA_BITS = 7; \
__VA_ARGS__(); \
} else if (29 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 12; \
constexpr int CONST_Q_MANTISA_BITS = 7; \
__VA_ARGS__(); \
} else if (6 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 6; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (7 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 6; \
constexpr int CONST_Q_MANTISA_BITS = 2; \
__VA_ARGS__(); \
} else if (2 == Q_BITS) { \
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \
constexpr int CONST_Q_BITS = 4; \
constexpr int CONST_Q_MANTISA_BITS = 1; \
__VA_ARGS__(); \
} else { \
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \
constexpr int CONST_Q_BITS = 4; \
constexpr int CONST_Q_MANTISA_BITS = 1; \
__VA_ARGS__(); \
} \
}()

#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \
[&] { \
if (12 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 3; \
constexpr int CONST_Q_EXPONENT_BITS = 4; \
__VA_ARGS__(); \
} else if (10 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 2; \
constexpr int CONST_Q_EXPONENT_BITS = 5; \
__VA_ARGS__(); \
} else if (28 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 7; \
constexpr int CONST_Q_EXPONENT_BITS = 4; \
__VA_ARGS__(); \
} else if (6 == Q_MANTISA_EXPONENT_BITS) { \
constexpr int CONST_Q_MANTISA_BITS = 2; \
constexpr int CONST_Q_EXPONENT_BITS = 3; \
__VA_ARGS__(); \
} else { \
constexpr int CONST_Q_MANTISA_BITS = 1; \
constexpr int CONST_Q_EXPONENT_BITS = 2; \
__VA_ARGS__(); \
} \
}()

template <typename T, int mantisa, int exponent>
void launch_quantization(T* val,
uint8_t* q_val,
int num_groups,
int group_size,
cudaStream_t stream,
float q_range,
int q_bits,
int q_mantisa_bits,
int stochastic_rounding);

template <typename T, int mantisa>
void launch_dequantization(uint8_t* val,
T* q_val,
int num_groups,
int group_size,
int q_mantisa_bits,
int q_exponent_bits,
cudaStream_t stream);
85 changes: 85 additions & 0 deletions csrc/fp_quantizer/quantize.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
// Copyright (c) Microsoft Corporation.
// SPDX-License-Identifier: Apache-2.0

// DeepSpeed Team

#include "quantize.h"

#include <c10/cuda/CUDAStream.h>
#include <torch/extension.h>
#include <vector>

#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \
(uint8_t*)out.data_ptr(), \
num_groups, \
group_size, \
at::cuda::getCurrentCUDAStream(), \
q_range, \
q_bits, \
q_mantisa_bits, \
stochastic_rounding); \
}

at::Tensor quantize(torch::Tensor& val,
int group_size,
int stochastic_rounding,
int q_bits,
int q_mantisa_bits)
{
int total_elems = at::numel(val);
auto options = at::TensorOptions()
.dtype(torch::kInt8)
.layout(val.layout())
.device(val.device())
.requires_grad(false);
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges
(q_bits == 12 ? 510.0 : // fp12 range
(q_bits == 6 ? 28.0 : // fp6 range
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4
// in case accuracy is not matching!
int num_groups = total_elems / group_size;
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options);

DISPATCH_QUANTIZE(kHalf, __half, 23, 8);
#ifdef BF16_AVAILABLE
DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8);
#endif

return out;
}

#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \
if (val.options().dtype() == torch::T_TYPE) { \
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \
(C_TYPE*)val.data_ptr(), \
num_groups, \
group_size, \
q_mantisa_bits, \
q_exponent_bits, \
at::cuda::getCurrentCUDAStream()); \
return; \
}

void dequantize(torch::Tensor& val,
torch::Tensor& val_q,
int group_size,
int q_mantisa_bits,
int q_exponent_bits)
{
int total_elems = at::numel(val);

int num_groups = total_elems / group_size;

DISPATCH_DEQUANTIZE(kHalf, __half, 10);
#ifdef BF16_AVAILABLE
DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7);
#endif
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("quantize", &quantize, "quantize function");
m.def("dequantize", &dequantize, "dequantize function");
}
Loading

0 comments on commit 6dcb50c

Please sign in to comment.