forked from tohtana/DeepSpeed
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
FP [6,8,12] quantizer op (microsoft#5336)
Flexible-bit quantizer-dequantizer library with fp6/fp12/fp8 support Requires Ampere+ architecture, this is due to the initial focus of this op only on `bfloat16` input types. Co-authored-by: Reza Yazdani <[email protected]>
- Loading branch information
Showing
12 changed files
with
1,014 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#pragma once | ||
|
||
#include <ATen/cuda/CUDAContext.h> | ||
#include <cuda_runtime_api.h> | ||
#include <cassert> | ||
#include <iostream> | ||
#include <vector> | ||
#include "cublas_v2.h" | ||
#include "cuda.h" | ||
#include "curand.h" | ||
|
||
#include <cuda.h> | ||
#include <cuda_runtime_api.h> | ||
#include <stdlib.h> | ||
#include <sys/time.h> | ||
#include <map> | ||
#include <memory> | ||
#include <stack> | ||
#include <string> | ||
#define WARP_SIZE 32 | ||
|
||
class FPContext { | ||
public: | ||
FPContext() : _seed(42) | ||
{ | ||
curandCreateGenerator(&_gen, CURAND_RNG_PSEUDO_DEFAULT); | ||
curandSetPseudoRandomGeneratorSeed(_gen, 123); | ||
} | ||
|
||
virtual ~FPContext() {} | ||
|
||
static FPContext& Instance() | ||
{ | ||
static FPContext _ctx; | ||
return _ctx; | ||
} | ||
|
||
curandGenerator_t& GetRandGenerator() { return _gen; } | ||
|
||
cudaStream_t GetCurrentStream() | ||
{ | ||
// get current pytorch stream. | ||
cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
return stream; | ||
} | ||
|
||
std::pair<uint64_t, uint64_t> IncrementOffset(uint64_t offset_inc) | ||
{ | ||
uint64_t offset = _curr_offset; | ||
_curr_offset += offset_inc; | ||
return std::pair<uint64_t, uint64_t>(_seed, offset); | ||
} | ||
|
||
void SetSeed(uint64_t new_seed) { _seed = new_seed; } | ||
|
||
private: | ||
curandGenerator_t _gen; | ||
cublasHandle_t _cublasHandle; | ||
uint64_t _seed; | ||
uint64_t _curr_offset; | ||
}; |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#pragma once | ||
|
||
#include <cuda.h> | ||
#include <stdint.h> | ||
|
||
#include <cuda_fp16.h> | ||
|
||
#include <cuda_bf16.h> | ||
#include <cuda_runtime_api.h> | ||
#include <stdio.h> | ||
|
||
#define QUANT_SWITCH(Q_BITS, ...) \ | ||
[&] { \ | ||
if (12 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ | ||
constexpr int CONST_Q_BITS = 8; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 3; \ | ||
__VA_ARGS__(); \ | ||
} else if (13 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ | ||
constexpr int CONST_Q_BITS = 8; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 3; \ | ||
__VA_ARGS__(); \ | ||
} else if (10 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ | ||
constexpr int CONST_Q_BITS = 8; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 2; \ | ||
__VA_ARGS__(); \ | ||
} else if (11 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ | ||
constexpr int CONST_Q_BITS = 8; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 2; \ | ||
__VA_ARGS__(); \ | ||
} else if (28 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ | ||
constexpr int CONST_Q_BITS = 12; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 7; \ | ||
__VA_ARGS__(); \ | ||
} else if (29 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ | ||
constexpr int CONST_Q_BITS = 12; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 7; \ | ||
__VA_ARGS__(); \ | ||
} else if (6 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ | ||
constexpr int CONST_Q_BITS = 6; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 2; \ | ||
__VA_ARGS__(); \ | ||
} else if (7 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ | ||
constexpr int CONST_Q_BITS = 6; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 2; \ | ||
__VA_ARGS__(); \ | ||
} else if (2 == Q_BITS) { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 0; \ | ||
constexpr int CONST_Q_BITS = 4; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 1; \ | ||
__VA_ARGS__(); \ | ||
} else { \ | ||
constexpr int CONST_STOCHASTIC_ROUNDING = 1; \ | ||
constexpr int CONST_Q_BITS = 4; \ | ||
constexpr int CONST_Q_MANTISA_BITS = 1; \ | ||
__VA_ARGS__(); \ | ||
} \ | ||
}() | ||
|
||
#define DEQUANT_SWITCH(Q_MANTISA_EXPONENT_BITS, ...) \ | ||
[&] { \ | ||
if (12 == Q_MANTISA_EXPONENT_BITS) { \ | ||
constexpr int CONST_Q_MANTISA_BITS = 3; \ | ||
constexpr int CONST_Q_EXPONENT_BITS = 4; \ | ||
__VA_ARGS__(); \ | ||
} else if (10 == Q_MANTISA_EXPONENT_BITS) { \ | ||
constexpr int CONST_Q_MANTISA_BITS = 2; \ | ||
constexpr int CONST_Q_EXPONENT_BITS = 5; \ | ||
__VA_ARGS__(); \ | ||
} else if (28 == Q_MANTISA_EXPONENT_BITS) { \ | ||
constexpr int CONST_Q_MANTISA_BITS = 7; \ | ||
constexpr int CONST_Q_EXPONENT_BITS = 4; \ | ||
__VA_ARGS__(); \ | ||
} else if (6 == Q_MANTISA_EXPONENT_BITS) { \ | ||
constexpr int CONST_Q_MANTISA_BITS = 2; \ | ||
constexpr int CONST_Q_EXPONENT_BITS = 3; \ | ||
__VA_ARGS__(); \ | ||
} else { \ | ||
constexpr int CONST_Q_MANTISA_BITS = 1; \ | ||
constexpr int CONST_Q_EXPONENT_BITS = 2; \ | ||
__VA_ARGS__(); \ | ||
} \ | ||
}() | ||
|
||
template <typename T, int mantisa, int exponent> | ||
void launch_quantization(T* val, | ||
uint8_t* q_val, | ||
int num_groups, | ||
int group_size, | ||
cudaStream_t stream, | ||
float q_range, | ||
int q_bits, | ||
int q_mantisa_bits, | ||
int stochastic_rounding); | ||
|
||
template <typename T, int mantisa> | ||
void launch_dequantization(uint8_t* val, | ||
T* q_val, | ||
int num_groups, | ||
int group_size, | ||
int q_mantisa_bits, | ||
int q_exponent_bits, | ||
cudaStream_t stream); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
// Copyright (c) Microsoft Corporation. | ||
// SPDX-License-Identifier: Apache-2.0 | ||
|
||
// DeepSpeed Team | ||
|
||
#include "quantize.h" | ||
|
||
#include <c10/cuda/CUDAStream.h> | ||
#include <torch/extension.h> | ||
#include <vector> | ||
|
||
#define DISPATCH_QUANTIZE(T_TYPE, C_TYPE, mantisa, exponent) \ | ||
if (val.options().dtype() == torch::T_TYPE) { \ | ||
launch_quantization<C_TYPE, mantisa, exponent>((C_TYPE*)val.data_ptr(), \ | ||
(uint8_t*)out.data_ptr(), \ | ||
num_groups, \ | ||
group_size, \ | ||
at::cuda::getCurrentCUDAStream(), \ | ||
q_range, \ | ||
q_bits, \ | ||
q_mantisa_bits, \ | ||
stochastic_rounding); \ | ||
} | ||
|
||
at::Tensor quantize(torch::Tensor& val, | ||
int group_size, | ||
int stochastic_rounding, | ||
int q_bits, | ||
int q_mantisa_bits) | ||
{ | ||
int total_elems = at::numel(val); | ||
auto options = at::TensorOptions() | ||
.dtype(torch::kInt8) | ||
.layout(val.layout()) | ||
.device(val.device()) | ||
.requires_grad(false); | ||
float q_range = q_bits == 8 ? (q_mantisa_bits == 3 ? 480.0 : 114688.0) : // fp8 ranges | ||
(q_bits == 12 ? 510.0 : // fp12 range | ||
(q_bits == 6 ? 28.0 : // fp6 range | ||
6.0)); // fp4 range (using power 2); TODO (Reza): add the power-4 | ||
// in case accuracy is not matching! | ||
int num_groups = total_elems / group_size; | ||
auto out = torch::empty({num_groups, group_size * q_bits / 8 + 4}, options); | ||
|
||
DISPATCH_QUANTIZE(kHalf, __half, 23, 8); | ||
#ifdef BF16_AVAILABLE | ||
DISPATCH_QUANTIZE(kBFloat16, __nv_bfloat16, 23, 8); | ||
#endif | ||
|
||
return out; | ||
} | ||
|
||
#define DISPATCH_DEQUANTIZE(T_TYPE, C_TYPE, mantisa) \ | ||
if (val.options().dtype() == torch::T_TYPE) { \ | ||
launch_dequantization<C_TYPE, mantisa>((uint8_t*)val_q.data_ptr(), \ | ||
(C_TYPE*)val.data_ptr(), \ | ||
num_groups, \ | ||
group_size, \ | ||
q_mantisa_bits, \ | ||
q_exponent_bits, \ | ||
at::cuda::getCurrentCUDAStream()); \ | ||
return; \ | ||
} | ||
|
||
void dequantize(torch::Tensor& val, | ||
torch::Tensor& val_q, | ||
int group_size, | ||
int q_mantisa_bits, | ||
int q_exponent_bits) | ||
{ | ||
int total_elems = at::numel(val); | ||
|
||
int num_groups = total_elems / group_size; | ||
|
||
DISPATCH_DEQUANTIZE(kHalf, __half, 10); | ||
#ifdef BF16_AVAILABLE | ||
DISPATCH_DEQUANTIZE(kBFloat16, __nv_bfloat16, 7); | ||
#endif | ||
} | ||
|
||
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) | ||
{ | ||
m.def("quantize", &quantize, "quantize function"); | ||
m.def("dequantize", &dequantize, "dequantize function"); | ||
} |
Oops, something went wrong.