Skip to content

Commit

Permalink
Add MaxPool FP16 in XnnPack EP (microsoft#22258)
Browse files Browse the repository at this point in the history
### Description
Add support for FP16 kernels in the XnnPack execution provider for
MaxPool operations.
Fixes:
[AB#50332](https://aiinfra.visualstudio.com/6a833879-cd9b-44a4-a9de-adc2d818f13c/_workitems/edit/50332)

### Motivation and Context
The major purpose of this pull request is to add some common
vars/functions and setup a consistent style for adding FP16 kernels in
XnnPack EP.

---------
  • Loading branch information
mszhanyi authored Oct 3, 2024
1 parent c73e6af commit bbb5498
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 5 deletions.
46 changes: 42 additions & 4 deletions onnxruntime/core/providers/xnnpack/nn/max_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "core/graph/graph.h"
#include "core/providers/utils.h"
#include "core/providers/xnnpack/xnnpack_init.h"
#include "core/framework/tensorprotoutils.h"

// to sanity check output shape
Expand Down Expand Up @@ -54,6 +55,10 @@ bool MaxPool::IsOnnxNodeSupported(const NodeUnit& node_unit,
// input of maxpool could be fp16/fp32/fp64,i8/u8 according to ONNX
if (x_type == nullptr ||
(x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT &&
// because pool_fp16_op_test can be enabled by other preprocessor, for example, COREML_ENABLE_MLPROGRAM
#ifdef XNNPACK_FP16_SUPPORTED
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_FLOAT16 &&
#endif
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_UINT8 &&
x_type->tensor_type().elem_type() != ONNX_NAMESPACE::TensorProto_DataType_INT8)) {
break;
Expand Down Expand Up @@ -193,9 +198,19 @@ MaxPool::MaxPool(const OpKernelInfo& info)
stride_height, stride_width,
dilation_height, dilation_width,
output_min, output_max, flags, &p);
} else if (input_dtype == ONNX_NAMESPACE::TensorProto_DataType_FLOAT16) {
maxpool_type_ = OpComputeType::op_compute_type_fp16;
const float output_min = -65504.0;
const float output_max = 65504.0;
status = xnn_create_max_pooling2d_nhwc_f16(input_padding_top, input_padding_right,
input_padding_bottom, input_padding_left,
pooling_height, pooling_width,
stride_height, stride_width,
dilation_height, dilation_width,
output_min, output_max, flags, &p);
} else {
auto stype = DataTypeImpl::ToString(DataTypeImpl::TypeFromProto(*X_arg.TypeAsProto()));
ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8, but got ", stype);
ORT_THROW("unsupported Conv in maxpool, we have FLOAT|UINT8|FLOAT16, but got ", stype);
}
ORT_ENFORCE(status == xnn_status_success, "xnn_create_max_pooling2d_nhwc_",
OpTypeToString(maxpool_type_), "failed. Status:", status);
Expand Down Expand Up @@ -225,10 +240,12 @@ Status MaxPool::Compute(OpKernelContext* context) const {
pthreadpool_t threadpool = GetThreadPool();

auto reshape_fn = xnn_reshape_max_pooling2d_nhwc_f32;
if (maxpool_type_ == OpComputeType::op_compute_type_qu8)
if (maxpool_type_ == OpComputeType::op_compute_type_qu8) {
reshape_fn = xnn_reshape_max_pooling2d_nhwc_u8;
else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
} else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
reshape_fn = xnn_reshape_max_pooling2d_nhwc_s8;
} else if (maxpool_type_ == OpComputeType::op_compute_type_fp16) {
reshape_fn = xnn_reshape_max_pooling2d_nhwc_f16;
}

auto status = reshape_fn(op0_.get(), N, H, W,
Expand All @@ -244,8 +261,10 @@ Status MaxPool::Compute(OpKernelContext* context) const {
status = xnn_setup_max_pooling2d_nhwc_f32(op0_.get(), X.Data<float>(), Y->MutableData<float>());
} else if (maxpool_type_ == OpComputeType::op_compute_type_qu8) {
status = xnn_setup_max_pooling2d_nhwc_u8(op0_.get(), X.Data<uint8_t>(), Y->MutableData<uint8_t>());
} else {
} else if (maxpool_type_ == OpComputeType::op_compute_type_qs8) {
status = xnn_setup_max_pooling2d_nhwc_s8(op0_.get(), X.Data<int8_t>(), Y->MutableData<int8_t>());
} else if (maxpool_type_ == OpComputeType::op_compute_type_fp16) {
status = xnn_setup_max_pooling2d_nhwc_f16(op0_.get(), X.Data<MLFloat16>(), Y->MutableData<MLFloat16>());
}

if (status != xnn_status_success) {
Expand Down Expand Up @@ -285,5 +304,24 @@ ONNX_OPERATOR_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, kXnnpackExecutionPro
DataTypeImpl::GetTensorType<uint8_t>(),
DataTypeImpl::GetTensorType<int8_t>()}),
MaxPool);

#ifdef XNNPACK_FP16_SUPPORTED
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 8, 9, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MaxPool);

ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 10, 10, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MaxPool);

ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 11, 11, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MaxPool);

