From bfe2e31a677ae1589e4d2c89b6f6a5a81252641c Mon Sep 17 00:00:00 2001 From: Shiyi Zou Date: Wed, 18 Sep 2024 11:03:03 +0800 Subject: [PATCH] rebase, address comments --- js/web/docs/webnn-operators.md | 2 +- .../core/providers/webnn/builders/helper.h | 2 +- .../webnn/builders/impl/lstm_op_builder.cc | 66 +++++++++++++++++-- 3 files changed, 61 insertions(+), 9 deletions(-) diff --git a/js/web/docs/webnn-operators.md b/js/web/docs/webnn-operators.md index 3709da364b805..6c50f3752737b 100644 --- a/js/web/docs/webnn-operators.md +++ b/js/web/docs/webnn-operators.md @@ -53,7 +53,7 @@ operators and the supported opset domain/versions in **WebNN EP** by ONNX Runtim | LessOrEqual | ai.onnx(12-15, 16+) | lesserOrEqual | ✓ | ✓ | | | Log | ai.onnx(7-12, 13+) | log | ✓ | ✓ | | | LpPool | ai.onnx(7-10, 11-17, 18+) | l2Pool2d | ✗ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'p' value is 2 | -| LSTM | ai.onnx(7-13, 14+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | +| LSTM | ai.onnx(7-13, 14-21, 22+) | lstm | ✓ | ✓ | Only supports 'layout' == 0, 'input_forget' == 0. 'clip' is not supported. The activation functions in 'activations' must be one of 'Relu', 'Tanh', 'Sigmoid'. Forward and backward activations must be the same if bidirectional. 'sequence_lens' if present should be constant with values equal to the first dimension length of input 'X' | | MatMul | ai.onnx(7-8, 9-12, 13+) | matmul | ✓ | ✓ | | | Max | ai.onnx(7, 8-11, 12, 13+) | max | ✓ | ✓ | | | MaxPool | ai.onnx(7, 8-9, 10, 11, 12+) | maxPool2d | ✓ | ✓ | Only supports 4-D input, 2-D 'kernel_shape', 'storage_order' != 1, one output | diff --git a/onnxruntime/core/providers/webnn/builders/helper.h b/onnxruntime/core/providers/webnn/builders/helper.h index fd8a3973de052..7ba1d18fa1a76 100644 --- a/onnxruntime/core/providers/webnn/builders/helper.h +++ b/onnxruntime/core/providers/webnn/builders/helper.h @@ -195,7 +195,7 @@ static const InlinedHashMap op_map = { {"LessOrEqual", "lesserOrEqual"}, {"Log", "log"}, {"LpPool", "l2Pool2d"}, - {"Lstm", "lstm"}, + {"LSTM", "lstm"}, {"MatMul", "matmul"}, {"MatMulInteger", "matmulInteger"}, {"Max", "max"}, diff --git a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc index e24a9b2c68474..8d8b16387319c 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/lstm_op_builder.cc @@ -25,6 +25,8 @@ class LstmOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const override; }; void LstmOpBuilder::AddInitializersToSkip(ModelBuilder& model_builder, const Node& node) const { @@ -49,12 +51,14 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N emscripten::val recurrent_weight = model_builder.GetOperand(input_defs[2]->Name()); emscripten::val options = emscripten::val::object(); + options.set("label", node.Name()); options.set("layout", emscripten::val("iofg")); if (input_defs.size() > 3 && input_defs[3]->Exists()) { emscripten::val bias = model_builder.GetOperand(input_defs[3]->Name()); emscripten::val split_options = emscripten::val::object(); split_options.set("axis", 1); + split_options.set("label", node.Name() + "_split"); // Split it to bias and recurrentBias. emscripten::val splitted_biases = model_builder.GetBuilder().call("split", bias, /*splits*/ 2, split_options); @@ -84,22 +88,19 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool has_Y = output_defs.size() > 0 && output_defs[0]->Exists(); bool has_Y_h = output_defs.size() > 1 && output_defs[1]->Exists(); bool has_Y_c = output_defs.size() > 2 && output_defs[2]->Exists(); - if (has_Y) { - options.set("returnSequence", true); - } + options.set("returnSequence", has_Y); if (helper.HasAttr("activations")) { const auto activations = helper.Get("activations", std::vector{"Sigmoid", "Tanh", "Tanh"}); - emscripten::val opt_activations = emscripten::val::array(); for (size_t i = 0; i < 3; ++i) { const std::string& activation = activations[i]; if (activation == "Relu") { - opt_activations.call("push", model_builder.GetBuilder().call("relu")); + opt_activations.call("push", emscripten::val("relu")); } else if (activation == "Sigmoid") { - opt_activations.call("push", model_builder.GetBuilder().call("sigmoid")); + opt_activations.call("push", emscripten::val("sigmoid")); } else if (activation == "Tanh") { - opt_activations.call("push", model_builder.GetBuilder().call("tanh")); + opt_activations.call("push", emscripten::val("tanh")); } } @@ -125,6 +126,10 @@ Status LstmOpBuilder::AddToModelBuilderImpl(ModelBuilder& model_builder, const N bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); + if (input_defs.size() < 3) { + LOGS(logger, ERROR) << "LSTM: input size must be greater than or equal to 3"; + return false; + } std::vector input_shape; if (!GetShape(*input_defs[0], input_shape, logger)) { @@ -191,6 +196,53 @@ bool LstmOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, return true; } +bool LstmOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, + const logging::Logger& logger) const { + const auto& input_defs = node.InputDefs(); + const auto& op_type = node.OpType(); + int32_t input0_type = 0; // input data type + int32_t input1_type = 0; // weight data type + int32_t input2_type = 0; // recurrentWeight data type + int32_t input3_type = 0; // bias data type + // input4 sequence_lens is skipped. + int32_t input5_type = 0; // initialHiddenState data type + int32_t input6_type = 0; // initialCellState data type + int32_t input7_type = 0; // peepholeWeight data type + bool has_input3 = input_defs.size() > 3 && input_defs[3]->Exists(); + bool has_input5 = input_defs.size() > 5 && input_defs[5]->Exists(); + bool has_input6 = input_defs.size() > 6 && input_defs[6]->Exists(); + bool has_input7 = input_defs.size() > 7 && input_defs[7]->Exists(); + + if (!GetType(*input_defs[0], input0_type, logger) || + !GetType(*input_defs[1], input1_type, logger) || + !GetType(*input_defs[2], input2_type, logger) || + (has_input3 && !GetType(*input_defs[3], input3_type, logger)) || + (has_input5 && !GetType(*input_defs[5], input5_type, logger)) || + (has_input6 && !GetType(*input_defs[6], input6_type, logger)) || + (has_input7 && !GetType(*input_defs[7], input7_type, logger))) { + return false; + } + + InlinedVector input_types = {input0_type, input1_type, input2_type}; + if (has_input3) { + input_types.push_back(input3_type); + } + if (has_input5) { + input_types.push_back(input5_type); + } + if (has_input6) { + input_types.push_back(input6_type); + } + if (has_input7) { + input_types.push_back(input7_type); + } + if (!AreInputDataTypesSame(op_type, input_types, logger)) { + return false; + } + + return IsDataTypeSupportedByOp(op_type, input0_type, wnn_limits, "input", "X", logger); +} + void CreateLstmOpBuilder(const std::string& op_type, OpBuilderRegistrations& op_registrations) { op_registrations.builders.push_back(std::make_unique()); op_registrations.op_builder_map.emplace(op_type, op_registrations.builders.back().get());