Skip to content

Commit

Permalink
added half function declarations to qnbitgemm.h
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Nov 14, 2024
1 parent 051daf2 commit 53e5fcf
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 23 deletions.
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/qnbitgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ MlasIsQNBitGemmAvailable(
switch (Variant) {
case SQNBitGemmVariant_BitWidth4_CompFp32: {
return Dispatch->SQ4BitGemmM1Kernel_CompFp32 != nullptr &&
Dispatch->Q4BitBlkDequantBForSgemm_CompFp32 != nullptr;
Dispatch->SQ4BitBlkDequantBForSgemm_CompFp32 != nullptr;
}
case SQNBitGemmVariant_BitWidth4_CompInt8: { // SQ4BitGemmKernel_BlkSum_CompInt8
return
Expand Down Expand Up @@ -387,7 +387,7 @@ SQ4BitGemm_CompFp32(
float* c_blk = C + n;
const float* bias = (Bias == nullptr) ? nullptr : Bias + n;

GetMlasPlatform().QNBitGemmDispatch->Q4BitBlkDequantBForSgemm_CompFp32(
GetMlasPlatform().QNBitGemmDispatch->SQ4BitBlkDequantBForSgemm_CompFp32(
BlkLen,
dequant_b, b_col, b_col_scale, b_col_zp, CountN, K, k_blks
);
Expand Down
63 changes: 49 additions & 14 deletions onnxruntime/core/mlas/lib/qnbitgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,17 +91,17 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
//

/** Gets size of packed quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBDataSize(). */
typedef size_t(SQ4BitGemmPackQuantBDataSize_Fn)(
typedef size_t(Q4BitGemmPackQuantBDataSize_Fn)(
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;
Q4BitGemmPackQuantBDataSize_Fn* Q4BitGemmPackQuantBDataSize = nullptr;

/** Packs quantized B data containing 4-bit integers. See MlasQNBitGemmPackQuantBData(). */
typedef void(SQ4BitGemmPackQuantBData_Fn)(
typedef void(Q4BitGemmPackQuantBData_Fn)(
size_t N,
size_t K,
size_t BlkLen,
Expand All @@ -111,7 +111,8 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
MLAS_THREADPOOL* ThreadPool
);

SQ4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* SQ4BitGemmPackQuantBData = nullptr;
Q4BitGemmPackQuantBData_Fn* HQ4BitGemmPackQuantBData = nullptr;

typedef void(SQ4BitGemmPackQuantBDataAndSumBlk_Fn)(
size_t N,
Expand Down Expand Up @@ -142,28 +143,28 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceSize_Fn)(
typedef size_t(Q4BitGemmPerGemmWorkspaceSize_Fn)(
size_t M,
size_t N,
size_t K,
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;
Q4BitGemmPerGemmWorkspaceSize_Fn* Q4BitGemmPerGemmWorkspaceSize = nullptr;

/**
* @brief Gets the required byte alignment of the per-GEMM intermediate workspace.
*
* @param[in] BlkLen number of quantized values per block
* @param[in] ComputeType GEMM compute type (e.g., multiplying float or int8 values)
*/
typedef size_t(SQ4BitGemmPerGemmWorkspaceAlignment_Fn)(
typedef size_t(Q4BitGemmPerGemmWorkspaceAlignment_Fn)(
size_t BlkLen,
MLAS_QNBIT_GEMM_COMPUTE_TYPE ComputeType
);

SQ4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;
Q4BitGemmPerGemmWorkspaceAlignment_Fn* Q4BitGemmPerGemmWorkspaceAlignment = nullptr;

//
// SQNBIT_CompFp32 kernel function prototypes.
Expand Down Expand Up @@ -205,8 +206,9 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
* B is a quantized 4-bit integer matrix that is block quantized and column major.
* This is equivalent to dequantizing B and then running MlasSgemmCopyPackB.
*
* @tparam T type of input A
* @param BlkLen Number of values in a block.
* @param[out] FpData Supplies the output buffer for the dequantized B float data.
* @param[out] FpData Supplies the output buffer for the dequantized B data in type T.
* It should have enough space for
* (CountN + 16 - 1) / 16 * 16 * (CountK + BlkLen - 1) / BlkLen * BlkLen
* elements. Only the first (CountN + 16 - 1) / 16 * 16 * CountK elements are
Expand All @@ -218,18 +220,20 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
* @param CountK Number of rows of B.
* @param BlockStrideQuantB Number of blocks between adjacent columns of the quantized B matrix.
*/
typedef void(Q4BitBlkDequantBForSgemm_CompFp32_Fn)(
template<typename T>
using Q4BitBlkDequantBForGemm_Fn = std::function<void(
size_t BlkLen,
float* FpData,
T* FpData,
const std::byte* QuantBData,
const float* QuantBScale,
const T* QuantBScale,
const std::byte* QuantBZeroPoint,
size_t CountN,
size_t CountK,
size_t BlockStrideQuantB
);
)>;

Q4BitBlkDequantBForSgemm_CompFp32_Fn* Q4BitBlkDequantBForSgemm_CompFp32 = nullptr;
Q4BitBlkDequantBForGemm_Fn<float> SQ4BitBlkDequantBForSgemm_CompFp32 = nullptr;
Q4BitBlkDequantBForGemm_Fn<MLAS_FP16> HQ4BitBlkDequantBForHgemm_CompFp16 = nullptr;

//
// SQNBIT_CompInt8 kernel function prototypes.
Expand Down Expand Up @@ -338,4 +342,35 @@ struct MLAS_QNBIT_GEMM_DISPATCH {
float* AScaledGroupSum // scale_k * Sum_blklen(a_i)
);
QuantizeARowComputeBlkSum_CompInt8_Fn* QuantizeARowComputeBlkSum_CompInt8 = nullptr;

/**
* @brief Multiply fp16 matrix A rows with fp16 matrix B columns.
* Results are written to fp16 matrix C.
* If bias is provided, the bias are added to the result.
*
* @param A first row of the A matrix segment. Row major.
* @param B first column of the B matrix segment. Column major.
* @param Bias the bias at the target column. Optional.
* @param[out] C first element of the output matrix segment. Row major.
* @param CountM the number of rows of A chunk.
* @param CountN the number of columns of B chunk.
* @param K the number of columns of A matrix and rows of B matrix.
* @param lda the leading dimension of A.
* @param ldb the leading dimension of B.
* @param ldc the leading dimension of C.
*/
using HQ4BitGemmKernel_CompFp16_Fn = std::function<void(
const MLAS_FP16* A,
const MLAS_FP16* B,
const MLAS_FP16* Bias,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t K,
size_t lda,
size_t ldb,
size_t ldc
)>;

HQ4BitGemmKernel_CompFp16_Fn HQ4BitGemmKernel_CompFp16 = nullptr;
};
4 changes: 2 additions & 2 deletions onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1341,7 +1341,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2 = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2;
Expand All @@ -1360,7 +1360,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx2vnni = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx2;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx2vnni;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx2;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512 = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32_avx512;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_avx512vnni.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = Q4BitBlkDequantBForSgemm_CompFp32_avx2;

d.SQ4BitGemmKernel_BlkSum_CompInt8 = SQ4BitGemmKernel_BlkSum_CompInt8_avx512vnni;
d.QuantizeARowComputeBlkSum_CompInt8 = QuantizeARow_CompInt8_avx512;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchNeon = []() {
d.Q4BitGemmPerGemmWorkspaceAlignment = sqnbitgemm_neon::Q4BitGemmPerGemmWorkspaceAlignment;

d.SQ4BitGemmM1Kernel_CompFp32 = sqnbitgemm_neon::SQ4BitGemmM1Kernel_CompFp32;
d.Q4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::Q4BitBlkDequantBForSgemm_CompFp32;
d.SQ4BitBlkDequantBForSgemm_CompFp32 = sqnbitgemm_neon::SQ4BitBlkDequantBForSgemm_CompFp32;

d.SQ4BitGemmKernel_CompInt8 = sqnbitgemm_neon::SQ4BitGemmKernel_CompInt8;
d.QuantizeARow_CompInt8 = sqnbitgemm_neon::QuantizeARow_CompInt8;
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ SQ4BitGemmM1Kernel_CompFp32(
);

void
Q4BitBlkDequantBForSgemm_CompFp32(
SQ4BitBlkDequantBForSgemm_CompFp32(
size_t BlkLen,
float* FpData,
const std::byte* QuantBData,
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/mlas/lib/sqnbitgemm_kernel_neon_fp32.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ Q4BitBlkDequantBForSgemm_CompFp32_Impl(
} // namespace

void
Q4BitBlkDequantBForSgemm_CompFp32(
SQ4BitBlkDequantBForSgemm_CompFp32(
size_t BlkLen,
float* FpData,
const std::byte* QuantBData,
Expand Down

0 comments on commit 53e5fcf

Please sign in to comment.