ONNX_OPERATOR_TYPED_KERNEL_EX(MaxPool, kMSInternalNHWCDomain, 12, MLFloat16, kXnnpackExecutionProvider,
KernelDefBuilder().TypeConstraint("T", DataTypeImpl::GetTensorType<MLFloat16>()),
MaxPool);
#endif

} // namespace xnnpack
} // namespace onnxruntime
28 changes: 28 additions & 0 deletions onnxruntime/core/providers/xnnpack/xnnpack_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
BuildKernelCreateInfo< \
ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Op)>

#define KERNEL_CREATE_INFO_VERSIONED_TYPED(Start, End, Type, Op, Domain) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, End, Type, Op)>

#define KERNEL_CREATE_INFO(Start, Op, Domain) \
BuildKernelCreateInfo< \
ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Op)>
Expand All @@ -39,6 +43,19 @@ KernelCreateInfo BuildKernelCreateInfo<void>() {
BuildKernelCreateInfo< \
ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, Domain, Start, Type, Op)>

#ifdef XNNPACK_FP16_SUPPORTED
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name) \
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, \
startver, endver, MLFloat16, name)

#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name) \
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, startver, \
MLFloat16, name)
#else
#define CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(provider, domain, startver, endver, name)
#define CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(provider, domain, startver, name)
#endif

// Layout sensitive operators in NHWC domain
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 7, 9, AveragePool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, AveragePool);
Expand Down Expand Up @@ -68,6 +85,10 @@ class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSIn
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 8, 9, MaxPool);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 10, 10, MaxPool);
CLASS_ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 11, 11, MaxPool);
CLASS_ONNX_OPERATOR_KERNEL_CLASS_NAME_FP16(kXnnpackExecutionProvider, kMSInternalNHWCDomain, 12, MaxPool);

// ONNX operators
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kXnnpackExecutionProvider, kOnnxDomain, 7, 8, Gemm);
Expand Down Expand Up @@ -138,6 +159,13 @@ std::unique_ptr<KernelRegistry> RegisterKernels() {
KERNEL_CREATE_INFO_TYPED(10, int8_t, QLinearConv, kMSInternalNHWCDomain),

KERNEL_CREATE_INFO(1, QLinearSoftmax, kDynamicDomainByCreate),

#ifdef XNNPACK_FP16_SUPPORTED
KERNEL_CREATE_INFO_VERSIONED_TYPED(8, 9, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(10, 10, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_VERSIONED_TYPED(11, 11, MLFloat16, MaxPool, kMSInternalNHWCDomain),
KERNEL_CREATE_INFO_TYPED(12, MLFloat16, MaxPool, kMSInternalNHWCDomain),
#endif
};

for (auto& function_table_entry : function_table) {
Expand Down
12 changes: 12 additions & 0 deletions onnxruntime/core/providers/xnnpack/xnnpack_init.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,18 @@ namespace xnnpack {
#define XNN_ALLOCATION_ALIGNMENT 16
#endif

#if defined(__aarch64__) || defined(_M_ARM64) || defined(_M_ARM64EC)
#define XNN_ARCH_ARM64 1
#else
#define XNN_ARCH_ARM64 0
#endif

// fp16 support can vary on a kernel by kernel basis. Keep it simple and limit to arm64 for now.
// e.g. XNNPACK maxpool has x64 and arm64 fp16 kernels.
#if XNN_ARCH_ARM64
#define XNNPACK_FP16_SUPPORTED
#endif

std::pair<AllocatorPtr&, xnn_allocator*> GetStoredAllocator();

} // namespace xnnpack
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/test/providers/cpu/nn/pool_fp16_op_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

#include "core/mlas/inc/mlas.h"

#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM)
#if defined(MLAS_F16VEC_INTRINSICS_SUPPORTED) || defined(COREML_ENABLE_MLPROGRAM) || defined(XNNPACK_FP16_SUPPORTED)

#include "core/providers/cpu/nn/pool.h"
#include "gtest/gtest.h"
Expand Down

0 comments on commit bbb5498

Please sign in to comment.