diff --git a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc index 23cc7f1b11459..141b86b3988c7 100644 --- a/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc +++ b/onnxruntime/core/providers/webnn/builders/impl/gru_op_builder.cc @@ -26,7 +26,7 @@ class GruOpBuilder : public BaseOpBuilder { private: bool IsOpSupportedImpl(const InitializedTensorSet& initializers, const Node& node, const WebnnDeviceType /*device_type*/, const logging::Logger& logger) const override; - bool HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, + bool HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const override; }; @@ -185,7 +185,7 @@ bool GruOpBuilder::IsOpSupportedImpl(const InitializedTensorSet& initializers, c return true; } -bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceType device_type, +bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const emscripten::val& wnn_limits, const logging::Logger& logger) const { const auto& input_defs = node.InputDefs(); const auto& op_type = node.OpType(); @@ -208,20 +208,7 @@ bool GruOpBuilder::HasSupportedInputsImpl(const Node& node, const WebnnDeviceTyp return false; } - std::unordered_set supported_data_types; - if (device_type == WebnnDeviceType::CPU) { - // WebNN CPU backend only support float32 input data type. - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - }; - } else if (device_type == WebnnDeviceType::GPU) { - supported_data_types = { - ONNX_NAMESPACE::TensorProto_DataType_FLOAT, - ONNX_NAMESPACE::TensorProto_DataType_FLOAT16, - }; - } - - if (!IsSupportedDataType(input0_type, supported_data_types)) { + if (!IsSupportedDataType(input0_type, wnn_limits["gru"]["input"]["dataTypes"])) { LOGS(logger, VERBOSE) << "[" << op_type << "] Input type: [" << input0_type << "] is not supported for now";