Skip to content

Commit

Permalink
optimize qlinearsoftmax
Browse files Browse the repository at this point in the history
  • Loading branch information
wejoncy committed Sep 26, 2024
1 parent 7880342 commit 59a0bb8
Show file tree
Hide file tree
Showing 9 changed files with 910 additions and 6 deletions.
8 changes: 8 additions & 0 deletions cmake/onnxruntime_mlas.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,9 @@ function(setup_mlas_source_for_windows)
${mlas_platform_srcs_avx2}
${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse.cpp
${MLAS_SRC_DIR}/qgemm_kernel_sse41.cpp
${MLAS_SRC_DIR}/intrinsics/avx512/quantize_avx512f.cpp
Expand Down Expand Up @@ -616,6 +619,9 @@ endif()
${MLAS_SRC_DIR}/dgemm.cpp
${MLAS_SRC_DIR}/pooling_fp16.cpp
${MLAS_SRC_DIR}/qgemm_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp
${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp
${mlas_platform_srcs_sse2}
${mlas_platform_srcs_avx}
${mlas_platform_srcs_avx2}
Expand All @@ -630,6 +636,7 @@ endif()
${MLAS_SRC_DIR}/q4gemm_avx512.cpp
)
set_source_files_properties(${MLAS_SRC_DIR}/q4gemm_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/qsoftmax_kernel_avx512.cpp PROPERTIES COMPILE_FLAGS "-mfma -mavx512vnni -mavx512bw -mavx512dq -mavx512vl -mavx512f")
endif()
if(NOT APPLE)
set(mlas_platform_srcs
Expand All @@ -640,6 +647,7 @@ endif()
)
set_source_files_properties(${MLAS_SRC_DIR}/qgemm_kernel_amx.cpp PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/x86_64/QgemmU8S8KernelAmx.S PROPERTIES COMPILE_FLAGS "-mavx2 -mavx512bw -mavx512dq -mavx512vl -mavx512f")
set_source_files_properties(${MLAS_SRC_DIR}/qsoftmax_kernel_avx2.cpp PROPERTIES COMPILE_FLAGS "-mavx2")
endif()

if(ONNXRUNTIME_MLAS_MULTI_ARCH)
Expand Down
15 changes: 10 additions & 5 deletions onnxruntime/contrib_ops/cpu/quantization/qlinear_softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "core/framework/tensorprotoutils.h"
#include "core/providers/common.h"
#include "core/providers/cpu/tensor/transpose.h"
#include "core/providers/cpu/math/softmax_shared.h"

#include "core/mlas/inc/mlas.h"
#include "core/platform/threadpool.h"
Expand All @@ -36,7 +37,7 @@ void QlinearBuildLookupTableUint32(gsl::span<QLinearSoftmax::EXP_OUT_DTYPE> tabl
for (int32_t i = 0; i < 256; i++) {
double scaled_exp_xi = exp((static_cast<double>(i) - 255 + bit_shift) * static_cast<double>(x_scale));
// we can't get the real max value of input tensor here, so we just assume 255-bit_shift.
// in the function of `QlinearSoftmaxCPU`,
// in the function of `QlinearSoftmaxCPUNaive`,
// all numbers will have a shift (255-bit_shift-max_value) if its max value is not 255
//
// if is_signed index = [1 2 3 ......126 127 -128 -127 ..... -3 -2 -1]
Expand Down Expand Up @@ -124,7 +125,7 @@ Status QLinearSoftmax::Compute(OpKernelContext* ctx) const {
}

template <typename T>
common::Status QlinearSoftmaxCPU(size_t N,
common::Status QlinearSoftmaxCPUNaive(size_t N,
size_t D,
const T* x_data,
T* y_data,
Expand All @@ -134,7 +135,7 @@ common::Status QlinearSoftmaxCPU(size_t N,
onnxruntime::concurrency::ThreadPool* thread_pool);

template <>
common::Status QlinearSoftmaxCPU<uint8_t>(size_t N,
common::Status QlinearSoftmaxCPUNaive<uint8_t>(size_t N,
size_t D,
const uint8_t* x_data,
uint8_t* y_data,
Expand Down Expand Up @@ -185,7 +186,7 @@ common::Status QlinearSoftmaxCPU<uint8_t>(size_t N,
const size_t vx = *x_t_cur++;
const QLinearSoftmax::EXP_OUT_DTYPE vt = shifted_lookuptable[vx];
// simulate round function, and re-quant to uint8
const uint32_t vq = static_cast<uint32_t>(std::nearbyintf(((vt * c_y_scale)) / vsum)) + c_y_zp;
const uint32_t vq = static_cast<uint32_t>(std::nearbyintf((vt * c_y_scale) / vsum)) + c_y_zp;
const uint8_t vy = vq > 255 ? static_cast<uint8_t>(255) : static_cast<uint8_t>(vq);
*y_t++ = vy;
} while (--elements_n != 0);
Expand All @@ -197,7 +198,7 @@ common::Status QlinearSoftmaxCPU<uint8_t>(size_t N,
}

template <>
common::Status QlinearSoftmaxCPU<int8_t>(size_t N,
common::Status QlinearSoftmaxCPUNaive<int8_t>(size_t N,
size_t D,
const int8_t* x_data,
int8_t* y_data,
Expand Down Expand Up @@ -280,11 +281,15 @@ Status QLinearSoftmax::ComputeInternal(OpKernelContext* context, const Tensor& i
if (is_signed_) {
using T = int8_t;
const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data<T>()) : 0;
status = QlinearSoftmaxCPUNaive<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
status = QlinearSoftmaxCPU<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
} else {
using T = uint8_t;
const T Y_zp = Y_zp_tensor ? *(Y_zp_tensor->Data<T>()) : 0;
// status = QlinearSoftmaxCPUNaive<T>(N, D, input.Data<T>(), output.MutableData<T>(),
// lookup_table.data(), Y_scale, Y_zp, thread_pool);
status = QlinearSoftmaxCPU<T>(N, D, input.Data<T>(), output.MutableData<T>(),
lookup_table.data(), Y_scale, Y_zp, thread_pool);
}
Expand Down
14 changes: 14 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1018,6 +1018,20 @@ MlasComputeSoftmax(
MLAS_THREADPOOL* ThreadPool
);

void
MLASCALL
MlasComputeQSoftmax(
const void* Input,
void* Output,
size_t N,
size_t D,
const float* LoopupTable,
float Scale,
float ZeroPoint,
bool is_signed,
MLAS_THREADPOOL* ThreadPool
);

void
MLASCALL
MlasComputeTanh(
Expand Down
215 changes: 215 additions & 0 deletions onnxruntime/core/mlas/lib/qsoftmax.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*++
Copyright (c) Microsoft Corporation. All rights reserved.
Licensed under the MIT License.
Module Name:
qsoftmax.cpp
Abstract:
This module implements miscellaneous computation routines.
Our usage requires building platform specific versions of the algorithm to
target different instruction sets. The implementation below targets the
base instruction set (typically SSE2) while assembly implementations target
newer instruction sets (such as FMA3).
--*/

#include "mlasi.h"


struct MLAS_QSOFTMAX_WORK_BLOCK {
const void* Input;
void* Output;
size_t N;
size_t D;
const float* LoopupTable;
float Scale;
float ZeroPoint;
size_t ThreadCountN;
bool is_signed;
};


extern void MlasQSoftmaxI8KernelAVX2(size_t N,
size_t D,
const int8_t* x_data,
int8_t* y_data,
const float* lookup_table,
float y_scale,
int8_t yzp,
float* tempaddr);

extern void MlasQSoftmaxU8KernelAVX2(size_t N,
size_t D,
const uint8_t* x_data,
uint8_t* y_data,
const float* lookup_table,
float y_scale,
uint8_t yzp,
float* tempaddr);

void
MlasComputeQSoftmaxThreaded(
void* Context,
ptrdiff_t Index
)
/*++
Routine Description:
This routine is invoked from a worker thread to execute a segment of a
softmax or log softmax operation.
Arguments:
Context - Supplies the pointer to the context for the threaded operation.
ThreadId - Supplies the current index of the threaded operation.
Return Value:
None.
--*/
{
const auto* WorkBlock = (MLAS_QSOFTMAX_WORK_BLOCK*)Context;

//
// Partition the operation along the N dimension.
//

size_t n;
size_t CountN;

MlasPartitionWork(Index, WorkBlock->ThreadCountN, WorkBlock->N, &n, &CountN);
size_t packBSize = (WorkBlock->D*sizeof(float) + ThreadedBufAlignment - 1) / ThreadedBufAlignment;
packBSize *= ThreadedBufAlignment;

MlasThreadedBufAlloc(packBSize);

float *tempaddr = reinterpret_cast <float*>(ThreadedBufHolder.get());

//
// Compute the softmax or log softmax function.
//

const size_t D = WorkBlock->D;
const float Scale = WorkBlock->Scale;
const float ZeroPoint = WorkBlock->ZeroPoint;
const float* LoopupTable = WorkBlock->LoopupTable;

const int8_t* Input = reinterpret_cast <const int8_t*>(WorkBlock->Input) + n * D;
int8_t* Output = reinterpret_cast <int8_t*>(WorkBlock->Output) + n * D;

#if defined(MLAS_SSE2_INTRINSICS)
// TODO: Use std::hardware_constructive_interference_size
constexpr size_t CacheLineSize = 64;
constexpr size_t ElementsPerCacheLine = CacheLineSize / sizeof(float);
#endif

while (CountN > 0) {
#if defined(MLAS_SSE2_INTRINSICS)
//
// Prefetch the next row of the input buffer.
//

for (size_t i = 0; i * ElementsPerCacheLine < D; i++) {
_mm_prefetch((char*)(Input + D) + i * CacheLineSize, _MM_HINT_T0);
}
#endif
if (WorkBlock->is_signed) {
MlasQSoftmaxI8KernelAVX2(1, D, (Input), Output, LoopupTable, Scale, ZeroPoint, tempaddr);
} else {
MlasQSoftmaxU8KernelAVX2(1, D, reinterpret_cast <const uint8_t*>(Input), reinterpret_cast <uint8_t*>(Output), LoopupTable, Scale, ZeroPoint, tempaddr);
}

Input += D;
Output += D;
CountN--;
}
}

void
MLASCALL
MlasComputeQSoftmax(
const void* Input,
void* Output,
size_t N,
size_t D,
const float* LoopupTable,
float Scale,
float ZeroPoint,
bool is_signed,
MLAS_THREADPOOL* ThreadPool
)
/*++
Routine Description:
This routine computes the quantized softmax function.
N.B. This implementation supports in place updates of the output buffer.
Arguments:
Input - Supplies the input buffer.
Output - Supplies the output buffer.
N - Supplies the number of rows to process.
D - Supplies the number of columns per row to process.
LoopupTable - Supplies lookup exp values.
Scale - quantization params.
ZeroPoint - quantization params.
is_signed - int8 or uint8.
ThreadPool - Supplies the thread pool object to use, else nullptr if the
base library threading support should be used.
Return Value:
None.
--*/
{

MLAS_QSOFTMAX_WORK_BLOCK WorkBlock;

//
// Capture the softmax parameters to the work block.
//

WorkBlock.LoopupTable = LoopupTable;
WorkBlock.Scale = Scale;
WorkBlock.ZeroPoint = ZeroPoint;
WorkBlock.Input = Input;
WorkBlock.Output = Output;
WorkBlock.N = N;
WorkBlock.D = D;
WorkBlock.is_signed = is_signed;

//
// Compute the number of target threads given the complexity of the softmax
// operation. Limit the number of threads to the number of rows and try to
// keep each thread processing a minimum number of elements before using
// another thread.
//

ptrdiff_t ThreadCountN = MlasGetMaximumThreadCount(ThreadPool);

if (size_t(ThreadCountN) > N) {
ThreadCountN = ptrdiff_t(N);
}

WorkBlock.ThreadCountN = ThreadCountN;

MlasExecuteThreaded(MlasComputeQSoftmaxThreaded, &WorkBlock, ThreadCountN, ThreadPool);
}
Loading

0 comments on commit 59a0bb8

Please sign in to comment.