From 02f0af0d0801d6e613e16a86b4914a041a11e42f Mon Sep 17 00:00:00 2001 From: shiyi Date: Wed, 11 Dec 2024 07:48:16 +0800 Subject: [PATCH] [WebNN] Improve data type check of slice op (#22988) A follow-up of [[WebNN] Support negative steps for slice](https://github.com/microsoft/onnxruntime/pull/22871#discussion_r1847929774). Slice op is emulated by reverse+slice when steps < 0 so `SliceOpBuilder::HasSupportedInputsImpl()` should also check the supported data types of reverse. --------- Co-authored-by: Wanming Lin --- .../core/providers/webnn/builders/helper.cc | 27 +++++++++++++++---- .../core/providers/webnn/builders/helper.h | 7 +++++ .../webnn/builders/impl/base_op_builder.cc | 8 +++--- .../webnn/builders/impl/base_op_builder.h | 4 +-- .../webnn/builders/impl/binary_op_builder.cc | 8 +++--- .../webnn/builders/impl/cast_op_builder.cc | 8 +++--- .../webnn/builders/impl/concat_op_builder.cc | 8 +++--- .../webnn/builders/impl/conv_op_builder.cc | 8 +++--- .../webnn/builders/impl/einsum_op_builder.cc | 14 +++++----- .../impl/gatherElements_op_builder.cc | 7 ++--- .../builders/impl/gatherND_op_builder.cc | 8 +++--- .../webnn/builders/impl/gather_op_builder.cc | 8 +++--- .../webnn/builders/impl/gemm_op_builder.cc | 8 +++--- .../webnn/builders/impl/gru_op_builder.cc | 8 +++--- .../webnn/builders/impl/logical_op_builder.cc | 8 +++--- .../webnn/builders/impl/lstm_op_builder.cc | 8 +++--- .../webnn/builders/impl/max_min_op_builder.cc | 8 +++--- .../builders/impl/normalization_op_builder.cc | 7 ++--- .../webnn/builders/impl/qdq_op_builder.cc | 8 +++--- .../impl/scatterElements_op_builder.cc | 7 ++--- .../builders/impl/scatterND_op_builder.cc | 7 ++--- .../webnn/builders/impl/slice_op_builder.cc | 26 ++++++++++++++++++ .../webnn/builders/impl/ternary_op_builder.cc | 8 +++--- 23 files changed, 136 insertions(+), 82 deletions(-) diff --git a/onnxruntime/core/providers/webnn/builders/helper.cc b/onnxruntime/core/providers/webnn/builders/helper.cc index f36f8283e9bf6..45a87960126cd 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.cc +++ b/onnxruntime/core/providers/webnn/builders/helper.cc @@ -178,14 +178,31 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, if (!GetWebNNOpType(onnx_op_type, webnn_op_type)) return false; + return IsDataTypeSupportedByWebNNOp(onnx_op_type, webnn_op_type, onnx_data_type, wnn_limits, + webnn_input_output_name, onnx_input_output_name, logger); +} + +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger) { + if (wnn_limits[webnn_op_type].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] is not supported for now"; + return false; + } + if (wnn_limits[webnn_op_type][webnn_input_output_name].isUndefined()) { + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] WebNN op [" << webnn_op_type << "] doesn't have parameter [" + << webnn_input_output_name << "]"; + return false; + } if (!IsSupportedDataType(onnx_data_type, wnn_limits[webnn_op_type][webnn_input_output_name]["dataTypes"])) { - LOGS(logger, VERBOSE) << "[" << onnx_op_type - << "] " << onnx_input_output_name - << " type: [" << onnx_data_type - << "] is not supported for now"; + LOGS(logger, VERBOSE) << "[" << onnx_op_type << "] " << onnx_input_output_name << "'s data type: [" + << onnx_data_type << "] is not supported by WebNN op [" << webnn_op_type << "] for now"; return false; } - return true; } diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index 7fdfc5aefa798..a06f46f1bdf0a 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -340,6 +340,13 @@ bool IsDataTypeSupportedByOp(const std::string& onnx_op_type, const std::string& webnn_input_output_name, const std::string& onnx_input_output_name, const logging::Logger& logger); +bool IsDataTypeSupportedByWebNNOp(const std::string& onnx_op_type, + const std::string& webnn_op_type, + const int32_t onnx_data_type, + const emscripten::val& wnn_limits, + const std::string& webnn_input_output_name, + const std::string& onnx_input_output_name, + const logging::Logger& logger); bool GetBidirectionalBroadcastShape(std::vector& shape_a, std::vector& shape_b, diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc index 70fa0f9516c5c..290d16a48dd83 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.cc @@ -29,7 +29,7 @@ Status BaseOpBuilder::AddToModelBuilder(ModelBuilder& model_builder, const Node& bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const emscripten::val& wnn_limits, const logging::Logger& logger) const { - if (!HasSupportedInputs(node, wnn_limits, logger)) + if (!HasSupportedInputs(initializers, node, wnn_limits, logger)) return false; if (!HasSupportedOutputs(node, wnn_limits, logger)) @@ -41,7 +41,7 @@ bool BaseOpBuilder::IsOpSupported(const InitializedTensorSet& initializers, cons return IsOpSupportedImpl(initializers, node, device_type, logger); } -bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, +bool BaseOpBuilder::HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto node_name = MakeString("Node [", node.Name(), "] type [", node.OpType(), "]"); for (const auto* input : node.InputDefs()) { @@ -50,10 +50,10 @@ bool BaseOpBuilder::HasSupportedInputs(const Node& node, const emscripten::val& } } - return HasSupportedInputsImpl(node, wnn_limits, logger); + return HasSupportedInputsImpl(initializers, node, wnn_limits, logger); } -bool BaseOpBuilder::HasSupportedInputsImpl(const Node& node, +bool BaseOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { // We only check the type of input 0 by default, specific op builder can override this. diff --git a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h index 9412fa8026fb3..0a4367a71add4 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h +++ b/onnxruntime/core/providers/webnn/builders/impl/base_op_builder.h @@ -40,7 +40,7 @@ class BaseOpBuilder : public IOpBuilder { return true; } - virtual bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + virtual bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; virtual bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; @@ -56,7 +56,7 @@ class BaseOpBuilder : public IOpBuilder { private: bool HasSupportedOpSet(const Node& node, const logging::Logger& logger) const; - bool HasSupportedInputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; + bool HasSupportedInputs(const InitializedTensorSet& initializers, const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; bool HasSupportedOutputs(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const; const bool allow_empty_tensor_as_input_; // Some operators can handle ignoring an empty tensor as input. diff --git a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc index af82a01b14de5..e14507e8f5aea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/binary_op_builder.cc @@ -22,8 +22,8 @@ class BinaryOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ bool BinaryOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers return true; } -bool BinaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool BinaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc index 70ebe18c85b86..4b2f04bed0eb1 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/cast_op_builder.cc @@ -21,8 +21,8 @@ class CastOpBuilder : public BaseOpBuilder { // Operator support related. private: - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -86,8 +86,8 @@ Status CastOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, } // Operator support related. -bool CastOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool CastOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc index 1a0d93ae7eada..bac528300e077 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/concat_op_builder.cc @@ -21,8 +21,8 @@ class ConcatOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -55,8 +55,8 @@ Status ConcatOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool ConcatOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConcatOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc index 52fcc39ae5418..81e688ea4f8ea 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/conv_op_builder.cc @@ -29,8 +29,8 @@ class ConvOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType device_type, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; void ConvOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -397,8 +397,8 @@ bool ConvOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool ConvOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool ConvOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc index 931854d0f33c1..ef713f48b8135 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/einsum_op_builder.cc @@ -27,8 +27,8 @@ class EinsumOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Helper functions, thanks for DML EP's OperatorHelper. @@ -735,8 +735,8 @@ bool EinsumOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool EinsumOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -776,11 +776,11 @@ bool EinsumOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten: return false; } else if (recognized_operator_type == RecognizedOperatorType::Pairwise) { // Map to WebNN's gemm or matmul - return IsDataTypeSupportedByOp("MatMul", input0_type, wnn_limits, "a", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "matmul", input0_type, wnn_limits, "a", "inputs", logger); } else if (recognized_operator_type == RecognizedOperatorType::ReduceSum) { - return IsDataTypeSupportedByOp("ReduceSum", input0_type, wnn_limits, "input", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "reduceSum", input0_type, wnn_limits, "input", "inputs", logger); } else { - return IsDataTypeSupportedByOp("Identity", input0_type, wnn_limits, "input", "inputs", logger); + return IsDataTypeSupportedByWebNNOp(op_type, "identity", input0_type, wnn_limits, "input", "inputs", logger); } } diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc index 225cfcdfc852c..cb7b7de74e121 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherElements_op_builder.cc @@ -20,8 +20,8 @@ class GatherElementsOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -49,7 +49,8 @@ Status GatherElementsOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builde // Operator support related. -bool GatherElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool GatherElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc index cb4f85a40ee12..002a1a6a63026 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gatherND_op_builder.cc @@ -22,8 +22,8 @@ class GatherNDOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -55,8 +55,8 @@ bool GatherNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initial return true; } -bool GatherNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc index ae9fe3e3f3bd1..88d22f103cadc 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gather_op_builder.cc @@ -22,8 +22,8 @@ class GatherOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -69,8 +69,8 @@ bool GatherOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool GatherOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GatherOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc index 252d49a2f4d4d..5f4e6de8fda98 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gemm_op_builder.cc @@ -25,8 +25,8 @@ class GemmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -215,8 +215,8 @@ bool GemmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializer return true; } -bool GemmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GemmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // A data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index ffb9b7fbf2e7a..b240e30d38b22 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,8 +26,8 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -187,8 +187,8 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool GruOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input_X_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc index d56fdbc08c677..91910f55f37c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/logical_op_builder.cc @@ -21,8 +21,8 @@ class LogicalOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -71,8 +71,8 @@ bool LogicalOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initiali return true; } -bool LogicalOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LogicalOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index 6213b039fb2f9..33ba22ac3fb5b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -25,8 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; bool HasSupportedOutputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -198,8 +198,8 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } -bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool LstmOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc index e111ca412c6e9..40f94186e9ed6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/max_min_op_builder.cc @@ -22,8 +22,8 @@ class MaxMinOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -87,8 +87,8 @@ bool MaxMinOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initializ return true; } -bool MaxMinOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool MaxMinOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; diff --git a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc index 79ed0393e3044..50e49884bdfa9 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/normalization_op_builder.cc @@ -25,8 +25,8 @@ class NormalizationOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status NormalizationOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -228,7 +228,8 @@ bool NormalizationOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initi return true; } -bool NormalizationOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool NormalizationOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); diff --git a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc index ca15e123d0999..b71507a871bf6 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/qdq_op_builder.cc @@ -22,8 +22,8 @@ class QDQOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; // Operator support related. - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, @@ -118,8 +118,8 @@ Status QDQOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, return Status::OK(); } -bool QDQOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool QDQOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type = 0; // input data type diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc index c786aa468736c..8c70525835059 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterElements_op_builder.cc @@ -22,8 +22,8 @@ class ScatterElementsOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -65,7 +65,8 @@ bool ScatterElementsOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* return true; } -bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool ScatterElementsOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc index feb93cc14b7c4..8089b9706886f 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/scatterND_op_builder.cc @@ -22,8 +22,8 @@ class ScatterNDOpBuilder : public BaseOpBuilder { // Operator support related. bool IsOpSupportedImpl(const InitializedTensorSet& /* initializers */, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -57,7 +57,8 @@ bool ScatterNDOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& /* initia return true; } -bool ScatterNDOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, +bool ScatterNDOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& data = *node.InputDefs()[0]; const auto& indices = *node.InputDefs()[1]; diff --git a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc index d51297f19f1c2..41c66038c2694 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/slice_op_builder.cc @@ -27,6 +27,8 @@ class SliceOpBuilder : public BaseOpBuilder { const logging::Logger& logger) const override ORT_MUST_USE_RESULT; bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /* device_type */, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; // TODO: Support Slice opset < 10, which uses attributes for starts and ends. int GetMinSupportedOpSet(const Node& /* node */) const override { return 10; } }; @@ -161,6 +163,30 @@ bool SliceOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool SliceOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& initializers, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& input = *input_defs[0]; + const auto& op_type = node.OpType(); + int32_t input_type; + if (!GetType(input, input_type, logger)) + return false; + + // If there is step < 0, check data type support of reverse. + if (input_defs.size() > 4 && input_defs[4]->Exists()) { + std::vector steps; + if (!ReadIntArrayFrom1DTensor(*initializers.at(input_defs[4]->Name()), steps, logger)) + return false; + if (std::any_of(steps.begin(), steps.end(), [](int64_t step) { return step < 0; })) { + if (!IsDataTypeSupportedByWebNNOp(op_type, "reverse", input_type, wnn_limits, "input", "data", logger)) { + return false; + } + } + } + + return IsDataTypeSupportedByOp(op_type, input_type, wnn_limits, "input", "data", logger); +} + void CreateSliceOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get()); diff --git a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc index 4b6cf312074ba..c7b3129c0c85b 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/ternary_op_builder.cc @@ -18,8 +18,8 @@ class TernaryOpBuilder : public BaseOpBuilder { private: Status AddToModelBuilderImpl(ModelBuilder& model_builder, const Node& node, const logging::Logger& logger) const override ORT_MUST_USE_RESULT; - bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; // Add operator related. @@ -46,8 +46,8 @@ Status TernaryOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, cons return Status::OK(); } -bool TernaryOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, - const logging::Logger& logger) const { +bool TernaryOpBuilder::HasSupportedInputsImpl(const InitializedTensorSet& /* initializers */, const Node& node, + const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); int32_t input0_type; // condition data type