Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
fajin-corp committed Dec 11, 2024
1 parent bf4d3e1 commit 8716319
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 0 deletions.
2 changes: 2 additions & 0 deletions onnxruntime/contrib_ops/cpu/bert/gqa_attention_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,8 @@ class GQAAttentionBase {
math::GemmEx<float, ThreadPool>(CblasNoTrans, CblasTrans, sequence_length, total_seqlen, head_size, alpha, q,
static_cast<int>(head_size), k, static_cast<int>(head_size), 0.0f /*bata*/,
output, static_cast<int>(present_buffer_sequence_length), nullptr);
} else if (GetMlasPlatform().HasFP16Support()) {
// TODO: if kernel available, call MlasHGemmEx
} else {
size_t bytes = head_size * (sequence_length + total_seqlen) * sizeof(float);
auto q_k_fp32 = allocator->Alloc(bytes);
Expand Down
50 changes: 50 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1458,6 +1458,56 @@ MlasRotaryEmbedOneRow(
T* output
);

/**
* @brief Check whether current CPU supports half precision gemm.
*/
bool
MLASCALL
MlasHGemmSupported(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB);

/**
* @brief half precision matrix/matrix multiply operation (HGEMM)
* C = alpha * op(A) * op(B) + beta * C
*
* @param TransA Supplies the transpose operation for matrix A. Currently only support CblasNoTrans.
* @param TransB Supplies the transpose operation for matrix B. Currently only support CblasTrans.
* @param M Supplies the number of rows of matrix A and matrix C.
* @param N Supplies the number of columns of matrix B and matrix C.
* @param K Supplies the number of columns of matrix A and the number of rows of matrix B.
* @param A Supplies the address of matrix A
* @param lda Supplies the first dimension of matrix A.
* @param B Supplies the address of matrix B
* @param ldb Supplies the first dimension of matrix B.
* @param C Supplies the address of matrix C
* @param ldc Supplies the first dimension of matrix C.
* @param alpha Supplies the scalar alpha multiplier (see GEMM definition)
* @param beta Supplies the scalar beta multiplier (see GEMM definition)
* @param ThreadPool Supplies the thread pool object to use, else nullptr if the base library threading support
* should be used.
*/
void
MLASCALL
MlasGemm(
CBLAS_TRANSPOSE TransA,
CBLAS_TRANSPOSE TransB,
size_t M,
size_t N,
size_t K,
const MLAS_FP16* A,
size_t lda,
const MLAS_FP16* B,
size_t ldb,
MLAS_FP16* C,
size_t ldc,
MLAS_FP16 alpha,
MLAS_FP16 beta,
MLAS_THREADPOOL* ThreadPool
) {
// TODO: call MlasGemmBatch for hgemm
}

/**
* @brief Whether current CPU supports FP16 acceleration.
*/
Expand Down
33 changes: 33 additions & 0 deletions onnxruntime/core/mlas/lib/halfgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,36 @@ MlasHalfGemmGetDispatch()
return &MlasHalfGemmDispatchDefault;
#endif
}

struct MLAS_HGEMM_DISPATCH {
/**
* @brief C = alpha * A * Transpose(B) + beta * C
*
* @param A first row of the A matrix segment. Row major.
* @param B first column of the B matrix segment. Column major.
* @param[out] C first element of the output matrix segment. Row major.
* @param CountM the number of rows of A chunk.
* @param CountN the number of columns of B chunk.
* @param CountK the number of columns of A chunk and the number of rows of B chunk.
* @param lda the leading dimension of A.
* @param ldb the leading dimension of B.
* @param ldc the leading dimension of C.
* @param alpha the alpha scalar value.
* @param beta the beta scalar value.
*/
typedef void(HGemmKernel_TransposeB_Fn)(
const MLAS_FP16* A,
const MLAS_FP16* B,
MLAS_FP16* C,
size_t CountM,
size_t CountN,
size_t CountK,
size_t lda,
size_t ldb,
size_t ldc,
MLAS_FP16 alpha,
MLAS_FP16 beta
);

HGemmKernel_TransposeB_Fn* HGemmKernel_TransposeB = nullptr;
};
7 changes: 7 additions & 0 deletions onnxruntime/core/mlas/lib/mlasi.h
Original file line number Diff line number Diff line change
Expand Up @@ -1055,6 +1055,12 @@ extern const MLAS_QNBIT_GEMM_DISPATCH MlasSQNBitGemmDispatchAvx512vnni;
struct MLAS_ROPE_DISPATCH;
extern const MLAS_ROPE_DISPATCH MlasRopeDispatchNeon;

//
// half gemm dispatch structure
//
struct MLAS_HGEMM_DISPATCH;
extern const MLAS_HGEMM_DISPATCH MlasHgemmDispatchNeon;


//
// Quantized depthwise convolution kernels.
Expand Down Expand Up @@ -1217,6 +1223,7 @@ struct MLAS_PLATFORM {
MLAS_CAST_F32_TO_F16_KERNEL* CastF32ToF16Kernel;

const MLAS_ROPE_DISPATCH* RopeDispatch{nullptr};
const MLAS_HGEMM_DISPATCH* HGemmDIspatch{nullptr};
};

inline
Expand Down

0 comments on commit 8716319

Please sign in to comment.